#version 450 #extension GL_EXT_control_flow_attributes : enable layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; layout(constant_id = 0) const uint N_BATCHES = 32; layout(constant_id = 1) const uint N = 16; layout(constant_id = 2) const uint K = 2; struct BatchInfo { uint inputOffset; uint inputSizeX; uint outputOffset; uint outputSizeX; }; layout(binding = 0) readonly buffer inputBuffer { float x[]; }; layout(binding = 1) writeonly buffer outputBuffer { float y[]; }; layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[K * N_BATCHES]; }; layout(binding = 3) readonly uniform opConfigBuffer { uint k; uint normTopk; uint indexesBufferIndex; }; layout(binding = 4) buffer indexesBuffer { float indexes[]; }; shared float topVals[K]; shared uint topIdx[K]; void main() { // TODO: this impl is not optimal const uint batchIndex = gl_WorkGroupID.y; BatchInfo info = infos[batchIndex]; for (uint i = 0; i < K; i++) { topVals[i] = -1e10f; topIdx[i] = 0; } for (uint i = 0; i < info.inputSizeX; i++) { float v = x[info.inputOffset + i]; for (uint k = 0; k < K; k++) { if (v > topVals[k]) { for (uint s = K - 1; s > k; s--) { topVals[s] = topVals[s - 1]; topIdx[s] = topIdx[s - 1]; } topVals[k] = v; topIdx[k] = i; break; } } } float sum = 1.0f; if (normTopk == 1) { sum = 0.0f; for (uint k = 0; k < K; k++) { sum += topVals[k]; } } for (uint k = 0; k < K; k++) { indexes[batchIndex * K + k] = float(topIdx[k]); y[infos[k * N_BATCHES + batchIndex].outputOffset] = topVals[k] / sum; } }