#ifdef _WIN32 #define _USE_MATH_DEFINES #endif #include "nn-core.hpp" #include #include #include #include // utility functions NnSize getBytes(NnFloatType floatType, NnSize n) { if (floatType == F_32) return n * sizeof(float); if (floatType == F_16) return n * (sizeof(float) / 2); if (floatType == F_Q40) { assert(n % Q40_BLOCK_SIZE == 0); return (n / Q40_BLOCK_SIZE) * sizeof(NnBlockQ40); } if (floatType == F_Q80) { assert(n % Q80_BLOCK_SIZE == 0); return (n / Q80_BLOCK_SIZE) * sizeof(NnBlockQ80); } throw std::invalid_argument("Unsupported float type: " + std::to_string(floatType)); } NnSize getBlockSize(NnFloatType floatType) { if (floatType == F_32) return 1; if (floatType == F_16) return 1; if (floatType == F_Q40) return Q40_BLOCK_SIZE; if (floatType == F_Q80) return Q80_BLOCK_SIZE; throw std::invalid_argument("Unsupported float type"); } NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output) { // If weight=F_UNK, then returned enum should be __ if (input == F_32 && output == F_32) { if (weight == F_UNK || weight == F_32) return F32_F32_F32; if (weight == F_Q40) return F32_Q40_F32; } if (input == F_32 && output == F_Q80) { if (weight == F_UNK || weight == F_32) return F32_F32_Q80; if (weight == F_Q40) return F32_Q40_Q80; } if (input == F_Q80 && output == F_32) { if (weight == F_UNK || weight == F_Q80) return Q80_Q80_F32; if (weight == F_32) return Q80_F32_F32; if (weight == F_Q40) return Q80_Q40_F32; } if (input == F_Q80 && output == F_Q80) { if (weight == F_UNK || weight == F_Q80) return Q80_Q80_Q80; } throw std::invalid_argument("Unsupported op quant: " + std::string(floatTypeToString(input)) + "/" + std::string(floatTypeToString(weight)) + "/" + std::string(floatTypeToString(output))); } const char *opCodeToString(NnOpCode code) { if (code == OP_MERGE_ADD) return "MERGE_ADD"; if (code == OP_MERGE_SUM) return "MERGE_SUM"; if (code == OP_EMBEDDING) return "EMBEDDING"; if (code == OP_INV_RMS) return "INV_RMS"; if (code == OP_RMS_NORM) return "RMS_NORM"; if (code == OP_MATMUL) return "MATMUL"; if (code == OP_ROPE) return "ROPE"; if (code == OP_MULTIHEAD_ATT) return "MULTIHEAD_ATT"; if (code == OP_GELU) return "GELU"; if (code == OP_SILU) return "SILU"; if (code == OP_MUL) return "MUL"; if (code == OP_SCALE) return "SCALE"; if (code == OP_CAST) return "CAST"; if (code == OP_REPEAT_Z) return "REPEAT_Z"; if (code == OP_SHIFT) return "SHIFT"; if (code == OP_SOFTMAX) return "SOFTMAX"; if (code == OP_MOE_GATE) return "MOE_GATE"; throw std::invalid_argument("Unknown op code: " + std::to_string(code)); } const char *opQuantTypeToString(NnOpQuantType type) { if (type == F32_F32_F32) return "F32_F32_F32"; if (type == F32_Q40_F32) return "F32_Q40_F32"; if (type == F32_Q40_Q80) return "F32_Q40_Q80"; if (type == F32_F32_Q80) return "F32_F32_Q80"; if (type == Q80_Q80_Q80) return "Q80_Q80_Q80"; if (type == Q80_Q80_F32) return "Q80_Q80_F32"; if (type == Q80_Q40_F32) return "Q80_Q40_F32"; if (type == Q80_F32_F32) return "Q80_F32_F32"; throw std::invalid_argument("Unknown op quant type"); } NnSize3D size0() { return { F_UNK, 0, 0, 0, 0, 0 }; } NnSize3D size1D(NnFloatType floatType, NnUint x) { return size3D(floatType, 1, 1, x); } NnSize3D size2D(NnFloatType floatType, NnUint y, NnUint x) { return size3D(floatType, 1, y, x); } NnSize3D size3D(NnFloatType floatType, NnUint z, NnUint y, NnUint x) { NnSize len = z * y * x; NnSize lenXY = y * x; return { floatType, z, y, x, len, getBytes(floatType, len), getBytes(floatType, lenXY) }; } NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index) { return { source, index, PNTR_BATCH }; } NnPointerConfig pointerBatchedSliceConfig(NnPointerSource source, NnUint index) { return { source, index, PNTR_BATCHED_SLICE }; } NnPointerConfig pointerRawConfig(NnPointerSource source, NnUint index) { return { source, index, PNTR_RAW }; } bool hasPointerContinuousMemory(NnPointerConfig *config) { if (config->type == PNTR_RAW) return true; if (config->type == PNTR_BATCH) return true; return false; } void releaseNetConfig(NnNetConfig *netConfig) { for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) { delete[] netConfig->pipes[pipeIndex].name; } delete[] netConfig->pipes; } void releaseNodeConfig(NnNodeConfig *nodeConfig) { for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex]; if (segment->nOps > 0) { for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) { NnOpConfig *op = &segment->ops[opIndex]; delete[] op->name; delete[] op->config; } delete[] segment->ops; } if (segment->nSyncs > 0) delete[] segment->syncs; } if (nodeConfig->nBuffers > 0) { for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++) delete[] nodeConfig->buffers[bufferIndex].name; delete[] nodeConfig->buffers; } delete[] nodeConfig->segments; } void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) { unsigned long total = 0; for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) total += netConfig->pipes[pipeIndex].size.nBytes; for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++) total += nodeConfig->buffers[bufferIndex].size.nBytes; for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) { NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex]; for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) { total += segment->ops[opIndex].weightSize.nBytes; total += segment->ops[opIndex].configSize; } } printf("📀 RequiredMemory: %lu MB\n", total / (1024 * 1024)); } Timer::Timer() { reset(); } void Timer::reset() { startTime = std::chrono::high_resolution_clock::now(); } NnUint Timer::elapsedMiliseconds() { auto endTime = std::chrono::high_resolution_clock::now(); return (NnUint)std::chrono::duration_cast(endTime - startTime).count(); } NnUint Timer::elapsedMicroseconds() { auto endTime = std::chrono::high_resolution_clock::now(); return (NnUint)std::chrono::duration_cast(endTime - startTime).count(); } // slicers NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes) { NnKvCacheSlice s; assert(kvDim % nNodes == 0); s.kvDim0 = kvDim / nNodes; s.keySize = size2D(F_32, seqLen, s.kvDim0); s.valueSize = size2D(F_32, seqLen, s.kvDim0); return s; } NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) { NnRowMatmulSlice s; assert(d % nNodes == 0); s.type = type; s.nNodes = nNodes; s.d0 = d / nNodes; s.n = n; s.size = size2D(type, s.n, d); s.sliceSize = size2D(type, s.n, s.d0); return s; } NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) { NnColMatmulSlice s; assert(n % nNodes == 0); s.type = type; s.nNodes = nNodes; s.n = n; s.n0 = n / nNodes; s.d = d; s.size = size2D(type, n, d); s.sliceSize = size2D(type, s.n0, d); return s; } NnRopeSlice sliceRope(NnRopeType type, NnUint qDim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headDim, float ropeTheta, NnUint nodeIndex) { NnRopeSlice s; assert(qDim >= kvDim); assert(qDim % nNodes == 0); assert(kvDim % nNodes == 0); s.kvDim = kvDim; s.nKvHeads = nKvHeads; s.seqLen = seqLen; s.headDim = headDim; s.ropeTheta = ropeTheta; s.qDim0 = qDim / nNodes; s.kvDim0 = kvDim / nNodes; assert(s.qDim0 % 2 == 0); assert(s.kvDim0 % 2 == 0); if (type == ROPE_LLAMA || type == ROPE_LLAMA3_1) { s.kvDimStart = s.kvDim0 * nodeIndex; s.qDimStart = s.qDim0 * nodeIndex; s.qDimEnd = s.qDimStart + s.qDim0; s.qShift = s.qDimStart - s.kvDimStart; s.sliceDim = s.qDimEnd - s.kvDimStart; assert(s.sliceDim % 2 == 0); s.cacheSize = size2D(F_32, seqLen, s.sliceDim); } else if (type == ROPE_FALCON) { s.cacheSize = size2D(F_32, seqLen, headDim); } else { throw std::invalid_argument("Unsupported rope type"); } return s; } NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes, NnUint nBatches) { NnMultiHeadAttSlice s; assert(nHeads % nNodes == 0); s.nHeads = nHeads; s.nHeads0 = nHeads / nNodes; s.attSize = size2D(F_32, nBatches, s.nHeads0 * seqLen); return s; } // splitters NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) { NnSize blockSize = getBlockSize(slice->type); NnSize batchBytes = getBytes(slice->type, blockSize); assert(slice->n % blockSize == 0); NnSize n = slice->n / blockSize; NnSize offset = slice->d0 * nodeIndex * n * batchBytes; NnSize copiedBytes = 0; for (NnUint d = 0; d < slice->d0; d++) { for (NnUint j = 0; j < n; j++) { NnSize o = (d * n + j) * batchBytes; std::memcpy(weight0 + o, weight + offset + o, batchBytes); copiedBytes += batchBytes; } } return copiedBytes; } NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) { NnSize blockSize = getBlockSize(slice->type); NnSize batchBytes = getBytes(slice->type, blockSize); assert(slice->n0 % blockSize == 0); NnSize n = slice->n / blockSize; NnSize rowBytes = n * batchBytes; NnSize row0Bytes = (slice->n0 / blockSize) * batchBytes; NnSize rowOffsetBytes = nodeIndex * row0Bytes; NnSize copiedBytes = 0; for (NnUint d = 0; d < slice->d; d++) { std::memcpy(&weight0[row0Bytes * d], &weight[rowBytes * d + rowOffsetBytes], row0Bytes); copiedBytes += row0Bytes; } return copiedBytes; } // helper static inline float scaleFrequencyLlama3(const float freq, const NnRopeOpConfig *config) { // https://github.com/meta-llama/llama-models/blob/4269717b2ea587627903bacbb75ccce1427ad914/models/llama3/reference_impl/model.py#L55 const float waveLen = 2.0f * M_PI / freq; const float highFreqWavelen = config->ropeScalingOrigMaxSeqLen / config->ropeScalingHighFreqFactor; if (waveLen < highFreqWavelen) { return freq; } const float lowFreqWavelen = config->ropeScalingOrigMaxSeqLen / config->ropeScalingLowFreqFactor; if (waveLen > lowFreqWavelen) { return freq / config->ropeScalingFactor; } const float smooth = (config->ropeScalingOrigMaxSeqLen / waveLen - config->ropeScalingLowFreqFactor) / (config->ropeScalingHighFreqFactor - config->ropeScalingLowFreqFactor); return (1 - smooth) * freq / config->ropeScalingFactor + smooth * freq; } static inline void fullfillRopeLlamaCache(const NnRopeOpConfig *config, float *cache) { assert((config->slice.qDimEnd - config->slice.kvDimStart) % 2 == 0); const bool applyScaling = config->ropeScalingFactor != 1.0f; for (NnUint pos = 0; pos < config->slice.seqLen; pos++) { for (NnUint i = config->slice.kvDimStart; i < config->slice.qDimEnd; i += 2) { const NnUint h = i % config->slice.headDim; float freq = 1.0f / powf(config->slice.ropeTheta, h / (float)config->slice.headDim); if (applyScaling) freq = scaleFrequencyLlama3(freq, config); const float val = pos * freq; const float fcr = cosf(val); const float fci = sinf(val); cache[pos * config->slice.sliceDim + (i - config->slice.kvDimStart)] = fcr; cache[pos * config->slice.sliceDim + (i - config->slice.kvDimStart) + 1] = fci; } } } static inline void fullfillRopeFalconCache(const NnRopeOpConfig *config, float *cache) { const float hs = (float)config->slice.headDim; for (NnUint pos = 0; pos < config->slice.seqLen; pos++) { for (NnUint j = 0; j < config->slice.headDim / 2; j++) { const float freq = 1.0f / powf(config->slice.ropeTheta, 2.0f * (float)(j / hs)); const float val = pos * freq; const float fcr = cosf(val); const float fci = sinf(val); cache[pos * config->slice.headDim + j] = fcr; cache[pos * config->slice.headDim + j + config->slice.headDim / 2] = fci; } } } void fullfillRopeCache(const NnRopeOpConfig *config, float *cache) { if (config->type == ROPE_LLAMA || config->type == ROPE_LLAMA3_1) fullfillRopeLlamaCache(config, cache); else if (config->type == ROPE_FALCON) fullfillRopeFalconCache(config, cache); else throw std::invalid_argument("Unsupported rope type"); }