Updated examples

This commit is contained in:
Alejandro Saucedo 2021-02-28 17:07:17 +00:00
parent 63e220a8a4
commit 4fddf74ca7
11 changed files with 408 additions and 405 deletions

View file

@ -647,12 +647,19 @@ extern py::object kp_debug, kp_info, kp_warning, kp_error;
#define KP_LOG_DEBUG(...)
#else
#if defined(VK_USE_PLATFORM_ANDROID_KHR)
#define KP_LOG_DEBUG(...) \
((void)__android_log_print(ANDROID_LOG_DEBUG, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__)))
#define KP_LOG_DEBUG(...) \
((void)__android_log_write( \
ANDROID_LOG_DEBUG, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str()))
#elif defined(KOMPUTE_BUILD_PYTHON)
#define KP_LOG_DEBUG(...) kp_debug(fmt::format(__VA_ARGS__))
#else
#define KP_LOG_DEBUG(...) fmt::print("[{} {}] [debug] [{}:{}] {}\n", __DATE__, __TIME__, __FILE__, __LINE__, fmt::format(__VA_ARGS__))
#define KP_LOG_DEBUG(...) \
fmt::print("[{} {}] [debug] [{}:{}] {}\n", \
__DATE__, \
__TIME__, \
__FILE__, \
__LINE__, \
fmt::format(__VA_ARGS__))
#endif // VK_USE_PLATFORM_ANDROID_KHR
#endif // SPDLOG_ACTIVE_LEVEL > 1
@ -660,12 +667,19 @@ extern py::object kp_debug, kp_info, kp_warning, kp_error;
#define KP_LOG_INFO(...)
#else
#if defined(VK_USE_PLATFORM_ANDROID_KHR)
#define KP_LOG_INFO(...) \
((void)__android_log_print(ANDROID_LOG_INFO, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__)))
#define KP_LOG_INFO(...) \
((void)__android_log_write( \
ANDROID_LOG_INFO, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str()))
#elif defined(KOMPUTE_BUILD_PYTHON)
#define KP_LOG_INFO(...) kp_info(fmt::format(__VA_ARGS__))
#else
#define KP_LOG_INFO(...) fmt::print("[{} {}] [debug] [{}:{}] {}\n", __DATE__, __TIME__, __FILE__, __LINE__, fmt::format(__VA_ARGS__))
#define KP_LOG_INFO(...) \
fmt::print("[{} {}] [debug] [{}:{}] {}\n", \
__DATE__, \
__TIME__, \
__FILE__, \
__LINE__, \
fmt::format(__VA_ARGS__))
#endif // VK_USE_PLATFORM_ANDROID_KHR
#endif // SPDLOG_ACTIVE_LEVEL > 2
@ -673,12 +687,19 @@ extern py::object kp_debug, kp_info, kp_warning, kp_error;
#define KP_LOG_WARN(...)
#else
#if defined(VK_USE_PLATFORM_ANDROID_KHR)
#define KP_LOG_WARN(...) \
((void)__android_log_print(ANDROID_LOG_WARN, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__)))
#define KP_LOG_WARN(...) \
((void)__android_log_write( \
ANDROID_LOG_WARN, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str()))
#elif defined(KOMPUTE_BUILD_PYTHON)
#define KP_LOG_WARN(...) kp_warning(fmt::format(__VA_ARGS__))
#else
#define KP_LOG_WARN(...) fmt::print("[{} {}] [debug] [{}:{}] {}\n", __DATE__, __TIME__, __FILE__, __LINE__, fmt::format(__VA_ARGS__))
#define KP_LOG_WARN(...) \
fmt::print("[{} {}] [debug] [{}:{}] {}\n", \
__DATE__, \
__TIME__, \
__FILE__, \
__LINE__, \
fmt::format(__VA_ARGS__))
#endif // VK_USE_PLATFORM_ANDROID_KHR
#endif // SPDLOG_ACTIVE_LEVEL > 3
@ -686,12 +707,19 @@ extern py::object kp_debug, kp_info, kp_warning, kp_error;
#define KP_LOG_ERROR(...)
#else
#if defined(VK_USE_PLATFORM_ANDROID_KHR)
#define KP_LOG_ERROR(...) \
((void)__android_log_print(ANDROID_LOG_ERROR, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__)))
#define KP_LOG_ERROR(...) \
((void)__android_log_write( \
ANDROID_LOG_ERROR, KOMPUTE_LOG_TAG, fmt::format(__VA_ARGS__).c_str()))
#elif defined(KOMPUTE_BUILD_PYTHON)
#define KP_LOG_ERROR(...) kp_error(fmt::format(__VA_ARGS__))
#else
#define KP_LOG_ERROR(...) fmt::print("[{} {}] [debug] [{}:{}] {}\n", __DATE__, __TIME__, __FILE__, __LINE__, fmt::format(__VA_ARGS__))
#define KP_LOG_ERROR(...) \
fmt::print("[{} {}] [debug] [{}:{}] {}\n", \
__DATE__, \
__TIME__, \
__FILE__, \
__LINE__, \
fmt::format(__VA_ARGS__))
#endif // VK_USE_PLATFORM_ANDROID_KHR
#endif // SPDLOG_ACTIVE_LEVEL > 4
#endif // KOMPUTE_SPDLOG_ENABLED
@ -701,9 +729,9 @@ extern py::object kp_debug, kp_info, kp_warning, kp_error;
#include <iostream>
#include <vector>
#include <SPIRV/GlslangToSpv.h>
#include <glslang/Include/ResourceLimits.h>
#include <glslang/Public/ShaderLang.h>
#include <SPIRV/GlslangToSpv.h>
namespace kp {
@ -711,157 +739,161 @@ namespace kp {
// Has been adobted by:
// https://github.com/KhronosGroup/glslang/blob/master/StandAlone/ResourceLimits.cpp
const TBuiltInResource defaultResource = {
/* .MaxLights = */ 0,
/* .MaxClipPlanes = */ 0,
/* .MaxTextureUnits = */ 0,
/* .MaxTextureCoords = */ 0,
/* .MaxVertexAttribs = */ 64,
/* .MaxVertexUniformComponents = */ 4096,
/* .MaxVaryingFloats = */ 64,
/* .MaxVertexTextureImageUnits = */ 0,
/* .MaxCombinedTextureImageUnits = */ 0,
/* .MaxTextureImageUnits = */ 0,
/* .MaxFragmentUniformComponents = */ 0,
/* .MaxDrawBuffers = */ 0,
/* .MaxVertexUniformVectors = */ 128,
/* .MaxVaryingVectors = */ 8,
/* .MaxFragmentUniformVectors = */ 0,
/* .MaxVertexOutputVectors = */ 16,
/* .MaxFragmentInputVectors = */ 0,
/* .MinProgramTexelOffset = */ -8,
/* .MaxProgramTexelOffset = */ 7,
/* .MaxClipDistances = */ 8,
/* .MaxComputeWorkGroupCountX = */ 65535,
/* .MaxComputeWorkGroupCountY = */ 65535,
/* .MaxComputeWorkGroupCountZ = */ 65535,
/* .MaxComputeWorkGroupSizeX = */ 1024,
/* .MaxComputeWorkGroupSizeY = */ 1024,
/* .MaxComputeWorkGroupSizeZ = */ 64,
/* .MaxComputeUniformComponents = */ 1024,
/* .MaxComputeTextureImageUnits = */ 16,
/* .MaxComputeImageUniforms = */ 8,
/* .MaxComputeAtomicCounters = */ 8,
/* .MaxComputeAtomicCounterBuffers = */ 1,
/* .MaxVaryingComponents = */ 60,
/* .MaxVertexOutputComponents = */ 64,
/* .MaxGeometryInputComponents = */ 64,
/* .MaxGeometryOutputComponents = */ 128,
/* .MaxFragmentInputComponents = */ 0,
/* .MaxImageUnits = */ 0,
/* .MaxCombinedImageUnitsAndFragmentOutputs = */ 0,
/* .MaxCombinedShaderOutputResources = */ 8,
/* .MaxImageSamples = */ 0,
/* .MaxVertexImageUniforms = */ 0,
/* .MaxTessControlImageUniforms = */ 0,
/* .MaxTessEvaluationImageUniforms = */ 0,
/* .MaxGeometryImageUniforms = */ 0,
/* .MaxFragmentImageUniforms = */ 0,
/* .MaxCombinedImageUniforms = */ 0,
/* .MaxGeometryTextureImageUnits = */ 0,
/* .MaxGeometryOutputVertices = */ 256,
/* .MaxGeometryTotalOutputComponents = */ 1024,
/* .MaxGeometryUniformComponents = */ 1024,
/* .MaxGeometryVaryingComponents = */ 64,
/* .MaxTessControlInputComponents = */ 128,
/* .MaxTessControlOutputComponents = */ 128,
/* .MaxTessControlTextureImageUnits = */ 0,
/* .MaxTessControlUniformComponents = */ 1024,
/* .MaxTessControlTotalOutputComponents = */ 4096,
/* .MaxTessEvaluationInputComponents = */ 128,
/* .MaxTessEvaluationOutputComponents = */ 128,
/* .MaxTessEvaluationTextureImageUnits = */ 16,
/* .MaxTessEvaluationUniformComponents = */ 1024,
/* .MaxTessPatchComponents = */ 120,
/* .MaxPatchVertices = */ 32,
/* .MaxTessGenLevel = */ 64,
/* .MaxViewports = */ 16,
/* .MaxVertexAtomicCounters = */ 0,
/* .MaxTessControlAtomicCounters = */ 0,
/* .MaxTessEvaluationAtomicCounters = */ 0,
/* .MaxGeometryAtomicCounters = */ 0,
/* .MaxFragmentAtomicCounters = */ 0,
/* .MaxCombinedAtomicCounters = */ 8,
/* .MaxAtomicCounterBindings = */ 1,
/* .MaxVertexAtomicCounterBuffers = */ 0,
/* .MaxTessControlAtomicCounterBuffers = */ 0,
/* .MaxTessEvaluationAtomicCounterBuffers = */ 0,
/* .MaxGeometryAtomicCounterBuffers = */ 0,
/* .MaxFragmentAtomicCounterBuffers = */ 0,
/* .MaxCombinedAtomicCounterBuffers = */ 1,
/* .MaxAtomicCounterBufferSize = */ 16384,
/* .MaxTransformFeedbackBuffers = */ 4,
/* .MaxTransformFeedbackInterleavedComponents = */ 64,
/* .MaxCullDistances = */ 8,
/* .MaxCombinedClipAndCullDistances = */ 8,
/* .MaxSamples = */ 4,
/* .maxMeshOutputVerticesNV = */ 256,
/* .maxMeshOutputPrimitivesNV = */ 512,
/* .maxMeshWorkGroupSizeX_NV = */ 32,
/* .maxMeshWorkGroupSizeY_NV = */ 1,
/* .maxMeshWorkGroupSizeZ_NV = */ 1,
/* .maxTaskWorkGroupSizeX_NV = */ 32,
/* .maxTaskWorkGroupSizeY_NV = */ 1,
/* .maxTaskWorkGroupSizeZ_NV = */ 1,
/* .maxMeshViewCountNV = */ 4,
/* .maxDualSourceDrawBuffersEXT = */ 1,
/* .MaxLights = */ 0,
/* .MaxClipPlanes = */ 0,
/* .MaxTextureUnits = */ 0,
/* .MaxTextureCoords = */ 0,
/* .MaxVertexAttribs = */ 64,
/* .MaxVertexUniformComponents = */ 4096,
/* .MaxVaryingFloats = */ 64,
/* .MaxVertexTextureImageUnits = */ 0,
/* .MaxCombinedTextureImageUnits = */ 0,
/* .MaxTextureImageUnits = */ 0,
/* .MaxFragmentUniformComponents = */ 0,
/* .MaxDrawBuffers = */ 0,
/* .MaxVertexUniformVectors = */ 128,
/* .MaxVaryingVectors = */ 8,
/* .MaxFragmentUniformVectors = */ 0,
/* .MaxVertexOutputVectors = */ 16,
/* .MaxFragmentInputVectors = */ 0,
/* .MinProgramTexelOffset = */ -8,
/* .MaxProgramTexelOffset = */ 7,
/* .MaxClipDistances = */ 8,
/* .MaxComputeWorkGroupCountX = */ 65535,
/* .MaxComputeWorkGroupCountY = */ 65535,
/* .MaxComputeWorkGroupCountZ = */ 65535,
/* .MaxComputeWorkGroupSizeX = */ 1024,
/* .MaxComputeWorkGroupSizeY = */ 1024,
/* .MaxComputeWorkGroupSizeZ = */ 64,
/* .MaxComputeUniformComponents = */ 1024,
/* .MaxComputeTextureImageUnits = */ 16,
/* .MaxComputeImageUniforms = */ 8,
/* .MaxComputeAtomicCounters = */ 8,
/* .MaxComputeAtomicCounterBuffers = */ 1,
/* .MaxVaryingComponents = */ 60,
/* .MaxVertexOutputComponents = */ 64,
/* .MaxGeometryInputComponents = */ 64,
/* .MaxGeometryOutputComponents = */ 128,
/* .MaxFragmentInputComponents = */ 0,
/* .MaxImageUnits = */ 0,
/* .MaxCombinedImageUnitsAndFragmentOutputs = */ 0,
/* .MaxCombinedShaderOutputResources = */ 8,
/* .MaxImageSamples = */ 0,
/* .MaxVertexImageUniforms = */ 0,
/* .MaxTessControlImageUniforms = */ 0,
/* .MaxTessEvaluationImageUniforms = */ 0,
/* .MaxGeometryImageUniforms = */ 0,
/* .MaxFragmentImageUniforms = */ 0,
/* .MaxCombinedImageUniforms = */ 0,
/* .MaxGeometryTextureImageUnits = */ 0,
/* .MaxGeometryOutputVertices = */ 256,
/* .MaxGeometryTotalOutputComponents = */ 1024,
/* .MaxGeometryUniformComponents = */ 1024,
/* .MaxGeometryVaryingComponents = */ 64,
/* .MaxTessControlInputComponents = */ 128,
/* .MaxTessControlOutputComponents = */ 128,
/* .MaxTessControlTextureImageUnits = */ 0,
/* .MaxTessControlUniformComponents = */ 1024,
/* .MaxTessControlTotalOutputComponents = */ 4096,
/* .MaxTessEvaluationInputComponents = */ 128,
/* .MaxTessEvaluationOutputComponents = */ 128,
/* .MaxTessEvaluationTextureImageUnits = */ 16,
/* .MaxTessEvaluationUniformComponents = */ 1024,
/* .MaxTessPatchComponents = */ 120,
/* .MaxPatchVertices = */ 32,
/* .MaxTessGenLevel = */ 64,
/* .MaxViewports = */ 16,
/* .MaxVertexAtomicCounters = */ 0,
/* .MaxTessControlAtomicCounters = */ 0,
/* .MaxTessEvaluationAtomicCounters = */ 0,
/* .MaxGeometryAtomicCounters = */ 0,
/* .MaxFragmentAtomicCounters = */ 0,
/* .MaxCombinedAtomicCounters = */ 8,
/* .MaxAtomicCounterBindings = */ 1,
/* .MaxVertexAtomicCounterBuffers = */ 0,
/* .MaxTessControlAtomicCounterBuffers = */ 0,
/* .MaxTessEvaluationAtomicCounterBuffers = */ 0,
/* .MaxGeometryAtomicCounterBuffers = */ 0,
/* .MaxFragmentAtomicCounterBuffers = */ 0,
/* .MaxCombinedAtomicCounterBuffers = */ 1,
/* .MaxAtomicCounterBufferSize = */ 16384,
/* .MaxTransformFeedbackBuffers = */ 4,
/* .MaxTransformFeedbackInterleavedComponents = */ 64,
/* .MaxCullDistances = */ 8,
/* .MaxCombinedClipAndCullDistances = */ 8,
/* .MaxSamples = */ 4,
/* .maxMeshOutputVerticesNV = */ 256,
/* .maxMeshOutputPrimitivesNV = */ 512,
/* .maxMeshWorkGroupSizeX_NV = */ 32,
/* .maxMeshWorkGroupSizeY_NV = */ 1,
/* .maxMeshWorkGroupSizeZ_NV = */ 1,
/* .maxTaskWorkGroupSizeX_NV = */ 32,
/* .maxTaskWorkGroupSizeY_NV = */ 1,
/* .maxTaskWorkGroupSizeZ_NV = */ 1,
/* .maxMeshViewCountNV = */ 4,
/* .maxDualSourceDrawBuffersEXT = */ 1,
/* .limits = */
{
/* .nonInductiveForLoops = */ 1,
/* .whileLoops = */ 1,
/* .doWhileLoops = */ 1,
/* .generalUniformIndexing = */ 1,
/* .generalAttributeMatrixVectorIndexing = */ 1,
/* .generalVaryingIndexing = */ 1,
/* .generalSamplerIndexing = */ 1,
/* .generalVariableIndexing = */ 1,
/* .generalConstantMatrixVectorIndexing = */ 1,
}
};
/* .limits = */ {
/* .nonInductiveForLoops = */ 1,
/* .whileLoops = */ 1,
/* .doWhileLoops = */ 1,
/* .generalUniformIndexing = */ 1,
/* .generalAttributeMatrixVectorIndexing = */ 1,
/* .generalVaryingIndexing = */ 1,
/* .generalSamplerIndexing = */ 1,
/* .generalVariableIndexing = */ 1,
/* .generalConstantMatrixVectorIndexing = */ 1,
}};
/**
Shader utily class with functions to compile and process glsl files.
*/
class Shader {
public:
class Shader
{
public:
/**
* Compile multiple sources with optional filenames. Currently this function
* uses the glslang C++ interface which is not thread safe so this funciton
* should not be called from multiple threads concurrently. If you have a
* online shader processing multithreading use-case that can't use offline
* online shader processing multithreading use-case that can't use offline
* compilation please open an issue.
*
* @param sources A list of raw glsl shaders in string format
* @param files A list of file names respective to each of the sources
* @param entryPoint The function name to use as entry point
* @param definitions List of pairs containing key value definitions
* @param resourcesLimit A list that contains the resource limits for the GLSL compiler
* @param resourcesLimit A list that contains the resource limits for the
* GLSL compiler
* @return The compiled SPIR-V binary in unsigned int32 format
*/
static std::vector<uint32_t> compile_sources(
const std::vector<std::string>& sources,
const std::vector<std::string>& files = {},
const std::string& entryPoint = "main",
std::vector<std::pair<std::string,std::string>> definitions = {},
const TBuiltInResource& resources = defaultResource);
const std::vector<std::string>& sources,
const std::vector<std::string>& files = {},
const std::string& entryPoint = "main",
std::vector<std::pair<std::string, std::string>> definitions = {},
const TBuiltInResource& resources = defaultResource);
/**
* Compile a single glslang source from string value. Currently this function
* uses the glslang C++ interface which is not thread safe so this funciton
* should not be called from multiple threads concurrently. If you have a
* online shader processing multithreading use-case that can't use offline
* compilation please open an issue.
* Compile a single glslang source from string value. Currently this
* function uses the glslang C++ interface which is not thread safe so this
* funciton should not be called from multiple threads concurrently. If you
* have a online shader processing multithreading use-case that can't use
* offline compilation please open an issue.
*
* @param source An individual raw glsl shader in string format
* @param entryPoint The function name to use as entry point
* @param definitions List of pairs containing key value definitions
* @param resourcesLimit A list that contains the resource limits for the GLSL compiler
* @param resourcesLimit A list that contains the resource limits for the
* GLSL compiler
* @return The compiled SPIR-V binary in unsigned int32 format
*/
static std::vector<uint32_t> compile_source(
const std::string& source,
const std::string& entryPoint = "main",
std::vector<std::pair<std::string,std::string>> definitions = {},
const TBuiltInResource& resources = defaultResource);
const std::string& source,
const std::string& entryPoint = "main",
std::vector<std::pair<std::string, std::string>> definitions = {},
const TBuiltInResource& resources = defaultResource);
};
}
@ -919,7 +951,7 @@ class Tensor
* otherwise there is no need to copy from host memory.
*/
void rebuild(const std::vector<float>& data,
TensorTypes tensorType = TensorTypes::eDevice);
TensorTypes tensorType = TensorTypes::eDevice);
/**
* Destroys and frees the GPU resources which include the buffer and memory.
@ -990,9 +1022,8 @@ class Tensor
* @param createBarrier Whether to create a barrier that ensures the data is
* copied before further operations. Default is true.
*/
void recordCopyFromStagingToDevice(
const vk::CommandBuffer& commandBuffer,
bool createBarrier);
void recordCopyFromStagingToDevice(const vk::CommandBuffer& commandBuffer,
bool createBarrier);
/**
* Records a copy from the internal device memory to the staging memory
@ -1003,9 +1034,8 @@ class Tensor
* @param createBarrier Whether to create a barrier that ensures the data is
* copied before further operations. Default is true.
*/
void recordCopyFromDeviceToStaging(
const vk::CommandBuffer& commandBuffer,
bool createBarrier);
void recordCopyFromDeviceToStaging(const vk::CommandBuffer& commandBuffer,
bool createBarrier);
/**
* Records the buffer memory barrier into the command buffer which
@ -1017,12 +1047,11 @@ class Tensor
* @param scrStageMask Pipeline stage flags for source stage mask
* @param dstStageMask Pipeline stage flags for destination stage mask
*/
void recordBufferMemoryBarrier(
const vk::CommandBuffer& commandBuffer,
vk::AccessFlagBits srcAccessMask,
vk::AccessFlagBits dstAccessMask,
vk::PipelineStageFlagBits srcStageMask,
vk::PipelineStageFlagBits dstStageMask);
void recordBufferMemoryBarrier(const vk::CommandBuffer& commandBuffer,
vk::AccessFlagBits srcAccessMask,
vk::AccessFlagBits dstAccessMask,
vk::PipelineStageFlagBits srcStageMask,
vk::PipelineStageFlagBits dstStageMask);
/**
* Constructs a vulkan descriptor buffer info which can be used to specify
@ -1070,11 +1099,11 @@ class Tensor
std::shared_ptr<vk::DeviceMemory> memory,
vk::MemoryPropertyFlags memoryPropertyFlags);
void recordCopyBuffer(const vk::CommandBuffer& commandBuffer,
std::shared_ptr<vk::Buffer> bufferFrom,
std::shared_ptr<vk::Buffer> bufferTo,
vk::DeviceSize bufferSize,
vk::BufferCopy copyRegion,
bool createBarrier);
std::shared_ptr<vk::Buffer> bufferFrom,
std::shared_ptr<vk::Buffer> bufferTo,
vk::DeviceSize bufferSize,
vk::BufferCopy copyRegion,
bool createBarrier);
// Private util functions
vk::BufferUsageFlags getPrimaryBufferUsageFlags();
@ -1094,8 +1123,7 @@ namespace kp {
*/
class Algorithm
{
public:
public:
/**
* Default constructor for Algorithm
*
@ -1103,12 +1131,11 @@ public:
* @param commandBuffer The vulkan command buffer to bind the pipeline and
* shaders
*/
Algorithm(
std::shared_ptr<vk::Device> device,
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
const std::vector<uint32_t>& spirv = {},
const Workgroup& workgroup = {},
const Constants& specializationConstants = {});
Algorithm(std::shared_ptr<vk::Device> device,
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
const std::vector<uint32_t>& spirv = {},
const Workgroup& workgroup = {},
const Constants& specializationConstants = {});
/**
* Initialiser for the shader data provided to the algorithm as well as
@ -1116,14 +1143,13 @@ public:
*
* @param shaderFileData The bytes in spir-v format of the shader
* @tensorParams The Tensors to be used in the Algorithm / shader for
* @specalizationInstalces The specialization parameters to pass to the function
* processing
* @specalizationInstalces The specialization parameters to pass to the
* function processing
*/
void rebuild(
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
const std::vector<uint32_t>& spirv = {},
const Workgroup& workgroup = {},
const Constants& specializationConstants = {});
void rebuild(const std::vector<std::shared_ptr<Tensor>>& tensors = {},
const std::vector<uint32_t>& spirv = {},
const Workgroup& workgroup = {},
const Constants& specializationConstants = {});
/**
* Destructor for Algorithm which is responsible for freeing and desroying
@ -1143,7 +1169,8 @@ public:
void bindCore(const vk::CommandBuffer& commandBuffer);
void bindPush(const vk::CommandBuffer& commandBuffer, const Constants& pushConstants);
void bindPush(const vk::CommandBuffer& commandBuffer,
const Constants& pushConstants);
bool isInit();
@ -1155,7 +1182,7 @@ public:
void destroy();
private:
private:
// -------------- NEVER OWNED RESOURCES
std::shared_ptr<vk::Device> mDevice;
std::vector<std::shared_ptr<Tensor>> mTensors;
@ -1489,7 +1516,7 @@ namespace kp {
/**
* Container of operations that can be sent to GPU as batch
*/
class Sequence: public std::enable_shared_from_this<Sequence>
class Sequence : public std::enable_shared_from_this<Sequence>
{
public:
/**
@ -1526,8 +1553,9 @@ class Sequence: public std::enable_shared_from_this<Sequence>
* which allows for extensible configurations on initialisation.
*/
template<typename T, typename... TArgs>
std::shared_ptr<Sequence>
record(std::vector<std::shared_ptr<Tensor>> tensors, TArgs&&... params)
std::shared_ptr<Sequence> record(
std::vector<std::shared_ptr<Tensor>> tensors,
TArgs&&... params)
{
KP_LOG_DEBUG("Kompute Sequence record function started");
@ -1536,14 +1564,13 @@ class Sequence: public std::enable_shared_from_this<Sequence>
"OpBase derived classes");
KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance");
std::shared_ptr<T> op{
new T(tensors, std::forward<TArgs>(params)...) };
std::shared_ptr<T> op{ new T(tensors, std::forward<TArgs>(params)...) };
return this->record(op);
}
template<typename T, typename... TArgs>
std::shared_ptr<Sequence>
record(std::shared_ptr<Algorithm> algorithm, TArgs&&... params)
std::shared_ptr<Sequence> record(std::shared_ptr<Algorithm> algorithm,
TArgs&&... params)
{
KP_LOG_DEBUG("Kompute Sequence record function started");
@ -1552,8 +1579,8 @@ class Sequence: public std::enable_shared_from_this<Sequence>
"OpBase derived classes");
KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance");
std::shared_ptr<T> op{
new T(algorithm, std::forward<TArgs>(params)...) };
std::shared_ptr<T> op{ new T(algorithm,
std::forward<TArgs>(params)...) };
return this->record(op);
}
@ -1576,8 +1603,8 @@ class Sequence: public std::enable_shared_from_this<Sequence>
*/
// TODO: Aim to have only a single function with tensors/algorithm
template<typename T, typename... TArgs>
std::shared_ptr<Sequence>
eval(std::vector<std::shared_ptr<Tensor>> tensors, TArgs&&... params)
std::shared_ptr<Sequence> eval(std::vector<std::shared_ptr<Tensor>> tensors,
TArgs&&... params)
{
KP_LOG_DEBUG("Kompute Sequence record function started");
@ -1586,16 +1613,16 @@ class Sequence: public std::enable_shared_from_this<Sequence>
"OpBase derived classes");
KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance");
std::shared_ptr<T> op{
new T(tensors, std::forward<TArgs>(params)...) };
std::shared_ptr<T> op{ new T(tensors, std::forward<TArgs>(params)...) };
// TODO: Aim to be able to handle errors when returning without throw except
// TODO: Aim to be able to handle errors when returning without throw
// except
return this->eval(op);
}
// Needded as otherise can't use initialiser list
template<typename T, typename... TArgs>
std::shared_ptr<Sequence>
eval(std::shared_ptr<Algorithm> algorithm, TArgs&&... params)
std::shared_ptr<Sequence> eval(std::shared_ptr<Algorithm> algorithm,
TArgs&&... params)
{
KP_LOG_DEBUG("Kompute Sequence record function started");
@ -1604,8 +1631,8 @@ class Sequence: public std::enable_shared_from_this<Sequence>
"OpBase derived classes");
KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance");
std::shared_ptr<T> op{
new T(algorithm, std::forward<TArgs>(params)...) };
std::shared_ptr<T> op{ new T(algorithm,
std::forward<TArgs>(params)...) };
return this->eval(op);
}
@ -1627,8 +1654,9 @@ class Sequence: public std::enable_shared_from_this<Sequence>
* @return shared_ptr<Sequence> of the Sequence class itself
*/
template<typename T, typename... TArgs>
std::shared_ptr<Sequence>
evalAsync(std::vector<std::shared_ptr<Tensor>> tensors, TArgs&&... params)
std::shared_ptr<Sequence> evalAsync(
std::vector<std::shared_ptr<Tensor>> tensors,
TArgs&&... params)
{
KP_LOG_DEBUG("Kompute Sequence record function started");
@ -1637,15 +1665,14 @@ class Sequence: public std::enable_shared_from_this<Sequence>
"OpBase derived classes");
KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance");
std::shared_ptr<T> op{
new T(tensors, std::forward<TArgs>(params)...) };
std::shared_ptr<T> op{ new T(tensors, std::forward<TArgs>(params)...) };
return this->evalAsync(op);
}
// Needed as otherwise it's not possible to use initializer lists
template<typename T, typename... TArgs>
std::shared_ptr<Sequence>
evalAsync(std::shared_ptr<Algorithm> algorithm, TArgs&&... params)
std::shared_ptr<Sequence> evalAsync(std::shared_ptr<Algorithm> algorithm,
TArgs&&... params)
{
KP_LOG_DEBUG("Kompute Sequence record function started");
@ -1654,8 +1681,8 @@ class Sequence: public std::enable_shared_from_this<Sequence>
"OpBase derived classes");
KP_LOG_DEBUG("Kompute Sequence creating OpBase derived class instance");
std::shared_ptr<T> op{
new T(algorithm, std::forward<TArgs>(params)...) };
std::shared_ptr<T> op{ new T(algorithm,
std::forward<TArgs>(params)...) };
return this->evalAsync(op);
}
@ -1670,7 +1697,8 @@ class Sequence: public std::enable_shared_from_this<Sequence>
std::shared_ptr<Sequence> evalAwait(uint64_t waitFor = UINT64_MAX);
/**
* Clear function clears all operations currently recorded and starts recording again.
* Clear function clears all operations currently recorded and starts
* recording again.
*/
void clear();
@ -1821,10 +1849,10 @@ class Manager
Tensor::TensorTypes tensorType = Tensor::TensorTypes::eDevice);
std::shared_ptr<Algorithm> algorithm(
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
const std::vector<uint32_t>& spirv = {},
const Workgroup& workgroup = {},
const Constants& specializationConstants = {});
const std::vector<std::shared_ptr<Tensor>>& tensors = {},
const std::vector<uint32_t>& spirv = {},
const Workgroup& workgroup = {},
const Constants& specializationConstants = {});
void destroy();
void clear();
@ -1856,7 +1884,8 @@ class Manager
// Create functions
void createInstance();
void createDevice(const std::vector<uint32_t>& familyQueueIndices = {}, uint32_t hysicalDeviceIndex = 0);
void createDevice(const std::vector<uint32_t>& familyQueueIndices = {},
uint32_t hysicalDeviceIndex = 0);
};
} // End namespace kp