#pragma once #include #include // SPDLOG_ACTIVE_LEVEL must be defined before spdlog.h import #if DEBUG #define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG #endif #include #include "kompute/OpBase.hpp" namespace kp { class Sequence { public: Sequence(); Sequence(std::shared_ptr physicalDevice, std::shared_ptr device, std::shared_ptr computeQueue, uint32_t queueIndex); ~Sequence(); // Record command functions void begin(); void end(); void eval(); // TODO: Explore design without template using just top level class template void record(std::vector> tensors) { static_assert(std::is_base_of::value, "Template only valid with OpBase derived classes"); SPDLOG_DEBUG("Kompute Sequence record function started"); SPDLOG_DEBUG("Kompute Sequence creating OpBase derived class instance"); T* op = new T(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer); OpBase* baseOp = dynamic_cast(op); std::unique_ptr baseOpPtr{ baseOp }; SPDLOG_DEBUG("Kompute Sequence running init on OpBase derived class instance"); baseOpPtr->init(tensors); SPDLOG_DEBUG("Kompute Sequence running record on OpBase derived class instance"); baseOpPtr->record(); mOperations.push_back(std::move(baseOpPtr)); } private: std::shared_ptr mPhysicalDevice = nullptr; std::shared_ptr mDevice = nullptr; std::shared_ptr mComputeQueue = nullptr; uint32_t mQueueIndex = -1; std::shared_ptr mCommandPool = nullptr; bool mFreeCommandPool = false; std::shared_ptr mCommandBuffer = nullptr; bool mFreeCommandBuffer = false; // Base op objects std::vector> mOperations; // Record state bool mRecording = false; // Create functions void createCommandPool(); void createCommandBuffer(); }; } // End namespace kp