init
This commit is contained in:
37
src/nn/vulkan/cast-forward-f32-f32.comp
Normal file
37
src/nn/vulkan/cast-forward-f32-f32.comp
Normal file
@@ -0,0 +1,37 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#define CHUNK_SIZE 4
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||
layout(constant_id = 1) const uint N_Z = 1;
|
||||
|
||||
struct BatchInfo {
|
||||
uint inputOffset;
|
||||
uint inputSizeX;
|
||||
uint outputOffset;
|
||||
uint outputSizeX;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||
|
||||
void main() {
|
||||
const uint batchIndex = gl_WorkGroupID.y;
|
||||
const uint zIndex = gl_WorkGroupID.z;
|
||||
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||
|
||||
const uint chunkIndex = gl_WorkGroupID.x;
|
||||
const BatchInfo info = infos[b];
|
||||
const uint offset = chunkIndex * CHUNK_SIZE;
|
||||
const uint xOffset = info.inputOffset + offset;
|
||||
const uint yOffset = info.outputOffset + offset;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||
y[yOffset + i] = x[xOffset + i];
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user