init
This commit is contained in:
69
src/nn/vulkan/moe-gate-forward-f32-f32.comp
Normal file
69
src/nn/vulkan/moe-gate-forward-f32-f32.comp
Normal file
@@ -0,0 +1,69 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||
layout(constant_id = 1) const uint N = 16;
|
||||
layout(constant_id = 2) const uint K = 2;
|
||||
|
||||
struct BatchInfo {
|
||||
uint inputOffset;
|
||||
uint inputSizeX;
|
||||
uint outputOffset;
|
||||
uint outputSizeX;
|
||||
};
|
||||
|
||||
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[K * N_BATCHES]; };
|
||||
layout(binding = 3) readonly uniform opConfigBuffer {
|
||||
uint k;
|
||||
uint normTopk;
|
||||
uint indexesBufferIndex;
|
||||
};
|
||||
layout(binding = 4) buffer indexesBuffer { float indexes[]; };
|
||||
|
||||
shared float topVals[K];
|
||||
shared uint topIdx[K];
|
||||
|
||||
void main() {
|
||||
// TODO: this impl is not optimal
|
||||
const uint batchIndex = gl_WorkGroupID.y;
|
||||
BatchInfo info = infos[batchIndex];
|
||||
|
||||
for (uint i = 0; i < K; i++) {
|
||||
topVals[i] = -1e10f;
|
||||
topIdx[i] = 0;
|
||||
}
|
||||
|
||||
for (uint i = 0; i < info.inputSizeX; i++) {
|
||||
float v = x[info.inputOffset + i];
|
||||
|
||||
for (uint k = 0; k < K; k++) {
|
||||
if (v > topVals[k]) {
|
||||
for (uint s = K - 1; s > k; s--) {
|
||||
topVals[s] = topVals[s - 1];
|
||||
topIdx[s] = topIdx[s - 1];
|
||||
}
|
||||
topVals[k] = v;
|
||||
topIdx[k] = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float sum = 1.0f;
|
||||
if (normTopk == 1) {
|
||||
sum = 0.0f;
|
||||
for (uint k = 0; k < K; k++) {
|
||||
sum += topVals[k];
|
||||
}
|
||||
}
|
||||
|
||||
for (uint k = 0; k < K; k++) {
|
||||
indexes[batchIndex * K + k] = float(topIdx[k]);
|
||||
y[infos[k * N_BATCHES + batchIndex].outputOffset] = topVals[k] / sum;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user