Removed workgroup templates on opalgobase classes

This commit is contained in:
Alejandro Saucedo 2020-11-01 16:28:48 +00:00
parent 6afe6463c2
commit 3ad5e4d3e7
5 changed files with 322 additions and 344 deletions

162
src/OpAlgoBase.cpp Normal file
View file

@ -0,0 +1,162 @@
#pragma once
#include "kompute/operations/OpAlgoBase.hpp"
namespace kp {
OpAlgoBase::OpAlgoBase()
{
SPDLOG_DEBUG("Kompute OpAlgoBase constructor base");
}
OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
KomputeWorkgroup komputeWorkgroup)
: OpBase(physicalDevice, device, commandBuffer, tensors, false)
{
SPDLOG_DEBUG("Kompute OpAlgoBase constructor with params numTensors: {}", tensors.size());
// The dispatch size is set up based on either explicitly provided template
// parameters or by default it would take the shape and size of the tensors
if (komputeWorkgroup.x > 0) {
// If at least the x value is provided we use mainly the parameters
// provided
this->mKomputeWorkgroup = {
0,
komputeWorkgroup.y > 0 ? komputeWorkgroup.y : 1,
komputeWorkgroup.z > 0 ? komputeWorkgroup.z : 1
};
} else {
this->mKomputeWorkgroup = {tensors[0]->size(), 1, 1};
}
SPDLOG_INFO("Kompute OpAlgoBase dispatch size X: {}, Y: {}, Z: {}",
this->mKomputeWorkgroup.x,
this->mKomputeWorkgroup.y,
this->mKomputeWorkgroup.z);
this->mAlgorithm = std::make_shared<Algorithm>(device, commandBuffer);
}
OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
std::string shaderFilePath,
KomputeWorkgroup komputeWorkgroup)
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup)
{
SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shaderfile path: {}", shaderFilePath);
this->mShaderFilePath = shaderFilePath;
}
OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>>& tensors,
const std::vector<char>& shaderDataRaw,
KomputeWorkgroup komputeWorkgroup)
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup)
{
SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shader raw data length: {}", shaderDataRaw.size());
this->mShaderDataRaw = shaderDataRaw;
}
OpAlgoBase::~OpAlgoBase()
{
SPDLOG_DEBUG("Kompute OpAlgoBase destructor started");
}
void
OpAlgoBase::init()
{
SPDLOG_DEBUG("Kompute OpAlgoBase init called");
if (this->mTensors.size() < 1) {
throw std::runtime_error(
"Kompute OpAlgoBase called with less than 1 tensor");
}
for (std::shared_ptr<Tensor> tensor : this->mTensors) {
if(!tensor->isInit()) {
throw std::runtime_error("Kompute OpAlgoBase validation failed; all tensor parameters must be initialised.");
}
}
SPDLOG_DEBUG("Kompute OpAlgoBase fetching spirv data");
std::vector<char> shaderFileData = this->fetchSpirvBinaryData();
SPDLOG_DEBUG("Kompute OpAlgoBase Initialising algorithm component");
this->mAlgorithm->init(shaderFileData, this->mTensors);
}
void
OpAlgoBase::record()
{
SPDLOG_DEBUG("Kompute OpAlgoBase record called");
// Barrier to ensure the data is finished writing to buffer memory
for (std::shared_ptr<Tensor> tensor : this->mTensors) {
tensor->recordBufferMemoryBarrier(
this->mCommandBuffer,
vk::AccessFlagBits::eHostWrite,
vk::AccessFlagBits::eShaderRead,
vk::PipelineStageFlagBits::eHost,
vk::PipelineStageFlagBits::eComputeShader);
}
this->mAlgorithm->recordDispatch(this->mKomputeWorkgroup.x, this->mKomputeWorkgroup.y, this->mKomputeWorkgroup.z);
}
void
OpAlgoBase::preEval()
{
SPDLOG_DEBUG("Kompute OpAlgoBase preEval called");
}
void
OpAlgoBase::postEval()
{
SPDLOG_DEBUG("Kompute OpAlgoBase postSubmit called");
}
std::vector<char> OpAlgoBase::fetchSpirvBinaryData()
{
SPDLOG_WARN(
"Kompute OpAlgoBase Running shaders directly from spirv file");
if (this->mShaderFilePath.size()) {
std::ifstream fileStream(this->mShaderFilePath,
std::ios::binary | std::ios::in | std::ios::ate);
if (!fileStream.good()) {
throw std::runtime_error("Error reading file: " + this->mShaderFilePath);
}
size_t shaderFileSize = fileStream.tellg();
fileStream.seekg(0, std::ios::beg);
char* shaderDataRaw = new char[shaderFileSize];
fileStream.read(shaderDataRaw, shaderFileSize);
fileStream.close();
SPDLOG_WARN(
"Kompute OpAlgoBase fetched {} bytes", shaderFileSize);
return std::vector<char>(shaderDataRaw,
shaderDataRaw + shaderFileSize);
}
else if (this->mShaderDataRaw.size()) {
return this->mShaderDataRaw;
}
else {
throw std::runtime_error("Kompute OpAlgoBase Error reached fetchSpirvBinaryData but neither filepath nor data provided");
}
}
}