llama-cpp-turboquant/src/include/kompute/operations/OpMult.hpp
2021-02-27 14:49:12 +00:00

85 lines
2.4 KiB
C++

#pragma once
#include <fstream>
#include "kompute/Core.hpp"
#if RELEASE
#include "kompute/shaders/shaderopmult.hpp"
#endif
#include "kompute/Algorithm.hpp"
#include "kompute/Tensor.hpp"
#include "kompute/operations/OpAlgoCreate.hpp"
namespace kp {
/**
* Operation that performs multiplication on two tensors and outpus on third
* tensor.
*/
class OpMult : public OpAlgoCreate
{
public:
/**
* Base constructor, should not be used unless explicitly intended.
*/
OpMult() {
}
/**
* Default constructor with parameters that provides the bare minimum
* requirements for the operations to be able to create and manage their
* sub-components.
*
* @param physicalDevice Vulkan physical device used to find device queues
* @param device Vulkan logical device for passing to Algorithm
* @param commandBuffer Vulkan Command Buffer to record commands into
* @param tensors Tensors that are to be used in this operation
* @param komputeWorkgroup Optional parameter to specify the layout for processing
*/
OpMult(std::shared_ptr<vk::PhysicalDevice> physicalDevice,
std::shared_ptr<vk::Device> device,
std::shared_ptr<vk::CommandBuffer> commandBuffer,
std::vector<std::shared_ptr<Tensor>> tensors,
const Workgroup& komputeWorkgroup = {})
: OpAlgoCreate(physicalDevice, device, commandBuffer, tensors, "", komputeWorkgroup)
{
KP_LOG_DEBUG("Kompute OpMult constructor with params");
#ifndef RELEASE
this->mShaderFilePath = "shaders/glsl/opmult.comp.spv";
#endif
}
#if RELEASE
/**
* If RELEASE=1 it will be using the static version of the shader which is
* loaded using this file directly. Otherwise it should not override the function.
*/
std::vector<uint32_t> fetchSpirvBinaryData() override
{
KP_LOG_WARN(
"Kompute OpMult Running shaders directly from header");
return std::vector<uint32_t>(
(uint32_t*)shader_data::shaders_glsl_opmult_comp_spv,
(uint32_t*)(shader_data::shaders_glsl_opmult_comp_spv +
kp::shader_data::shaders_glsl_opmult_comp_spv_len));
}
#endif
/**
* Default destructor, which is in charge of destroying the algorithm
* components but does not destroy the underlying tensors
*/
~OpMult() override {
KP_LOG_DEBUG("Kompute OpMult destructor started");
}
};
} // End namespace kp