// Defining OPMULT_H to ensure cpp class doesn't reimport #define OPMULT_H #pragma once #include #include // SPDLOG_ACTIVE_LEVEL must be defined before spdlog.h import #if DEBUG #define SPDLOG_ACTIVE_LEVEL SPDLOG_LEVEL_DEBUG #endif #include #include "Algorithm.hpp" #include "Tensor.hpp" #include "OpBase.hpp" namespace kp { template class OpMult : public OpBase { public: OpMult(); OpMult(std::shared_ptr physicalDevice, std::shared_ptr device, std::shared_ptr commandBuffer); ~OpMult(); void init(std::vector> tensors) override; void record() override; void postSubmit() override; private: // Always owned resources std::shared_ptr mTensorOutputStaging; // Optionally owned resources std::shared_ptr mAlgorithm; bool mFreeAlgorithm = false; // Never owned resources std::shared_ptr mTensorLHS; std::shared_ptr mTensorRHS; std::shared_ptr mTensorOutput; uint32_t mX; uint32_t mY; uint32_t mZ; }; } // End namespace kp // Including implemenation for template class #include "OpMult.cpp"