#version 450 #extension GL_EXT_control_flow_attributes : enable #define N_THREADS 256 layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in; layout(constant_id = 0) const uint N_BATCHES = 32; layout(constant_id = 1) const uint N_Z = 1; struct BatchInfo { uint inputOffset; uint inputSizeX; uint outputOffset; uint outputSizeX; }; layout(binding = 0) readonly buffer inputBuffer { float x[]; }; layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; }; shared float temp[N_THREADS]; void main() { const uint threadIndex = gl_LocalInvocationID.x; const uint batchIndex = gl_WorkGroupID.y; const uint zIndex = gl_WorkGroupID.z; const uint b = zIndex * N_BATCHES + batchIndex; const BatchInfo info = infos[b]; const uint size = info.inputSizeX; const uint xOffset = info.inputOffset; const uint yOffset = info.outputOffset; float m = -1e10f; for (uint i = threadIndex; i < size; i += N_THREADS) { m = max(m, x[xOffset + i]); } temp[threadIndex] = m; barrier(); [[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) { if (threadIndex < i) temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]); barrier(); } barrier(); const float maxVal = temp[0]; for (uint i = threadIndex; i < size; i += N_THREADS) { y[yOffset + i] = exp(x[xOffset + i] - maxVal); } }