From b68446beeb73cc5aac4e6fe2f6483bcbe3112a06 Mon Sep 17 00:00:00 2001 From: Alejandro Saucedo Date: Sun, 8 Nov 2020 16:04:05 +0000 Subject: [PATCH] Updated readme for python example --- README.md | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 43b3b8511..4facb1137 100644 --- a/README.md +++ b/README.md @@ -306,8 +306,18 @@ tensor_out = Tensor([0, 0, 0]) mgr.eval_tensor_create_def([tensor_in_a, tensor_in_b, tensor_out]) -shaderFilePath = "shaders/glsl/opmult.comp" -mgr.eval_async_algo_file_def([tensor_in_a, tensor_in_b, tensor_out], shaderFilePath) +# Define the function via PyShader or directly as glsl string or spirv bytes +@python2shader +def compute_shader_multiply(index=("input", "GlobalInvocationId", ivec3), + data1=("buffer", 0, Array(f32)), + data2=("buffer", 1, Array(f32)), + data3=("buffer", 2, Array(f32))): + i = index.x + data3[i] = data1[i] * data2[i] + +# Run shader operation synchronously +mgr.eval_algo_data_def( + [tensor_in_a, tensor_in_b, tensor_out], compute_shader_multiply.to_spirv()) # Alternatively can pass raw string/bytes: # shaderFileData = """ shader code here... """ @@ -332,13 +342,22 @@ tensor_in_a = Tensor([2, 2, 2]) tensor_in_b = Tensor([1, 2, 3]) tensor_out = Tensor([0, 0, 0]) -shaderFilePath = "../../shaders/glsl/opmult.comp" - mgr.eval_tensor_create_def([tensor_in_a, tensor_in_b, tensor_out]) seq = mgr.create_sequence("op") -mgr.eval_async_algo_file_def([tensor_in_a, tensor_in_b, tensor_out], shaderFilePath) +# Define the function via PyShader or directly as glsl string or spirv bytes +@python2shader +def compute_shader_multiply(index=("input", "GlobalInvocationId", ivec3), + data1=("buffer", 0, Array(f32)), + data2=("buffer", 1, Array(f32)), + data3=("buffer", 2, Array(f32))): + i = index.x + data3[i] = data1[i] * data2[i] + +# Run shader operation asynchronously and then await +mgr.eval_async_algo_data_def( + [tensor_in_a, tensor_in_b, tensor_out], compute_shader_multiply.to_spirv()) mgr.eval_await_def() seq.begin()