init
This commit is contained in:
56
src/nn/vulkan/softmax-forward-f32-f32.comp
Normal file
56
src/nn/vulkan/softmax-forward-f32-f32.comp
Normal file
@@ -0,0 +1,56 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
#define N_THREADS 256
|
||||
|
||||
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||
layout(constant_id = 1) const uint N_Z = 1;
|
||||
|
||||
struct BatchInfo {
|
||||
uint inputOffset;
|
||||
uint inputSizeX;
|
||||
uint outputOffset;
|
||||
uint outputSizeX;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||
|
||||
shared float temp[N_THREADS];
|
||||
|
||||
void main() {
|
||||
const uint threadIndex = gl_LocalInvocationID.x;
|
||||
const uint batchIndex = gl_WorkGroupID.y;
|
||||
const uint zIndex = gl_WorkGroupID.z;
|
||||
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||
|
||||
const BatchInfo info = infos[b];
|
||||
const uint size = info.inputSizeX;
|
||||
const uint xOffset = info.inputOffset;
|
||||
const uint yOffset = info.outputOffset;
|
||||
|
||||
float m = -1e10f;
|
||||
for (uint i = threadIndex; i < size; i += N_THREADS) {
|
||||
m = max(m, x[xOffset + i]);
|
||||
}
|
||||
|
||||
temp[threadIndex] = m;
|
||||
barrier();
|
||||
|
||||
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
|
||||
if (threadIndex < i)
|
||||
temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
|
||||
barrier();
|
||||
}
|
||||
|
||||
barrier();
|
||||
|
||||
const float maxVal = temp[0];
|
||||
for (uint i = threadIndex; i < size; i += N_THREADS) {
|
||||
y[yOffset + i] = exp(x[xOffset + i] - maxVal);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user