All python tests pass
This commit is contained in:
parent
4c4d073b90
commit
91d3b9a223
11 changed files with 158 additions and 169 deletions
|
|
@ -6,10 +6,10 @@
|
|||
#include "kompute/Tensor.hpp"
|
||||
#include "kompute/Algorithm.hpp"
|
||||
#include "kompute/operations/OpBase.hpp"
|
||||
#include "kompute/operations/OpMult.hpp"
|
||||
#include "kompute/operations/OpTensorCopy.hpp"
|
||||
#include "kompute/operations/OpTensorSyncDevice.hpp"
|
||||
#include "kompute/operations/OpTensorSyncLocal.hpp"
|
||||
#include "kompute/operations/OpAlgoDispatch.hpp"
|
||||
#include "kompute/operations/OpMult.hpp"
|
||||
#include "kompute/Sequence.hpp"
|
||||
#include "kompute/Manager.hpp"
|
||||
|
|
|
|||
|
|
@ -1247,106 +1247,6 @@ class OpBase
|
|||
|
||||
} // End namespace kp
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace kp {
|
||||
|
||||
/**
|
||||
* Operation that provides a general abstraction that simplifies the use of
|
||||
* algorithm and parameter components which can be used with shaders.
|
||||
* By default it enables the user to provide a dynamic number of tensors
|
||||
* which are then passed as inputs.
|
||||
*/
|
||||
class OpAlgoDispatch : public OpBase
|
||||
{
|
||||
public:
|
||||
|
||||
OpAlgoDispatch(const std::shared_ptr<kp::Algorithm>& algorithm, bool skipAlgoCheck = false);
|
||||
|
||||
/**
|
||||
* Default destructor, which is in charge of destroying the algorithm
|
||||
* components but does not destroy the underlying tensors
|
||||
*/
|
||||
virtual ~OpAlgoDispatch() override;
|
||||
|
||||
/**
|
||||
* This records the commands that are to be sent to the GPU. This includes
|
||||
* the barriers that ensure the memory has been copied before going in and
|
||||
* out of the shader, as well as the dispatch operation that sends the
|
||||
* shader processing to the gpu. This function also records the GPU memory
|
||||
* copy of the output data for the staging buffer so it can be read by the
|
||||
* host.
|
||||
*/
|
||||
virtual void record(std::shared_ptr<vk::CommandBuffer> commandBuffer) override;
|
||||
|
||||
/**
|
||||
* Does not perform any preEval commands.
|
||||
*/
|
||||
virtual void preEval() override;
|
||||
|
||||
/**
|
||||
* Executes after the recorded commands are submitted, and performs a copy
|
||||
* of the GPU Device memory into the staging buffer so the output data can
|
||||
* be retrieved.
|
||||
*/
|
||||
virtual void postEval() override;
|
||||
|
||||
private:
|
||||
// -------------- ALWAYS OWNED RESOURCES
|
||||
std::shared_ptr<Algorithm> mAlgorithm;
|
||||
};
|
||||
|
||||
} // End namespace kp
|
||||
|
||||
namespace kp {
|
||||
|
||||
/**
|
||||
* Operation that performs multiplication on two tensors and outpus on third
|
||||
* tensor.
|
||||
*/
|
||||
class OpMult : public OpAlgoDispatch
|
||||
{
|
||||
public:
|
||||
|
||||
/**
|
||||
* Default constructor with parameters that provides the bare minimum
|
||||
* requirements for the operations to be able to create and manage their
|
||||
* sub-components.
|
||||
*
|
||||
* @param physicalDevice Vulkan physical device used to find device queues
|
||||
* @param device Vulkan logical device for passing to Algorithm
|
||||
* @param commandBuffer Vulkan Command Buffer to record commands into
|
||||
* @param tensors Tensors that are to be used in this operation
|
||||
* @param komputeWorkgroup Optional parameter to specify the layout for processing
|
||||
*/
|
||||
OpMult(std::vector<std::shared_ptr<Tensor>> tensors, std::shared_ptr<Algorithm> algorithm)
|
||||
: OpAlgoDispatch(algorithm, true)
|
||||
{
|
||||
KP_LOG_DEBUG("Kompute OpMult constructor with params");
|
||||
|
||||
if (tensors.size() != 3) {
|
||||
throw std::runtime_error("Kompute OpMult expected 3 tensors but got " + tensors.size());
|
||||
}
|
||||
|
||||
std::vector<uint32_t> spirv(
|
||||
(uint32_t*)shader_data::shaders_glsl_opmult_comp_spv,
|
||||
(uint32_t*)(shader_data::shaders_glsl_opmult_comp_spv +
|
||||
kp::shader_data::shaders_glsl_opmult_comp_spv_len));
|
||||
|
||||
algorithm->rebuild(tensors, spirv);
|
||||
}
|
||||
|
||||
/**
|
||||
* Default destructor, which is in charge of destroying the algorithm
|
||||
* components but does not destroy the underlying tensors
|
||||
*/
|
||||
virtual ~OpMult() override {
|
||||
KP_LOG_DEBUG("Kompute OpMult destructor started");
|
||||
}
|
||||
};
|
||||
|
||||
} // End namespace kp
|
||||
|
||||
namespace kp {
|
||||
|
||||
/**
|
||||
|
|
@ -1484,6 +1384,106 @@ class OpTensorSyncLocal : public OpBase
|
|||
|
||||
namespace kp {
|
||||
|
||||
/**
|
||||
* Operation that provides a general abstraction that simplifies the use of
|
||||
* algorithm and parameter components which can be used with shaders.
|
||||
* By default it enables the user to provide a dynamic number of tensors
|
||||
* which are then passed as inputs.
|
||||
*/
|
||||
class OpAlgoDispatch : public OpBase
|
||||
{
|
||||
public:
|
||||
|
||||
OpAlgoDispatch(const std::shared_ptr<kp::Algorithm>& algorithm);
|
||||
|
||||
/**
|
||||
* Default destructor, which is in charge of destroying the algorithm
|
||||
* components but does not destroy the underlying tensors
|
||||
*/
|
||||
virtual ~OpAlgoDispatch() override;
|
||||
|
||||
/**
|
||||
* This records the commands that are to be sent to the GPU. This includes
|
||||
* the barriers that ensure the memory has been copied before going in and
|
||||
* out of the shader, as well as the dispatch operation that sends the
|
||||
* shader processing to the gpu. This function also records the GPU memory
|
||||
* copy of the output data for the staging buffer so it can be read by the
|
||||
* host.
|
||||
*/
|
||||
virtual void record(std::shared_ptr<vk::CommandBuffer> commandBuffer) override;
|
||||
|
||||
/**
|
||||
* Does not perform any preEval commands.
|
||||
*/
|
||||
virtual void preEval() override;
|
||||
|
||||
/**
|
||||
* Executes after the recorded commands are submitted, and performs a copy
|
||||
* of the GPU Device memory into the staging buffer so the output data can
|
||||
* be retrieved.
|
||||
*/
|
||||
virtual void postEval() override;
|
||||
|
||||
private:
|
||||
// -------------- ALWAYS OWNED RESOURCES
|
||||
std::shared_ptr<Algorithm> mAlgorithm;
|
||||
};
|
||||
|
||||
} // End namespace kp
|
||||
|
||||
#include <fstream>
|
||||
|
||||
namespace kp {
|
||||
|
||||
/**
|
||||
* Operation that performs multiplication on two tensors and outpus on third
|
||||
* tensor.
|
||||
*/
|
||||
class OpMult : public OpAlgoDispatch
|
||||
{
|
||||
public:
|
||||
|
||||
/**
|
||||
* Default constructor with parameters that provides the bare minimum
|
||||
* requirements for the operations to be able to create and manage their
|
||||
* sub-components.
|
||||
*
|
||||
* @param physicalDevice Vulkan physical device used to find device queues
|
||||
* @param device Vulkan logical device for passing to Algorithm
|
||||
* @param commandBuffer Vulkan Command Buffer to record commands into
|
||||
* @param tensors Tensors that are to be used in this operation
|
||||
* @param komputeWorkgroup Optional parameter to specify the layout for processing
|
||||
*/
|
||||
OpMult(std::vector<std::shared_ptr<Tensor>> tensors, std::shared_ptr<Algorithm> algorithm)
|
||||
: OpAlgoDispatch(algorithm)
|
||||
{
|
||||
KP_LOG_DEBUG("Kompute OpMult constructor with params");
|
||||
|
||||
if (tensors.size() != 3) {
|
||||
throw std::runtime_error("Kompute OpMult expected 3 tensors but got " + tensors.size());
|
||||
}
|
||||
|
||||
std::vector<uint32_t> spirv(
|
||||
(uint32_t*)shader_data::shaders_glsl_opmult_comp_spv,
|
||||
(uint32_t*)(shader_data::shaders_glsl_opmult_comp_spv +
|
||||
kp::shader_data::shaders_glsl_opmult_comp_spv_len));
|
||||
|
||||
algorithm->rebuild(tensors, spirv);
|
||||
}
|
||||
|
||||
/**
|
||||
* Default destructor, which is in charge of destroying the algorithm
|
||||
* components but does not destroy the underlying tensors
|
||||
*/
|
||||
virtual ~OpMult() override {
|
||||
KP_LOG_DEBUG("Kompute OpMult destructor started");
|
||||
}
|
||||
};
|
||||
|
||||
} // End namespace kp
|
||||
|
||||
namespace kp {
|
||||
|
||||
/**
|
||||
* Container of operations that can be sent to GPU as batch
|
||||
*/
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue