Files
dllama/src/nn/nn-core.cpp
Chris 42172cbb6f
Some checks failed
main / Linux (amd64, ubuntu-22.04) (push) Successful in 49s
main / Linux (arm64, ubuntu-24.04-arm) (push) Has been cancelled
main / Windows (push) Has been cancelled
init
2025-10-24 11:42:14 +02:00

382 lines
13 KiB
C++

#ifdef _WIN32
#define _USE_MATH_DEFINES
#endif
#include "nn-core.hpp"
#include <cassert>
#include <cstring>
#include <cmath>
#include <stdexcept>
// utility functions
NnSize getBytes(NnFloatType floatType, NnSize n) {
if (floatType == F_32)
return n * sizeof(float);
if (floatType == F_16)
return n * (sizeof(float) / 2);
if (floatType == F_Q40) {
assert(n % Q40_BLOCK_SIZE == 0);
return (n / Q40_BLOCK_SIZE) * sizeof(NnBlockQ40);
}
if (floatType == F_Q80) {
assert(n % Q80_BLOCK_SIZE == 0);
return (n / Q80_BLOCK_SIZE) * sizeof(NnBlockQ80);
}
throw std::invalid_argument("Unsupported float type: " + std::to_string(floatType));
}
NnSize getBlockSize(NnFloatType floatType) {
if (floatType == F_32)
return 1;
if (floatType == F_16)
return 1;
if (floatType == F_Q40)
return Q40_BLOCK_SIZE;
if (floatType == F_Q80)
return Q80_BLOCK_SIZE;
throw std::invalid_argument("Unsupported float type");
}
NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output) {
// If weight=F_UNK, then returned enum should be <input>_<input>_<output>
if (input == F_32 && output == F_32) {
if (weight == F_UNK || weight == F_32)
return F32_F32_F32;
if (weight == F_Q40)
return F32_Q40_F32;
}
if (input == F_32 && output == F_Q80) {
if (weight == F_UNK || weight == F_32)
return F32_F32_Q80;
if (weight == F_Q40)
return F32_Q40_Q80;
}
if (input == F_Q80 && output == F_32) {
if (weight == F_UNK || weight == F_Q80)
return Q80_Q80_F32;
if (weight == F_32)
return Q80_F32_F32;
if (weight == F_Q40)
return Q80_Q40_F32;
}
if (input == F_Q80 && output == F_Q80) {
if (weight == F_UNK || weight == F_Q80)
return Q80_Q80_Q80;
}
throw std::invalid_argument("Unsupported op quant: " +
std::string(floatTypeToString(input)) + "/" +
std::string(floatTypeToString(weight)) + "/" +
std::string(floatTypeToString(output)));
}
const char *opCodeToString(NnOpCode code) {
if (code == OP_MERGE_ADD) return "MERGE_ADD";
if (code == OP_MERGE_SUM) return "MERGE_SUM";
if (code == OP_EMBEDDING) return "EMBEDDING";
if (code == OP_INV_RMS) return "INV_RMS";
if (code == OP_RMS_NORM) return "RMS_NORM";
if (code == OP_MATMUL) return "MATMUL";
if (code == OP_ROPE) return "ROPE";
if (code == OP_MULTIHEAD_ATT) return "MULTIHEAD_ATT";
if (code == OP_GELU) return "GELU";
if (code == OP_SILU) return "SILU";
if (code == OP_MUL) return "MUL";
if (code == OP_SCALE) return "SCALE";
if (code == OP_CAST) return "CAST";
if (code == OP_REPEAT_Z) return "REPEAT_Z";
if (code == OP_SHIFT) return "SHIFT";
if (code == OP_SOFTMAX) return "SOFTMAX";
if (code == OP_MOE_GATE) return "MOE_GATE";
throw std::invalid_argument("Unknown op code: " + std::to_string(code));
}
const char *opQuantTypeToString(NnOpQuantType type) {
if (type == F32_F32_F32) return "F32_F32_F32";
if (type == F32_Q40_F32) return "F32_Q40_F32";
if (type == F32_Q40_Q80) return "F32_Q40_Q80";
if (type == F32_F32_Q80) return "F32_F32_Q80";
if (type == Q80_Q80_Q80) return "Q80_Q80_Q80";
if (type == Q80_Q80_F32) return "Q80_Q80_F32";
if (type == Q80_Q40_F32) return "Q80_Q40_F32";
if (type == Q80_F32_F32) return "Q80_F32_F32";
throw std::invalid_argument("Unknown op quant type");
}
NnSize3D size0() {
return { F_UNK, 0, 0, 0, 0, 0 };
}
NnSize3D size1D(NnFloatType floatType, NnUint x) {
return size3D(floatType, 1, 1, x);
}
NnSize3D size2D(NnFloatType floatType, NnUint y, NnUint x) {
return size3D(floatType, 1, y, x);
}
NnSize3D size3D(NnFloatType floatType, NnUint z, NnUint y, NnUint x) {
NnSize len = z * y * x;
NnSize lenXY = y * x;
return { floatType, z, y, x, len, getBytes(floatType, len), getBytes(floatType, lenXY) };
}
NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index) {
return { source, index, PNTR_BATCH };
}
NnPointerConfig pointerBatchedSliceConfig(NnPointerSource source, NnUint index) {
return { source, index, PNTR_BATCHED_SLICE };
}
NnPointerConfig pointerRawConfig(NnPointerSource source, NnUint index) {
return { source, index, PNTR_RAW };
}
bool hasPointerContinuousMemory(NnPointerConfig *config) {
if (config->type == PNTR_RAW)
return true;
if (config->type == PNTR_BATCH)
return true;
return false;
}
void releaseNetConfig(NnNetConfig *netConfig) {
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) {
delete[] netConfig->pipes[pipeIndex].name;
}
delete[] netConfig->pipes;
}
void releaseNodeConfig(NnNodeConfig *nodeConfig) {
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex];
if (segment->nOps > 0) {
for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) {
NnOpConfig *op = &segment->ops[opIndex];
delete[] op->name;
delete[] op->config;
}
delete[] segment->ops;
}
if (segment->nSyncs > 0)
delete[] segment->syncs;
}
if (nodeConfig->nBuffers > 0) {
for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++)
delete[] nodeConfig->buffers[bufferIndex].name;
delete[] nodeConfig->buffers;
}
delete[] nodeConfig->segments;
}
void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
unsigned long total = 0;
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++)
total += netConfig->pipes[pipeIndex].size.nBytes;
for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++)
total += nodeConfig->buffers[bufferIndex].size.nBytes;
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex];
for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) {
total += segment->ops[opIndex].weightSize.nBytes;
total += segment->ops[opIndex].configSize;
}
}
printf("📀 RequiredMemory: %lu MB\n", total / (1024 * 1024));
}
Timer::Timer() {
reset();
}
void Timer::reset() {
startTime = std::chrono::high_resolution_clock::now();
}
NnUint Timer::elapsedMiliseconds() {
auto endTime = std::chrono::high_resolution_clock::now();
return (NnUint)std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
}
NnUint Timer::elapsedMicroseconds() {
auto endTime = std::chrono::high_resolution_clock::now();
return (NnUint)std::chrono::duration_cast<std::chrono::microseconds>(endTime - startTime).count();
}
// slicers
NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes) {
NnKvCacheSlice s;
assert(kvDim % nNodes == 0);
s.kvDim0 = kvDim / nNodes;
s.keySize = size2D(F_32, seqLen, s.kvDim0);
s.valueSize = size2D(F_32, seqLen, s.kvDim0);
return s;
}
NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) {
NnRowMatmulSlice s;
assert(d % nNodes == 0);
s.type = type;
s.nNodes = nNodes;
s.d0 = d / nNodes;
s.n = n;
s.size = size2D(type, s.n, d);
s.sliceSize = size2D(type, s.n, s.d0);
return s;
}
NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) {
NnColMatmulSlice s;
assert(n % nNodes == 0);
s.type = type;
s.nNodes = nNodes;
s.n = n;
s.n0 = n / nNodes;
s.d = d;
s.size = size2D(type, n, d);
s.sliceSize = size2D(type, s.n0, d);
return s;
}
NnRopeSlice sliceRope(NnRopeType type, NnUint qDim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headDim, float ropeTheta, NnUint nodeIndex) {
NnRopeSlice s;
assert(qDim >= kvDim);
assert(qDim % nNodes == 0);
assert(kvDim % nNodes == 0);
s.kvDim = kvDim;
s.nKvHeads = nKvHeads;
s.seqLen = seqLen;
s.headDim = headDim;
s.ropeTheta = ropeTheta;
s.qDim0 = qDim / nNodes;
s.kvDim0 = kvDim / nNodes;
assert(s.qDim0 % 2 == 0);
assert(s.kvDim0 % 2 == 0);
if (type == ROPE_LLAMA || type == ROPE_LLAMA3_1) {
s.kvDimStart = s.kvDim0 * nodeIndex;
s.qDimStart = s.qDim0 * nodeIndex;
s.qDimEnd = s.qDimStart + s.qDim0;
s.qShift = s.qDimStart - s.kvDimStart;
s.sliceDim = s.qDimEnd - s.kvDimStart;
assert(s.sliceDim % 2 == 0);
s.cacheSize = size2D(F_32, seqLen, s.sliceDim);
} else if (type == ROPE_FALCON) {
s.cacheSize = size2D(F_32, seqLen, headDim);
} else {
throw std::invalid_argument("Unsupported rope type");
}
return s;
}
NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes, NnUint nBatches) {
NnMultiHeadAttSlice s;
assert(nHeads % nNodes == 0);
s.nHeads = nHeads;
s.nHeads0 = nHeads / nNodes;
s.attSize = size2D(F_32, nBatches, s.nHeads0 * seqLen);
return s;
}
// splitters
NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) {
NnSize blockSize = getBlockSize(slice->type);
NnSize batchBytes = getBytes(slice->type, blockSize);
assert(slice->n % blockSize == 0);
NnSize n = slice->n / blockSize;
NnSize offset = slice->d0 * nodeIndex * n * batchBytes;
NnSize copiedBytes = 0;
for (NnUint d = 0; d < slice->d0; d++) {
for (NnUint j = 0; j < n; j++) {
NnSize o = (d * n + j) * batchBytes;
std::memcpy(weight0 + o, weight + offset + o, batchBytes);
copiedBytes += batchBytes;
}
}
return copiedBytes;
}
NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) {
NnSize blockSize = getBlockSize(slice->type);
NnSize batchBytes = getBytes(slice->type, blockSize);
assert(slice->n0 % blockSize == 0);
NnSize n = slice->n / blockSize;
NnSize rowBytes = n * batchBytes;
NnSize row0Bytes = (slice->n0 / blockSize) * batchBytes;
NnSize rowOffsetBytes = nodeIndex * row0Bytes;
NnSize copiedBytes = 0;
for (NnUint d = 0; d < slice->d; d++) {
std::memcpy(&weight0[row0Bytes * d], &weight[rowBytes * d + rowOffsetBytes], row0Bytes);
copiedBytes += row0Bytes;
}
return copiedBytes;
}
// helper
static inline float scaleFrequencyLlama3(const float freq, const NnRopeOpConfig *config) {
// https://github.com/meta-llama/llama-models/blob/4269717b2ea587627903bacbb75ccce1427ad914/models/llama3/reference_impl/model.py#L55
const float waveLen = 2.0f * M_PI / freq;
const float highFreqWavelen = config->ropeScalingOrigMaxSeqLen / config->ropeScalingHighFreqFactor;
if (waveLen < highFreqWavelen) {
return freq;
}
const float lowFreqWavelen = config->ropeScalingOrigMaxSeqLen / config->ropeScalingLowFreqFactor;
if (waveLen > lowFreqWavelen) {
return freq / config->ropeScalingFactor;
}
const float smooth = (config->ropeScalingOrigMaxSeqLen / waveLen - config->ropeScalingLowFreqFactor) /
(config->ropeScalingHighFreqFactor - config->ropeScalingLowFreqFactor);
return (1 - smooth) * freq / config->ropeScalingFactor + smooth * freq;
}
static inline void fullfillRopeLlamaCache(const NnRopeOpConfig *config, float *cache) {
assert((config->slice.qDimEnd - config->slice.kvDimStart) % 2 == 0);
const bool applyScaling = config->ropeScalingFactor != 1.0f;
for (NnUint pos = 0; pos < config->slice.seqLen; pos++) {
for (NnUint i = config->slice.kvDimStart; i < config->slice.qDimEnd; i += 2) {
const NnUint h = i % config->slice.headDim;
float freq = 1.0f / powf(config->slice.ropeTheta, h / (float)config->slice.headDim);
if (applyScaling)
freq = scaleFrequencyLlama3(freq, config);
const float val = pos * freq;
const float fcr = cosf(val);
const float fci = sinf(val);
cache[pos * config->slice.sliceDim + (i - config->slice.kvDimStart)] = fcr;
cache[pos * config->slice.sliceDim + (i - config->slice.kvDimStart) + 1] = fci;
}
}
}
static inline void fullfillRopeFalconCache(const NnRopeOpConfig *config, float *cache) {
const float hs = (float)config->slice.headDim;
for (NnUint pos = 0; pos < config->slice.seqLen; pos++) {
for (NnUint j = 0; j < config->slice.headDim / 2; j++) {
const float freq = 1.0f / powf(config->slice.ropeTheta, 2.0f * (float)(j / hs));
const float val = pos * freq;
const float fcr = cosf(val);
const float fci = sinf(val);
cache[pos * config->slice.headDim + j] = fcr;
cache[pos * config->slice.headDim + j + config->slice.headDim / 2] = fci;
}
}
}
void fullfillRopeCache(const NnRopeOpConfig *config, float *cache) {
if (config->type == ROPE_LLAMA || config->type == ROPE_LLAMA3_1)
fullfillRopeLlamaCache(config, cache);
else if (config->type == ROPE_FALCON)
fullfillRopeFalconCache(config, cache);
else
throw std::invalid_argument("Unsupported rope type");
}