init
Some checks failed
main / Linux (amd64, ubuntu-22.04) (push) Successful in 49s
main / Linux (arm64, ubuntu-24.04-arm) (push) Has been cancelled
main / Windows (push) Has been cancelled

This commit is contained in:
2025-10-24 11:42:14 +02:00
commit 42172cbb6f
85 changed files with 40316 additions and 0 deletions

View File

@@ -0,0 +1,69 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N = 16;
layout(constant_id = 2) const uint K = 2;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[K * N_BATCHES]; };
layout(binding = 3) readonly uniform opConfigBuffer {
uint k;
uint normTopk;
uint indexesBufferIndex;
};
layout(binding = 4) buffer indexesBuffer { float indexes[]; };
shared float topVals[K];
shared uint topIdx[K];
void main() {
// TODO: this impl is not optimal
const uint batchIndex = gl_WorkGroupID.y;
BatchInfo info = infos[batchIndex];
for (uint i = 0; i < K; i++) {
topVals[i] = -1e10f;
topIdx[i] = 0;
}
for (uint i = 0; i < info.inputSizeX; i++) {
float v = x[info.inputOffset + i];
for (uint k = 0; k < K; k++) {
if (v > topVals[k]) {
for (uint s = K - 1; s > k; s--) {
topVals[s] = topVals[s - 1];
topIdx[s] = topIdx[s - 1];
}
topVals[k] = v;
topIdx[k] = i;
break;
}
}
}
float sum = 1.0f;
if (normTopk == 1) {
sum = 0.0f;
for (uint k = 0; k < K; k++) {
sum += topVals[k];
}
}
for (uint k = 0; k < K; k++) {
indexes[batchIndex * K + k] = float(topIdx[k]);
y[infos[k * N_BATCHES + batchIndex].outputOffset] = topVals[k] / sum;
}
}