init
Some checks failed
main / Linux (amd64, ubuntu-22.04) (push) Successful in 49s
main / Linux (arm64, ubuntu-24.04-arm) (push) Has been cancelled
main / Windows (push) Has been cancelled

This commit is contained in:
2025-10-24 11:42:14 +02:00
commit 42172cbb6f
85 changed files with 40316 additions and 0 deletions

View File

@@ -0,0 +1,56 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define N_THREADS 256
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
shared float temp[N_THREADS];
void main() {
const uint threadIndex = gl_LocalInvocationID.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint zIndex = gl_WorkGroupID.z;
const uint b = zIndex * N_BATCHES + batchIndex;
const BatchInfo info = infos[b];
const uint size = info.inputSizeX;
const uint xOffset = info.inputOffset;
const uint yOffset = info.outputOffset;
float m = -1e10f;
for (uint i = threadIndex; i < size; i += N_THREADS) {
m = max(m, x[xOffset + i]);
}
temp[threadIndex] = m;
barrier();
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
if (threadIndex < i)
temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
barrier();
}
barrier();
const float maxVal = temp[0];
for (uint i = threadIndex; i < size; i += N_THREADS) {
y[yOffset + i] = exp(x[xOffset + i] - maxVal);
}
}