57 lines
1.5 KiB
Plaintext
57 lines
1.5 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
|
|
#define N_THREADS 256
|
|
|
|
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
|
layout(constant_id = 1) const uint N_Z = 1;
|
|
|
|
struct BatchInfo {
|
|
uint inputOffset;
|
|
uint inputSizeX;
|
|
uint outputOffset;
|
|
uint outputSizeX;
|
|
};
|
|
|
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
|
|
|
shared float temp[N_THREADS];
|
|
|
|
void main() {
|
|
const uint threadIndex = gl_LocalInvocationID.x;
|
|
const uint batchIndex = gl_WorkGroupID.y;
|
|
const uint zIndex = gl_WorkGroupID.z;
|
|
const uint b = zIndex * N_BATCHES + batchIndex;
|
|
|
|
const BatchInfo info = infos[b];
|
|
const uint size = info.inputSizeX;
|
|
const uint xOffset = info.inputOffset;
|
|
const uint yOffset = info.outputOffset;
|
|
|
|
float m = -1e10f;
|
|
for (uint i = threadIndex; i < size; i += N_THREADS) {
|
|
m = max(m, x[xOffset + i]);
|
|
}
|
|
|
|
temp[threadIndex] = m;
|
|
barrier();
|
|
|
|
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
|
|
if (threadIndex < i)
|
|
temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
|
|
barrier();
|
|
}
|
|
|
|
barrier();
|
|
|
|
const float maxVal = temp[0];
|
|
for (uint i = threadIndex; i < size; i += N_THREADS) {
|
|
y[yOffset + i] = exp(x[xOffset + i] - maxVal);
|
|
}
|
|
}
|