Reformatted
This commit is contained in:
parent
181efc954b
commit
5bb9046b49
11 changed files with 148 additions and 107 deletions
|
|
@ -46,23 +46,36 @@ OpMult::init(std::vector<std::shared_ptr<Tensor>> tensors)
|
|||
this->mTensorOutput = tensors[2];
|
||||
|
||||
// TODO: Explore adding a validate function
|
||||
if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() && this->mTensorOutput->isInit())) {
|
||||
throw std::runtime_error("Kompute OpMult all tensor parameters must be initialised. LHS: " + std::to_string(this->mTensorLHS->isInit()) + " RHS: " + std::to_string(this->mTensorRHS->isInit()) + " Output: " + std::to_string(this->mTensorOutput->isInit()));
|
||||
if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() &&
|
||||
this->mTensorOutput->isInit())) {
|
||||
throw std::runtime_error(
|
||||
"Kompute OpMult all tensor parameters must be initialised. LHS: " +
|
||||
std::to_string(this->mTensorLHS->isInit()) +
|
||||
" RHS: " + std::to_string(this->mTensorRHS->isInit()) +
|
||||
" Output: " + std::to_string(this->mTensorOutput->isInit()));
|
||||
}
|
||||
|
||||
// TODO: Explore use-cases where tensors shouldn't be the same size, and how to deal with those situations
|
||||
if (!(this->mTensorLHS->size() == this->mTensorRHS->size() && this->mTensorRHS->size() == this->mTensorOutput->size())) {
|
||||
throw std::runtime_error("Kompute OpMult all tensor parameters must be the same size LHS: " + std::to_string(this->mTensorLHS->size()) + " RHS: " + std::to_string(this->mTensorRHS->size()) + " Output: " + std::to_string(this->mTensorOutput->size()));
|
||||
// TODO: Explore use-cases where tensors shouldn't be the same size, and how
|
||||
// to deal with those situations
|
||||
if (!(this->mTensorLHS->size() == this->mTensorRHS->size() &&
|
||||
this->mTensorRHS->size() == this->mTensorOutput->size())) {
|
||||
throw std::runtime_error(
|
||||
"Kompute OpMult all tensor parameters must be the same size LHS: " +
|
||||
std::to_string(this->mTensorLHS->size()) +
|
||||
" RHS: " + std::to_string(this->mTensorRHS->size()) +
|
||||
" Output: " + std::to_string(this->mTensorOutput->size()));
|
||||
}
|
||||
|
||||
this->mTensorOutputStaging = std::make_shared<Tensor>(
|
||||
this->mTensorOutput->data(), Tensor::TensorTypes::eStaging);
|
||||
|
||||
this->mTensorOutputStaging->init(this->mPhysicalDevice, this->mDevice, this->mCommandBuffer, this->mTensorOutput->data());
|
||||
this->mTensorOutputStaging->init(this->mPhysicalDevice,
|
||||
this->mDevice,
|
||||
this->mCommandBuffer,
|
||||
this->mTensorOutput->data());
|
||||
|
||||
// TODO: Make this path configurable
|
||||
this->mAlgorithm->init(
|
||||
"shaders/glsl/opmult.comp.spv", tensors);
|
||||
this->mAlgorithm->init("shaders/glsl/opmult.comp.spv", tensors);
|
||||
}
|
||||
|
||||
void
|
||||
|
|
@ -75,7 +88,8 @@ OpMult::record()
|
|||
this->mTensorOutputStaging->recordCopyFrom(this->mTensorOutput);
|
||||
}
|
||||
|
||||
void OpMult::postSubmit()
|
||||
void
|
||||
OpMult::postSubmit()
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute OpCreateTensor postSubmit called");
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue