#pragma once #include "kompute/operations/OpAlgoBase.hpp" namespace kp { OpAlgoBase::OpAlgoBase() { SPDLOG_DEBUG("Kompute OpAlgoBase constructor base"); } OpAlgoBase::OpAlgoBase(std::shared_ptr physicalDevice, std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors, KomputeWorkgroup komputeWorkgroup) : OpBase(physicalDevice, device, commandBuffer, tensors, false) { SPDLOG_DEBUG("Kompute OpAlgoBase constructor with params numTensors: {}", tensors.size()); // The dispatch size is set up based on either explicitly provided template // parameters or by default it would take the shape and size of the tensors if (komputeWorkgroup.x > 0) { // If at least the x value is provided we use mainly the parameters // provided this->mKomputeWorkgroup = { 0, komputeWorkgroup.y > 0 ? komputeWorkgroup.y : 1, komputeWorkgroup.z > 0 ? komputeWorkgroup.z : 1 }; } else { this->mKomputeWorkgroup = {tensors[0]->size(), 1, 1}; } SPDLOG_INFO("Kompute OpAlgoBase dispatch size X: {}, Y: {}, Z: {}", this->mKomputeWorkgroup.x, this->mKomputeWorkgroup.y, this->mKomputeWorkgroup.z); this->mAlgorithm = std::make_shared(device, commandBuffer); } OpAlgoBase::OpAlgoBase(std::shared_ptr physicalDevice, std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors, std::string shaderFilePath, KomputeWorkgroup komputeWorkgroup) : OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup) { SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shaderfile path: {}", shaderFilePath); this->mShaderFilePath = shaderFilePath; } OpAlgoBase::OpAlgoBase(std::shared_ptr physicalDevice, std::shared_ptr device, std::shared_ptr commandBuffer, std::vector>& tensors, const std::vector& shaderDataRaw, KomputeWorkgroup komputeWorkgroup) : OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup) { SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shader raw data length: {}", shaderDataRaw.size()); this->mShaderDataRaw = shaderDataRaw; } OpAlgoBase::~OpAlgoBase() { SPDLOG_DEBUG("Kompute OpAlgoBase destructor started"); } void OpAlgoBase::init() { SPDLOG_DEBUG("Kompute OpAlgoBase init called"); if (this->mTensors.size() < 1) { throw std::runtime_error( "Kompute OpAlgoBase called with less than 1 tensor"); } for (std::shared_ptr tensor : this->mTensors) { if(!tensor->isInit()) { throw std::runtime_error("Kompute OpAlgoBase validation failed; all tensor parameters must be initialised."); } } SPDLOG_DEBUG("Kompute OpAlgoBase fetching spirv data"); std::vector shaderFileData = this->fetchSpirvBinaryData(); SPDLOG_DEBUG("Kompute OpAlgoBase Initialising algorithm component"); this->mAlgorithm->init(shaderFileData, this->mTensors); } void OpAlgoBase::record() { SPDLOG_DEBUG("Kompute OpAlgoBase record called"); // Barrier to ensure the data is finished writing to buffer memory for (std::shared_ptr tensor : this->mTensors) { tensor->recordBufferMemoryBarrier( this->mCommandBuffer, vk::AccessFlagBits::eHostWrite, vk::AccessFlagBits::eShaderRead, vk::PipelineStageFlagBits::eHost, vk::PipelineStageFlagBits::eComputeShader); } this->mAlgorithm->recordDispatch(this->mKomputeWorkgroup.x, this->mKomputeWorkgroup.y, this->mKomputeWorkgroup.z); } void OpAlgoBase::preEval() { SPDLOG_DEBUG("Kompute OpAlgoBase preEval called"); } void OpAlgoBase::postEval() { SPDLOG_DEBUG("Kompute OpAlgoBase postSubmit called"); } std::vector OpAlgoBase::fetchSpirvBinaryData() { SPDLOG_WARN( "Kompute OpAlgoBase Running shaders directly from spirv file"); if (this->mShaderFilePath.size()) { std::ifstream fileStream(this->mShaderFilePath, std::ios::binary | std::ios::in | std::ios::ate); if (!fileStream.good()) { throw std::runtime_error("Error reading file: " + this->mShaderFilePath); } size_t shaderFileSize = fileStream.tellg(); fileStream.seekg(0, std::ios::beg); char* shaderDataRaw = new char[shaderFileSize]; fileStream.read(shaderDataRaw, shaderFileSize); fileStream.close(); SPDLOG_WARN( "Kompute OpAlgoBase fetched {} bytes", shaderFileSize); return std::vector(shaderDataRaw, shaderDataRaw + shaderFileSize); } else if (this->mShaderDataRaw.size()) { return this->mShaderDataRaw; } else { throw std::runtime_error("Kompute OpAlgoBase Error reached fetchSpirvBinaryData but neither filepath nor data provided"); } } }