#pragma once #include #include "kompute/Core.hpp" #if RELEASE #include "kompute/shaders/shaderopmult.hpp" #endif #include "kompute/Algorithm.hpp" #include "kompute/Tensor.hpp" #include "kompute/operations/OpAlgoBase.hpp" namespace kp { /** * Operation that performs multiplication on two tensors and outpus on third * tensor. The template parameters specify the processing GPU layout number of * iterations for each x, y, z parameter. More specifically, this will be the * input to ".dispatch(uint32_t tX, uint32_t tY, uint32_t, tZ)" */ template class OpMult : public OpAlgoBase { public: /** * Base constructor, should not be used unless explicitly intended. */ OpMult() { } /** * Default constructor with parameters that provides the bare minimum * requirements for the operations to be able to create and manage their * sub-components. * * @param physicalDevice Vulkan physical device used to find device queues * @param device Vulkan logical device for passing to Algorithm * @param commandBuffer Vulkan Command Buffer to record commands into * @param tensors Tensors that are to be used in this operation * @param freeTensors Whether operation manages the memory of the Tensors */ OpMult(std::shared_ptr physicalDevice, std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors) : OpAlgoBase(physicalDevice, device, commandBuffer, tensors, true) { SPDLOG_DEBUG("Kompute OpMult constructor with params"); #ifndef RELEASE this->mOptSpirvBinPath = "shaders/glsl/opmult.comp.spv"; #endif } #if RELEASE /** * If release it will be using the static version of the shader which is * loaded using this file directly. * * @param physicalDevice Vulkan physical device used to find device queues * @param device Vulkan logical device for passing to Algorithm * @param commandBuffer Vulkan Command Buffer to record commands into * @param tensors Tensors that are to be used in this operation * @param freeTensors Whether operation manages the memory of the Tensors */ std::vector fetchSpirvBinaryData() override { SPDLOG_WARN( "Kompute OpMult Running shaders directly from header"); return std::vector( shader_data::shaders_glsl_opmult_comp_spv, shader_data::shaders_glsl_opmult_comp_spv + kp::shader_data::shaders_glsl_opmult_comp_spv_len); } #endif /** * Default destructor, which is in charge of destroying the algorithm * components but does not destroy the underlying tensors */ ~OpMult() { SPDLOG_DEBUG("Kompute OpMult destructor started"); } }; } // End namespace kp