70 lines
1.8 KiB
Plaintext
70 lines
1.8 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
|
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
|
layout(constant_id = 1) const uint N = 16;
|
|
layout(constant_id = 2) const uint K = 2;
|
|
|
|
struct BatchInfo {
|
|
uint inputOffset;
|
|
uint inputSizeX;
|
|
uint outputOffset;
|
|
uint outputSizeX;
|
|
};
|
|
|
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[K * N_BATCHES]; };
|
|
layout(binding = 3) readonly uniform opConfigBuffer {
|
|
uint k;
|
|
uint normTopk;
|
|
uint indexesBufferIndex;
|
|
};
|
|
layout(binding = 4) buffer indexesBuffer { float indexes[]; };
|
|
|
|
shared float topVals[K];
|
|
shared uint topIdx[K];
|
|
|
|
void main() {
|
|
// TODO: this impl is not optimal
|
|
const uint batchIndex = gl_WorkGroupID.y;
|
|
BatchInfo info = infos[batchIndex];
|
|
|
|
for (uint i = 0; i < K; i++) {
|
|
topVals[i] = -1e10f;
|
|
topIdx[i] = 0;
|
|
}
|
|
|
|
for (uint i = 0; i < info.inputSizeX; i++) {
|
|
float v = x[info.inputOffset + i];
|
|
|
|
for (uint k = 0; k < K; k++) {
|
|
if (v > topVals[k]) {
|
|
for (uint s = K - 1; s > k; s--) {
|
|
topVals[s] = topVals[s - 1];
|
|
topIdx[s] = topIdx[s - 1];
|
|
}
|
|
topVals[k] = v;
|
|
topIdx[k] = i;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
float sum = 1.0f;
|
|
if (normTopk == 1) {
|
|
sum = 0.0f;
|
|
for (uint k = 0; k < K; k++) {
|
|
sum += topVals[k];
|
|
}
|
|
}
|
|
|
|
for (uint k = 0; k < K; k++) {
|
|
indexes[batchIndex * K + k] = float(topIdx[k]);
|
|
y[infos[k * N_BATCHES + batchIndex].outputOffset] = topVals[k] / sum;
|
|
}
|
|
}
|