Reformatted
This commit is contained in:
parent
3036cbd95f
commit
1f614a87e4
11 changed files with 125 additions and 96 deletions
|
|
@ -10,13 +10,14 @@ OpAlgoBase::OpAlgoBase()
|
|||
}
|
||||
|
||||
OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
KomputeWorkgroup komputeWorkgroup)
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
KomputeWorkgroup komputeWorkgroup)
|
||||
: OpBase(physicalDevice, device, commandBuffer, tensors, false)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpAlgoBase constructor with params numTensors: {}", tensors.size());
|
||||
SPDLOG_DEBUG("Kompute OpAlgoBase constructor with params numTensors: {}",
|
||||
tensors.size());
|
||||
|
||||
// The dispatch size is set up based on either explicitly provided template
|
||||
// parameters or by default it would take the shape and size of the tensors
|
||||
|
|
@ -29,38 +30,42 @@ OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
|||
komputeWorkgroup.z > 0 ? komputeWorkgroup.z : 1
|
||||
};
|
||||
} else {
|
||||
this->mKomputeWorkgroup = {tensors[0]->size(), 1, 1};
|
||||
this->mKomputeWorkgroup = { tensors[0]->size(), 1, 1 };
|
||||
}
|
||||
SPDLOG_INFO("Kompute OpAlgoBase dispatch size X: {}, Y: {}, Z: {}",
|
||||
this->mKomputeWorkgroup.x,
|
||||
this->mKomputeWorkgroup.y,
|
||||
this->mKomputeWorkgroup.z);
|
||||
this->mKomputeWorkgroup.x,
|
||||
this->mKomputeWorkgroup.y,
|
||||
this->mKomputeWorkgroup.z);
|
||||
|
||||
this->mAlgorithm = std::make_shared<Algorithm>(device, commandBuffer);
|
||||
}
|
||||
|
||||
OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
std::string shaderFilePath,
|
||||
KomputeWorkgroup komputeWorkgroup)
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
std::string shaderFilePath,
|
||||
KomputeWorkgroup komputeWorkgroup)
|
||||
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shaderfile path: {}", shaderFilePath);
|
||||
SPDLOG_DEBUG(
|
||||
"Kompute OpAlgoBase shaderFilePath constructo with shaderfile path: {}",
|
||||
shaderFilePath);
|
||||
|
||||
this->mShaderFilePath = shaderFilePath;
|
||||
}
|
||||
|
||||
OpAlgoBase::OpAlgoBase(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<char>& shaderDataRaw,
|
||||
KomputeWorkgroup komputeWorkgroup)
|
||||
std::shared_ptr<vk::Device> device,
|
||||
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
||||
std::vector<std::shared_ptr<Tensor>>& tensors,
|
||||
const std::vector<char>& shaderDataRaw,
|
||||
KomputeWorkgroup komputeWorkgroup)
|
||||
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shader raw data length: {}", shaderDataRaw.size());
|
||||
SPDLOG_DEBUG("Kompute OpAlgoBase shaderFilePath constructo with shader raw "
|
||||
"data length: {}",
|
||||
shaderDataRaw.size());
|
||||
|
||||
this->mShaderDataRaw = shaderDataRaw;
|
||||
}
|
||||
|
|
@ -78,11 +83,13 @@ OpAlgoBase::init()
|
|||
if (this->mTensors.size() < 1) {
|
||||
throw std::runtime_error(
|
||||
"Kompute OpAlgoBase called with less than 1 tensor");
|
||||
}
|
||||
}
|
||||
|
||||
for (std::shared_ptr<Tensor> tensor : this->mTensors) {
|
||||
if(!tensor->isInit()) {
|
||||
throw std::runtime_error("Kompute OpAlgoBase validation failed; all tensor parameters must be initialised.");
|
||||
if (!tensor->isInit()) {
|
||||
throw std::runtime_error(
|
||||
"Kompute OpAlgoBase validation failed; all tensor parameters "
|
||||
"must be initialised.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -110,7 +117,9 @@ OpAlgoBase::record()
|
|||
vk::PipelineStageFlagBits::eComputeShader);
|
||||
}
|
||||
|
||||
this->mAlgorithm->recordDispatch(this->mKomputeWorkgroup.x, this->mKomputeWorkgroup.y, this->mKomputeWorkgroup.z);
|
||||
this->mAlgorithm->recordDispatch(this->mKomputeWorkgroup.x,
|
||||
this->mKomputeWorkgroup.y,
|
||||
this->mKomputeWorkgroup.z);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -125,17 +134,19 @@ OpAlgoBase::postEval()
|
|||
SPDLOG_DEBUG("Kompute OpAlgoBase postSubmit called");
|
||||
}
|
||||
|
||||
std::vector<char> OpAlgoBase::fetchSpirvBinaryData()
|
||||
std::vector<char>
|
||||
OpAlgoBase::fetchSpirvBinaryData()
|
||||
{
|
||||
SPDLOG_WARN(
|
||||
"Kompute OpAlgoBase Running shaders directly from spirv file");
|
||||
SPDLOG_WARN("Kompute OpAlgoBase Running shaders directly from spirv file");
|
||||
|
||||
if (this->mShaderFilePath.size()) {
|
||||
std::ifstream fileStream(this->mShaderFilePath,
|
||||
std::ios::binary | std::ios::in | std::ios::ate);
|
||||
std::ios::binary | std::ios::in |
|
||||
std::ios::ate);
|
||||
|
||||
if (!fileStream.good()) {
|
||||
throw std::runtime_error("Error reading file: " + this->mShaderFilePath);
|
||||
throw std::runtime_error("Error reading file: " +
|
||||
this->mShaderFilePath);
|
||||
}
|
||||
|
||||
size_t shaderFileSize = fileStream.tellg();
|
||||
|
|
@ -144,19 +155,16 @@ std::vector<char> OpAlgoBase::fetchSpirvBinaryData()
|
|||
fileStream.read(shaderDataRaw, shaderFileSize);
|
||||
fileStream.close();
|
||||
|
||||
SPDLOG_WARN(
|
||||
"Kompute OpAlgoBase fetched {} bytes", shaderFileSize);
|
||||
SPDLOG_WARN("Kompute OpAlgoBase fetched {} bytes", shaderFileSize);
|
||||
|
||||
return std::vector<char>(shaderDataRaw,
|
||||
shaderDataRaw + shaderFileSize);
|
||||
}
|
||||
else if (this->mShaderDataRaw.size()) {
|
||||
return std::vector<char>(shaderDataRaw, shaderDataRaw + shaderFileSize);
|
||||
} else if (this->mShaderDataRaw.size()) {
|
||||
return this->mShaderDataRaw;
|
||||
}
|
||||
else {
|
||||
throw std::runtime_error("Kompute OpAlgoBase Error reached fetchSpirvBinaryData but neither filepath nor data provided");
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"Kompute OpAlgoBase Error reached fetchSpirvBinaryData but neither "
|
||||
"filepath nor data provided");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue