init
Some checks failed
main / Linux (amd64, ubuntu-22.04) (push) Successful in 49s
main / Linux (arm64, ubuntu-24.04-arm) (push) Has been cancelled
main / Windows (push) Has been cancelled

This commit is contained in:
2025-10-24 11:42:14 +02:00
commit 42172cbb6f
85 changed files with 40316 additions and 0 deletions

30
.env.example Normal file
View File

@@ -0,0 +1,30 @@
# Distributed Llama Docker Environment Configuration
# Copy this file to .env and customize as needed
# Model configuration
MODEL_NAME=llama3_2_3b_instruct_q40
MAX_SEQ_LEN=4096
BUFFER_FLOAT_TYPE=q80
# Thread configuration
CONTROLLER_NTHREADS=4
WORKER_NTHREADS=4
# To use a different model, change MODEL_NAME to one of:
# - llama3_1_8b_instruct_q40
# - llama3_1_405b_instruct_q40
# - llama3_2_1b_instruct_q40
# - llama3_2_3b_instruct_q40
# - llama3_3_70b_instruct_q40
# - deepseek_r1_distill_llama_8b_q40
# - qwen3_0.6b_q40
# - qwen3_1.7b_q40
# - qwen3_8b_q40
# - qwen3_14b_q40
# - qwen3_30b_a3b_q40
# Performance tuning:
# - Adjust CONTROLLER_NTHREADS and WORKER_NTHREADS based on your Pi's CPU cores
# - For Pi 4 (4 cores): use 4 threads
# - For Pi 3 (4 cores): use 2-4 threads
# - For Pi Zero 2 (4 cores): use 2 threads

BIN
.github/8raspi.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 542 KiB

BIN
.github/8raspi2.jpg vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 298 KiB

BIN
.github/cover.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 61 KiB

63
.github/workflows/main.yml vendored Normal file
View File

@@ -0,0 +1,63 @@
name: main
on:
pull_request:
branches:
- main
- feat/nn
push:
branches:
- main
- feat/nn
jobs:
build-linux:
name: Linux
runs-on: ${{matrix.runner}}
strategy:
matrix:
include:
- runner: ubuntu-22.04
arch: amd64
- runner: ubuntu-24.04-arm
arch: arm64
steps:
- name: Checkout Repo
uses: actions/checkout@v3
- name: Dependencies
id: dependencies
run: sudo apt-get update && sudo apt-get install build-essential
- name: Build
id: build
run: |
make dllama
make nn-cpu-test
make nn-cpu-ops-test
make tokenizer-test
- name: nn-cpu-test
run: ./nn-cpu-test
- name: nn-cpu-ops-test
run: ./nn-cpu-ops-test
- name: tokenizer-test
run: ./tokenizer-test
build-windows:
name: Windows
runs-on: windows-latest
steps:
- name: Checkout Repo
uses: actions/checkout@v3
- name: Dependencies
id: dependencies
run: choco install make
- name: Build
id: build
run: |
make dllama
make nn-cpu-test
make nn-cpu-ops-test
make tokenizer-test
- name: nn-cpu-test
run: ./nn-cpu-test
- name: nn-cpu-ops-test
run: ./nn-cpu-ops-test
- name: tokenizer-test
run: ./tokenizer-test

19
.gitignore vendored Normal file
View File

@@ -0,0 +1,19 @@
.vscode/settings.json
*.o
*.0
*.dSYM
*.data
*.temp
*.tmp
__pycache__
*-test
/models
main
run*.sh
server
/dllama
/dllama-*
*.exe
*.spv

17
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,17 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "main",
"type": "cppdbg",
"request": "launch",
"program": "${workspaceFolder}/main",
"args": [],
"stopAtEntry": false,
"cwd": "${workspaceFolder}",
"environment": [],
"externalConsole": false,
"MIMode": "lldb"
}
]
}

202
DOCKER_README.md Normal file
View File

@@ -0,0 +1,202 @@
# Distributed Llama Docker Setup for Raspberry Pi
This directory contains Docker configurations to run Distributed Llama on Raspberry Pi devices using containers. There are two variants:
1. **Controller** (`Dockerfile.controller`) - Downloads models and runs the API server
2. **Worker** (`Dockerfile.worker`) - Runs worker nodes that connect to the controller
## Quick Start with Docker Compose
### 1. Download a Model
First, download a model using the controller container:
```bash
# Create a models directory
mkdir -p models
# Download a model (this will take some time)
docker-compose run --rm controller --download llama3_2_3b_instruct_q40
```
### 2. Start the Distributed Setup
```bash
# Start all services (1 controller + 3 workers)
docker-compose up
```
The API will be available at `http://localhost:9999`
### 3. Test the API
```bash
# List available models
curl http://localhost:9999/v1/models
# Send a chat completion request
curl -X POST http://localhost:9999/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"max_tokens": 100
}'
```
## Manual Docker Usage
### Building the Images
```bash
# Build controller image
docker build -f Dockerfile.controller -t distributed-llama-controller .
# Build worker image
docker build -f Dockerfile.worker -t distributed-llama-worker .
```
### Running the Controller
```bash
# Download a model first
docker run -v ./models:/app/models distributed-llama-controller --download llama3_2_3b_instruct_q40
# Run API server (standalone mode, no workers)
docker run -p 9999:9999 -v ./models:/app/models distributed-llama-controller \
--model llama3_2_3b_instruct_q40
# Run API server with workers
docker run -p 9999:9999 -v ./models:/app/models distributed-llama-controller \
--model llama3_2_3b_instruct_q40 \
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
```
### Running Workers
```bash
# Run a worker on default port 9999
docker run -p 9999:9999 distributed-llama-worker
# Run a worker with custom settings
docker run -p 9998:9998 distributed-llama-worker --port 9998 --nthreads 2
```
## Available Models
You can download any of these models:
- `llama3_1_8b_instruct_q40`
- `llama3_1_405b_instruct_q40` (very large, 56 parts)
- `llama3_2_1b_instruct_q40`
- `llama3_2_3b_instruct_q40`
- `llama3_3_70b_instruct_q40`
- `deepseek_r1_distill_llama_8b_q40`
- `qwen3_0.6b_q40`
- `qwen3_1.7b_q40`
- `qwen3_8b_q40`
- `qwen3_14b_q40`
- `qwen3_30b_a3b_q40`
## Configuration Options
### Controller Options
- `--model <name>`: Model name to use (required)
- `--port <port>`: API server port (default: 9999)
- `--nthreads <n>`: Number of threads (default: 4)
- `--max-seq-len <n>`: Maximum sequence length (default: 4096)
- `--buffer-float-type <type>`: Buffer float type (default: q80)
- `--workers <addresses>`: Space-separated worker addresses
- `--download <model>`: Download a model and exit
### Worker Options
- `--port <port>`: Worker port (default: 9999)
- `--nthreads <n>`: Number of threads (default: 4)
## Environment Variables (Docker Compose)
You can customize the setup using environment variables:
```bash
# Set model and thread counts
MODEL_NAME=llama3_2_1b_instruct_q40 \
CONTROLLER_NTHREADS=2 \
WORKER_NTHREADS=2 \
docker-compose up
```
Available variables:
- `MODEL_NAME`: Model to use (default: llama3_2_3b_instruct_q40)
- `CONTROLLER_NTHREADS`: Controller threads (default: 4)
- `WORKER_NTHREADS`: Worker threads (default: 4)
- `MAX_SEQ_LEN`: Maximum sequence length (default: 4096)
- `BUFFER_FLOAT_TYPE`: Buffer float type (default: q80)
## Multi-Device Setup
To run across multiple Raspberry Pi devices:
### Device 1 (Controller)
```bash
# Run controller
docker run -p 9999:9999 -v ./models:/app/models distributed-llama-controller \
--model llama3_2_3b_instruct_q40 \
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
```
### Devices 2-4 (Workers)
```bash
# Run worker on each device
docker run -p 9999:9999 distributed-llama-worker --nthreads 4
```
## Performance Tips
1. **Thread Count**: Set `--nthreads` to the number of CPU cores on each device
2. **Memory**: Larger models require more RAM. Monitor usage with `docker stats`
3. **Network**: Use wired Ethernet connections for better performance between devices
4. **Storage**: Use fast SD cards (Class 10 or better) or USB 3.0 storage for model files
## Troubleshooting
### Model Download Issues
```bash
# Check if model files exist
ls -la models/llama3_2_3b_instruct_q40/
# Re-download if corrupted
docker-compose run --rm controller --download llama3_2_3b_instruct_q40
```
### Worker Connection Issues
```bash
# Check worker logs
docker-compose logs worker1
# Test network connectivity
docker exec -it <controller_container> ping 172.20.0.11
```
### Resource Issues
```bash
# Monitor resource usage
docker stats
# Reduce thread count if CPU usage is too high
CONTROLLER_NTHREADS=2 WORKER_NTHREADS=2 docker-compose up
```
## Web Interface
You can use the web chat interface at [llama-ui.js.org](https://llama-ui.js.org/):
1. Open the website
2. Go to settings
3. Set base URL to: `http://your-pi-ip:9999`
4. Save and start chatting
## License
This Docker setup follows the same license as the main Distributed Llama project.

160
Dockerfile.controller Normal file
View File

@@ -0,0 +1,160 @@
# Dockerfile for Distributed Llama Controller (Raspberry Pi)
# This variant can download models and start the API server
FROM arm64v8/debian:bookworm-slim
# Install dependencies
RUN apt-get update && apt-get install -y \
build-essential \
g++ \
make \
git \
python3 \
python3-pip \
curl \
wget \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /app
# Copy source code
COPY src/ ./src/
COPY Makefile ./
COPY launch.py ./
# Build the applications
RUN make dllama && make dllama-api
# Create models directory for volume mount
RUN mkdir -p /app/models
# Create a script to download models
COPY <<EOF /app/download-model.sh
#!/bin/bash
if [ -z "\$1" ]; then
echo "Usage: download-model.sh <model_name>"
echo "Available models:"
python3 launch.py
exit 1
fi
python3 launch.py "\$1" -skip-run -skip-script -y
EOF
RUN chmod +x /app/download-model.sh
# Create entrypoint script
COPY <<EOF /app/entrypoint.sh
#!/bin/bash
# Default values
MODEL_NAME=""
API_PORT=9999
NTHREADS=4
MAX_SEQ_LEN=4096
WORKERS=""
BUFFER_FLOAT_TYPE="q80"
# Parse command line arguments
while [[ \$# -gt 0 ]]; do
case \$1 in
--model)
MODEL_NAME="\$2"
shift 2
;;
--port)
API_PORT="\$2"
shift 2
;;
--nthreads)
NTHREADS="\$2"
shift 2
;;
--max-seq-len)
MAX_SEQ_LEN="\$2"
shift 2
;;
--workers)
shift
WORKERS="\$@"
break
;;
--buffer-float-type)
BUFFER_FLOAT_TYPE="\$2"
shift 2
;;
--download)
MODEL_NAME="\$2"
echo "Downloading model: \$MODEL_NAME"
/app/download-model.sh "\$MODEL_NAME"
exit 0
;;
--help)
echo "Usage: docker run distributed-llama-controller [OPTIONS]"
echo ""
echo "Options:"
echo " --download <model> Download a model and exit"
echo " --model <model> Model name to use"
echo " --port <port> API server port (default: 9999)"
echo " --nthreads <n> Number of threads (default: 4)"
echo " --max-seq-len <n> Maximum sequence length (default: 4096)"
echo " --buffer-float-type <type> Buffer float type (default: q80)"
echo " --workers <workers> Space-separated list of worker addresses (e.g., 10.0.0.2:9999 10.0.0.3:9999)"
echo ""
echo "Examples:"
echo " # Download a model"
echo " docker run -v ./models:/app/models distributed-llama-controller --download llama3_2_3b_instruct_q40"
echo ""
echo " # Run API server with workers"
echo " docker run -p 9999:9999 -v ./models:/app/models distributed-llama-controller \\"
echo " --model llama3_2_3b_instruct_q40 --workers 10.0.0.2:9999 10.0.0.3:9999"
exit 0
;;
*)
echo "Unknown option: \$1"
exit 1
;;
esac
done
if [ -z "\$MODEL_NAME" ]; then
echo "Error: --model is required"
echo "Use --help for usage information"
exit 1
fi
MODEL_PATH="/app/models/\$MODEL_NAME/dllama_model_\$MODEL_NAME.m"
TOKENIZER_PATH="/app/models/\$MODEL_NAME/dllama_tokenizer_\$MODEL_NAME.t"
if [ ! -f "\$MODEL_PATH" ] || [ ! -f "\$TOKENIZER_PATH" ]; then
echo "Error: Model files not found for \$MODEL_NAME"
echo "Model path: \$MODEL_PATH"
echo "Tokenizer path: \$TOKENIZER_PATH"
echo ""
echo "Please download the model first:"
echo "docker run -v ./models:/app/models distributed-llama-controller --download \$MODEL_NAME"
exit 1
fi
# Build the command
CMD="./dllama-api --port \$API_PORT --model \$MODEL_PATH --tokenizer \$TOKENIZER_PATH --buffer-float-type \$BUFFER_FLOAT_TYPE --nthreads \$NTHREADS --max-seq-len \$MAX_SEQ_LEN"
if [ ! -z "\$WORKERS" ]; then
CMD="\$CMD --workers \$WORKERS"
fi
echo "Starting API server with command:"
echo "\$CMD"
echo ""
exec \$CMD
EOF
RUN chmod +x /app/entrypoint.sh
# Expose the default API port
EXPOSE 9999
# Use the entrypoint script
ENTRYPOINT ["/app/entrypoint.sh"]

75
Dockerfile.worker Normal file
View File

@@ -0,0 +1,75 @@
# Dockerfile for Distributed Llama Worker (Raspberry Pi)
# This variant runs as a worker node and connects to a controller
FROM arm64v8/debian:bookworm-slim
# Install dependencies
RUN apt-get update && apt-get install -y \
build-essential \
g++ \
make \
&& rm -rf /var/lib/apt/lists/*
# Set working directory
WORKDIR /app
# Copy source code
COPY src/ ./src/
COPY Makefile ./
# Build only the worker application
RUN make dllama
# Create entrypoint script
COPY <<EOF /app/entrypoint.sh
#!/bin/bash
# Default values
PORT=9999
NTHREADS=4
# Parse command line arguments
while [[ \$# -gt 0 ]]; do
case \$1 in
--port)
PORT="\$2"
shift 2
;;
--nthreads)
NTHREADS="\$2"
shift 2
;;
--help)
echo "Usage: docker run distributed-llama-worker [OPTIONS]"
echo ""
echo "Options:"
echo " --port <port> Worker port (default: 9999)"
echo " --nthreads <n> Number of threads (default: 4)"
echo ""
echo "Example:"
echo " docker run -p 9999:9999 distributed-llama-worker --port 9999 --nthreads 4"
exit 0
;;
*)
echo "Unknown option: \$1"
exit 1
;;
esac
done
# Build the command
CMD="./dllama worker --port \$PORT --nthreads \$NTHREADS"
echo "Starting worker with command:"
echo "\$CMD"
echo ""
exec \$CMD
EOF
RUN chmod +x /app/entrypoint.sh
# Expose the default worker port
EXPOSE 9999
# Use the entrypoint script
ENTRYPOINT ["/app/entrypoint.sh"]

9
LICENSE Normal file
View File

@@ -0,0 +1,9 @@
The MIT License (MIT)
Copyright (c) 2024 Bartłomiej Tadych (b4rtaz)
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

90
Makefile Normal file
View File

@@ -0,0 +1,90 @@
CXX = g++
CXXFLAGS = -std=c++11 -Werror -Wformat -Werror=format-security
ifndef TERMUX_VERSION
CXXFLAGS += -march=native -mtune=native
endif
ifdef DEBUG
CXXFLAGS += -g -fsanitize=address
else
CXXFLAGS += -O3
endif
ifdef WVLA
CXXFLAGS += -Wvla-extension
endif
ifdef DLLAMA_VULKAN
CGLSLC = glslc
ifeq ($(OS),Windows_NT)
LIBS += -L$(VK_SDK_PATH)\lib -lvulkan-1
CXXFLAGS += -DDLLAMA_VULKAN -I$(VK_SDK_PATH)\include
else
LIBS += -lvulkan
CXXFLAGS += -DDLLAMA_VULKAN
endif
DEPS += nn-vulkan.o
endif
ifeq ($(OS),Windows_NT)
LIBS += -lws2_32
DELETE_CMD = del /f
else
LIBS += -lpthread
DELETE_CMD = rm -fv
endif
.PHONY: clean dllama
clean:
$(DELETE_CMD) *.o dllama dllama-* socket-benchmark mmap-buffer-* *-test *.exe
# nn
nn-quants.o: src/nn/nn-quants.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
nn-core.o: src/nn/nn-core.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
nn-executor.o: src/nn/nn-executor.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
nn-network.o: src/nn/nn-network.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
llamafile-sgemm.o: src/nn/llamafile/sgemm.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
nn-cpu-ops.o: src/nn/nn-cpu-ops.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
nn-cpu.o: src/nn/nn-cpu.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
nn-cpu-test: src/nn/nn-cpu-test.cpp nn-quants.o nn-core.o nn-executor.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
nn-cpu-ops-test: src/nn/nn-cpu-ops-test.cpp nn-quants.o nn-core.o nn-executor.o llamafile-sgemm.o nn-cpu.o
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
nn-vulkan.o: src/nn/nn-vulkan.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
ifdef DLLAMA_VULKAN
VULKAN_SHADER_SRCS := $(wildcard src/nn/vulkan/*.comp)
VULKAN_SHADER_BINS := $(VULKAN_SHADER_SRCS:.comp=.spv)
DEPS += $(VULKAN_SHADER_BINS)
%.spv: %.comp
$(CGLSLC) -c $< -o $@ --target-env=vulkan1.2
nn-vulkan-test: src/nn/nn-vulkan-test.cpp nn-quants.o nn-core.o nn-executor.o nn-vulkan.o ${DEPS}
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)
endif
# llm
tokenizer.o: src/tokenizer.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
llm.o: src/llm.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
app.o: src/app.cpp
$(CXX) $(CXXFLAGS) -c $^ -o $@
tokenizer-test: src/tokenizer-test.cpp nn-quants.o nn-core.o llamafile-sgemm.o nn-cpu-ops.o tokenizer.o
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
dllama: src/dllama.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o ${DEPS}
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)
dllama-api: src/dllama-api.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o ${DEPS}
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)

142
README.md Normal file
View File

@@ -0,0 +1,142 @@
![Distributed Llama](.github/cover.png)
# Distributed Llama
[![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/b4rtaz/distributed-llama/.github%2Fworkflows%2Fmain.yml?style=flat-square)](https://github.com/b4rtaz/distributed-llama/actions) [![License: MIT](https://img.shields.io/github/license/mashape/apistatus.svg?style=flat-square)](/LICENSE) [![Discord](https://discordapp.com/api/guilds/1245814812353495070/widget.png?style=shield)](https://n4no.com/projects/distributedLlama/discord.php)
Connect home devices into a powerful cluster to accelerate LLM inference. More devices mean faster performance, leveraging tensor parallelism and high-speed synchronization over Ethernet.
Supports Linux, macOS, and Windows. Optimized for ARM and x86_64 AVX2 CPUs.
**How to Run**
- [💻 How to Run on Linux, MacOS or Windows](./docs/HOW_TO_RUN_LINUX_MACOS_WIN.md)
- [🍓 How to Run on Raspberry Pi](./docs/HOW_TO_RUN_RASPBERRYPI.md)
- [🧠 How to Run on GPU](./docs/HOW_TO_RUN_GPU.md)
**News**
- 16 Sep 2025 - Qwen 3 MoE models are now supported on Vulkan.
- 5 Sep 2025 - Qwen 3 MoE models are now supported on CPU.
- 3 Aug 2025 - Qwen 3 0.6B, 1.7B, 8B and 14B models are now supported.
- 23 Mar 2025 - [🌋 Experimental Vulkan support](https://github.com/b4rtaz/distributed-llama/releases/tag/v0.13.0)
- 12 Feb 2025 - 🚧 Merged the [fundamental codebase refactor](https://github.com/b4rtaz/distributed-llama/releases/tag/v0.12.0)
- 9 Jan 2025 - [🍎 Llama 3.3 70B on 4 x Mac Mini M4 Pro 24GB RAM](https://github.com/b4rtaz/distributed-llama/discussions/147)
### 🔥 Setup Root Node by Single Command
Python 3 and C++ compiler required. The command will download the model and the tokenizer.
| Model | Size | Command |
| --------------------------------- | -------- | ---------------------------------------------------- |
| Llama 3.1 8B Instruct Q40 | 6.32 GB | `python launch.py llama3_1_8b_instruct_q40` |
| Llama 3.1 405B Instruct Q40 | 238 GB | `python launch.py llama3_1_405b_instruct_q40`. |
| Llama 3.2 1B Instruct Q40 | 1.7 GB | `python launch.py llama3_2_1b_instruct_q40` |
| Llama 3.2 3B Instruct Q40 | 3.4 GB | `python launch.py llama3_2_3b_instruct_q40` |
| Llama 3.3 70B Instruct Q40 | 40 GB | `python launch.py llama3_3_70b_instruct_q40` |
| DeepSeek R1 Distill Llama 8B Q40 | 6.32 GB | `python launch.py deepseek_r1_distill_llama_8b_q40` |
| Qwen 3 0.6B Q40 | 0.9 GB | `python launch.py qwen3_0.6b_q40` |
| Qwen 3 1.7B Q40 | 2.2 GB | `python launch.py qwen3_1.7b_q40` |
| Qwen 3 8B Q40 | 6.7 GB | `python launch.py qwen3_8b_q40` |
| Qwen 3 14B Q40 | 10.9 GB | `python launch.py qwen3_14b_q40` |
| Qwen 3 30B A3B Q40 | 17.0 GB | `python launch.py qwen3_30b_a3b_q40` |
### 🛠️ Convert Model Manually
* [🤗 How to Convert Hugging Face Model](./docs/HOW_TO_CONVERT_HF_MODEL.md)
### 🚧 Known Limitations
* You can run Distributed Llama only on 1, 2, 4... 2^n nodes.
* The maximum number of nodes is equal to the number of KV heads in the model [#70](https://github.com/b4rtaz/distributed-llama/issues/70).
* Only the following quantizations are supported [#183](https://github.com/b4rtaz/distributed-llama/issues/183):
* `q40` model with `q80` `buffer-float-type`
* `f32` model with `f32` `buffer-float-type`
### 👷 Architecture
````
[🔀 SWITCH OR ROUTER]
| | | |
| | | |_______ 🔸 device1 (ROOT) 10.0.0.1
| | |_________ 🔹 device2 (WORKER 1) 10.0.0.2:9999
| |___________ 🔹 device3 (WORKER 2) 10.0.0.3:9999
|_____________ 🔹 device4 (WORKER 3) 10.0.0.4:9999
...
````
The project is split up into two parts:
* **🔸 Root node** - it's responsible for loading the model and weights and forward them to workers. Also, it synchronizes the state of the neural network. The root node is also a worker, it processes own slice of the neural network.
* **🔹 Worker node** - it processes own slice of the neural network. It doesn't require any configuration related to the model.
You always need the root node and you can add 2^n - 1 worker nodes to speed up the inference. The RAM usage of the neural network is split up across all nodes. The root node requires a bit more RAM than worker nodes.
### 🎹 Commands
* `dllama inference` - run the inference with a simple benchmark,
* `dllama chat` - run the CLI chat,
* `dllama worker` - run the worker node,
* `dllama-api` - run the API server.
<details>
<summary>🎹 Supported Arguments</summary>
<br />Inference, Chat, API
| Argument | Description | Example |
| ---------------------------- | ---------------------------------------------------------------- | -------------------------------------- |
| `--model <path>` | Path to model. | `dllama_model_meta-llama-3-8b_q40.m` |
| `--tokenizer <path>` | Tokenizer to model. | `dllama_tokenizer_llama3.t` |
| `--buffer-float-type <type>` | Float precision of synchronization. | `q80` |
| `--workers <workers>` | Addresses of workers (ip:port), separated by space. | `10.0.0.1:9999 10.0.0.2:9999` |
| `--max-seq-len <n>` | The maximum sequence length, it helps to reduce the RAM usage. | `4096` |
Inference, Chat, Worker, API
| Argument | Description | Example |
| ---------------------------- | --------------------------------------------------------------------- | ----------------------------------- |
| `--nthreads <n>` | Amount of threads. Don't set a higher value than number of CPU cores. | `4` |
Worker, API
| Argument | Description | Example |
| ---------------------------- | --------------------------------- | ----------------- |
| `--port <port>` | Binding port. | `9999` |
Inference
| Argument | Description | Example |
| ---------------------------- | ------------------------------ | ------------------ |
| `--prompt <prompt>` | Initial prompt. | `"Hello World"` |
| `--steps <steps>` | Number of tokens to generate. | `256` |
</details>
## 📊 Measurements
Please check the [discussions](https://github.com/b4rtaz/distributed-llama/discussions) section, where many measurements were published on different configurations.
## ✋ Contribution
Feel free to contribute to this project. For small changes, simply create a new merge request. For larger changes, please create an issue to discuss your plans. Please follow these guidelines when contributing:
* Make only minimal changes and avoid modifying files that are not necessary.
* Ensure the code is compatible across all supported systems and CPUs.
* This repository is maintained in English.
## 💡 License
This project is released under the MIT license.
## 📖 Citation
```
@misc{dllama,
author = {Bartłomiej Tadych},
title = {Distributed Llama},
year = {2024},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/b4rtaz/distributed-llama}},
commit = {7eb77ca93ec0d502e28d36b6fb20039b449cbea4}
}
```

4
converter/.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
*.t
*.m
*.bin
*/

265
converter/convert-hf.py Normal file
View File

@@ -0,0 +1,265 @@
import gc
import json
import sys
import os
from writer import parseFloatType, writeTensor, writeHeader, FloatType
from safetensors import safe_open
class ArchType:
LLAMA = 0xABCD00
QWEN3 = 0xABCD01
QWEN3_MOE = 0xABCD02
def permute(tensor, nHeads: int, nKvHeads: int):
if nHeads != nKvHeads:
nHeads = nKvHeads
return (tensor.reshape(nHeads, 2, tensor.shape[0] // nHeads // 2, *tensor.shape[1:]).swapaxes(1, 2).reshape(tensor.shape))
class Processor:
def __init__(self, config):
self.config = config
self.archType = config['arch_type']
self.currentModelIndex = None
self.currentModel = None
self.currentModelKeys = None
self.layerMap = {}
self.plan = []
def __unloadModel(self):
if self.currentModel:
del self.currentModel
self.currentModel = None
gc.collect()
self.currentModelIndex = None
def __loadModel(self, index: int):
if (self.currentModelIndex == index):
return
self.__unloadModel()
filePath = self.config['files'][index]
fileName = os.path.basename(filePath)
print(f'💿 Loading file {fileName}...')
self.currentModel = safe_open(filePath, framework='pt', device='cpu')
self.currentModelKeys = list(self.currentModel.keys())
for key in self.currentModelKeys:
self.layerMap[key] = index
print(f'Found {len(self.currentModelKeys)} layers')
self.currentModelIndex = index
def __transformQ(self, tensor):
if self.archType == ArchType.LLAMA:
return permute(tensor, self.config['n_heads'], self.config['n_heads'])
return tensor
def __transformK(self, tensor):
if self.archType == ArchType.LLAMA:
return permute(tensor, self.config['n_heads'], self.config['n_kv_heads'])
return tensor
def __preparePlan(self):
wt = self.config['weights_float_type']
p = self.plan
p.append([FloatType.F32,
'model.embed_tokens.weight'])
for l in range(0, self.config['n_layers']):
p.append([wt, self.__transformQ,
f'model.layers.{l}.self_attn.q_proj.weight'])
p.append([wt, self.__transformK,
f'model.layers.{l}.self_attn.k_proj.weight'])
p.append([wt,
f'model.layers.{l}.self_attn.v_proj.weight'])
p.append([wt,
f'model.layers.{l}.self_attn.o_proj.weight'])
if (self.config['n_experts'] > 0):
p.append([FloatType.F32, f'model.layers.{l}.mlp.gate.weight'])
for e in range(self.config['n_experts']):
p.append([wt,
f'model.layers.{l}.mlp.experts.{e}.gate_proj.weight'])
p.append([wt,
f'model.layers.{l}.mlp.experts.{e}.down_proj.weight'])
p.append([wt,
f'model.layers.{l}.mlp.experts.{e}.up_proj.weight'])
else:
p.append([wt,
f'model.layers.{l}.mlp.gate_proj.weight'])
p.append([wt,
f'model.layers.{l}.mlp.down_proj.weight'])
p.append([wt,
f'model.layers.{l}.mlp.up_proj.weight'])
if (self.archType == ArchType.QWEN3 or self.archType == ArchType.QWEN3_MOE):
p.append([FloatType.F32,
f'model.layers.{l}.self_attn.q_norm.weight'])
p.append([FloatType.F32,
f'model.layers.{l}.self_attn.k_norm.weight'])
p.append([FloatType.F32,
f'model.layers.{l}.input_layernorm.weight'])
p.append([FloatType.F32,
f'model.layers.{l}.post_attention_layernorm.weight'])
p.append([FloatType.F32,
'model.norm.weight'])
p.append([wt,
'lm_head.weight', 'model.embed_tokens.weight'])
def write(self, outputFile: str):
self.__preparePlan()
# Loading the last model file to get the layer names
self.__loadModel(len(self.config['files']) - 1)
self.__unloadModel()
for planItem in self.plan:
lookup = planItem[1:]
transform = None
if (callable(lookup[0])):
transform = lookup[0]
lookup = lookup[1:]
if (self.currentModelIndex == None):
modelIndex = 0
else:
modelIndex = None
for layerName in lookup:
if (layerName in self.layerMap):
modelIndex = self.layerMap[layerName]
break
if (modelIndex is None):
modelIndex = self.currentModelIndex + 1
self.__loadModel(modelIndex)
tensor = None
for layerName in lookup:
if (layerName in self.currentModelKeys):
tensor = self.currentModel.get_tensor(layerName)
break
if tensor is None:
raise Exception(f'Layer {lookup[0]} not found')
print(f'🔶 Writing tensor {layerName} {tensor.shape}...')
floatType = planItem[0]
if (transform):
tensor = transform(tensor)
writeTensor(outputFile, tensor, floatType)
def parseArchType(type: str):
archType = {
'llama': ArchType.LLAMA,
'mistral': ArchType.LLAMA,
'qwen3': ArchType.QWEN3,
'qwen3_moe': ArchType.QWEN3_MOE,
}.get(type)
if (archType is None):
raise Exception(f'Unsupported arch type: {type}')
return archType
def parseHiddenAct(act: str):
hiddenAct = {
'gelu': 0,
'silu': 1
}.get(act)
if (hiddenAct is None):
raise Exception(f'Unsupported hidden act: {act}')
return hiddenAct
def parseRopeType(rt: str):
ropeType = {
'llama3': 2, # LLAMA3_1
}.get(rt)
if (ropeType is None):
raise Exception(f'Unsupported rope type: {ropeType}')
return ropeType
def parseRmsNormEpsilon(epsilon: float):
if (epsilon == 1e-05):
return 5
elif (epsilon == 1e-06):
return 6
raise Exception(f'Unsupported epsilon: {epsilon}')
def loadConfig(folderPath: str, weightsFloatType: int):
allFiles = os.listdir(folderPath)
allFiles.sort()
with open(os.path.join(folderPath, 'config.json')) as fc:
config = json.load(fc)
files = []
for fileName in allFiles:
if fileName.endswith('.safetensors') and not fileName.startswith('.'):
files.append(os.path.join(folderPath, fileName))
if (len(files) == 0):
raise Exception('Not found any model file')
result = {
'version': 0,
'arch_type': parseArchType(config['model_type']),
'hidden_act': parseHiddenAct(config['hidden_act']),
'dim': config['hidden_size'],
'hidden_dim': config['intermediate_size'],
'n_layers': config['num_hidden_layers'],
'n_heads': config['num_attention_heads'],
'n_kv_heads': config['num_key_value_heads'],
'weights_float_type': weightsFloatType,
'max_seq_len': config['max_position_embeddings'],
'vocab_size': config['vocab_size'],
'files': files,
}
nExperts = config.get('num_experts')
nActiveExperts = config.get('num_experts_per_tok')
result['n_experts'] = int(nExperts) if nExperts is not None else 0
result['n_active_experts'] = int(nActiveExperts) if nActiveExperts is not None else 0
ropeTheta = config.get('rope_theta')
if (ropeTheta is not None):
result['rope_theta'] = int(ropeTheta)
ropeScaling = config.get('rope_scaling')
if (ropeScaling is not None):
result['rope_scaling_factor'] = int(ropeScaling['factor'])
result['rope_scaling_low_freq_factor'] = int(ropeScaling['low_freq_factor'])
result['rope_scaling_high_freq_factory'] = int(ropeScaling['high_freq_factor'])
result['rope_scaling_orig_max_seq_len'] = int(ropeScaling['original_max_position_embeddings'])
result['rope_type'] = parseRopeType(ropeScaling['rope_type'])
headDim = config.get('head_dim')
if (headDim is not None):
result['head_dim'] = headDim
rmsNormEps = config.get('rms_norm_eps')
if (rmsNormEps is not None):
result['norm_epsilon'] = parseRmsNormEpsilon(rmsNormEps)
moeHiddenDim = config.get('moe_intermediate_size')
if (moeHiddenDim is not None):
result['moe_hidden_dim'] = int(moeHiddenDim)
return result
def printUsage():
print('Usage: python convert-hf.py <sourceFolderPath> <weightsFloatType> <name>')
print()
print('Options:')
print(' <sourceFolderPath> The path to the folder containing the model files')
print(' <weightsFloatType> The float type of the weights (e.g. "q40")')
print(' <name> The name of the model (e.g. "llama3")')
if __name__ == '__main__':
if (len(sys.argv) < 4):
printUsage()
exit(1)
sourceFolderPath = sys.argv[1]
weightsFloatType = parseFloatType(sys.argv[2])
name = sys.argv[3]
outputFileName = f'dllama_model_{name}_{sys.argv[2]}.m'
print(f'Output file: {outputFileName}')
config = loadConfig(sourceFolderPath, weightsFloatType)
with open(outputFileName, 'wb') as outputFile:
writeHeader(outputFile, config)
processor = Processor(config)
processor.write(outputFile)
print(f'{outputFileName} created successfully')

121
converter/convert-llama.py Normal file
View File

@@ -0,0 +1,121 @@
import os
import sys
import json
import torch
import math
import numpy as np
from writer import writeTensor, writeHeader, parseFloatType, strFloatType, FloatType
from pathlib import Path
LAYER_CHUNK_SIZE = 48
def convert(modelPath, outputPath, targetFloatType):
paramsPath = os.path.join(modelPath, 'params.json')
with open(paramsPath) as f:
params = json.load(f)
if (params['vocab_size'] < 1):
raise Exception('vocab_size is invalid, please update params.json file')
if (params.get('max_seq_len') is None):
raise Exception('max_seq_len is required, please update params.json file')
params['n_kv_heads'] = params.get('n_kv_heads') or params['n_heads']
params['head_size'] = params['dim'] / params['n_heads']
params['arch_type'] = 0xABCD00
params['n_experts'] = 0
params['n_active_experts'] = 0
params['weights_float_type'] = targetFloatType
if ('rope_theta' in params):
params['rope_theta'] = int(params['rope_theta'])
modelPaths = sorted(list(Path(modelPath).glob('consolidated.*.pth')))
nSlices = len(modelPaths)
layers = []
layers.append('tok_embeddings.weight')
for layerIndex in range(0, params['n_layers']):
layers.append(f'layers.{layerIndex}.attention.wq.weight')
layers.append(f'layers.{layerIndex}.attention.wk.weight')
layers.append(f'layers.{layerIndex}.attention.wv.weight')
layers.append(f'layers.{layerIndex}.attention.wo.weight')
layers.append(f'layers.{layerIndex}.feed_forward.w1.weight')
layers.append(f'layers.{layerIndex}.feed_forward.w2.weight')
layers.append(f'layers.{layerIndex}.feed_forward.w3.weight')
layers.append(f'layers.{layerIndex}.attention_norm.weight')
layers.append(f'layers.{layerIndex}.ffn_norm.weight')
layers.append('norm.weight')
layers.append('output.weight')
isHeaderWrote = False
outFile = open(outputPath, 'wb')
nChunks = math.ceil(len(layers) / LAYER_CHUNK_SIZE)
for chunkIndex in range(0, nChunks):
chunkLayerNames = layers[LAYER_CHUNK_SIZE * chunkIndex:LAYER_CHUNK_SIZE * (chunkIndex + 1)]
models = {}
for layerName in chunkLayerNames:
models[layerName] = []
print(f'💿 Chunking model {chunkIndex + 1}/{nChunks}...')
for modelPath in modelPaths:
model = torch.load(modelPath, map_location='cpu')
for modelKey in model:
if (modelKey in chunkLayerNames):
models[modelKey].append(model[modelKey])
if not isHeaderWrote:
params['hidden_dim'] = model['layers.0.feed_forward.w1.weight'].shape[0] * nSlices
writeHeader(outFile, params)
isHeaderWrote = True
del model
for layerName in chunkLayerNames:
if layerName == 'rope.freqs':
continue
isAxis1 = (
layerName == 'tok_embeddings.weight' or
layerName.endswith('.attention.wo.weight') or
layerName.endswith('.feed_forward.w2.weight')
)
isAlwaysF32 = (
layerName == 'tok_embeddings.weight' or
layerName.endswith('.attention_norm.weight') or
layerName.endswith('.ffn_norm.weight') or
layerName == 'norm.weight'
)
floatType = FloatType.F32 if isAlwaysF32 else targetFloatType
tensors = models[layerName]
if len(tensors) == 1 or len(tensors[0].shape) == 1:
tensor = tensors[0]
else:
tensor = torch.cat(tensors, dim=(1 if isAxis1 else 0))
print(f'🔶 Exporting {layerName} {tensor.shape}...')
writeTensor(outFile, tensor, floatType)
del models
outFile.close()
def usage():
print('Usage: python convert-llama.py <modelPath> <targetFloatType>')
exit(1)
if __name__ == '__main__':
if (len(sys.argv) < 3):
usage()
modelPath = sys.argv[1]
targetFloatType = parseFloatType(sys.argv[2])
targetFloatTypeStr = strFloatType(targetFloatType)
modelName = os.path.basename(modelPath)
outputFileName = f'dllama_model_{modelName.lower()}_{targetFloatTypeStr}.m'
print(f'Model name: {modelName}')
print(f'Target float type: {targetFloatTypeStr}')
print(f'Target file: {outputFileName}')
convert(modelPath, outputFileName, targetFloatType)
print('Done!')

View File

@@ -0,0 +1,137 @@
import sys
import json
import os
from sentencepiece import SentencePieceProcessor
from transformers import PreTrainedTokenizerFast
writer = __import__('tokenizer-writer')
def openJson(path):
with open(path, 'r', encoding='utf-8') as file:
return json.load(file)
def unicodeToBytes():
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2 ** 8):
if b not in bs:
bs.append(b)
cs.append(2 ** 8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(cs, bs))
class TokensResolver:
def __init__(self, dirPath, tokenizerConfig):
self.dirPath = dirPath
self.tokenizerConfig = tokenizerConfig
self.bosId = None
self.eosIds = None
self.tokens = []
self.scores = []
def resolvePreTrainedTokenizerFast(self):
utb = unicodeToBytes()
tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(self.dirPath, 'tokenizer.json'))
vocabLen = len(tokenizer.get_vocab())
for i in range(vocabLen):
tokenChars = list(tokenizer.convert_ids_to_tokens([i])[0])
tokenBytes = []
for chr in tokenChars:
if (chr in utb):
tokenBytes.append(utb[chr])
else:
tokenBytes += list(chr.encode('utf-8'))
self.tokens.append(bytes(tokenBytes))
self.scores.append(-float(i))
self.bosId = tokenizer.bos_token_id
if (tokenizer.eos_token_id):
self.eosIds = [tokenizer.eos_token_id]
if (self.bosId is None or self.eosId is None):
config = openJson(os.path.join(self.dirPath, 'config.json'))
if (self.bosId is None):
self.bosId = config['bos_token_id']
if (self.eosIds is None):
self.eosIds = config['eos_token_id']
if isinstance(self.eosIds, list):
self.eosIds = self.eosIds
else:
self.eosIds = [self.eosIds]
def resolveLlamaTokenizer(self):
modelPath = os.path.join(self.dirPath, 'tokenizer.model')
processor = SentencePieceProcessor(model_file=modelPath)
assert processor.vocab_size() == processor.get_piece_size()
self.bosId = processor.bos_id()
self.eosIds = [processor.eos_id()]
vocabSize = processor.vocab_size()
for i in range(vocabSize):
t = processor.id_to_piece(i)
s = processor.get_score(i)
t = t.replace('', ' ') # sentencepiece uses this character as whitespace
# Check for byte characters
if len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
# For example, "<0x0A>"" is a newline character
b = bytearray.fromhex(t[3:-1])
else:
b = t.encode('utf-8')
self.tokens.append(b)
self.scores.append(s)
def resolve(self):
cls = self.tokenizerConfig['tokenizer_class']
if (cls == 'PreTrainedTokenizerFast' or
cls == 'LlamaTokenizerFast' or
cls == 'Qwen2Tokenizer'):
return self.resolvePreTrainedTokenizerFast()
if (cls == 'LlamaTokenizer'):
return self.resolveLlamaTokenizer()
raise Exception(f'Tokenizer {cls} is not supported')
def printUsage():
print('Usage: python convert-tokenizer-hf.py <tokenizerFolderPath> <name>')
print()
print('Options:')
print(' <tokenizerFolderPath> The path to the folder with tokenizer_config.json')
print(' <name> The name of the tokenizer (e.g. "llama3")')
if __name__ == '__main__':
if (len(sys.argv) < 2):
printUsage()
exit(1)
dirPath = sys.argv[1]
name = sys.argv[2]
tokenizerConfig = openJson(os.path.join(dirPath, 'tokenizer_config.json'))
resolver = TokensResolver(dirPath, tokenizerConfig)
resolver.resolve()
if (resolver.bosId is None or resolver.eosIds is None):
raise Exception('Cannot resolve bosId or eosIds')
print(f'bosId: {resolver.bosId} ({resolver.tokens[resolver.bosId]})')
for eosId in resolver.eosIds:
print(f'eosId: {eosId} ({resolver.tokens[eosId]})')
chatTemplate = None
if ('chat_template' in tokenizerConfig):
chatTemplate = tokenizerConfig['chat_template'].encode('utf-8')
addBos = True
if ('add_bos_token' in tokenizerConfig):
addBos = tokenizerConfig['add_bos_token']
outputFileName = f'dllama_tokenizer_{name}.t'
with open(outputFileName, 'wb') as outputFile:
writer.writeTokenizer(
outputFile,
resolver.tokens,
resolver.scores,
chatTemplate,
resolver.bosId,
addBos,
resolver.eosIds)
print(f'✅ Created {outputFileName}')

View File

@@ -0,0 +1,44 @@
import sys
import os
from sentencepiece import SentencePieceProcessor
writer = __import__('tokenizer-writer')
chatTemplate = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
def printUsage():
print('Usage: python convert-tokenizer-llama2.py <llama2FolderPath>')
print()
print('Options:')
print(' <llama2FolderPath> The path to the folder with llama2 folder path')
if __name__ == '__main__':
if (len(sys.argv) < 2):
printUsage()
exit(1)
dirPath = sys.argv[1]
modelPath = os.path.join(dirPath, 'tokenizer.model')
processor = SentencePieceProcessor(model_file=modelPath)
vocabSize = processor.vocab_size()
tokens = []
scores = []
for i in range(vocabSize):
t = processor.id_to_piece(i)
s = processor.get_score(i)
t = t.replace('', ' ') # sentencepiece uses this character as whitespace
b = t.encode('utf-8')
tokens.append(b)
scores.append(s)
outputFileName = 'dllama_tokenizer_llama2.t'
with open(outputFileName, 'wb') as outputFile:
writer.writeTokenizer(
outputFile,
tokens,
scores,
chatTemplate.encode('utf-8'),
processor.bos_id(),
[processor.eos_id()])
print(f'✅ Created {outputFileName}')

View File

@@ -0,0 +1,78 @@
import sys
import base64
writer = __import__('tokenizer-writer')
# Format of input file:
# ```
# IQ== 0
# Ig== 1
# Iw== 2
# ...
# ```
nSpecialTokens = 256
specialTokens = [
'<|begin_of_text|>',
'<|end_of_text|>',
'<|reserved_special_token_0|>',
'<|reserved_special_token_1|>',
'<|reserved_special_token_2|>',
'<|reserved_special_token_3|>',
'<|start_header_id|>',
'<|end_header_id|>',
'<|reserved_special_token_4|>',
'<|eot_id|>',
] + [
f'<|reserved_special_token_{i}|>'
for i in range(5, nSpecialTokens - 5)
]
bosId = 128000
eosId = 128001
chatEosId = 128009
chatTemplate = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
def printUsage():
print('Usage: python convert-tokenizer-llama3.py <tokenizerPath>')
print()
print('Options:')
print(' <tokenizerPath> The path to the Llama 3 tokenizer model (tokenizer.model)')
if __name__ == '__main__':
if (len(sys.argv) < 2):
printUsage()
exit(1)
modelPath = sys.argv[1]
outputFileName = 'dllama_tokenizer_llama3.t'
with open(modelPath, 'r') as inputFile:
with open(outputFileName, 'wb') as outputFile:
inputLines = inputFile.readlines()
nLines = len(inputLines)
tokens = []
scores = []
for line in inputLines:
s = line.split(' ')
bytes = base64.b64decode(s[0])
score = -float(s[1])
tokens.append(bytes)
scores.append(score)
specialTokenIndex = nLines
for token in specialTokens:
bytes = token.encode('utf-8')
score = -float(specialTokenIndex)
tokens.append(bytes)
scores.append(score)
specialTokenIndex += 1
writer.writeTokenizer(
outputFile,
tokens,
scores,
chatTemplate.encode('utf-8'),
bosId,
[eosId, chatEosId])
print(f'✅ Created {outputFileName}')

View File

@@ -0,0 +1,5 @@
python>=3.9
numpy==1.23.5
pytorch==2.0.1
safetensors==0.4.2
sentencepiece==0.1.99

View File

@@ -0,0 +1,57 @@
import struct
def writeTokenizer(file, tokens, scores, chatTemplate, bosId, addBos, eosTokens):
headerKeys = {
'version': 0,
'vocab_size': 1,
'max_token_length': 2,
'bos_id': 3,
'chat_template': 7,
'n_eos_tokens': 9,
'add_bos': 10,
}
header = struct.pack('i', 0x567124)
nTokens = len(tokens)
maxTokenLength = max(len(t) for t in tokens)
params = {}
params['bos_id'] = bosId
params['version'] = 1
params['vocab_size'] = nTokens
params['max_token_length'] = maxTokenLength
if (chatTemplate):
params['chat_template'] = len(chatTemplate)
params['n_eos_tokens'] = len(eosTokens)
params['add_bos'] = 1 if addBos else 0
data = b''
for key in params:
value = params[key]
if value is None:
continue
if key in headerKeys:
data += struct.pack('ii', headerKeys[key], params[key])
else:
print(f'Unknown header key: {key}')
print('⭐ Params:')
print(params)
if (chatTemplate):
print('⭐ Chat template:')
print(chatTemplate)
header += struct.pack('i', len(header) * 2 + len(data))
file.write(header)
file.write(data)
if chatTemplate:
file.write(chatTemplate)
for eosToken in eosTokens:
file.write(struct.pack('i', eosToken))
for i in range(0, nTokens):
size = len(tokens[i])
assert(size > 0)
file.write(struct.pack('fI', scores[i], size))
file.write(tokens[i])

35
converter/writer-test.py Normal file
View File

@@ -0,0 +1,35 @@
import sys
import time
import torch
from writer import writeQuantizedQ40Tensor
TEMP_FILE_NAME = 'writer-test.temp'
def readBase64FromFile(path):
with open(path, 'rb') as file:
return file.read().hex()
def testWriteQuantizedQ40Tensor():
EXPECTED_OUTPUT = '7e346345a692b89665b2c5790537876e598aaa366d988876a898b8d788a98868ce660c66f6b3a88cba5ce9a871987ba9cc5bcaaa760c1eb556a4455b747b6b9504968828ef2a8d7c1db5c6be3764799e66db6d8e76463126a30e4333cad7a4f645947c6cf97f9de086d468c8d535a6ba7dc799d3d0c657bab6799468cad8bb349eb7d7635c7c798998696bb38e4085a9eb34444ba96a7f8ba7b2b42d746a96cf9660aeb4499d8708ad5c7b9a7558947645f3bbb6b0346a656887ad9a86059baac5c596ab781c703569bb8a4356a4bd58cb78736ba09759bb0e34a6274e827b957d7a67dfa86846955660d234b6d9d78a378094a8a8708a7a774ae92f8a36b8c999a9b77a7d958a69747c807963941235379886d69a7a8767b3a6a4ac71999760'
torch.manual_seed(seed=1)
tensor = torch.randn(32, 16)
with open(TEMP_FILE_NAME, 'wb') as file:
writeQuantizedQ40Tensor(file, tensor)
contentBase64 = readBase64FromFile(TEMP_FILE_NAME)
assert contentBase64 == EXPECTED_OUTPUT, f'Received: {contentBase64}'
print('✅ writeQuantizedQ40Tensor')
def runWriteQuantizedQ40TensorBenchmark():
tensor = torch.randn(8192, 4096)
t0 = time.time()
with open(TEMP_FILE_NAME, 'wb') as file:
writeQuantizedQ40Tensor(file, tensor)
t1 = time.time()
print(f'🕐 writeQuantizedQ40Tensor: {t1 - t0:.4f}s')
if __name__ == '__main__':
testWriteQuantizedQ40Tensor()
runWriteQuantizedQ40TensorBenchmark()

148
converter/writer.py Normal file
View File

@@ -0,0 +1,148 @@
import struct
import torch
import time
import numpy as np
class FloatType:
F32 = 0
F16 = 1
Q40 = 2
Q80 = 3
floatTypeMap = {
'f32': FloatType.F32,
'f16': FloatType.F16,
'q40': FloatType.Q40,
'q80': FloatType.Q80,
}
floatTypeNames = list(floatTypeMap.keys())
def parseFloatType(type):
floatType = floatTypeMap.get(type)
if floatType is not None:
return floatType
raise Exception(f'{type} is not supported')
def strFloatType(type):
return floatTypeNames[type]
def writeQuantizedQ40Tensor(file, x):
x = x.to(torch.float32).numpy().astype(np.float32)
blockSize = 32
blockHalfSize = blockSize // 2
assert(x.shape[0] % blockSize == 0)
groups = x.reshape(-1, blockSize)
gmax = np.max(groups, axis=1)
gmin = np.min(groups, axis=1)
deltas = np.divide(np.where(-gmin > gmax, gmin, gmax), -8)
deltas16 = deltas.astype(np.float16)
ids = np.where(deltas != 0, 1.0 / deltas, 0)
groups = np.add(groups * ids[:, np.newaxis], 8.5)
groups = np.clip(groups, 0, 15).astype(int)
gLow = groups[:, :blockHalfSize] & 0xF
gHigh = (groups[:, blockHalfSize:] & 0xF) << 4
gCombined = gLow | gHigh
nBytes = 0
for groupIndex in range(0, len(groups)):
delta16 = deltas16[groupIndex]
buffer = struct.pack(f'e{blockHalfSize}B', delta16, *gCombined[groupIndex])
file.write(buffer)
nBytes += len(buffer)
return nBytes
def writeQuantizedQ80Tensor(file, x):
x = x.to(torch.float32).numpy().astype(np.float32)
blockSize = 32
assert(x.shape[0] % blockSize == 0)
groups = x.reshape(-1, blockSize)
gmax = np.max(groups, axis=1)
gmin = np.min(groups, axis=1)
gabsMax = np.where(-gmin > gmax, -gmin, gmax)
deltas = gabsMax / ((1 << 7) - 1)
deltas16 = deltas.astype(np.float16)
ids = np.where(deltas != 0, 1.0 / deltas, 0)
groups = groups * ids[:, np.newaxis]
groups8 = np.round(groups).astype(np.int8)
nBytes = 0
for groupIndex in range(0, len(groups)):
buffer = struct.pack(f'e{blockSize}b', deltas16[groupIndex], *groups8[groupIndex])
file.write(buffer)
nBytes += len(buffer)
return nBytes
def writeF32Tensor(file, d):
chunkSize = 10000
nBytes = 0
for i in range(0, len(d), chunkSize):
chunk = d[i:i+chunkSize].to(torch.float32).numpy().astype(np.float32)
b = struct.pack(f'{len(chunk)}f', *chunk)
nBytes += len(b)
file.write(b)
return nBytes
def writeF16Tensor(file, d):
d = d.to(torch.float16).numpy().astype(np.float16)
b = struct.pack(f'{len(d)}e', *d)
file.write(b)
return len(b)
def writeTensor(file, tensor, floatType):
d = tensor.detach().cpu().view(-1)
t0 = time.time()
nBytes = 0
if (floatType == FloatType.F16):
nBytes = writeF16Tensor(file, d)
elif (floatType == FloatType.F32):
nBytes = writeF32Tensor(file, d)
elif (floatType == FloatType.Q40):
nBytes = writeQuantizedQ40Tensor(file, d)
elif (floatType == FloatType.Q80):
nBytes = writeQuantizedQ80Tensor(file, d)
else:
raise Exception(f'Unknown float type')
t1 = time.time()
print(f'Saved {strFloatType(floatType)} tensor in {t1 - t0:.2f}s, {nBytes} bytes')
def writeHeader(file, params):
headerKeys = {
'version': 0,
'arch_type': 1,
'dim': 2,
'hidden_dim': 3,
'n_layers': 4,
'n_heads': 5,
'n_kv_heads': 6,
'n_experts': 7,
'n_active_experts': 8,
'vocab_size': 9,
'max_seq_len': 10,
'hidden_act': 11,
'rope_theta': 12,
'weights_float_type': 13,
'rope_scaling_factor': 14,
'rope_scaling_low_freq_factor': 15,
'rope_scaling_high_freq_factory': 16,
'rope_scaling_orig_max_seq_len': 17,
'rope_type': 18,
'head_dim': 19,
'norm_epsilon': 20,
'moe_hidden_dim': 21,
}
header = struct.pack('i', 0xA00ABCD)
data = b''
for key in params:
if key in headerKeys:
data += struct.pack('ii', headerKeys[key], params[key])
else:
print(f'Warning: Unknown header key: {key}')
header += struct.pack('i', len(header) * 2 + len(data))
file.write(header)
file.write(data)
for key in params:
print(f'🎓 {key}: {params[key]}')
print()

81
docker-compose.yml Normal file
View File

@@ -0,0 +1,81 @@
version: '3.8'
services:
# Controller service - downloads models and runs API
controller:
build:
context: .
dockerfile: Dockerfile.controller
ports:
- "9999:9999"
volumes:
- ./models:/app/models
networks:
distributed-llama:
ipv4_address: 172.20.0.10
environment:
- MODEL_NAME=${MODEL_NAME:-llama3_2_3b_instruct_q40}
- NTHREADS=${CONTROLLER_NTHREADS:-4}
- MAX_SEQ_LEN=${MAX_SEQ_LEN:-4096}
- BUFFER_FLOAT_TYPE=${BUFFER_FLOAT_TYPE:-q80}
command: >
--model ${MODEL_NAME:-llama3_2_3b_instruct_q40}
--port 9999
--nthreads ${CONTROLLER_NTHREADS:-4}
--max-seq-len ${MAX_SEQ_LEN:-4096}
--buffer-float-type ${BUFFER_FLOAT_TYPE:-q80}
--workers 172.20.0.11:9999 172.20.0.12:9999 172.20.0.13:9999
depends_on:
- worker1
- worker2
- worker3
# Worker services
worker1:
build:
context: .
dockerfile: Dockerfile.worker
networks:
distributed-llama:
ipv4_address: 172.20.0.11
environment:
- NTHREADS=${WORKER_NTHREADS:-4}
command: >
--port 9999
--nthreads ${WORKER_NTHREADS:-4}
worker2:
build:
context: .
dockerfile: Dockerfile.worker
networks:
distributed-llama:
ipv4_address: 172.20.0.12
environment:
- NTHREADS=${WORKER_NTHREADS:-4}
command: >
--port 9999
--nthreads ${WORKER_NTHREADS:-4}
worker3:
build:
context: .
dockerfile: Dockerfile.worker
networks:
distributed-llama:
ipv4_address: 172.20.0.13
environment:
- NTHREADS=${WORKER_NTHREADS:-4}
command: >
--port 9999
--nthreads ${WORKER_NTHREADS:-4}
networks:
distributed-llama:
driver: bridge
ipam:
config:
- subnet: 172.20.0.0/16
volumes:
models:

View File

@@ -0,0 +1,32 @@
# How to Convert 🤗 Hugging Face Model
Currently, Distributed Llama supports these Hugging Face models: `llama`, `mistral`, `qwen3` and `qwen3_moe`. You can try to convert any compatible Hugging Face model and run it with Distributed Llama.
> [!IMPORTANT]
> All converters are in the early stages of development. After conversion, the model may not work correctly.
1. Download a model, for example: [Mistral-7B-v0.3](https://huggingface.co/mistralai/Mistral-7B-v0.3/tree/main).
2. The downloaded model should contain `config.json`, `tokenizer.json`, `tokenizer_config.json` and `tokenizer.model` and safetensor files.
3. Run the converter of the model:
```sh
cd converter
python convert-hf.py path/to/hf/model q40 mistral-7b-0.3
```
4. Run the converter of the tokenizer:
```sh
python convert-tokenizer-hf.py path/to/hf/model mistral-7b-0.3
```
5. That's it! Now you can run the Distributed Llama.
```sh
./dllama inference \
--prompt "Hello world" \
--steps 64 \
--model dllama_model_mistral-7b-0.3_q40.m \
--tokenizer dllama_tokenizer_mistral-7b-0.3.t \
--buffer-float-type q80
```

34
docs/HOW_TO_RUN_GPU.md Normal file
View File

@@ -0,0 +1,34 @@
# How to Run Distributed Llama on 🧠 GPU
Distributed Llama can run on GPU devices using Vulkan API. This article describes how to build and run the project on GPU.
Before you start here, please check how to build and run Distributed Llama on CPU:
* [🍓 How to Run on Raspberry Pi](./HOW_TO_RUN_RASPBERRYPI.md)
* [💻 How to Run on Linux, MacOS or Windows](./HOW_TO_RUN_LINUX_MACOS_WIN.md)
To run on GPU, please follow these steps:
1. Install Vulkan SDK for your platform.
* Linux: please check [this article](https://vulkan.lunarg.com/doc/view/latest/linux/getting_started_ubuntu.html).
* MacOS: download SDK [here](https://vulkan.lunarg.com/sdk/home#mac).
2. Build Distributed Llama with GPU support:
```bash
DLLAMA_VULKAN=1 make dllama
DLLAMA_VULKAN=1 make dllama-api
```
3. Now `dllama` and `dllama-api` binaries supports arguments related to GPU usage.
```
--gpu-index <index> Use GPU device with given index (use `0` for first device)
```
4. You can run the root node or worker node on GPU by specifying the `--gpu-index` argument. Vulkan backend requires single thread, so you should also set `--nthreads 1`.
```bash
./dllama inference ... --nthreads 1 --gpu-index 0
./dllama chat ... --nthreads 1 --gpu-index 0
./dllama worker ... --nthreads 1 --gpu-index 0
./dllama-api ... --nthreads 1 --gpu-index 0
```

View File

@@ -0,0 +1,89 @@
# How to Run Distributed Llama on 💻 Linux, MacOS or Windows
This article describes how to run Distributed Llama on 4 devices, but you can also run it on 1, 2, 4, 8... devices. Please adjust the commands and topology according to your configuration.
````
[🔀 SWITCH OR ROUTER]
| | | |
| | | |_______ 🔸 device1 (ROOT) 10.0.0.1
| | |_________ 🔹 device2 (WORKER 1) 10.0.0.2:9999
| |___________ 🔹 device3 (WORKER 2) 10.0.0.3:9999
|_____________ 🔹 device4 (WORKER 3) 10.0.0.4:9999
````
1. Install Git and C++ compiler on **🔸🔹 ALL** devices:
* Linux:
```
sudo apt install git build-essential
```
* MacOS
```
brew install git
```
* Windows
Install Git and Mingw (via [Chocolatey](https://chocolatey.org/install)):
```powershell
choco install git mingw
```
2. Connect **🔸🔹 ALL** devices to your **🔀 SWITCH OR ROUTER** via Ethernet cable. If you're using only two devices, it's better to connect them directly without a switch.
3. Clone this repository and compile Distributed Llama on **🔸🔹 ALL** devices:
```sh
git clone https://github.com/b4rtaz/distributed-llama.git
cd distributed-llama
make dllama
make dllama-api
```
4. Download the model to the **🔸 ROOT** device using the `launch.py` script. You don't need to download the model on worker devices.
```sh
python3 launch.py # Prints a list of available models
python3 launch.py llama3_2_3b_instruct_q40 # Downloads the model to the root device
```
5. Start workers on all **🔹 WORKER** devices:
```sh
./dllama worker --port 9999 --nthreads 4
```
6. Run the inference to test if everything works fine on the **🔸 ROOT** device:
```sh
./dllama inference \
--prompt "Hello world" \
--steps 32 \
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \
--buffer-float-type q80 \
--nthreads 4 \
--max-seq-len 4096 \
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
```
7. To run the API server, start it on the **🔸 ROOT** device:
```sh
./dllama-api \
--port 9999 \
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \
--buffer-float-type q80 \
--nthreads 4 \
--max-seq-len 4096 \
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
```
Now you can connect to the API server:
```
http://10.0.0.1:9999/v1/models
```
8. When the API server is running, you can open the web chat in your browser, open [llama-ui.js.org](https://llama-ui.js.org/), go to the settings and set the base URL to: `http://10.0.0.1:9999`. Press the "save" button and start chatting!

View File

@@ -0,0 +1,96 @@
# How to Run Distributed Llama on 🍓 Raspberry Pi
This article describes how to run Distributed Llama on 4 Raspberry Pi devices, but you can also run it on 1, 2, 4, 8... devices. Please adjust the commands and topology according to your configuration.
````
[🔀 SWITCH OR ROUTER]
| | | |
| | | |_______ 🔸 raspberrypi1 (ROOT) 10.0.0.1
| | |_________ 🔹 raspberrypi2 (WORKER 1) 10.0.0.2:9999
| |___________ 🔹 raspberrypi3 (WORKER 2) 10.0.0.3:9999
|_____________ 🔹 raspberrypi4 (WORKER 3) 10.0.0.4:9999
````
1. Install `Raspberry Pi OS Lite (64 bit)` on your **🔸🔹 ALL** Raspberry Pi devices. This OS doesn't have desktop environment but you can easily connect via SSH to manage it.
2. Connect **🔸🔹 ALL** devices to your **🔀 SWITCH OR ROUTER** via Ethernet cable. If you're using only two devices, it's better to connect them directly without a switch.
3. Connect to all devices via SSH from your computer.
```
ssh user@raspberrypi1.local
ssh user@raspberrypi2.local
ssh user@raspberrypi3.local
ssh user@raspberrypi4.local
```
4. Install Git on **🔸🔹 ALL** devices:
```sh
sudo apt install git
```
5. Clone this repository and compile Distributed Llama on **🔸🔹 ALL** devices:
```sh
git clone https://github.com/b4rtaz/distributed-llama.git
cd distributed-llama
make dllama
make dllama-api
```
6. Download the model to the **🔸 ROOT** device using the `launch.py` script. You don't need to download the model on worker devices.
```sh
python3 launch.py # Prints a list of available models
python3 launch.py llama3_2_3b_instruct_q40 # Downloads the model to the root device
```
7. Assign static IP addresses on **🔸🔹 ALL** devices. Each device must have a unique IP address in the same subnet.
```sh
sudo ip addr add 10.0.0.1/24 dev eth0 # 🔸 ROOT
sudo ip addr add 10.0.0.2/24 dev eth0 # 🔹 WORKER 1
sudo ip addr add 10.0.0.3/24 dev eth0 # 🔹 WORKER 2
sudo ip addr add 10.0.0.4/24 dev eth0 # 🔹 WORKER 3
```
8. Start workers on all **🔹 WORKER** devices:
```sh
sudo nice -n -20 ./dllama worker --port 9999 --nthreads 4
```
9. Run the inference to test if everything works fine on the **🔸 ROOT** device:
```sh
sudo nice -n -20 ./dllama inference \
--prompt "Hello world" \
--steps 32 \
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \
--buffer-float-type q80 \
--nthreads 4 \
--max-seq-len 4096 \
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
```
10. To run the API server, start it on the **🔸 ROOT** device:
```sh
sudo nice -n -20 ./dllama-api \
--port 9999 \
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \
--buffer-float-type q80 \
--nthreads 4 \
--max-seq-len 4096 \
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
```
Now you can connect to the API server from your computer:
```
http://raspberrypi1.local:9999/v1/models
```
11. When the API server is running, you can open the web chat in your browser, open [llama-ui.js.org](https://llama-ui.js.org/), go to the settings and set the base URL to: `http://raspberrypi1.local:9999`. Press the "save" button and start chatting!

View File

@@ -0,0 +1,49 @@
// This is a simple client for dllama-api.
//
// Usage:
//
// 1. Start the server, how to do it is described in the `src/apps/dllama-api/README.md` file.
// 2. Run this script: `node examples/chat-api-client.js`
const HOST = process.env.HOST ? process.env.HOST : '127.0.0.1';
const PORT = process.env.PORT ? Number(process.env.PORT) : 9990;
async function chat(messages, maxTokens) {
const response = await fetch(`http://${HOST}:${PORT}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
messages,
temperature: 0.7,
stop: ['<|eot_id|>'],
max_tokens: maxTokens
}),
});
return await response.json();
}
async function ask(system, user, maxTokens) {
console.log(`> system: ${system}`);
console.log(`> user: ${user}`);
const response = await chat([
{
role: 'system',
content: system
},
{
role: 'user',
content: user
}
], maxTokens);
console.log(response.usage);
console.log(response.choices[0].message.content);
}
async function main() {
await ask('You are an excellent math teacher.', 'What is 1 + 2?', 128);
await ask('You are a romantic.', 'Where is Europe?', 128);
}
main();

200
examples/macbeth.sh Normal file
View File

@@ -0,0 +1,200 @@
#!/bin/bash
# This is a simple test of generating a sequence that fulfills the KV cache.
#
# Used model & tokenizer: https://huggingface.co/b4rtaz/llama-3-8b-distributed-llama
# Probably, this test will be working correctly only on MacBook Pro M1, due to differences in float multiplication on different CPUs.
cd "$(dirname "$0")"
cd ..
# Source: https://www.opensourceshakespeare.org/views/plays/play_view.php?WorkID=macbeth&Scope=entire
PROMPT="Duncan. What bloody man is that? He can report,
As seemeth by his plight, of the revolt
The newest state. 20
Malcolm. This is the sergeant
Who like a good and hardy soldier fought
'Gainst my captivity. Hail, brave friend!
Say to the king the knowledge of the broil
As thou didst leave it. 25
Sergeant. Doubtful it stood;
As two spent swimmers, that do cling together
And choke their art. The merciless Macdonwald—
Worthy to be a rebel, for to that
The multiplying villanies of nature 30
Do swarm upon him—from the western isles
Of kerns and gallowglasses is supplied;
And fortune, on his damned quarrel smiling,
Show'd like a rebel's whore: but all's too weak:
For brave Macbeth—well he deserves that name— 35
Disdaining fortune, with his brandish'd steel,
Which smoked with bloody execution,
Like valour's minion carved out his passage
Till he faced the slave;
Which ne'er shook hands, nor bade farewell to him, 40
Till he unseam'd him from the nave to the chaps,
And fix'd his head upon our battlements.
Duncan. O valiant cousin! worthy gentleman!
Sergeant. As whence the sun 'gins his reflection
Shipwrecking storms and direful thunders break, 45
So from that spring whence comfort seem'd to come
Discomfort swells. Mark, king of Scotland, mark:
No sooner justice had with valour arm'd
Compell'd these skipping kerns to trust their heels,
But the Norweyan lord surveying vantage, 50
With furbish'd arms and new supplies of men
Began a fresh assault.
Duncan. Dismay'd not this
Our captains, Macbeth and Banquo?
Sergeant. Yes; 55
As sparrows eagles, or the hare the lion.
If I say sooth, I must report they were
As cannons overcharged with double cracks, so they
Doubly redoubled strokes upon the foe:
Except they meant to bathe in reeking wounds, 60
Or memorise another Golgotha,
I cannot tell.
But I am faint, my gashes cry for help.
Duncan. So well thy words become thee as thy wounds;
They smack of honour both. Go get him surgeons. 65
[Exit Sergeant, attended]
Who comes here?"
GENERATED="Malcolm. The worthy Thane of Ross.
Duncan. What a haste looks through a duel's wounds! 70
Some must be pac'd.
[Exit Ross]
See this encounter is like to the poring
On of a beggar's story, told by one
That means to pluck upon the heart the strings
And draw the tears thriftily. 75
[Enter Lennox]
How goes the night, boy?
Lennox. The night is long that none should wake.
Duncan. You do not need to stare. The Moor
To know the man. 'Tis the Moors devices. 80
[Exit Lennox]
By the happy right of mine own hands,
Strike all that live in this poor thing of mine.
'Tis calld the Eyrie, and I am sick at heart.
As hellish-devils do the damned souls
O'their bad lives, thus ill-breveted, linger
O'er lamps and forks and other instruments
That prove the stages of the night. 90
Good sir, take note; I bid you farewell:
Come sleep, and cut short this nitty romance.
[He sleeps.]
If cravens, I bear them like the Minion of the moon,
With tiptoe foot he sneaks and starts to be a man. 95
And when he is found asleep, awake him with this armed' s address:
That sleep which th'assassin hallowed,
Scotland, awake; your king is murder'd, sleep no more. 100
*Furbish'd. Weapons polished for battle.
*Thriftily. Fastidiously, thoughtfully.
*Eyrie. Fortress; the lair of birds of prey.
*Minion. A braggart, a coward.
1.5
Macbeth. So foul and fair a day I have not seen. 5
Ross. Good morning, noble Macbeth. I come from Inverness,
And find our throne void, the arm'd rest you; 10
My Lord of Cassil has resigned his life.
Macbeth. Whate'er you owe, in time repay, fair friends.
Note you the words; I pray you do.
Ross. I am your faithful servant, and will keep
My sworn reward upon your life; my lord.
Macbeth. You shall be well rewarded; stay the press, 20
And I'll not fail. How now, good fellow?
Servant. Sir, his schoolmaster. 25
Macbeth. Well, good, though, old.
Tell me, good fellow, how goes the night? 30
Servant. There's marrygold and fire in your veins, my lord.
Macbeth. He does commend you; the weight of this old night's embargoes 35
Did one hour's waste of time lay upon him.
I know when we are too safe, 'tis dangerous to be secure;
Therefore our fearful parts do brave the danger 40
Which knows it not. I see you are a gentleman.
And a laudable one too; I am most off obliged.
Servant. I should be sorry, my good lord, to have had the labour 45
To outlive this damned hour. 50
Macbeth. What's done cannot be undone. To bed, to bed, to bed.
Servant. Will it please you to lie still? 55
Macbeth. Lord, lord, my heart is in my mouth. All's true that ends well.
Servant. I thank you, fair, and leave you to the content. 60
Macbeth. You see, my lord, it smokes, and shows no cause
Why the drone dies. 65
Servant. Grief fills the room up of one vast stair,
And downs our vaults to the inconstant man above. 70
Macbeth. Go bid thy masters and thy mistress say, 75
I have power in earth to do so much.
There's comfort yet. They are assailable. Then say I,
Thus ye may answer.
Servant. He cannot be wronged; or being wronged, 80
I cannot help him. 85
Macbeth. You know but by this; as this, 90
The Jew foole is hang'd. 95
Servant. No more today, my lord. 100
Macbeth. He does shame to tell him he loves him, but not remove him 105
From his true place; no.
Servant. That's true, and now I remember the story 110
Of that sign in Leo four diurnal courses
Returning in a constant motion were within 115
A boare that had on Taurus' back tetracted; 120
Or neuer, or but once in modulated accidence. 125
Macbeth. Thou climd'st alone, ty'd to the stag's horn.
Servant. I was a bull, for this the goodly year. 130
Come, put me in my place.
Macbeth. Now go to sleep. 135
Servant. The west neuer sett before the equinox 140
Till now; and sunnes look'd not theyr frequencie 145
Upon our lappe till now, my lord. 150
Macbeth. This game of chance you term a gong.
Servant. A gong is a scotch word for an egg. 155
Macbeth. Peace, be still. 160
Servant. I coniecture I smell the blood of an Englishman. 165
Macbeth. The faith is murthered.
Servant. That murder'd in his sleep. 170
Macbeth. And sleeping murdered. 175
Servant. In the fair queen heere in his royal court. 180
Macbeth. So great a mercy that it may last eternally.
Servant. The earth hath bubbles as the water hath, 185
And these are of them. Whate'er we will do 190
To mend the trespasses of the comming time 195
Shall be the seedes of new mischefe, and shall beget 200
The formes of the extinctnese, which we are now. 205
Macbeth. We have scorch'd the snake, not kill'd it. 210
Servant. They hunt it in the morn. Good gally, good lord! 215
It weares a gilded snout. 220
Macbeth. It is the very painting of your fear. 225
Servant. This is the worst. 230
Macbeth. A fair quater of a mile is yet to go. 235
Servant. A mile and half. 240
Macbeth. I have run fifteen miles to-day.
Servant. A calender's date.
Macbeth. A bigger patch, a bigger patch. 245
Servant. Thirteen of more. 250
Macbeth. Wast thou with him? 255
Servant. No, nor he to night. 260
Macbeth. Thou seest the moon"
echo "Generating, it can take a while..."
OUTPUT=$(( ./dllama generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 2 --steps 2048 --model models/llama3_8b_q40/dllama_model_llama3_8b_q40.m --tokenizer models/llama3_8b_q40/dllama_tokenizer_llama3_8b_q40.t --workers 127.0.0.1:9999 127.0.0.1:9998 127.0.0.1:9997 ) 2>&1)
echo "$OUTPUT"
if [[ $OUTPUT == *"$GENERATED"* ]]; then
echo "✅ Output is same"
else
echo "❌ Output is different"
fi

52
examples/n-workers.sh Normal file
View File

@@ -0,0 +1,52 @@
#!/bin/bash
# This script starts N workers from a single command. Mainly useful for testing and debugging.
# Usage:
#
# W=7 T=2 bash n-workers.sh start
# W=7 bash n-workers.sh stop
#
# Env vars:
# W - n workers
# T - n threads per worker
cd "$(dirname "$0")"
if [ -z "$W" ]; then
W=3
fi
if [ -z "$T" ]; then
T=1
fi
if [ "$1" == "start" ]; then
for (( w = 0; w < $W ; w += 1 ));
do
PORT=$(expr 9999 - $w)
PROC_ID=$(lsof -ti:$PORT)
if [ -n "$PROC_ID" ]; then
kill -9 $PROC_ID
echo "Killed process $PROC_ID"
fi
mkdir -p dllama_worker_$w # macOs does not support -Logfile argument, so we place logs inside different directories
cd dllama_worker_$w
screen -d -L -S dllama_worker_$w -m ../../dllama worker --port $PORT --nthreads $T
cd ..
echo "Started worker $w on port $PORT"
done
sleep 2
elif [ "$1" == "stop" ]; then
for (( w = 0; w < $W ; w += 1 ));
do
screen -S dllama_worker_$w -X quit
done
echo "Stopped $W workers"
else
echo "Usage: $0 [start|stop]"
fi
echo "> screen -ls"
screen -ls

195
launch.py Normal file
View File

@@ -0,0 +1,195 @@
import os
import sys
import time
import socket
import multiprocessing
from urllib.request import urlopen
def parts(length):
result = []
for i in range(length):
a = chr(97 + (i // 26))
b = chr(97 + (i % 26))
result.append(a + b)
return result
# [['model-url-0', 'model-url-1', ...], 'tokenizer-url', 'weights-float-type', 'buffer-float-type', 'model-type']
MODELS = {
'llama3_1_8b_instruct_q40': [
['https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.1_instruct_q40.m?download=true'],
'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'llama3_1_405b_instruct_q40': [
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Llama-3_1-405B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama31_405b_q40_{suffix}?download=true', parts(56))),
'https://huggingface.co/b4rtaz/Llama-3_1-405B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'llama3_2_1b_instruct_q40': [
['https://huggingface.co/b4rtaz/Llama-3_2-1B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.2-1b-instruct_q40.m?download=true'],
'https://huggingface.co/b4rtaz/Llama-3_2-1B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3_2.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'llama3_2_3b_instruct_q40': [
['https://huggingface.co/b4rtaz/Llama-3_2-3B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.2-3b-instruct_q40.m?download=true'],
'https://huggingface.co/b4rtaz/Llama-3_2-3B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3_2.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'llama3_3_70b_instruct_q40': [
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Llama-3_3-70B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama-3.3-70b_q40{suffix}?download=true', parts(11))),
'https://huggingface.co/b4rtaz/Llama-3_3-70B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama-3.3-70b.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'deepseek_r1_distill_llama_8b_q40': [
['https://huggingface.co/b4rtaz/DeepSeek-R1-Distill-Llama-8B-Distributed-Llama/resolve/main/dllama_model_deepseek-r1-distill-llama-8b_q40.m?download=true'],
'https://huggingface.co/b4rtaz/DeepSeek-R1-Distill-Llama-8B-Distributed-Llama/resolve/main/dllama_tokenizer_deepseek-r1-distill-llama-8b.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'qwen3_0.6b_q40': [
['https://huggingface.co/b4rtaz/Qwen3-0.6B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_0.6b_q40.m?download=true'],
'https://huggingface.co/b4rtaz/Qwen3-0.6B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_0.6b.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'qwen3_1.7b_q40': [
['https://huggingface.co/b4rtaz/Qwen3-1.7B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_1.7b_q40.m?download=true'],
'https://huggingface.co/b4rtaz/Qwen3-1.7B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_1.7b.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'qwen3_8b_q40': [
['https://huggingface.co/b4rtaz/Qwen3-8B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_8b_q40.m?download=true'],
'https://huggingface.co/b4rtaz/Qwen3-8B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_8b.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'qwen3_14b_q40': [
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Qwen3-14B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_14b_q40_{suffix}?download=true', parts(2))),
'https://huggingface.co/b4rtaz/Qwen3-14B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_14b.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
'qwen3_30b_a3b_q40': [
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Qwen3-30B-A3B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_30b_a3b_{suffix}?download=true', parts(5))),
'https://huggingface.co/b4rtaz/Qwen3-30B-A3B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_30b_a3b.t?download=true',
'q40', 'q80', 'chat', '--max-seq-len 4096'
],
}
def confirm(message: str):
alwaysYes = sys.argv.count('-y') > 0
if alwaysYes:
return True
result = input(f'{message} ("Y" if yes): ').upper()
return result == 'Y' or result == 'YES'
def downloadFile(urls, path: str):
if os.path.isfile(path):
fileName = os.path.basename(path)
if not confirm(f'{fileName} already exists, do you want to download again?'):
return
socket.setdefaulttimeout(30)
lastSizeMb = 0
with open(path, 'wb') as file:
for url in urls:
startPosition = file.tell()
success = False
for attempt in range(8):
print(f'📄 {url} (attempt: {attempt})')
try:
with urlopen(url) as response:
while True:
chunk = response.read(4096)
if not chunk:
break
file.write(chunk)
sizeMb = file.tell() // (1024 * 1024)
if sizeMb != lastSizeMb:
sys.stdout.write("\rDownloaded %i MB" % sizeMb)
lastSizeMb = sizeMb
sys.stdout.write('\n')
success = True
break
except Exception as e:
print(f'\n❌ Error downloading {url}: {e}')
file.seek(startPosition)
file.truncate()
time.sleep(1 * attempt)
if not success:
raise Exception(f'Failed to download {url}')
sys.stdout.write('\n')
def download(modelName: str, model: list):
dirPath = os.path.join('models', modelName)
print(f'📀 Downloading {modelName} to {dirPath}...')
os.makedirs(dirPath, exist_ok=True)
modelUrls = model[0]
tokenizerUrl = model[1]
modelPath = os.path.join(dirPath, f'dllama_model_{modelName}.m')
tokenizerPath = os.path.join(dirPath, f'dllama_tokenizer_{modelName}.t')
downloadFile(modelUrls, modelPath)
downloadFile([tokenizerUrl], tokenizerPath)
print('📀 All files are downloaded')
return (modelPath, tokenizerPath)
def writeRunFile(modelName: str, command: str):
filePath = f'run_{modelName}.sh'
with open(filePath, 'w') as file:
file.write('#!/bin/sh\n')
file.write('\n')
file.write(f'{command}\n')
return filePath
def printUsage():
print('Usage: python download-model.py <model>')
print()
print('Options:')
print(' <model> The name of the model to download')
print(' -skip-run Do not run the model after download')
print(' -skip-script Do not create a script to run the model')
print(' -y Skip confirmation prompts')
print()
print('Available models:')
for model in MODELS:
print(f' {model}')
if __name__ == '__main__':
if (len(sys.argv) < 2):
printUsage()
exit(1)
os.chdir(os.path.dirname(__file__))
modelName = sys.argv[1].replace('-', '_')
if modelName not in MODELS:
print(f'Model is not supported: {modelName}')
exit(1)
model = MODELS[modelName]
(modelPath, tokenizerPath) = download(modelName, model)
nThreads = multiprocessing.cpu_count()
if (model[4] == 'chat'):
command = './dllama chat'
else:
command = './dllama inference --steps 64 --prompt "Hello world"'
command += f' --model {modelPath} --tokenizer {tokenizerPath} --buffer-float-type {model[3]} --nthreads {nThreads}'
if (len(model) > 5):
command += f' {model[5]}'
print('To run Distributed Llama you need to execute:')
print('--- copy start ---')
print()
print('\033[96m' + command + '\033[0m')
print()
print('--- copy end -----')
skipRun = sys.argv.count('-skip-run') > 0
skipScript = sys.argv.count('-skip-script') > 0
if (not skipScript):
runFilePath = writeRunFile(modelName, command)
print(f'🌻 Created {runFilePath} script to easy run')
if (not skipRun):
if (confirm('Do you want to run Distributed Llama?')):
if (not os.path.isfile('dllama')):
os.system('make dllama')
os.system(command)

BIN
report/report.pdf Normal file

Binary file not shown.

179
src/api-types.hpp Executable file
View File

@@ -0,0 +1,179 @@
#ifndef API_TYPES_HPP
#define API_TYPES_HPP
#include <string>
#include "json.hpp"
using json = nlohmann::json;
struct ChatMessageDelta {
std::string role;
std::string content;
ChatMessageDelta() : role(""), content("") {}
ChatMessageDelta(const std::string& role_, const std::string& content_) : role(role_), content(content_) {}
};
struct ChatMessage {
std::string role;
std::string content;
ChatMessage() : role(""), content("") {}
ChatMessage(const std::string& role_, const std::string& content_) : role(role_), content(content_) {}
};
struct ChunkChoice {
int index;
ChatMessageDelta delta;
std::string finish_reason;
ChunkChoice() : index(0) {}
};
struct Choice {
int index;
ChatMessage message;
std::string finish_reason;
Choice() : finish_reason("") {}
Choice(ChatMessage &message_) : message(message_), finish_reason("") {}
Choice(const std::string &reason_) : finish_reason(reason_) {}
};
struct ChatCompletionChunk {
std::string id;
std::string object;
long long created;
std::string model;
std::vector<ChunkChoice> choices;
ChatCompletionChunk(ChunkChoice &choice_)
: id("cmpl-c0"), object("chat.completion"), model("Distributed Model") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
};
// Struct to represent the usage object
struct ChatUsage {
int prompt_tokens;
int completion_tokens;
int total_tokens;
ChatUsage() : prompt_tokens(0), completion_tokens(0), total_tokens(0) {}
ChatUsage(int pt, int ct, int tt) : prompt_tokens(pt), completion_tokens(ct), total_tokens(tt) {}
};
struct ChatCompletion {
std::string id;
std::string object;
long long created; // Unix timestamp
std::string model;
std::vector<Choice> choices;
ChatUsage usage;
ChatCompletion() : id(), object(), model() {}
ChatCompletion(const Choice &choice_, const ChatUsage& usage_)
: id("cmpl-j0"), object("chat.completion"), model("Distributed Model"), usage(usage_) {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
};
struct Model {
std::string id;
std::string object;
long long created;
std::string owned_by;
Model() : id(), object(), created(0), owned_by() {}
Model(const std::string &id_) : id(id_), object("model"), created(0), owned_by("user") {}
};
struct ModelList {
std::string object;
std::vector<Model> data;
ModelList(): object("list") {}
ModelList(const Model &model_) : object("list") {
data.push_back(model_);
}
};
struct InferenceParams {
std::vector<ChatMessage> messages;
int max_tokens;
float temperature;
float top_p;
std::vector<std::string> stop;
bool stream;
unsigned long long seed;
};
// Define to_json for Delta struct
void to_json(json& j, const ChatMessageDelta& msg) {
j = json{{"role", msg.role}, {"content", msg.content}};
}
void to_json(json& j, const ChatMessage& msg) {
j = json{{"role", msg.role}, {"content", msg.content}};
}
void to_json(json& j, const ChunkChoice& choice) {
j = json{{"index", choice.index}, {"delta", choice.delta}, {"finish_reason", choice.finish_reason}};
}
void to_json(json& j, const Choice& choice) {
j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}};
}
void to_json(json& j, const ChatCompletionChunk& completion) {
j = json{{"id", completion.id},
{"object", completion.object},
{"created", completion.created},
{"model", completion.model},
{"choices", completion.choices}};
}
void to_json(json& j, const ChatUsage& usage) {
j = json{{"completion_tokens", usage.completion_tokens},
{"prompt_tokens", usage.prompt_tokens},
{"total_tokens", usage.total_tokens}};
}
void to_json(json& j, const ChatCompletion& completion) {
j = json{{"id", completion.id},
{"object", completion.object},
{"created", completion.created},
{"model", completion.model},
{"usage", completion.usage},
{"choices", completion.choices}};
}
void to_json(json& j, const Model& model) {
j = json{{"id", model.id},
{"object", model.object},
{"created", model.created},
{"owned_by", model.owned_by}};
}
void to_json(json& j, const ModelList& models) {
j = json{{"object", models.object},
{"data", models.data}};
}
std::vector<ChatMessage> parseChatMessages(json &json){
std::vector<ChatMessage> messages;
messages.reserve(json.size());
for (const auto& item : json) {
messages.emplace_back(
item["role"].template get<std::string>(),
item["content"].template get<std::string>()
);
}
return messages;
}
#endif

358
src/app.cpp Normal file
View File

@@ -0,0 +1,358 @@
#include "app.hpp"
#include <cassert>
#include <cstring>
#include <stdexcept>
#if defined(DLLAMA_VULKAN)
#include "nn/nn-vulkan.hpp"
#endif
static NnFloatType parseFloatType(char *val) {
if (std::strcmp(val, "f32") == 0) return F_32;
if (std::strcmp(val, "f16") == 0) return F_16;
if (std::strcmp(val, "q40") == 0) return F_Q40;
if (std::strcmp(val, "q80") == 0) return F_Q80;
throw std::runtime_error("Invalid float type: " + std::string(val));
}
static ChatTemplateType parseChatTemplateType(char *val) {
if (std::strcmp(val, "llama2") == 0) return TEMPLATE_LLAMA2;
if (std::strcmp(val, "llama3") == 0) return TEMPLATE_LLAMA3;
if (std::strcmp(val, "deepSeek3") == 0) return TEMPLATE_DEEP_SEEK3;
throw std::runtime_error("Invalid chat template type: " + std::string(val));
}
AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
AppCliArgs args;
args.help = false;
args.mode = nullptr;
args.nBatches = 32;
args.nThreads = 1;
args.modelPath = nullptr;
args.tokenizerPath = nullptr;
args.prompt = nullptr;
args.syncType = F_32;
args.nWorkers = 0;
args.workerHosts = nullptr;
args.workerPorts = nullptr;
args.port = 9990;
args.temperature = 0.8f;
args.topp = 0.9f;
args.steps = 0;
args.seed = (unsigned long long)time(nullptr);
args.chatTemplateType = TEMPLATE_UNKNOWN;
args.maxSeqLen = 0;
args.netTurbo = true;
args.gpuIndex = -1;
args.gpuSegmentFrom = -1;
args.gpuSegmentTo = -1;
int i = 1;
if (requireMode && argc > 1) {
args.mode = argv[1];
i++;
}
// First see if any of the args are asking for help/usage and fail fast
for (int x = 0; x < argc; x++) {
if ((std::strcmp(argv[x], "--usage") == 0) ||
(std::strcmp(argv[x], "--help") == 0) ||
(std::strcmp(argv[x], "-h") == 0)) {
args.help = true;
return args;
}
}
for (; i + 1 < argc; i += 2) {
char *name = argv[i];
char *value = argv[i + 1];
if (std::strcmp(name, "--model") == 0) {
args.modelPath = value;
} else if (std::strcmp(name, "--tokenizer") == 0) {
args.tokenizerPath = value;
} else if (std::strcmp(name, "--prompt") == 0) {
args.prompt = value;
} else if (std::strcmp(name, "--buffer-float-type") == 0) {
args.syncType = parseFloatType(value);
} else if (std::strcmp(name, "--workers") == 0) {
int j = i + 1;
for (; j < argc && argv[j][0] != '-'; j++);
int count = j - i - 1;
args.nWorkers = count;
args.workerHosts = new char*[count];
args.workerPorts = new NnUint[count];
for (int s = 0; s < count; s++) {
char *v = argv[i + 1 + s];
char *separator = std::strstr(v, ":");
if (separator == NULL) {
throw std::runtime_error("Invalid worker address: " + std::string(v));
}
int hostLen = separator - v;
args.workerHosts[s] = new char[hostLen + 1];
std::memcpy(args.workerHosts[s], v, hostLen);
args.workerHosts[s][hostLen] = '\0';
args.workerPorts[s] = std::atoi(separator + 1);
}
i += count - 1;
} else if (std::strcmp(name, "--port") == 0) {
args.port = atoi(value);
} else if (std::strcmp(name, "--nthreads") == 0) {
args.nThreads = atoi(value);
} else if (std::strcmp(name, "--steps") == 0) {
args.steps = atoi(value);
} else if (std::strcmp(name, "--temperature") == 0) {
args.temperature = atof(value);
} else if (std::strcmp(name, "--topp") == 0) {
args.topp = atof(value);
} else if (std::strcmp(name, "--seed") == 0) {
args.seed = atoll(value);
} else if (std::strcmp(name, "--chat-template") == 0) {
args.chatTemplateType = parseChatTemplateType(value);
} else if (std::strcmp(name, "--max-seq-len") == 0) {
args.maxSeqLen = (unsigned int)atoi(value);
} else if (std::strcmp(name, "--gpu-index") == 0) {
args.gpuIndex = atoi(value);
} else if (std::strcmp(name, "--gpu-segments") == 0) {
char *separator = std::strstr(value, ":");
if (separator == NULL)
throw std::runtime_error("GPU segments expected in the format <from>:<to>");
args.gpuSegmentFrom = atoi(value);
args.gpuSegmentTo = atoi(separator + 1);
} else if (std::strcmp(name, "--net-turbo") == 0) {
args.netTurbo = atoi(value) == 1;
} else {
throw std::runtime_error("Unknown option: " + std::string(name));
}
}
if (args.nThreads < 1)
throw std::runtime_error("Number of threads must be at least 1");
return args;
}
AppCliArgs::~AppCliArgs() {
if (workerHosts != nullptr) {
for (NnUint i = 0; i < nWorkers; i++)
delete[] workerHosts[i];
delete[] workerHosts;
}
if (workerPorts != nullptr)
delete[] workerPorts;
}
static std::vector<NnExecutorDevice> resolveDevices(AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
std::vector<NnExecutorDevice> devices;
if (args->gpuIndex >= 0) {
#if defined(DLLAMA_VULKAN)
devices.push_back(NnExecutorDevice(
new NnVulkanDevice(args->gpuIndex, netConfig, nodeConfig, netExecution),
args->gpuSegmentFrom,
args->gpuSegmentTo
));
#else
throw std::runtime_error("This build does not support GPU");
#endif
}
if (args->gpuIndex < 0 || (args->gpuSegmentFrom >= 0 && args->gpuSegmentTo >= 0)) {
devices.push_back(NnExecutorDevice(new NnCpuDevice(netConfig, nodeConfig, netExecution), -1, -1));
}
return devices;
}
RootLlmInference::RootLlmInference(LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
this->header = net->header;
this->tokenPipe = (float *)execution->pipes[net->tokenPipeIndex];
this->positionPipe = (float *)execution->pipes[net->positionPipeIndex];
this->logitsPipe = (float *)execution->pipes[net->logitsPipeIndex];
this->execution = execution;
this->executor = executor;
this->network = network; // May be nullptr!
}
void RootLlmInference::setBatchSize(NnUint batchSize) {
execution->setBatchSize(batchSize);
controlPacket.batchSize = batchSize;
}
void RootLlmInference::setPosition(NnUint position) {
assert(position >= 0);
assert(position + execution->batchSize - 1 < header->seqLen);
controlPacket.position = position;
for (NnUint i = 0; i < execution->batchSize; i++)
positionPipe[i] = (float)(position + i);
}
void RootLlmInference::setToken(NnUint batchIndex, NnUint token) {
assert(batchIndex >= 0 && batchIndex < execution->batchSize);
tokenPipe[batchIndex] = (float)token;
}
void RootLlmInference::forward() {
if (network != nullptr)
network->writeAll(&controlPacket, sizeof(LlmControlPacket));
executor->forward();
}
void RootLlmInference::finish() {
if (network != nullptr) {
controlPacket.batchSize = 0;
network->writeAll(&controlPacket, sizeof(LlmControlPacket));
}
}
WorkerLlmInference::WorkerLlmInference(NnNetExecution *execution, NnNetwork *network) {
this->isFinished = false;
this->execution = execution;
this->network = network;
this->positionPipe = (float *)execution->pipes[0];
}
bool WorkerLlmInference::tryReadControlPacket() {
const unsigned long maxAttempts = 10000;
if (!network->tryReadWithMaxAttempts(ROOT_SOCKET_INDEX, &controlPacket, sizeof(LlmControlPacket), maxAttempts))
return false;
if (controlPacket.batchSize == 0) {
printf("🛑 Stop signal\n");
isFinished = true;
return true;
}
for (NnUint i = 0; i < controlPacket.batchSize; i++)
positionPipe[i] = (float)(controlPacket.position + i);
execution->setBatchSize(controlPacket.batchSize);
return true;
}
void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context)) {
NnUint nNodes = args->nWorkers + 1;
LlmHeader header = loadLlmHeader(args->modelPath, args->maxSeqLen, args->syncType);
if (nNodes > header.nKvHeads)
// TODO: https://github.com/b4rtaz/distributed-llama/issues/70
throw std::runtime_error("This version does not support more nodes than the number of KV heads in the model");
if (header.weightType == F_Q40 && header.syncType != F_Q80)
throw std::runtime_error("This version supports only Q40 weights with Q80 sync type");
Tokenizer tokenizer(args->tokenizerPath);
if (tokenizer.vocabSize != header.vocabSize)
printf("Tokenizer vocab size (%d) does not match the model vocab size (%d)\n", tokenizer.vocabSize, header.vocabSize);
Sampler sampler(tokenizer.vocabSize, args->temperature, args->topp, args->seed);
LlmNet net = buildLlmNet(&header, nNodes, args->nBatches);
std::unique_ptr<LlmNet, void(*)(LlmNet *)> netPtr(&net, releaseLlmNet);
NnNodeConfig *rootNodeConfig = &net.nodeConfigs[0];
printLlmHeader(&header);
printNodeRequiredMemory(&net.netConfig, rootNodeConfig);
NnNetExecution execution(args->nThreads, &net.netConfig);
std::unique_ptr<NnNodeSynchronizer> synchronizer(nullptr);
std::unique_ptr<NnNetwork> networkPtr(nullptr);
NnNetwork *network = nullptr;
if (nNodes == 1) {
synchronizer.reset(new NnFakeNodeSynchronizer());
} else {
networkPtr = NnNetwork::connect(args->nWorkers, args->workerHosts, args->workerPorts);
network = networkPtr.get();
synchronizer.reset(new NnNetworkNodeSynchronizer(network, &execution, &net.netConfig, rootNodeConfig));
NnRootConfigWriter configWriter(network);
configWriter.writeToWorkers(&net.netConfig, net.nodeConfigs);
}
std::vector<NnExecutorDevice> devices = resolveDevices(args, &net.netConfig, rootNodeConfig, &execution);
NnExecutor executor(&net.netConfig, rootNodeConfig, &devices, &execution, synchronizer.get(), args->benchmark);
NnRootWeightLoader weightLoader(&executor, network, nNodes);
loadLlmNetWeight(args->modelPath, &net, &weightLoader);
RootLlmInference inference(&net, &execution, &executor, network);
if (network != nullptr) {
network->resetStats();
if (args->netTurbo) {
network->setTurbo(true);
printf("🚁 Network is in non-blocking mode\n");
}
}
AppInferenceContext context;
context.args = args;
context.header = &header;
context.inference = &inference;
context.sampler = &sampler;
context.tokenizer = &tokenizer;
context.network = network;
context.executor = &executor;
handler(&context);
inference.finish();
}
void runWorkerApp(AppCliArgs *args) {
while (true) {
std::unique_ptr<NnNetwork> networkPtr = NnNetwork::serve(args->port);
NnNetwork *network = networkPtr.get();
NnWorkerConfigReader configReader(network);
NnNetConfig netConfig = configReader.readNet();
NnNodeConfig nodeConfig = configReader.readNode();
std::unique_ptr<NnNetConfig, void(*)(NnNetConfig *)> netConfigPtr(&netConfig, releaseNetConfig);
std::unique_ptr<NnNodeConfig, void(*)(NnNodeConfig *)> nodeConfigPtr(&nodeConfig, releaseNodeConfig);
printNodeRequiredMemory(&netConfig, &nodeConfig);
NnNetExecution execution(args->nThreads, &netConfig);
std::vector<NnExecutorDevice> devices = resolveDevices(args, &netConfig, &nodeConfig, &execution);
NnNetworkNodeSynchronizer synchronizer(network, &execution, &netConfig, &nodeConfig);
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
NnWorkerWeightReader weightReader(&executor, network);
weightReader.read();
WorkerLlmInference inference(&execution, network);
bool isFirstAttempt = true;
bool isTurboEnabled = false;
clock_t startTime;
while (true) {
try {
if (isFirstAttempt)
startTime = clock();
if (!inference.tryReadControlPacket()) {
if (isTurboEnabled && !isFirstAttempt && clock() - startTime > CLOCKS_PER_SEC) {
network->setTurbo(false);
isTurboEnabled = false;
printf("🚁 Network is in blocking mode\n");
}
isFirstAttempt = false;
continue;
}
if (inference.isFinished)
break;
if (args->netTurbo && !isTurboEnabled) {
network->setTurbo(true);
isTurboEnabled = true;
printf("🚁 Network is in non-blocking mode\n");
}
executor.forward();
isFirstAttempt = true;
} catch (const NnReadNetworkException &e) {
printf("Read network exception: %s\n", e.message);
break;
} catch (const NnWriteNetworkException &e) {
printf("Write network exception: %s\n", e.message);
break;
}
}
}
}

95
src/app.hpp Normal file
View File

@@ -0,0 +1,95 @@
#ifndef APP_HPP
#define APP_HPP
#include <chrono>
#include "nn/nn-core.hpp"
#include "nn/nn-cpu.hpp"
#include "tokenizer.hpp"
#include "llm.hpp"
class AppCliArgs {
public:
char *mode;
NnUint nThreads;
NnUint nBatches;
bool help;
// inference
char *modelPath;
char *tokenizerPath;
char *prompt;
NnFloatType syncType;
NnUint nWorkers;
char **workerHosts;
NnUint *workerPorts;
float temperature;
float topp;
NnUint steps;
bool benchmark;
unsigned long long seed;
ChatTemplateType chatTemplateType;
NnUint maxSeqLen;
bool netTurbo;
int gpuIndex;
int gpuSegmentFrom;
int gpuSegmentTo;
// worker
NnUint port;
static AppCliArgs parse(int argc, char **argv, bool hasMode);
~AppCliArgs();
};
typedef struct {
NnUint position;
NnUint batchSize; // 0 = stop signal
} LlmControlPacket;
class RootLlmInference {
public:
float *logitsPipe;
private:
float *tokenPipe;
float *positionPipe;
LlmHeader *header;
NnNetExecution *execution;
NnExecutor *executor;
NnNetwork *network;
LlmControlPacket controlPacket;
public:
RootLlmInference(LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network);
void setBatchSize(NnUint batchSize);
void setPosition(NnUint position);
void setToken(NnUint batchIndex, NnUint token);
void forward();
void finish();
};
class WorkerLlmInference {
public:
bool isFinished;
private:
float *positionPipe;
NnNetExecution *execution;
NnNetwork *network;
LlmControlPacket controlPacket;
public:
WorkerLlmInference(NnNetExecution *execution, NnNetwork *network);
bool tryReadControlPacket();
};
typedef struct {
AppCliArgs *args;
LlmHeader *header;
RootLlmInference *inference;
Tokenizer *tokenizer;
Sampler *sampler;
NnNetwork *network;
NnExecutor *executor;
} AppInferenceContext;
void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context));
void runWorkerApp(AppCliArgs *args);
#endif

622
src/dllama-api.cpp Normal file
View File

@@ -0,0 +1,622 @@
#include <cstring>
#include <cstdlib>
#include <cstdint>
#include <cstdio>
#include <cassert>
#include <sstream>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <csignal>
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h>
#else
#include <sys/socket.h>
#include <netinet/in.h>
#include <unistd.h>
#endif
#include "tokenizer.hpp"
#include "app.hpp"
#include "json.hpp"
#include "api-types.hpp"
#include "nn/nn-network.hpp"
typedef unsigned int pos_t;
using json = nlohmann::json;
enum class HttpMethod {
METHOD_GET = 0,
METHOD_POST = 1,
METHOD_PUT = 2,
METHOD_DELETE = 3,
METHOD_OPTIONS = 4,
METHOD_UNKNOWN = 5
};
class HttpRequest {
public:
static HttpRequest read(int serverSocket) {
HttpRequest req(serverSocket);
std::vector<char> httpRequest = req.readHttpRequest();
// Parse the HTTP request
std::string data = std::string(httpRequest.begin(), httpRequest.end());
// Split request into lines
std::istringstream iss(data);
std::string line;
std::getline(iss, line);
// Parse request line
std::istringstream lineStream(line);
std::string methodStr, path;
lineStream >> methodStr >> path;
req.method = parseMethod(methodStr);
req.path = path;
// Parse headers
while (std::getline(iss, line) && line != "\r") {
size_t pos = line.find(':');
if (pos != std::string::npos) {
std::string key = line.substr(0, pos);
std::string value = line.substr(pos + 2); // Skip ': ' after key
// Trim whitespace and non-printable characters from header value
value.erase(std::remove_if(value.begin(), value.end(), [](unsigned char c) {
return std::isspace(c) || !std::isprint(c);
}), value.end());
req.headers[key] = value;
}
}
// Parse body
std::getline(iss, req.body, '\0');
if (req.body.size() > 0) {
// printf("body: %s\n", req.body.c_str());
req.parsedJson = json::parse(req.body);
}
return req;
}
static HttpMethod parseMethod(const std::string& method) {
if (method == "GET") return HttpMethod::METHOD_GET;
if (method == "POST") return HttpMethod::METHOD_POST;
if (method == "PUT") return HttpMethod::METHOD_PUT;
if (method == "DELETE") return HttpMethod::METHOD_DELETE;
if (method == "OPTIONS") return HttpMethod::METHOD_OPTIONS;
return HttpMethod::METHOD_UNKNOWN;
}
private:
int serverSocket;
public:
std::string path;
std::unordered_map<std::string, std::string> headers;
std::string body;
json parsedJson;
HttpMethod method;
HttpRequest(int serverSocket) {
this->serverSocket = serverSocket;
}
std::vector<char> readHttpRequest() {
std::string httpRequest;
char buffer[1024 * 64];
ssize_t bytesRead;
// First, read all headers
std::string headerData;
size_t headerEnd;
bool headerDone = false;
std::string extraReadPastHeader;
while (!headerDone) {
bytesRead = recv(serverSocket, buffer, sizeof(buffer) - 1, 0);
if (bytesRead <= 0) {
throw std::runtime_error("Error while reading headers from socket");
}
buffer[bytesRead] = '\0';
headerData.append(buffer);
// Check for end of headers (http header says "\r\n\r\n")
headerEnd = headerData.find("\r\n\r\n");
if (headerEnd != std::string::npos) {
headerDone = true;
if (headerEnd < headerData.size()-4) {
// We read something past the header
extraReadPastHeader = headerData.substr(headerEnd+4);
}
}
}
httpRequest.append(headerData);
// Next, find Content-Length header for body length
std::istringstream headerStream(headerData);
std::string line;
ssize_t contentLength = 0;
while (std::getline(headerStream, line) && line != "\r") {
size_t pos = line.find(':');
if (pos != std::string::npos) {
std::string key = line.substr(0, pos);
std::string value = line.substr(pos + 2); // Skip ': ' after key
if (key == "Content-Length") {
try {
contentLength = std::stoi(value); // stoi ignores any whitespace
} catch (const std::invalid_argument& e) {
throw std::runtime_error("Bad Content-Length header - not a number");
}
break;
}
}
}
// Now read the full content body
if (contentLength > 0) {
// If we read any extra past the header before, read that much less now
// But first, sanity check to make sure Content-Length isn't lying and there is actually more
if (extraReadPastHeader.size() > static_cast<size_t>(contentLength)) {
throw std::runtime_error("Received more body data than Content-Length header said");
}
contentLength -= extraReadPastHeader.size();
std::vector<char> body(contentLength);
ssize_t totalRead = 0;
while (totalRead < contentLength) {
bytesRead = recv(serverSocket, body.data() + totalRead, contentLength - totalRead, 0);
if (bytesRead <= 0) {
throw std::runtime_error("Error while reading body from socket");
}
totalRead += bytesRead;
}
if (body.size() > 0) {
httpRequest.append(body.data(), contentLength);
}
}
return std::vector<char>(httpRequest.begin(), httpRequest.end());
}
std::string getMethod() {
if (method == HttpMethod::METHOD_GET) return "GET";
if (method == HttpMethod::METHOD_POST) return "POST";
if (method == HttpMethod::METHOD_PUT) return "PUT";
if (method == HttpMethod::METHOD_DELETE) return "DELETE";
if (method == HttpMethod::METHOD_OPTIONS) return "OPTIONS";
return "UNKNOWN";
}
void writeCors() {
std::ostringstream buffer;
buffer << "HTTP/1.1 204 No Content\r\n"
<< "Access-Control-Allow-Origin: *\r\n"
<< "Access-Control-Allow-Methods: GET, POST, PUT, DELETE\r\n"
<< "Access-Control-Allow-Headers: Content-Type, Authorization\r\n"
<< "Connection: close\r\n"
<< "\r\n";
std::string data = buffer.str();
writeSocket(serverSocket, data.c_str(), data.size());
}
void writeNotFound() {
std::ostringstream buffer;
buffer << "HTTP/1.1 404 Not Found\r\n"
<< "Connection: close\r\n"
<< "Content-Length: 9\r\n"
<< "\r\n"
<< "Not Found";
std::string data = buffer.str();
writeSocket(serverSocket, data.c_str(), data.size());
}
void writeJson(std::string json) {
std::ostringstream buffer;
buffer << "HTTP/1.1 200 OK\r\n"
<< "Access-Control-Allow-Origin: *\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Connection: close\r\n"
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
std::string data = buffer.str();
writeSocket(serverSocket, data.c_str(), data.size());
}
void writeStreamStartChunk() {
std::ostringstream buffer;
buffer << "HTTP/1.1 200 OK\r\n"
<< "Access-Control-Allow-Origin: *\r\n"
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
<< "Connection: close\r\n"
<< "Transfer-Encoding: chunked\r\n\r\n";
std::string data = buffer.str();
writeSocket(serverSocket, data.c_str(), data.size());
}
void writeStreamChunk(const std::string data) {
std::ostringstream buffer;
buffer << std::hex << data.size() << "\r\n" << data << "\r\n";
std::string d = buffer.str();
writeSocket(serverSocket, d.c_str(), d.size());
}
void writeStreamEndChunk() {
const char *endChunk = "0000\r\n\r\n";
writeSocket(serverSocket, endChunk, strlen(endChunk));
}
};
struct Route {
std::string path;
HttpMethod method;
std::function<void(HttpRequest&)> handler;
};
class Router {
public:
static void resolve(HttpRequest& request, std::vector<Route>& routes) {
if (request.method == HttpMethod::METHOD_OPTIONS) {
request.writeCors();
return;
}
for (const auto& route : routes) {
if (request.method == route.method && request.path == route.path) {
route.handler(request);
return;
}
}
request.writeNotFound();
}
};
void writeChatCompletionChunk(HttpRequest &request, const std::string &delta, const bool stop){
ChunkChoice choice;
if (stop) {
choice.finish_reason = "stop";
} else {
choice.delta = ChatMessageDelta("assistant", delta);
}
ChatCompletionChunk chunk = ChatCompletionChunk(choice);
std::ostringstream buffer;
buffer << "data: " << ((json)chunk).dump() << "\r\n\r\n";
request.writeStreamChunk(buffer.str());
if (stop) {
request.writeStreamChunk("data: [DONE]");
request.writeStreamEndChunk();
}
}
class NaiveCacheItem {
public:
pos_t endPos;
ChatMessage message;
NaiveCacheItem(pos_t endPos, ChatMessage message) {
this->endPos = endPos;
this->message = message;
}
};
class NaiveCache {
private:
std::vector<NaiveCacheItem> cache;
public:
void push(NaiveCacheItem item) {
cache.push_back(item);
}
void clear() {
cache.clear();
}
bool resolveDeltaPrompt(std::vector<ChatMessage>& messages, pos_t& startPos) {
size_t cacheSize = cache.size();
if (cacheSize == 0)
return false;
if (messages.size() > cacheSize) {
size_t i = 0;
while (i < cacheSize) {
if (
cache[i].message.role != messages[i].role ||
cache[i].message.content != messages[i].content
) break;
i++;
}
if (i == cacheSize) {
startPos = cache[i - 1].endPos;
messages.erase(messages.begin(), messages.begin() + i);
printf("🐤 Found naive cache for %zu messages, pos=%d\n", i, startPos);
return true;
}
}
cache.clear();
return false;
}
};
class ApiServer {
private:
RootLlmInference *inference;
Tokenizer *tokenizer;
Sampler *sampler;
AppCliArgs *args;
LlmHeader *header;
EosDetector *eosDetector;
ChatTemplateGenerator *templateGenerator;
NaiveCache naiveCache;
public:
ApiServer(RootLlmInference *inference, Tokenizer *tokenizer, Sampler *sampler, AppCliArgs *args, LlmHeader *header, EosDetector *eosDetector, ChatTemplateGenerator *templateGenerator) {
this->inference = inference;
this->tokenizer = tokenizer;
this->sampler = sampler;
this->args = args;
this->header = header;
this->eosDetector = eosDetector;
this->templateGenerator = templateGenerator;
}
void complete(HttpRequest& request) {
InferenceParams params = parseRequest(request);
pos_t startPos = 0;
std::vector<ChatMessage> deltaPrompt = params.messages;
naiveCache.resolveDeltaPrompt(deltaPrompt, startPos);
size_t nInputItems = deltaPrompt.size();
std::unique_ptr<ChatItem[]> inputItemsPtr(new ChatItem[nInputItems]);
ChatItem *inputItems = inputItemsPtr.get();
for (size_t i = 0; i < nInputItems; i++) {
inputItems[i].role = deltaPrompt[i].role;
inputItems[i].message = deltaPrompt[i].content;
}
GeneratedChat inputPrompt = templateGenerator->generate(nInputItems, inputItems, true);
printf("🔹%s🔸", inputPrompt.content);
int nPromptTokens;
std::unique_ptr<int[]> promptTokensPtr(new int[inputPrompt.length + 2]);
int *promptTokens = promptTokensPtr.get();
bool isStart = startPos == 0;
tokenizer->encode((char*)inputPrompt.content, promptTokens, &nPromptTokens, isStart, true);
pos_t promptEndPos = startPos + nPromptTokens - 1;
if (promptEndPos > header->seqLen)
promptEndPos = header->seqLen;
pos_t maxPredPos = params.max_tokens > 0 ? (promptEndPos + params.max_tokens) : header->seqLen;
if (maxPredPos > header->seqLen)
maxPredPos = header->seqLen;
for (size_t j = 0; j < deltaPrompt.size(); j++) {
naiveCache.push(NaiveCacheItem(promptEndPos, deltaPrompt[j]));
}
std::string buffer;
if (params.stream)
request.writeStreamStartChunk();
if (inputPrompt.publicPrompt != nullptr) {
if (params.stream)
writeChatCompletionChunk(request, inputPrompt.publicPrompt, false);
buffer += inputPrompt.publicPrompt;
}
NnUint pos = startPos;
int token;
for (NnUint i = 0; ;) {
long remainingTokens = promptEndPos - pos;
if (remainingTokens <= 0)
break;
NnUint batchSize = remainingTokens < args->nBatches
? remainingTokens
: args->nBatches;
inference->setBatchSize(batchSize);
inference->setPosition(pos);
for (NnUint j = 0; j < batchSize; j++)
inference->setToken(j, promptTokens[i + j]);
inference->forward();
i += batchSize;
pos += batchSize;
token = promptTokens[i + 1];
}
inference->setBatchSize(1);
tokenizer->resetDecoder();
eosDetector->reset();
for (; pos < maxPredPos;) {
inference->setPosition(pos);
inference->setToken(0, token);
inference->forward();
token = sampler->sample(inference->logitsPipe);
char *piece = tokenizer->decode(token);
EosDetectorType eosType = eosDetector->append(token, piece);
if (piece != nullptr) {
printf("%s", piece);
fflush(stdout);
}
if (eosType == NOT_EOS || eosType == EOS) {
char *delta = eosDetector->getDelta();
if (delta != nullptr) {
std::string deltaStr(delta);
if (params.stream)
writeChatCompletionChunk(request, deltaStr, false);
buffer += deltaStr;
}
eosDetector->reset();
}
pos++;
if (eosType == EOS) break;
}
ChatMessage chatMessage("assistant", buffer);
if (pos == header->seqLen) {
naiveCache.clear();
} else {
naiveCache.push(NaiveCacheItem(pos, chatMessage));
}
if (params.stream) {
writeChatCompletionChunk(request, "", true);
} else {
int nCompletionTokens = pos - promptEndPos;
ChatUsage usage(nPromptTokens, nCompletionTokens, nPromptTokens + nCompletionTokens);
Choice choice(chatMessage);
ChatCompletion completion(choice, usage);
std::string chatJson = ((json)completion).dump();
request.writeJson(chatJson);
}
printf("🔶\n");
fflush(stdout);
}
private:
InferenceParams parseRequest(HttpRequest& request) {
InferenceParams params;
params.temperature = args->temperature;
params.top_p = args->topp;
params.seed = args->seed;
params.stream = false;
params.messages = parseChatMessages(request.parsedJson["messages"]);
params.max_tokens = -1;
if (request.parsedJson.contains("stream")) {
params.stream = request.parsedJson["stream"].get<bool>();
}
if (request.parsedJson.contains("temperature")) {
params.temperature = request.parsedJson["temperature"].template get<float>();
}
if (request.parsedJson.contains("seed")) {
params.seed = request.parsedJson["seed"].template get<unsigned long long>();
sampler->setSeed(params.seed);
}
if (request.parsedJson.contains("max_tokens")) {
params.max_tokens = request.parsedJson["max_tokens"].template get<int>();
}
if (request.parsedJson.contains("stop")) {
params.stop = request.parsedJson["stop"].template get<std::vector<std::string>>();
} else {
const std::string defaultStop = "<|eot_id|>";
params.stop = std::vector<std::string>{defaultStop};
}
return params;
}
};
void handleCompletionsRequest(HttpRequest& request, ApiServer *api) {
api->complete(request);
}
void handleModelsRequest(HttpRequest& request, const char* modelPath) {
std::string path(modelPath);
size_t pos = path.find_last_of("/\\");
std::string modelName = (pos == std::string::npos) ? path : path.substr(pos + 1);
Model model(modelName);
ModelList list(model);
std::string response = ((json)list).dump();
request.writeJson(response);
}
static void server(AppInferenceContext *context) {
int serverSocket = createServerSocket(context->args->port);
TokenizerChatStops stops(context->tokenizer);
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
ApiServer api(context->inference, context->tokenizer, context->sampler, context->args, context->header, &eosDetector, &templateGenerator);
printf("Server URL: http://127.0.0.1:%d/v1/\n", context->args->port);
std::vector<Route> routes = {
{
"/v1/chat/completions",
HttpMethod::METHOD_POST,
std::bind(&handleCompletionsRequest, std::placeholders::_1, &api)
},
{
"/v1/models",
HttpMethod::METHOD_GET,
std::bind(&handleModelsRequest, std::placeholders::_1, context->args->modelPath)
}
};
while (true) {
try {
int clientSocket = acceptSocket(serverSocket);
HttpRequest request = HttpRequest::read(clientSocket);
printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str());
Router::resolve(request, routes);
destroySocket(clientSocket);
} catch (NnReadNetworkException& ex) {
printf("Read socket error: %d %s\n", ex.code, ex.message);
} catch (NnWriteNetworkException& ex) {
printf("Write socket error: %d %s\n", ex.code, ex.message);
}
}
destroySocket(serverSocket);
}
#ifdef _WIN32
#define EXECUTABLE_NAME "dllama-api.exe"
#else
#define EXECUTABLE_NAME "dllama-api"
#endif
void usage() {
fprintf(stderr, "Usage: %s {--model <path>} {--tokenizer <path>} [--port <p>]\n", EXECUTABLE_NAME);
fprintf(stderr, " [--buffer-float-type {f32|f16|q40|q80}]\n");
fprintf(stderr, " [--weights-float-type {f32|f16|q40|q80}]\n");
fprintf(stderr, " [--max-seq-len <max>]\n");
fprintf(stderr, " [--nthreads <n>]\n");
fprintf(stderr, " [--workers <ip:port> ...]\n");
fprintf(stderr, " [--temperature <temp>]\n");
fprintf(stderr, " [--topp <t>]\n");
fprintf(stderr, " [--seed <s>]\n");
fprintf(stderr, "Example:\n");
fprintf(stderr, " sudo nice -n -20 ./dllama-api --port 9990 --nthreads 4 \\\n");
fprintf(stderr, " --model dllama_model_llama3_2_3b_instruct_q40.m \\\n");
fprintf(stderr, " --tokenizer dllama_tokenizer_llama3_2_3b_instruct_q40.t \\\n");
fprintf(stderr, " --buffer-float-type q80 --max-seq-len 8192 \\\n");
fprintf(stderr, " --workers 10.0.0.2:9998 10.0.0.3:9998 10.0.0.4:9998\n");
fflush(stderr);
}
int main(int argc, char *argv[]) {
#ifdef SIGPIPE
std::signal(SIGPIPE, SIG_IGN);
#endif
initQuants();
initSockets();
int returnCode = EXIT_SUCCESS;
try {
AppCliArgs args = AppCliArgs::parse(argc, argv, false);
if (args.help) {
usage();
} else {
runInferenceApp(&args, server);
}
} catch (std::exception &e) {
printf("🚨 Critical error: %s\n", e.what());
returnCode = EXIT_FAILURE;
}
cleanupSockets();
return returnCode;
}

285
src/dllama.cpp Normal file
View File

@@ -0,0 +1,285 @@
#include "nn/nn-core.hpp"
#include "nn/nn-config-builder.hpp"
#include "nn/nn-cpu.hpp"
#include "nn/nn-cpu-ops.hpp"
#include "nn/nn-network.hpp"
#include "nn/nn-executor.hpp"
#include "llm.hpp"
#include "tokenizer.hpp"
#include "app.hpp"
#include <stdexcept>
#include <cmath>
static void inference(AppInferenceContext *context) {
if (context->args->prompt == nullptr)
throw std::runtime_error("Prompt is required");
if (context->args->steps == 0)
throw std::runtime_error("Number of steps is required");
std::vector<int> inputTokensVec(std::strlen(context->args->prompt) + 3);
int *inputTokens = inputTokensVec.data();
NnUint pos = 0;
int nInputTokens;
context->tokenizer->encode(context->args->prompt, inputTokens, &nInputTokens, true, true);
if (nInputTokens > context->header->seqLen)
throw std::runtime_error("The number of prompt tokens is greater than the sequence length");
if (nInputTokens > context->args->steps)
throw std::runtime_error("The number of prompt tokens is greater than the number of steps");
NnSize sentBytes = 0;
NnSize recvBytes = 0;
NnUint evalTotalTime = 0;
NnUint predTotalTime = 0;
int token = inputTokens[pos];
printf("%s\n", context->args->prompt);
for (;;) {
long remainingTokens = nInputTokens - 1 - (long)pos;
if (remainingTokens <= 0)
break;
NnUint batchSize = remainingTokens < context->args->nBatches
? remainingTokens
: context->args->nBatches;
context->inference->setBatchSize(batchSize);
context->inference->setPosition(pos);
for (NnUint i = 0; i < batchSize; i++)
context->inference->setToken(i, inputTokens[pos + i]);
context->inference->forward();
pos += batchSize;
token = inputTokens[pos + 1];
if (context->network != nullptr)
context->network->getStats(&sentBytes, &recvBytes);
NnUint evalTime = context->executor->getTotalTime(STEP_EXECUTE_OP);
NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES);
printf("🔷️ Eval%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | (%d tokens)\n",
evalTime / 1000,
syncTime / 1000,
sentBytes / 1024,
recvBytes / 1024,
batchSize);
evalTotalTime += evalTime + syncTime;
}
fflush(stdout);
context->inference->setBatchSize(1);
context->tokenizer->resetDecoder();
const NnUint maxPos = std::min(context->header->seqLen, context->args->steps);
for (; pos < maxPos; pos++) {
context->inference->setPosition(pos);
context->inference->setToken(0, token);
context->inference->forward();
token = context->sampler->sample(context->inference->logitsPipe);
char *piece = context->tokenizer->decode(token);
if (context->network != nullptr)
context->network->getStats(&sentBytes, &recvBytes);
NnUint predTime = context->executor->getTotalTime(STEP_EXECUTE_OP);
NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES);
printf("🔶 Pred%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | %s\n",
predTime / 1000,
syncTime / 1000,
sentBytes / 1024,
recvBytes / 1024,
piece == nullptr ? "~" : piece);
fflush(stdout);
predTotalTime += predTime + syncTime;
}
NnUint nEvalTokens = nInputTokens - 1;
NnUint nPredTokens = pos - nEvalTokens;
float evalTotalTimeMs = evalTotalTime / 1000.0;
float predTotalTimeMs = predTotalTime / 1000.0;
printf("\n");
printf("Evaluation\n");
printf(" nBatches: %d\n", context->args->nBatches);
printf(" nTokens: %d\n", nEvalTokens);
printf(" tokens/s: %3.2f (%3.2f ms/tok)\n",
(nEvalTokens * 1000) / evalTotalTimeMs,
evalTotalTimeMs / ((float) nEvalTokens));
printf("Prediction\n");
printf(" nTokens: %d\n", nPredTokens);
printf(" tokens/s: %3.2f (%3.2f ms/tok)\n",
(nPredTokens * 1000) / predTotalTimeMs,
predTotalTimeMs / ((float) nPredTokens));
}
static NnUint readStdin(const char *guide, char *buffer, NnUint size) {
std::fflush(stdin);
std::printf("%s", guide);
if (std::fgets(buffer, size, stdin) != NULL) {
NnUint length = std::strlen(buffer);
if (length > 0 && buffer[length - 1] == '\n') {
buffer[length - 1] = '\0';
length--;
}
return length;
}
return 0;
}
static void perplexity(AppInferenceContext *context) {
if (context->args->prompt == nullptr)
throw std::runtime_error("Prompt is required");
std::vector<int> inputTokensVec(std::strlen(context->args->prompt) + 3);
int *inputTokens = inputTokensVec.data();
int nInputTokens;
context->tokenizer->encode(context->args->prompt, inputTokens, &nInputTokens, true, true);
printf("Evaluating %d tokens...\n", nInputTokens);
float totalLogProb = 0.0f;
NnUint pos = 0;
context->inference->setBatchSize(1);
for (pos = 0; pos < nInputTokens - 1; pos++) {
context->inference->setPosition(pos);
context->inference->setToken(0, inputTokens[pos]);
context->inference->forward();
float *logits = context->inference->logitsPipe;
softmax_F32(logits, context->header->vocabSize);
int targetToken = inputTokens[pos + 1];
float prob = logits[targetToken];
totalLogProb += std::log(std::max(prob, 1e-30f));
printf("%5d / %d, prob=%f\n", pos + 1, nInputTokens - 1, prob);
}
float avgLogProb = totalLogProb / (float)(nInputTokens - 1);
float perplexity = expf(-avgLogProb);
printf("\n");
printf("Results\n");
printf(" perplexity: %f (lower = better)\n", perplexity);
printf(" avgLogProb: %f\n", avgLogProb);
printf(" bitPerToken: %f\n", -avgLogProb / std::log(2.0));
}
static void chat(AppInferenceContext *context) {
const NnUint seqLen = context->header->seqLen;
char prompt[2048];
TokenizerChatStops stops(context->tokenizer);
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
const NnUint sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt));
std::vector<ChatItem> deltaItems;
if (sysPromptLength > 0)
deltaItems.push_back(ChatItem{"system", prompt});
NnUint pos = 0;
NnUint userPromptLength;
int token;
int nInputTokens;
do {
do {
userPromptLength = readStdin("\n👱 User\n> ", prompt, sizeof(prompt));
} while (userPromptLength == 0);
deltaItems.push_back(ChatItem{"user", prompt});
GeneratedChat inputPrompt = templateGenerator.generate(deltaItems.size(), deltaItems.data(), true);
std::unique_ptr<int[]> inputTokensPtr(new int[inputPrompt.length + 2]);
int *inputTokens = inputTokensPtr.get();
bool isStart = pos == 0;
context->tokenizer->encode((char*)inputPrompt.content, inputTokens, &nInputTokens, isStart, true);
NnUint userPromptEndPos = (NnUint)std::min<unsigned int>(seqLen, pos + nInputTokens - 1);
for (NnUint i = 0; ;) {
int remainingTokens = userPromptEndPos - pos;
if (remainingTokens <= 0)
break;
NnUint batchSize = remainingTokens < context->args->nBatches
? remainingTokens
: context->args->nBatches;
context->inference->setBatchSize(batchSize);
context->inference->setPosition(pos);
for (NnUint j = 0; j < batchSize; j++)
context->inference->setToken(j, inputTokens[i + j]);
context->inference->forward();
i += batchSize;
pos += batchSize;
token = inputTokens[i + 1];
}
context->inference->setBatchSize(1);
context->tokenizer->resetDecoder();
printf("\n🤖 Assistant\n");
if (inputPrompt.publicPrompt != nullptr)
printf("%s", inputPrompt.publicPrompt);
while (pos < seqLen) {
context->inference->setPosition(pos);
context->inference->setToken(0, token);
context->inference->forward();
token = context->sampler->sample(context->inference->logitsPipe);
char *piece = context->tokenizer->decode(token);
EosDetectorType eosType = eosDetector.append(token, piece);
if (eosType == NOT_EOS || eosType == EOS) {
char *delta = eosDetector.getDelta();
if (delta != nullptr) {
printf("%s", delta);
fflush(stdout);
}
eosDetector.reset();
}
pos++;
if (eosType == EOS) break;
}
deltaItems.clear();
} while (pos < seqLen);
printf("(end of context)\n");
}
int main(int argc, char **argv) {
initQuants();
initSockets();
int returnCode = EXIT_SUCCESS;
try {
AppCliArgs args = AppCliArgs::parse(argc, argv, true);
if (std::strcmp(args.mode, "inference") == 0) {
args.benchmark = true;
runInferenceApp(&args, &inference);
} else if (std::strcmp(args.mode, "perplexity") == 0)
runInferenceApp(&args, &perplexity);
else if (std::strcmp(args.mode, "chat") == 0)
runInferenceApp(&args, &chat);
else if (std::strcmp(args.mode, "worker") == 0)
runWorkerApp(&args);
else
throw std::runtime_error("Unsupported mode");
} catch (std::exception &e) {
printf("🚨 Critical error: %s\n", e.what());
returnCode = EXIT_FAILURE;
}
cleanupSockets();
return returnCode;
}

24765
src/json.hpp Normal file

File diff suppressed because it is too large Load Diff

669
src/llm.cpp Normal file
View File

@@ -0,0 +1,669 @@
#include "nn/nn-core.hpp"
#include "nn/nn-config-builder.hpp"
#include "nn/nn-cpu.hpp"
#include "nn/nn-network.hpp"
#include "mmap.hpp"
#include "llm.hpp"
#include <cerrno>
#include <stdexcept>
static const char *hiddenActToString(LlmHiddenAct act) {
if (act == HIDDEN_ACT_GELU) return "Gelu";
if (act == HIDDEN_ACT_SILU) return "Silu";
throw std::runtime_error("Unsupported hidden act");
}
static const char *ropeTypeToString(NnRopeType type) {
if (type == ROPE_LLAMA) return "Llama";
if (type == ROPE_LLAMA3_1) return "Llama3.1";
if (type == ROPE_FALCON) return "Falcon";
throw std::runtime_error("Unsupported rope type");
}
static const char *archTypeToString(LlmArchType type) {
if (type == LLAMA) return "Llama";
if (type == QWEN3) return "Qwen3";
if (type == QWEN3_MOE) return "Qwen3 MoE";
throw std::runtime_error("Unsupported architecture");
}
static float convertNormEpsilon(int value) {
if (value == 5) return 1e-05f;
if (value == 6) return 1e-06f;
throw std::runtime_error("Unsupported norm epsilon");
}
LlmHeader loadLlmHeader(const char *path, const NnUint maxSeqLen, NnFloatType syncType) {
LlmHeader header;
std::memset(&header, 0, sizeof(LlmHeader));
header.weightType = F_UNK;
header.hiddenAct = HIDDEN_ACT_SILU;
header.ropeType = ROPE_LLAMA;
header.ropeTheta = 10000.0f;
header.ropeScalingFactor = 1.0f;
header.normEpsilon = 1e-5f;
header.moeHiddenDim = 0u;
std::unique_ptr<FILE, int(*)(FILE *)> fdPtr(fopen(path, "rb"), fclose);
FILE *fd = fdPtr.get();
if (fd == NULL)
throw std::runtime_error(std::string("Cannot open model file (") + path + std::string("): ") + std::strerror(errno));
int magic;
if (fread(&magic, sizeof(int), 1, fd) != 1)
throw std::runtime_error("Cannot read magic value");
if (magic == 0xABCD00 || magic == 0xABCD01)
throw std::runtime_error("Old model format is not supported");
if (magic != 0xA00ABCD)
throw std::runtime_error("Unsupported magic number");
if (fread(&header.headerSize, sizeof(int), 1, fd) != 1)
throw std::runtime_error("Cannot read header size");
std::vector<int> bufferPtr(header.headerSize);
int *buffer = &bufferPtr[0];
if (fread(buffer, header.headerSize, 1, fd) != 1)
throw std::runtime_error("Cannot read header values");
int nKv = (header.headerSize - 2 * sizeof(int)) / sizeof(int);
for (int i = 0; i < nKv; i += 2) {
int key = buffer[i];
int value = buffer[i + 1];
if (key == VERSION) header.version = value;
else if (key == ARCH_TYPE) header.archType = (LlmArchType)value;
else if (key == DIM) header.dim = value;
else if (key == HIDDEN_DIM) header.hiddenDim = value;
else if (key == N_LAYERS) header.nLayers = value;
else if (key == N_HEADS) header.nHeads = value;
else if (key == N_KV_HEADS) header.nKvHeads = value;
else if (key == N_EXPERTS) header.nExperts = value;
else if (key == N_ACTIVE_EXPERTS) header.nActiveExperts = value;
else if (key == VOCAB_SIZE) header.vocabSize = value;
else if (key == SEQ_LEN) header.seqLen = value;
else if (key == HIDDEN_ACT) header.hiddenAct = (LlmHiddenAct)value;
else if (key == ROPE_THETA) header.ropeTheta = (float)value;
else if (key == WEIGHT_FLOAT_TYPE) header.weightType = (NnFloatType)value;
else if (key == ROPE_SCALING_FACTOR) header.ropeScalingFactor = (float)value;
else if (key == ROPE_SCALING_LOW_FREQ_FACTOR) header.ropeScalingLowFreqFactor = (float)value;
else if (key == ROPE_SCALING_HIGH_FREQ_FACTORY) header.ropeScalingHighFreqFactory = (float)value;
else if (key == ROPE_SCALING_ORIG_MAX_SEQ_LEN) header.ropeScalingOrigMaxSeqLen = value;
else if (key == ROPE_TYPE) header.ropeType = (NnRopeType)value;
else if (key == HEAD_DIM) header.headDim = value;
else if (key == NORM_EPSILON) header.normEpsilon = convertNormEpsilon(value);
else if (key == MOE_HIDDEN_DIM) header.moeHiddenDim = value;
else throw std::runtime_error("Unsupported header key");
}
if (header.weightType == F_UNK)
throw std::runtime_error("Model does not specify weight type");
header.origSeqLen = header.seqLen;
if (maxSeqLen > 0 && header.seqLen > maxSeqLen)
header.seqLen = maxSeqLen;
if (header.headDim == 0)
header.headDim = header.dim / header.nHeads;
header.qDim = header.headDim * header.nHeads;
header.kvDim = header.headDim * header.nKvHeads;
header.syncType = syncType;
header.fileSize = (NnSize)seekToEnd(fd);
if (header.archType == QWEN3 || header.archType == QWEN3_MOE)
header.ropeType = ROPE_FALCON;
return header;
}
void printLlmHeader(LlmHeader *header) {
printf("💡 Arch: %s\n", archTypeToString(header->archType));
printf("💡 HiddenAct: %s\n", hiddenActToString(header->hiddenAct));
printf("💡 Dim: %u\n", header->dim);
printf("💡 HeadDim: %u\n", header->headDim);
printf("💡 QDim: %u\n", header->qDim);
printf("💡 KvDim: %u\n", header->kvDim);
printf("💡 HiddenDim: %u\n", header->hiddenDim);
printf("💡 VocabSize: %u\n", header->vocabSize);
printf("💡 nLayers: %u\n", header->nLayers);
printf("💡 nHeads: %u\n", header->nHeads);
printf("💡 nKvHeads: %u\n", header->nKvHeads);
if (header->seqLen != header->origSeqLen) {
printf("💡 OrigSeqLen: %u\n", header->origSeqLen);
}
if (header->nExperts > 0) {
printf("💡 nExperts: %u\n", header->nExperts);
printf("💡 nActiveExperts: %u\n", header->nActiveExperts);
printf("💡 MoeHiddenDim: %u\n", header->moeHiddenDim);
}
printf("💡 SeqLen: %u\n", header->seqLen);
printf("💡 NormEpsilon: %f\n", header->normEpsilon);
printf("💡 RopeType: %s\n", ropeTypeToString(header->ropeType));
printf("💡 RopeTheta: %.0f\n", header->ropeTheta);
if (header->ropeType == ROPE_LLAMA3_1) {
printf("💡 RopeScaling: f=%.1f, l=%.1f, h=%.1f, o=%d\n",
header->ropeScalingFactor,
header->ropeScalingLowFreqFactor,
header->ropeScalingHighFreqFactory,
header->ropeScalingOrigMaxSeqLen);
}
}
LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches) {
NnUint nExpertsOr1 = std::max(h->nExperts, 1u);
NnUint nActiveExpertsOr1 = std::max(h->nActiveExperts, 1u);
NnUint ffDim = h->hiddenDim;
if (h->archType == QWEN3_MOE)
ffDim = h->moeHiddenDim;
LlmNet n;
n.tokenEmbeddingSize = size2D(F_32, h->vocabSize, h->dim);
n.rmsNormSize = size1D(F_32, h->dim);
n.qkRmsNormSize = size1D(F_32, h->headDim);
n.moeGateSize = size2D(F_32, h->dim, h->nExperts);
NnKvCacheSlice kvCacheSlice = sliceKvCache(h->kvDim, h->seqLen, nNodes);
NnMultiHeadAttSlice multiHeadAttSlice = sliceMultiHeadAtt(h->nHeads, h->seqLen, nNodes, nBatches);
n.qSlice = sliceRowMatmul(h->weightType, nNodes, h->dim, h->qDim);
n.kSlice = sliceRowMatmul(h->weightType, nNodes, h->dim, h->kvDim);
n.vSlice = sliceRowMatmul(h->weightType, nNodes, h->dim, h->kvDim);
n.woSlice = sliceColMatmul(h->weightType, nNodes, h->qDim, h->dim);
n.w1Slice = sliceRowMatmul(h->weightType, nNodes, h->dim, ffDim);
n.w2Slice = sliceColMatmul(h->weightType, nNodes, ffDim, h->dim);
n.w3Slice = sliceRowMatmul(h->weightType, nNodes, h->dim, ffDim);
n.wclsSlice = sliceRowMatmul(h->weightType, nNodes, h->dim, h->vocabSize);
NnUint nQNormColumns = 1;
NnUint nKNormColumns = 1;
NnUint nInvBufferColumns = 1;
if (h->archType == QWEN3 || h->archType == QWEN3_MOE) {
ASSERT_EQ(n.qSlice.d0 % h->headDim, 0);
ASSERT_EQ(n.kSlice.d0 % h->headDim, 0);
nQNormColumns = n.qSlice.d0 / h->headDim;
nKNormColumns = n.kSlice.d0 / h->headDim;
nInvBufferColumns = std::max(nQNormColumns, nKNormColumns);
}
NnNetConfigBuilder netBuilder(nNodes, nBatches);
n.positionPipeIndex = netBuilder.addPipe("POS", size2D(F_32, nBatches, 1));
n.tokenPipeIndex = netBuilder.addPipe("TOK", size2D(F_32, nBatches, 1));
n.xPipeIndex = netBuilder.addPipe("X", size2D(F_32, nBatches, h->dim));
n.logitsPipeIndex = netBuilder.addPipe("LG", size2D(F_32, nBatches, h->vocabSize));
const NnUint zqPipeIndex = netBuilder.addPipe("ZQ", size2D(h->syncType, nBatches, h->dim * nNodes));
netBuilder.addPreSync(n.positionPipeIndex);
n.header = h;
n.netConfig = netBuilder.build();
n.nodeConfigs = new NnNodeConfig[nNodes];
for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) {
NnRopeSlice ropeSlice = sliceRope(h->ropeType, h->qDim, h->kvDim, h->nKvHeads, nNodes, h->seqLen, h->headDim, h->ropeTheta, nodeIndex);
NnNodeConfigBuilder nodeBuilder(nodeIndex);
const NnUint xBufferIndex = nodeBuilder.addBuffer("x", size2D(F_32, nBatches, h->dim));
const NnUint yBufferIndex = nodeBuilder.addBuffer("y", size2D(F_32, nBatches, h->dim));
const NnUint yqBufferIndex = h->syncType == F_32
? yBufferIndex
: nodeBuilder.addBuffer("q_y", size2D(h->syncType, nBatches, h->dim));
const NnUint zBufferIndex = nodeBuilder.addBuffer("z", size2D(F_32, nBatches, h->qDim));
const NnUint zqSliceBufferIndex = nodeBuilder.addBuffer("q_z_slice", size2D(h->syncType, nBatches, h->qDim / nNodes));
const NnUint qBufferIndex = nodeBuilder.addBuffer("q", size2D(F_32, nBatches, n.qSlice.d0));
const NnUint kTempBufferIndex = nodeBuilder.addBuffer("k_temp", size2D(F_32, nBatches, n.kSlice.d0));
const NnUint vTempBufferIndex = nodeBuilder.addBuffer("v_temp", size2D(F_32, nBatches, n.vSlice.d0));
const NnUint invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, nBatches, nInvBufferColumns));
const NnUint ropeCacheBufferIndex = nodeBuilder.addBuffer("rope_cache", ropeSlice.cacheSize);
const NnUint attBufferIndex = nodeBuilder.addBuffer("att", multiHeadAttSlice.attSize);
const NnUint logitsSliceBufferIndex = nodeBuilder.addBuffer("lg", size2D(F_32, nBatches, h->vocabSize / nNodes));
// not moe
const NnUint dBufferIndex = nodeBuilder.addBuffer("d", size2D(F_32, nBatches, n.w1Slice.d0));
const NnUint dqBufferIndex = h->syncType == F_32
? dBufferIndex
: nodeBuilder.addBuffer("q_d", size2D(h->syncType, nBatches, n.w1Slice.d0));
const NnUint lBufferIndex = nodeBuilder.addBuffer("l", size2D(F_32, nBatches, n.w3Slice.d0));
// moe
const NnUint moeGtBufferIndex = nodeBuilder.addBuffer("gt", size2D(F_32, nBatches, nExpertsOr1));
const NnUint moeExpertIndexesBufferIndex = nodeBuilder.addBuffer("act_exp_ix", size2D(F_32, nBatches, nActiveExpertsOr1));
const NnUint moeYBufferIndex = nodeBuilder.addBuffer("moe_y", size3D(F_32, nActiveExpertsOr1, nBatches, h->dim));
const NnUint moeYqBufferIndex = h->syncType == F_32
? moeYBufferIndex
: nodeBuilder.addBuffer("q_moe_y", size3D(h->syncType, nActiveExpertsOr1, nBatches, h->dim));
const NnUint moeDBufferIndex = nodeBuilder.addBuffer("moe_d", size3D(F_32, nActiveExpertsOr1, nBatches, n.w1Slice.d0));
const NnUint moeDQBufferIndex = h->syncType == F_32
? moeDBufferIndex
: nodeBuilder.addBuffer("q_moe_d", size3D(h->syncType, nActiveExpertsOr1, nBatches, n.w1Slice.d0));
const NnUint moeLBufferIndex = nodeBuilder.addBuffer("moe_l", size3D(F_32, nActiveExpertsOr1, nBatches, n.w3Slice.d0));
const NnUint moeSBufferIndex = nodeBuilder.addBuffer("moe_s", size3D(F_32, nActiveExpertsOr1, nBatches, 1));
NnSegmentConfigBuilder start;
if (nodeIndex == 0) {
start.addOp(
OP_EMBEDDING, "embedding", 0,
pointerBatchConfig(SRC_PIPE, n.tokenPipeIndex),
pointerBatchConfig(SRC_PIPE, n.xPipeIndex),
n.tokenEmbeddingSize,
NnEmbeddingOpConfig{});
}
start.addSync(n.xPipeIndex, SYNC_WITH_ROOT);
nodeBuilder.addSegment(start.build());
for (NnUint layerIndex = 0; layerIndex < h->nLayers; layerIndex++) {
const NnUint kBufferIndex = nodeBuilder.addBuffer("k", kvCacheSlice.keySize);
const NnUint vBufferIndex = nodeBuilder.addBuffer("v", kvCacheSlice.valueSize);
NnSegmentConfigBuilder att;
NnSegmentConfigBuilder ff;
// att
if (layerIndex == 0) {
att.addOp(
OP_CAST, "block_cast_x", layerIndex,
pointerBatchConfig(SRC_PIPE, n.xPipeIndex),
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
size0(),
NnCastOpCodeConfig{});
} else {
att.addOp(
OP_MERGE_ADD, "block_merge_add", layerIndex,
pointerBatchConfig(SRC_PIPE, zqPipeIndex),
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
size0(),
NnMergeAddOpCodeConfig{});
}
att.addOp(
OP_INV_RMS, "block_norm_pre_0", layerIndex,
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
size0(),
NnInvRmsOpConfig{h->normEpsilon, 1});
att.addOp(
OP_RMS_NORM, "block_norm_0", layerIndex,
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
n.rmsNormSize,
NnRmsNormOpConfig{invRmsBufferIndex, 1});
if (yBufferIndex != yqBufferIndex) {
att.addOp(
OP_CAST, "block_cast_y", layerIndex,
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
size0(),
NnCastOpCodeConfig{});
}
att.addOp(
OP_MATMUL, "block_matmul_q", layerIndex,
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
size2D(h->weightType, n.qSlice.n, n.qSlice.d0),
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
att.addOp(
OP_MATMUL, "block_matmul_k", layerIndex,
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
size2D(h->weightType, n.kSlice.n, n.kSlice.d0),
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
att.addOp(
OP_MATMUL, "block_matmul_v", layerIndex,
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
pointerBatchConfig(SRC_BUFFER, vTempBufferIndex),
size2D(h->weightType, n.vSlice.n, n.vSlice.d0),
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
if (h->archType == QWEN3 || h->archType == QWEN3_MOE) {
att.addOp(OP_INV_RMS, "block_norm_pre_q", layerIndex,
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
size0(),
NnInvRmsOpConfig{h->normEpsilon, nQNormColumns});
att.addOp(
OP_RMS_NORM, "block_norm_q", layerIndex,
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
size2D(F_32, 1, n.header->headDim),
NnRmsNormOpConfig{invRmsBufferIndex, nQNormColumns});
att.addOp(OP_INV_RMS, "block_norm_pre_k", layerIndex,
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
size0(),
NnInvRmsOpConfig{h->normEpsilon, nKNormColumns});
att.addOp(
OP_RMS_NORM, "block_norm_k", layerIndex,
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
size2D(F_32, 1, n.header->headDim),
NnRmsNormOpConfig{invRmsBufferIndex, nKNormColumns});
}
att.addOp(
OP_ROPE, "block_rope_q", layerIndex,
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
size0(),
NnRopeOpConfig{n.header->ropeType, 1, n.positionPipeIndex, ropeCacheBufferIndex,
h->ropeScalingFactor, h->ropeScalingLowFreqFactor, h->ropeScalingHighFreqFactory, h->ropeScalingOrigMaxSeqLen,
ropeSlice});
att.addOp(
OP_ROPE, "block_rope_k", layerIndex,
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
size0(),
NnRopeOpConfig{n.header->ropeType, 0, n.positionPipeIndex, ropeCacheBufferIndex,
h->ropeScalingFactor, h->ropeScalingLowFreqFactor, h->ropeScalingHighFreqFactory, h->ropeScalingOrigMaxSeqLen,
ropeSlice});
att.addOp(
OP_SHIFT, "block_shift_k", layerIndex,
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
pointerRawConfig(SRC_BUFFER, kBufferIndex),
size0(),
NnShiftOpCodeConfig{n.positionPipeIndex});
att.addOp(
OP_SHIFT, "block_shift_v", layerIndex,
pointerBatchConfig(SRC_BUFFER, vTempBufferIndex),
pointerRawConfig(SRC_BUFFER, vBufferIndex),
size0(),
NnShiftOpCodeConfig{n.positionPipeIndex});
att.addOp(
OP_MULTIHEAD_ATT, "block_multihead_att", layerIndex,
pointerBatchedSliceConfig(SRC_BUFFER, zBufferIndex),
pointerBatchedSliceConfig(SRC_BUFFER, zBufferIndex),
size0(),
NnMultiHeadAttOpConfig{
multiHeadAttSlice.nHeads, multiHeadAttSlice.nHeads0,
h->nKvHeads, h->headDim, h->seqLen, n.qSlice.d0, kvCacheSlice.kvDim0,
n.positionPipeIndex, qBufferIndex, kBufferIndex, vBufferIndex, attBufferIndex});
att.addOp(
OP_CAST, "block_cast_y2", layerIndex,
pointerBatchedSliceConfig(SRC_BUFFER, zBufferIndex),
pointerBatchConfig(SRC_BUFFER, zqSliceBufferIndex),
size0(),
NnCastOpCodeConfig{});
att.addOp(
OP_MATMUL, "block_matmul_wo", layerIndex,
pointerBatchConfig(SRC_BUFFER, zqSliceBufferIndex),
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
size2D(h->weightType, n.woSlice.n0, n.woSlice.d),
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
att.addOp(
OP_CAST, "block_cast_d", layerIndex,
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
pointerBatchedSliceConfig(SRC_PIPE, zqPipeIndex),
size0(),
NnCastOpCodeConfig{});
att.addSync(zqPipeIndex, SYNC_NODE_SLICES);
// ff
ff.addOp(
OP_MERGE_ADD, "block_merge_add2", layerIndex,
pointerBatchConfig(SRC_PIPE, zqPipeIndex),
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
size0(),
NnMergeAddOpCodeConfig{});
ff.addOp(
OP_INV_RMS, "block_norm_pre_1", layerIndex,
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
size0(),
NnInvRmsOpConfig{h->normEpsilon, 1});
ff.addOp(
OP_RMS_NORM, "block_norm_1", layerIndex,
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
n.rmsNormSize,
NnRmsNormOpConfig{invRmsBufferIndex, 1});
if (h->archType == QWEN3_MOE) {
ff.addOp(
OP_REPEAT_Z, "block_moe_y_repeat", layerIndex,
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeYqBufferIndex),
size0(),
NnRepeatZOpCodeConfig{});
ff.addOp(
OP_MATMUL, "block_moe_gate", layerIndex,
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeGtBufferIndex),
n.moeGateSize,
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
ff.addOp(
OP_SOFTMAX, "block_moe_softmax", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeGtBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeGtBufferIndex),
size0(),
NnSoftmaxOpCodeConfig{});
ff.addOp(
OP_MOE_GATE, "block_moe_gate2", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeGtBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeSBufferIndex),
size0(),
NnMoeGateOpCodeConfig{h->nActiveExperts, 1u, moeExpertIndexesBufferIndex});
ff.addOp(
OP_MATMUL, "block_matmul_w1", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeYqBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
size3D(h->weightType, h->nExperts, n.w1Slice.n, n.w1Slice.d0),
NnMatmulOpConfig{h->nExperts, h->nActiveExperts, moeExpertIndexesBufferIndex});
ff.addOp(
OP_MATMUL, "block_matmul_w3", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeYqBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeLBufferIndex),
size3D(h->weightType, h->nExperts, n.w3Slice.n, n.w3Slice.d0),
NnMatmulOpConfig{h->nExperts, h->nActiveExperts, moeExpertIndexesBufferIndex});
ff.addOp(
OP_SILU, "block_act", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
size0(),
NnSiluOpCodeConfig{});
ff.addOp(
OP_MUL, "block_mul", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
size0(),
NnMulOpCodeConfig{moeLBufferIndex});
if (moeDBufferIndex != moeDQBufferIndex) {
ff.addOp(
OP_CAST, "block_cast_d2", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeDQBufferIndex),
size0(),
NnCastOpCodeConfig{});
}
ff.addOp(
OP_MATMUL, "block_matmul_w2", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeDQBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeYBufferIndex),
size3D(h->weightType, h->nExperts, n.w2Slice.n0, n.w2Slice.d),
NnMatmulOpConfig{h->nExperts, h->nActiveExperts, moeExpertIndexesBufferIndex});
ff.addOp(
OP_SCALE, "block_moe_scale", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeYBufferIndex),
pointerBatchConfig(SRC_BUFFER, moeYBufferIndex),
size0(),
NnScaleOpCodeConfig{moeSBufferIndex});
ff.addOp(
OP_MERGE_SUM, "block_moe_merge_sum", layerIndex,
pointerBatchConfig(SRC_BUFFER, moeYBufferIndex),
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
size0(),
NnMergeSumOpCodeConfig{});
} else {
if (yBufferIndex != yqBufferIndex) {
ff.addOp(
OP_CAST, "block_cast_y3", layerIndex,
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
size0(),
NnCastOpCodeConfig{});
}
ff.addOp(
OP_MATMUL, "block_matmul_w1", layerIndex,
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
size2D(h->weightType, n.w1Slice.n, n.w1Slice.d0),
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
ff.addOp(
OP_MATMUL, "block_matmul_w3", layerIndex,
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
pointerBatchConfig(SRC_BUFFER, lBufferIndex),
size2D(h->weightType, n.w3Slice.n, n.w3Slice.d0),
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
ff.addOp(
OP_SILU, "block_act", layerIndex,
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
size0(),
NnSiluOpCodeConfig{});
ff.addOp(
OP_MUL, "block_mul", layerIndex,
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
size0(),
NnMulOpCodeConfig{lBufferIndex});
if (dBufferIndex != dqBufferIndex) {
ff.addOp(
OP_CAST, "block_cast_d2", layerIndex,
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
pointerBatchConfig(SRC_BUFFER, dqBufferIndex),
size0(),
NnCastOpCodeConfig{});
}
ff.addOp(
OP_MATMUL, "block_matmul_w2", layerIndex,
pointerBatchConfig(SRC_BUFFER, dqBufferIndex),
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
size2D(h->weightType, n.w2Slice.n0, n.w2Slice.d),
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
}
ff.addOp(
OP_CAST, "block_cast_d3", layerIndex,
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
pointerBatchedSliceConfig(SRC_PIPE, zqPipeIndex),
size0(),
NnCastOpCodeConfig{});
ff.addSync(zqPipeIndex, SYNC_NODE_SLICES);
nodeBuilder.addSegment(att.build());
nodeBuilder.addSegment(ff.build());
}
NnSegmentConfigBuilder end;
end.addOp(
OP_MERGE_ADD, "final_merge_add", 0,
pointerBatchConfig(SRC_PIPE, zqPipeIndex),
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
size0(),
NnMergeAddOpCodeConfig{});
end.addOp(
OP_INV_RMS, "final_norm_pre", 0,
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
size0(),
NnInvRmsOpConfig{h->normEpsilon, 1});
end.addOp(
OP_RMS_NORM, "final_norm", 0,
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
n.rmsNormSize,
NnRmsNormOpConfig{invRmsBufferIndex, 1});
if (yBufferIndex != yqBufferIndex) {
end.addOp(
OP_CAST, "final_cast_y", 0,
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
size0(),
NnCastOpCodeConfig{});
}
end.addOp(
OP_MATMUL, "final_matmul_logits", 0,
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
pointerBatchConfig(SRC_BUFFER, logitsSliceBufferIndex),
size2D(h->weightType, n.wclsSlice.n, n.wclsSlice.d0),
NnMatmulOpConfig{});
end.addOp(
OP_CAST, "final_cast_logits", 0,
pointerBatchConfig(SRC_BUFFER, logitsSliceBufferIndex),
pointerBatchedSliceConfig(SRC_PIPE, n.logitsPipeIndex),
size0(),
NnCastOpCodeConfig{});
end.addSync(n.logitsPipeIndex, SYNC_NODE_SLICES_EXCEPT_ROOT);
nodeBuilder.addSegment(end.build());
n.nodeConfigs[nodeIndex] = nodeBuilder.build();
}
return n;
}
void releaseLlmNet(LlmNet *net) {
for (NnUint nodeIndex = 0u; nodeIndex < net->netConfig.nNodes; nodeIndex++)
releaseNodeConfig(&net->nodeConfigs[nodeIndex]);
releaseNetConfig(&net->netConfig);
delete[] net->nodeConfigs;
}
void loadLlmNetWeight(const char *path, LlmNet *net, NnRootWeightLoader *loader) {
MmapFile file;
openMmapFile(&file, path, net->header->fileSize);
#if DEBUG_USE_MMAP_FOR_WEIGHTS
assert(net->netConfig.nNodes == 1u);
#else
std::unique_ptr<MmapFile, void(*)(MmapFile *)> fdPtr(&file, closeMmapFile);
printf("💿 Loading weights...\n");
#endif
Timer timer;
NnByte *data = (NnByte *)file.data;
NnByte *b = &data[net->header->headerSize];
b += loader->loadRoot("embedding", 0, net->tokenEmbeddingSize.nBytes, b);
for (NnUint layerIndex = 0u; layerIndex < net->header->nLayers; layerIndex++) {
b += loader->loadRowMatmulSlices("block_matmul_q", layerIndex, 0u, &net->qSlice, b);
b += loader->loadRowMatmulSlices("block_matmul_k", layerIndex, 0u, &net->kSlice, b);
b += loader->loadRowMatmulSlices("block_matmul_v", layerIndex, 0u, &net->vSlice, b);
b += loader->loadColMatmulSlices("block_matmul_wo", layerIndex, 0u, &net->woSlice, b);
if (net->header->nExperts > 0u) {
b += loader->loadAll("block_moe_gate", layerIndex, net->moeGateSize.nBytes, b);
for (NnUint expertIndex = 0u; expertIndex < net->header->nExperts; expertIndex++) {
b += loader->loadRowMatmulSlices("block_matmul_w1", layerIndex, expertIndex, &net->w1Slice, b);
b += loader->loadColMatmulSlices("block_matmul_w2", layerIndex, expertIndex, &net->w2Slice, b);
b += loader->loadRowMatmulSlices("block_matmul_w3", layerIndex, expertIndex, &net->w3Slice, b);
}
} else {
b += loader->loadRowMatmulSlices("block_matmul_w1", layerIndex, 0u, &net->w1Slice, b);
b += loader->loadColMatmulSlices("block_matmul_w2", layerIndex, 0u, &net->w2Slice, b);
b += loader->loadRowMatmulSlices("block_matmul_w3", layerIndex, 0u, &net->w3Slice, b);
}
if (net->header->archType == QWEN3 || net->header->archType == QWEN3_MOE) {
b += loader->loadAll("block_norm_q", layerIndex, net->qkRmsNormSize.nBytes, b);
b += loader->loadAll("block_norm_k", layerIndex, net->qkRmsNormSize.nBytes, b);
}
b += loader->loadAll("block_norm_0", layerIndex, net->rmsNormSize.nBytes, b);
b += loader->loadAll("block_norm_1", layerIndex, net->rmsNormSize.nBytes, b);
if (timer.elapsedMiliseconds() > 10000)
printf("💿 Loaded %u/%u\n", layerIndex + 1, net->header->nLayers);
}
b += loader->loadAll("final_norm", 0u, net->rmsNormSize.nBytes, b);
b += loader->loadRowMatmulSlices("final_matmul_logits", 0u, 0u, &net->wclsSlice, b);
long long missingBytes = (long long)(b - data) - net->header->fileSize;
if (missingBytes != 0u)
throw std::runtime_error("Missing bytes in weight file: " + std::to_string(missingBytes));
printf("💿 Weights loaded\n");
loader->finish();
}

104
src/llm.hpp Normal file
View File

@@ -0,0 +1,104 @@
#ifndef LLM_HPP
#define LLM_HPP
#include "nn/nn-core.hpp"
#include "nn/nn-executor.hpp"
#include "nn/nn-network.hpp"
enum LlmHeaderKey {
VERSION = 0,
ARCH_TYPE = 1,
DIM = 2,
HIDDEN_DIM = 3,
N_LAYERS = 4,
N_HEADS = 5,
N_KV_HEADS = 6,
N_EXPERTS = 7,
N_ACTIVE_EXPERTS = 8,
VOCAB_SIZE = 9,
SEQ_LEN = 10,
HIDDEN_ACT = 11,
ROPE_THETA = 12,
WEIGHT_FLOAT_TYPE = 13,
ROPE_SCALING_FACTOR = 14,
ROPE_SCALING_LOW_FREQ_FACTOR = 15,
ROPE_SCALING_HIGH_FREQ_FACTORY = 16,
ROPE_SCALING_ORIG_MAX_SEQ_LEN = 17,
ROPE_TYPE = 18,
HEAD_DIM = 19,
NORM_EPSILON = 20,
MOE_HIDDEN_DIM = 21,
};
enum LlmHiddenAct {
HIDDEN_ACT_GELU,
HIDDEN_ACT_SILU,
};
enum LlmArchType {
LLAMA = 0xABCD00,
QWEN3 = 0xABCD01,
QWEN3_MOE = 0xABCD02,
};
typedef struct {
NnSize headerSize;
NnSize fileSize;
int version;
LlmArchType archType;
NnUint dim;
NnUint nLayers;
NnUint nHeads;
NnUint headDim;
NnUint nKvHeads;
NnUint nExperts;
NnUint nActiveExperts;
NnUint origSeqLen; // Original model context length
NnUint seqLen; // Limited context length by the `--max-seq-len` argument
NnUint hiddenDim;
NnUint moeHiddenDim;
LlmHiddenAct hiddenAct;
NnUint qDim;
NnUint kvDim;
NnUint vocabSize;
float ropeTheta;
NnRopeType ropeType;
float ropeScalingFactor;
float ropeScalingLowFreqFactor;
float ropeScalingHighFreqFactory;
NnUint ropeScalingOrigMaxSeqLen;
float normEpsilon;
NnFloatType weightType;
NnFloatType syncType;
} LlmHeader;
typedef struct {
LlmHeader *header;
NnNetConfig netConfig;
NnNodeConfig *nodeConfigs;
NnRowMatmulSlice qSlice;
NnRowMatmulSlice kSlice;
NnRowMatmulSlice vSlice;
NnColMatmulSlice woSlice;
NnRowMatmulSlice w1Slice;
NnColMatmulSlice w2Slice;
NnRowMatmulSlice w3Slice;
NnRowMatmulSlice wclsSlice;
NnUint positionPipeIndex;
NnUint tokenPipeIndex;
NnUint xPipeIndex;
NnUint logitsPipeIndex;
NnSize3D tokenEmbeddingSize;
NnSize3D rmsNormSize;
NnSize3D qkRmsNormSize;
NnSize3D moeGateSize;
} LlmNet;
LlmHeader loadLlmHeader(const char* path, const unsigned int maxSeqLen, NnFloatType syncType);
void printLlmHeader(LlmHeader *header);
LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches);
void releaseLlmNet(LlmNet *net);
void loadLlmNetWeight(const char* path, LlmNet *net, NnRootWeightLoader *loader);
#endif

83
src/mmap.hpp Normal file
View File

@@ -0,0 +1,83 @@
#ifndef MMAP_HPP
#define MMAP_HPP
#include <cstdio>
#include <stdexcept>
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
#endif
struct MmapFile {
void* data;
size_t size;
#ifdef _WIN32
HANDLE hFile;
HANDLE hMapping;
#else
int fd;
#endif
};
long seekToEnd(FILE* file) {
#ifdef _WIN32
_fseeki64(file, 0, SEEK_END);
return _ftelli64(file);
#else
fseek(file, 0, SEEK_END);
return ftell(file);
#endif
}
void openMmapFile(MmapFile *file, const char *path, size_t size) {
file->size = size;
#ifdef _WIN32
file->hFile = CreateFileA(path, GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
if (file->hFile == INVALID_HANDLE_VALUE) {
printf("Cannot open file %s\n", path);
exit(EXIT_FAILURE);
}
file->hMapping = CreateFileMappingA(file->hFile, NULL, PAGE_READONLY, 0, 0, NULL);
if (file->hMapping == NULL) {
printf("CreateFileMappingA failed, error: %lu\n", GetLastError());
CloseHandle(file->hFile);
exit(EXIT_FAILURE);
}
file->data = (void *)MapViewOfFile(file->hMapping, FILE_MAP_READ, 0, 0, 0);
if (file->data == NULL) {
printf("MapViewOfFile failed!\n");
CloseHandle(file->hMapping);
CloseHandle(file->hFile);
exit(EXIT_FAILURE);
}
#else
file->fd = open(path, O_RDONLY);
if (file->fd == -1) {
throw std::runtime_error("Cannot open file");
}
file->data = mmap(NULL, size, PROT_READ, MAP_PRIVATE, file->fd, 0);
if (file->data == MAP_FAILED) {
close(file->fd);
throw std::runtime_error("Mmap failed");
}
#endif
}
void closeMmapFile(MmapFile *file) {
#ifdef _WIN32
UnmapViewOfFile(file->data);
CloseHandle(file->hMapping);
CloseHandle(file->hFile);
#else
munmap(file->data, file->size);
close(file->fd);
#endif
}
#endif

1010
src/nn/llamafile/sgemm.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,9 @@
#ifndef LLAMAFILE_SGEMM_H
#define LLAMAFILE_SGEMM_H
#include <cstdint>
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype);
#endif

View File

@@ -0,0 +1,137 @@
#ifndef NN_CONFIG_BUILDER_H
#define NN_CONFIG_BUILDER_H
#include "nn-core.hpp"
#include <cassert>
#include <cstring>
static char *cloneString(const char *str) {
NnUint len = std::strlen(str);
char *copy = new char[len + 1];
std::memcpy(copy, str, len + 1);
return copy;
}
class NnNetConfigBuilder {
public:
NnUint nNodes;
NnUint nBatches;
std::list<NnPipeConfig> pipes;
std::list<NnPreSyncConfig> preSyncs;
NnNetConfigBuilder(NnUint nNodes, NnUint nBatches) {
this->nNodes = nNodes;
this->nBatches = nBatches;
}
NnUint addPipe(const char *name, NnSize3D size) {
NnUint pipeIndex = pipes.size();
pipes.push_back({ cloneString(name), size });
return pipeIndex;
}
void addPreSync(NnUint pipeIndex) {
preSyncs.push_back({ pipeIndex });
}
NnNetConfig build() {
NnNetConfig config;
config.nNodes = nNodes;
config.nBatches = nBatches;
config.nPipes = pipes.size();
config.pipes = new NnPipeConfig[config.nPipes];
std::copy(pipes.begin(), pipes.end(), config.pipes);
config.nPreSyncs = preSyncs.size();
if (config.nPreSyncs > 0) {
config.preSyncs = new NnPreSyncConfig[config.nPreSyncs];
std::copy(preSyncs.begin(), preSyncs.end(), config.preSyncs);
} else {
config.preSyncs = nullptr;
}
return config;
}
};
class NnNodeConfigBuilder {
public:
NnUint nodeIndex;
std::list<NnBufferConfig> buffers;
std::list<NnSegmentConfig> segments;
NnNodeConfigBuilder(NnUint nodeIndex) {
this->nodeIndex = nodeIndex;
}
NnUint addBuffer(const char *name, NnSize3D size) {
NnUint bufferIndex = buffers.size();
buffers.push_back({ cloneString(name), size });
return bufferIndex;
}
void addSegment(NnSegmentConfig segment) {
segments.push_back(segment);
}
NnNodeConfig build() {
NnNodeConfig config;
config.nodeIndex = nodeIndex;
config.nBuffers = buffers.size();
if (config.nBuffers > 0) {
config.buffers = new NnBufferConfig[config.nBuffers];
std::copy(buffers.begin(), buffers.end(), config.buffers);
} else {
config.buffers = nullptr;
}
config.nSegments = segments.size();
assert(config.nSegments > 0);
config.segments = new NnSegmentConfig[config.nSegments];
std::copy(segments.begin(), segments.end(), config.segments);
return config;
}
};
class NnSegmentConfigBuilder {
private:
std::list<NnOpConfig> ops;
std::list<NnSyncConfig> syncs;
public:
template <typename T>
void addOp(NnOpCode code, const char *name, NnUint index, NnPointerConfig input, NnPointerConfig output, NnSize3D weightSize, T config) {
NnUint configSize = sizeof(T);
NnByte *configCopy = new NnByte[configSize];
std::memcpy(configCopy, &config, configSize);
ops.push_back({
code,
cloneString(name),
index,
input,
output,
weightSize,
configCopy,
configSize
});
};
void addSync(NnUint pipeIndex, NnSyncType syncType) {
syncs.push_back({ pipeIndex, syncType });
}
NnSegmentConfig build() {
NnSegmentConfig segment;
segment.nOps = ops.size();
if (segment.nOps > 0) {
segment.ops = new NnOpConfig[segment.nOps];
std::copy(ops.begin(), ops.end(), segment.ops);
}
segment.nSyncs = syncs.size();
if (segment.nSyncs > 0) {
segment.syncs = new NnSyncConfig[segment.nSyncs];
std::copy(syncs.begin(), syncs.end(), segment.syncs);
}
return segment;
}
};
#endif

381
src/nn/nn-core.cpp Normal file
View File

@@ -0,0 +1,381 @@
#ifdef _WIN32
#define _USE_MATH_DEFINES
#endif
#include "nn-core.hpp"
#include <cassert>
#include <cstring>
#include <cmath>
#include <stdexcept>
// utility functions
NnSize getBytes(NnFloatType floatType, NnSize n) {
if (floatType == F_32)
return n * sizeof(float);
if (floatType == F_16)
return n * (sizeof(float) / 2);
if (floatType == F_Q40) {
assert(n % Q40_BLOCK_SIZE == 0);
return (n / Q40_BLOCK_SIZE) * sizeof(NnBlockQ40);
}
if (floatType == F_Q80) {
assert(n % Q80_BLOCK_SIZE == 0);
return (n / Q80_BLOCK_SIZE) * sizeof(NnBlockQ80);
}
throw std::invalid_argument("Unsupported float type: " + std::to_string(floatType));
}
NnSize getBlockSize(NnFloatType floatType) {
if (floatType == F_32)
return 1;
if (floatType == F_16)
return 1;
if (floatType == F_Q40)
return Q40_BLOCK_SIZE;
if (floatType == F_Q80)
return Q80_BLOCK_SIZE;
throw std::invalid_argument("Unsupported float type");
}
NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output) {
// If weight=F_UNK, then returned enum should be <input>_<input>_<output>
if (input == F_32 && output == F_32) {
if (weight == F_UNK || weight == F_32)
return F32_F32_F32;
if (weight == F_Q40)
return F32_Q40_F32;
}
if (input == F_32 && output == F_Q80) {
if (weight == F_UNK || weight == F_32)
return F32_F32_Q80;
if (weight == F_Q40)
return F32_Q40_Q80;
}
if (input == F_Q80 && output == F_32) {
if (weight == F_UNK || weight == F_Q80)
return Q80_Q80_F32;
if (weight == F_32)
return Q80_F32_F32;
if (weight == F_Q40)
return Q80_Q40_F32;
}
if (input == F_Q80 && output == F_Q80) {
if (weight == F_UNK || weight == F_Q80)
return Q80_Q80_Q80;
}
throw std::invalid_argument("Unsupported op quant: " +
std::string(floatTypeToString(input)) + "/" +
std::string(floatTypeToString(weight)) + "/" +
std::string(floatTypeToString(output)));
}
const char *opCodeToString(NnOpCode code) {
if (code == OP_MERGE_ADD) return "MERGE_ADD";
if (code == OP_MERGE_SUM) return "MERGE_SUM";
if (code == OP_EMBEDDING) return "EMBEDDING";
if (code == OP_INV_RMS) return "INV_RMS";
if (code == OP_RMS_NORM) return "RMS_NORM";
if (code == OP_MATMUL) return "MATMUL";
if (code == OP_ROPE) return "ROPE";
if (code == OP_MULTIHEAD_ATT) return "MULTIHEAD_ATT";
if (code == OP_GELU) return "GELU";
if (code == OP_SILU) return "SILU";
if (code == OP_MUL) return "MUL";
if (code == OP_SCALE) return "SCALE";
if (code == OP_CAST) return "CAST";
if (code == OP_REPEAT_Z) return "REPEAT_Z";
if (code == OP_SHIFT) return "SHIFT";
if (code == OP_SOFTMAX) return "SOFTMAX";
if (code == OP_MOE_GATE) return "MOE_GATE";
throw std::invalid_argument("Unknown op code: " + std::to_string(code));
}
const char *opQuantTypeToString(NnOpQuantType type) {
if (type == F32_F32_F32) return "F32_F32_F32";
if (type == F32_Q40_F32) return "F32_Q40_F32";
if (type == F32_Q40_Q80) return "F32_Q40_Q80";
if (type == F32_F32_Q80) return "F32_F32_Q80";
if (type == Q80_Q80_Q80) return "Q80_Q80_Q80";
if (type == Q80_Q80_F32) return "Q80_Q80_F32";
if (type == Q80_Q40_F32) return "Q80_Q40_F32";
if (type == Q80_F32_F32) return "Q80_F32_F32";
throw std::invalid_argument("Unknown op quant type");
}
NnSize3D size0() {
return { F_UNK, 0, 0, 0, 0, 0 };
}
NnSize3D size1D(NnFloatType floatType, NnUint x) {
return size3D(floatType, 1, 1, x);
}
NnSize3D size2D(NnFloatType floatType, NnUint y, NnUint x) {
return size3D(floatType, 1, y, x);
}
NnSize3D size3D(NnFloatType floatType, NnUint z, NnUint y, NnUint x) {
NnSize len = z * y * x;
NnSize lenXY = y * x;
return { floatType, z, y, x, len, getBytes(floatType, len), getBytes(floatType, lenXY) };
}
NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index) {
return { source, index, PNTR_BATCH };
}
NnPointerConfig pointerBatchedSliceConfig(NnPointerSource source, NnUint index) {
return { source, index, PNTR_BATCHED_SLICE };
}
NnPointerConfig pointerRawConfig(NnPointerSource source, NnUint index) {
return { source, index, PNTR_RAW };
}
bool hasPointerContinuousMemory(NnPointerConfig *config) {
if (config->type == PNTR_RAW)
return true;
if (config->type == PNTR_BATCH)
return true;
return false;
}
void releaseNetConfig(NnNetConfig *netConfig) {
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) {
delete[] netConfig->pipes[pipeIndex].name;
}
delete[] netConfig->pipes;
}
void releaseNodeConfig(NnNodeConfig *nodeConfig) {
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex];
if (segment->nOps > 0) {
for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) {
NnOpConfig *op = &segment->ops[opIndex];
delete[] op->name;
delete[] op->config;
}
delete[] segment->ops;
}
if (segment->nSyncs > 0)
delete[] segment->syncs;
}
if (nodeConfig->nBuffers > 0) {
for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++)
delete[] nodeConfig->buffers[bufferIndex].name;
delete[] nodeConfig->buffers;
}
delete[] nodeConfig->segments;
}
void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
unsigned long total = 0;
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++)
total += netConfig->pipes[pipeIndex].size.nBytes;
for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++)
total += nodeConfig->buffers[bufferIndex].size.nBytes;
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex];
for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) {
total += segment->ops[opIndex].weightSize.nBytes;
total += segment->ops[opIndex].configSize;
}
}
printf("📀 RequiredMemory: %lu MB\n", total / (1024 * 1024));
}
Timer::Timer() {
reset();
}
void Timer::reset() {
startTime = std::chrono::high_resolution_clock::now();
}
NnUint Timer::elapsedMiliseconds() {
auto endTime = std::chrono::high_resolution_clock::now();
return (NnUint)std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
}
NnUint Timer::elapsedMicroseconds() {
auto endTime = std::chrono::high_resolution_clock::now();
return (NnUint)std::chrono::duration_cast<std::chrono::microseconds>(endTime - startTime).count();
}
// slicers
NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes) {
NnKvCacheSlice s;
assert(kvDim % nNodes == 0);
s.kvDim0 = kvDim / nNodes;
s.keySize = size2D(F_32, seqLen, s.kvDim0);
s.valueSize = size2D(F_32, seqLen, s.kvDim0);
return s;
}
NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) {
NnRowMatmulSlice s;
assert(d % nNodes == 0);
s.type = type;
s.nNodes = nNodes;
s.d0 = d / nNodes;
s.n = n;
s.size = size2D(type, s.n, d);
s.sliceSize = size2D(type, s.n, s.d0);
return s;
}
NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) {
NnColMatmulSlice s;
assert(n % nNodes == 0);
s.type = type;
s.nNodes = nNodes;
s.n = n;
s.n0 = n / nNodes;
s.d = d;
s.size = size2D(type, n, d);
s.sliceSize = size2D(type, s.n0, d);
return s;
}
NnRopeSlice sliceRope(NnRopeType type, NnUint qDim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headDim, float ropeTheta, NnUint nodeIndex) {
NnRopeSlice s;
assert(qDim >= kvDim);
assert(qDim % nNodes == 0);
assert(kvDim % nNodes == 0);
s.kvDim = kvDim;
s.nKvHeads = nKvHeads;
s.seqLen = seqLen;
s.headDim = headDim;
s.ropeTheta = ropeTheta;
s.qDim0 = qDim / nNodes;
s.kvDim0 = kvDim / nNodes;
assert(s.qDim0 % 2 == 0);
assert(s.kvDim0 % 2 == 0);
if (type == ROPE_LLAMA || type == ROPE_LLAMA3_1) {
s.kvDimStart = s.kvDim0 * nodeIndex;
s.qDimStart = s.qDim0 * nodeIndex;
s.qDimEnd = s.qDimStart + s.qDim0;
s.qShift = s.qDimStart - s.kvDimStart;
s.sliceDim = s.qDimEnd - s.kvDimStart;
assert(s.sliceDim % 2 == 0);
s.cacheSize = size2D(F_32, seqLen, s.sliceDim);
} else if (type == ROPE_FALCON) {
s.cacheSize = size2D(F_32, seqLen, headDim);
} else {
throw std::invalid_argument("Unsupported rope type");
}
return s;
}
NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes, NnUint nBatches) {
NnMultiHeadAttSlice s;
assert(nHeads % nNodes == 0);
s.nHeads = nHeads;
s.nHeads0 = nHeads / nNodes;
s.attSize = size2D(F_32, nBatches, s.nHeads0 * seqLen);
return s;
}
// splitters
NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) {
NnSize blockSize = getBlockSize(slice->type);
NnSize batchBytes = getBytes(slice->type, blockSize);
assert(slice->n % blockSize == 0);
NnSize n = slice->n / blockSize;
NnSize offset = slice->d0 * nodeIndex * n * batchBytes;
NnSize copiedBytes = 0;
for (NnUint d = 0; d < slice->d0; d++) {
for (NnUint j = 0; j < n; j++) {
NnSize o = (d * n + j) * batchBytes;
std::memcpy(weight0 + o, weight + offset + o, batchBytes);
copiedBytes += batchBytes;
}
}
return copiedBytes;
}
NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) {
NnSize blockSize = getBlockSize(slice->type);
NnSize batchBytes = getBytes(slice->type, blockSize);
assert(slice->n0 % blockSize == 0);
NnSize n = slice->n / blockSize;
NnSize rowBytes = n * batchBytes;
NnSize row0Bytes = (slice->n0 / blockSize) * batchBytes;
NnSize rowOffsetBytes = nodeIndex * row0Bytes;
NnSize copiedBytes = 0;
for (NnUint d = 0; d < slice->d; d++) {
std::memcpy(&weight0[row0Bytes * d], &weight[rowBytes * d + rowOffsetBytes], row0Bytes);
copiedBytes += row0Bytes;
}
return copiedBytes;
}
// helper
static inline float scaleFrequencyLlama3(const float freq, const NnRopeOpConfig *config) {
// https://github.com/meta-llama/llama-models/blob/4269717b2ea587627903bacbb75ccce1427ad914/models/llama3/reference_impl/model.py#L55
const float waveLen = 2.0f * M_PI / freq;
const float highFreqWavelen = config->ropeScalingOrigMaxSeqLen / config->ropeScalingHighFreqFactor;
if (waveLen < highFreqWavelen) {
return freq;
}
const float lowFreqWavelen = config->ropeScalingOrigMaxSeqLen / config->ropeScalingLowFreqFactor;
if (waveLen > lowFreqWavelen) {
return freq / config->ropeScalingFactor;
}
const float smooth = (config->ropeScalingOrigMaxSeqLen / waveLen - config->ropeScalingLowFreqFactor) /
(config->ropeScalingHighFreqFactor - config->ropeScalingLowFreqFactor);
return (1 - smooth) * freq / config->ropeScalingFactor + smooth * freq;
}
static inline void fullfillRopeLlamaCache(const NnRopeOpConfig *config, float *cache) {
assert((config->slice.qDimEnd - config->slice.kvDimStart) % 2 == 0);
const bool applyScaling = config->ropeScalingFactor != 1.0f;
for (NnUint pos = 0; pos < config->slice.seqLen; pos++) {
for (NnUint i = config->slice.kvDimStart; i < config->slice.qDimEnd; i += 2) {
const NnUint h = i % config->slice.headDim;
float freq = 1.0f / powf(config->slice.ropeTheta, h / (float)config->slice.headDim);
if (applyScaling)
freq = scaleFrequencyLlama3(freq, config);
const float val = pos * freq;
const float fcr = cosf(val);
const float fci = sinf(val);
cache[pos * config->slice.sliceDim + (i - config->slice.kvDimStart)] = fcr;
cache[pos * config->slice.sliceDim + (i - config->slice.kvDimStart) + 1] = fci;
}
}
}
static inline void fullfillRopeFalconCache(const NnRopeOpConfig *config, float *cache) {
const float hs = (float)config->slice.headDim;
for (NnUint pos = 0; pos < config->slice.seqLen; pos++) {
for (NnUint j = 0; j < config->slice.headDim / 2; j++) {
const float freq = 1.0f / powf(config->slice.ropeTheta, 2.0f * (float)(j / hs));
const float val = pos * freq;
const float fcr = cosf(val);
const float fci = sinf(val);
cache[pos * config->slice.headDim + j] = fcr;
cache[pos * config->slice.headDim + j + config->slice.headDim / 2] = fci;
}
}
}
void fullfillRopeCache(const NnRopeOpConfig *config, float *cache) {
if (config->type == ROPE_LLAMA || config->type == ROPE_LLAMA3_1)
fullfillRopeLlamaCache(config, cache);
else if (config->type == ROPE_FALCON)
fullfillRopeFalconCache(config, cache);
else
throw std::invalid_argument("Unsupported rope type");
}

333
src/nn/nn-core.hpp Normal file
View File

@@ -0,0 +1,333 @@
#ifndef NN_CORE_H
#define NN_CORE_H
#include <chrono>
#include <list>
#include <memory>
#include <cstdint>
#include "nn-quants.hpp"
// primitives
typedef struct {
NnFloatType floatType;
NnUint z;
NnUint y;
NnUint x;
NnSize length;
NnSize nBytes;
NnSize nBytesXY;
} NnSize3D;
// slices
typedef struct {
NnUint kvDim0;
NnSize3D keySize;
NnSize3D valueSize;
} NnKvCacheSlice;
typedef struct {
NnFloatType type;
NnUint nNodes;
NnUint d0;
NnUint n;
NnSize3D size;
NnSize3D sliceSize;
} NnRowMatmulSlice;
typedef struct {
NnFloatType type;
NnUint nNodes;
NnUint n;
NnUint n0;
NnUint d;
NnSize3D size;
NnSize3D sliceSize;
} NnColMatmulSlice;
typedef struct {
NnUint qDim0;
NnUint qDimStart;
NnUint qDimEnd;
NnUint qShift;
NnUint kvDim;
NnUint kvDim0;
NnUint kvDimStart;
NnUint sliceDim;
NnUint seqLen;
NnUint headDim;
NnUint nKvHeads;
float ropeTheta;
NnSize3D cacheSize;
} NnRopeSlice;
typedef struct {
NnUint nHeads;
NnUint nHeads0;
NnSize3D attSize;
} NnMultiHeadAttSlice;
// base enums
enum NnOpCode {
OP_MERGE_ADD,
OP_MERGE_SUM,
OP_EMBEDDING,
OP_INV_RMS,
OP_RMS_NORM,
OP_MATMUL,
OP_ROPE,
OP_MULTIHEAD_ATT,
OP_GELU,
OP_SILU,
OP_MUL,
OP_SCALE,
OP_CAST,
OP_REPEAT_Z,
OP_SHIFT,
OP_SOFTMAX,
OP_MOE_GATE,
};
enum NnOpQuantType {
// <input>_<weight>_<output>
F32_F32_F32,
F32_Q40_F32,
F32_Q40_Q80,
F32_F32_Q80,
Q80_Q80_Q80,
Q80_Q80_F32,
Q80_Q40_F32,
Q80_F32_F32,
};
#define N_OP_CODES (OP_SHIFT + 1)
#define N_OP_QUANTS (Q80_F32_F32 + 1)
enum NnPointerSource {
SRC_PIPE,
SRC_BUFFER,
};
enum NnPointerType {
PNTR_RAW,
PNTR_BATCH,
PNTR_BATCHED_SLICE
};
enum NnSyncType {
SYNC_WITH_ROOT, // whole pipe to all nodes
SYNC_NODE_SLICES, // my slice of pipe to all nodes
SYNC_NODE_SLICES_EXCEPT_ROOT, // only workers send slices to root, root does not send
};
enum NnRopeType {
ROPE_LLAMA = 0,
ROPE_FALCON = 1,
ROPE_LLAMA3_1 = 2,
};
// base configs
typedef struct {
char *name;
NnSize3D size;
} NnPipeConfig;
typedef struct {
char *name;
NnSize3D size;
} NnBufferConfig;
typedef struct {
NnPointerSource source;
NnUint pointerIndex;
NnPointerType type;
} NnPointerConfig;
typedef struct {
NnOpCode code;
char *name;
NnUint index;
NnPointerConfig input;
NnPointerConfig output;
NnSize3D weightSize;
NnByte *config;
NnUint configSize;
} NnOpConfig;
typedef struct {
NnUint pipeIndex;
} NnPreSyncConfig;
typedef struct {
NnUint pipeIndex;
NnSyncType syncType;
} NnSyncConfig;
typedef struct {
NnUint nOps;
NnOpConfig *ops;
NnUint nSyncs;
NnSyncConfig *syncs;
} NnSegmentConfig;
typedef struct {
NnUint nBatches;
NnUint nNodes;
NnUint nPipes;
NnPipeConfig *pipes;
NnUint nPreSyncs;
NnPreSyncConfig *preSyncs;
} NnNetConfig;
typedef struct {
NnUint nodeIndex;
NnUint nBuffers;
NnBufferConfig *buffers;
NnUint nSegments;
NnSegmentConfig *segments;
} NnNodeConfig;
// op configs
typedef struct {
// empty
} NnEmbeddingOpConfig;
typedef struct {
float epsilon;
NnUint nColumns;
} NnInvRmsOpConfig;
typedef struct {
NnUint invRmsBufferIndex;
NnUint nColumns;
} NnRmsNormOpConfig;
typedef struct {
NnUint nExperts;
NnUint nActiveExperts;
NnUint activeExpertIndexesBufferIndex;
} NnMatmulOpConfig;
typedef struct {
NnRopeType type;
NnUint isQ; // Cannot use `bool` here due to GPU memory alignment
NnUint positionPipeIndex;
NnUint ropeCacheBufferIndex;
float ropeScalingFactor;
float ropeScalingLowFreqFactor;
float ropeScalingHighFreqFactor;
NnUint ropeScalingOrigMaxSeqLen;
NnRopeSlice slice;
} NnRopeOpConfig;
typedef struct {
NnUint nHeads;
NnUint nHeads0;
NnUint nKvHeads;
NnUint headDim;
NnUint seqLen;
NnUint qSliceD0;
NnUint kvDim0;
NnUint positionPipeIndex;
NnUint queryBufferIndex;
NnUint keyCacheBufferIndex;
NnUint valueCacheBufferIndex;
NnUint attBufferIndex;
} NnMultiHeadAttOpConfig;
typedef struct {
// empty
} NnMergeAddOpCodeConfig;
typedef struct {
// empty
} NnMergeSumOpCodeConfig;
typedef struct {
// empty
} NnSiluOpCodeConfig;
typedef struct {
NnUint multiplierBufferIndex;
} NnMulOpCodeConfig;
typedef struct {
NnUint scaleBufferIndex;
} NnScaleOpCodeConfig;
typedef struct {
// empty
} NnCastOpCodeConfig;
typedef struct {
// empty
} NnRepeatZOpCodeConfig;
typedef struct {
NnUint indexPipeIndex;
} NnShiftOpCodeConfig;
typedef struct {
// empty
} NnSoftmaxOpCodeConfig;
typedef struct {
NnUint k;
NnUint normTopk;
NnUint indexesBufferIndex;
} NnMoeGateOpCodeConfig;
// utility functions
const char *opCodeToString(NnOpCode code);
const char *opQuantTypeToString(NnOpQuantType type);
NnSize getBytes(NnFloatType floatType, NnSize n);
NnSize getBlockSize(NnFloatType floatType);
NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output);
NnSize3D size0();
NnSize3D size1D(NnFloatType floatType, NnUint x);
NnSize3D size2D(NnFloatType floatType, NnUint y, NnUint x);
NnSize3D size3D(NnFloatType floatType, NnUint z, NnUint y, NnUint x);
NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index);
NnPointerConfig pointerBatchedSliceConfig(NnPointerSource source, NnUint index);
NnPointerConfig pointerRawConfig(NnPointerSource source, NnUint index);
bool hasPointerContinuousMemory(NnPointerConfig *config);
void releaseNetConfig(NnNetConfig *netConfig);
void releaseNodeConfig(NnNodeConfig *nodeConfig);
void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig);
class Timer {
private:
std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
public:
Timer();
void reset();
NnUint elapsedMiliseconds();
NnUint elapsedMicroseconds();
};
// slicers
NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes);
NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d);
NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d);
NnRopeSlice sliceRope(NnRopeType type, NnUint qDim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headDim, float ropeTheta, NnUint nodeIndex);
NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes, NnUint nBatches);
// splitters
NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0);
NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0);
// rope
void fullfillRopeCache(const NnRopeOpConfig *config, float *cache);
#endif

374
src/nn/nn-cpu-ops-test.cpp Normal file
View File

@@ -0,0 +1,374 @@
#include "nn-cpu-ops.cpp"
#include <vector>
// framework
void printPassed(const char *name) {
printf("✅ %24s passed\n", name);
fflush(stdout);
}
void rand(float *o, const NnUint n, const NnUint seed) {
srand(seed + 123456);
for (NnUint i = 0; i < n; i++) {
float v = (float)(rand() / RAND_MAX);
o[i] = v * 2.0f - 1.0f;
}
}
void compare_F32(const char *name, const float *a, const float *b, const NnUint n, const float epsilon) {
for (NnUint i = 0; i < n; i++) {
float error = fabs(a[i] - b[i]);
if (error > epsilon) {
printf("❌ %s failed\n", name);
for (NnUint j = i; j < i + 16 && j < n; j++)
printf(" [%3d] %f != %f\n", j, a[j], b[j]);
exit(1);
}
}
printPassed(name);
}
// tests
void testSplitThreads() {
// <0; 32> across 3 threads
{
SPLIT_THREADS(a0Start, a0End, 32, 3, 0); // thread 0
assert(a0Start == 0);
assert(a0End == 11);
}
{
SPLIT_THREADS(a1Start, a1End, 32, 3, 1); // thread 1
assert(a1Start == 11);
assert(a1End == 22);
}
{
SPLIT_THREADS(a2Start, a2End, 32, 3, 2); // thread 2
assert(a2Start == 22);
assert(a2End == 32);
}
// <0; 4> across 8 threads
{
SPLIT_THREADS(b0Start, b0End, 4, 8, 0); // thread 0
assert(b0Start == 0);
assert(b0End == 1);
}
{
SPLIT_THREADS(b0Start, b0End, 4, 8, 3); // thread 3
assert(b0Start == 3);
assert(b0End == 4);
}
{
SPLIT_THREADS(b0Start, b0End, 4, 8, 4); // thread 4
assert(b0Start == 4);
assert(b0End == 4);
}
{
SPLIT_THREADS(b0Start, b0End, 4, 8, 7); // thread 7
assert(b0Start == 4);
assert(b0End == 4);
}
printPassed("splitThreads");
}
void testConvertF32toF16() {
float x[] = {0.0f, 0.25f, 0.3456f, 1.0f};
for (NnUint i = 0; i < sizeof(x) / sizeof(float); i++) {
NnFp16 f16 = CONVERT_F32_TO_F16(x[i]);
float f32 = CONVERT_F16_TO_F32(f16);
compare_F32("convertF32toF16", &x[i], &f32, 1, 0.0005);
}
}
// quantization
void testQuantization(const NnUint m) {
std::vector<float> a(m * Q40_BLOCK_SIZE);
std::vector<float> aTemp(m * Q40_BLOCK_SIZE);
std::vector<NnBlockQ40> aQ40(m);
std::vector<NnBlockQ80> aQ80(m);
rand(a.data(), m * Q40_BLOCK_SIZE, m);
quantizeF32toQ40(a.data(), aQ40.data(), m * Q40_BLOCK_SIZE, 1, 0);
dequantizeQ40toF32(aQ40.data(), aTemp.data(), m * Q40_BLOCK_SIZE, 1, 0);
compare_F32("testQuantization_Q40", a.data(), aTemp.data(), m * Q40_BLOCK_SIZE, 0.13);
quantizeF32toQ80(a.data(), aQ80.data(), m * Q80_BLOCK_SIZE, 1, 0);
dequantizeQ80toF32(aQ80.data(), aTemp.data(), m * Q80_BLOCK_SIZE, 1, 0);
compare_F32("testQuantization_Q80", a.data(), aTemp.data(), m * Q80_BLOCK_SIZE, 0.01);
}
// invRms
void testInvRms() {
const float epsilon = 0.00001;
std::vector<float> x(8);
x[0] = 0.1f;
x[1] = 0.3f;
x[2] = 0.2f;
x[3] = 0.4f;
x[4] = 0.6f;
x[5] = 0.5f;
x[6] = 0.0f;
x[7] = 0.8f;
const float y0 = invRms_F32(x.data(), 8, epsilon);
float ev0 = 1.0f / 0.4402f;
compare_F32("rms_F32", &y0, &ev0, 1, 0.001f);
}
// rmsNorm
void testRmsNorm(const NnUint m) {
std::vector<float> x(m);
std::vector<NnBlockQ80> xQ80(m / Q80_BLOCK_SIZE);
std::vector<float> w(m);
std::vector<float> y(m);
std::vector<float> yTemp(m);
rand(x.data(), m, m);
rand(w.data(), m, m * m);
quantizeF32toQ80(x.data(), xQ80.data(), m, 1, 0);
const float rms = invRms_F32(x.data(), m, 1e-5f);
rmsNorm_F32(y.data(), x.data(), rms, w.data(), m, 1, 0);
rmsNorm_Q80_F32_F32(yTemp.data(), xQ80.data(), rms, w.data(), m, 1, 0);
compare_F32("rmsNorm_Q80_F32_F32", y.data(), yTemp.data(), m, 0.01);
}
// a *= b
void testMul(const NnUint m) {
const NnUint n = Q80_BLOCK_SIZE * m;
std::vector<float> a0(n);
std::vector<float> b0(n);
std::vector<float> aQ(n);
std::vector<NnBlockQ80> b1(n / Q80_BLOCK_SIZE);
rand(a0.data(), n, m);
rand(aQ.data(), n, m);
rand(b0.data(), n, m);
quantizeF32toQ80(b0.data(), b1.data(), n, 1, 0);
mul_F32(a0.data(), a0.data(), b0.data(), n, 1, 0);
mul_Q80_F32(aQ.data(), aQ.data(), b1.data(), n, 1, 0);
compare_F32("mul_Q80_F32", a0.data(), aQ.data(), n, 0.005);
}
// y += x
void testAdd(const NnUint m) {
const NnUint n = Q80_BLOCK_SIZE * m;
std::vector<float> y(n);
std::vector<float> yTemp(n);
std::vector<float> x(n);
std::vector<NnBlockQ80> xQ80(n / Q80_BLOCK_SIZE);
rand(y.data(), n, m);
rand(yTemp.data(), n, m);
rand(x.data(), n, m);
quantizeF32toQ80(x.data(), xQ80.data(), n, 1, 0);
add_F32(y.data(), x.data(), n, 1, 0);
add_Q80_F32(yTemp.data(), xQ80.data(), n, 1, 0);
compare_F32("add_Q80_F32", y.data(), yTemp.data(), n, 0.01);
}
void testMergeSum() {
float inp[] = {
/* [z0, y0] */ 0.1f, 0.2f,
/* [z0, y1] */ 0.3f, 0.4f,
/* [z1, y0] */ 0.5f, 0.6f,
/* [z1, y1] */ 0.7f, 0.8f,
};
float out[4];
float *i[4] = {
&inp[0],
&inp[2],
&inp[4],
&inp[6],
};
float *o[2] = {
&out[0],
&out[2]
};
mergeSum_F32(o, i, 2u, 2u, 2u, 2u, 1u, 0u);
const float expectedOutput[4] = {
0.6f,
0.8f,
1.0f,
1.2f,
};
compare_F32("mergeSum_F32", out, expectedOutput, 4u, 0.00000001f);
}
void testSoftmax() {
std::vector<float> y(8);
for (NnUint i = 0; i < 8; i++)
y[i] = i / 8.0f;
softmax_F32(y.data(), 8);
float expectedOutput[8] = {
0.077399f,
0.087780f,
0.099500f,
0.112761f,
0.127778f,
0.144793f,
0.164072f,
0.185917f
};
compare_F32("softmax_F32", y.data(), expectedOutput, 8, 0.001);
}
void testSilu() {
std::vector<float> y(8);
for (NnUint i = 0; i < 8; i++)
y[i] = i / 8.0f;
silu_F32(y.data(), 8, 1, 0);
float expectedOutput[8] = {
0.000000f,
0.066401f,
0.140544f,
0.222250f,
0.311233f,
0.407116f,
0.509461f,
0.617802f
};
compare_F32("silu_F32", y.data(), expectedOutput, 8, 0.001);
}
// matmul
void testMatmul_F32_Q40_F32(const NnUint m = 2) {
const NnUint n = Q80_BLOCK_SIZE * m;
const NnUint d = Q80_BLOCK_SIZE * m;
std::vector<float> x(n);
std::vector<float> w(n * d);
std::vector<float> o(d);
std::vector<float> oTemp(d);
std::vector<NnBlockQ80> xQ80(n / Q80_BLOCK_SIZE);
std::vector<NnBlockQ40> wQ40((n * d) / Q40_BLOCK_SIZE);
rand(x.data(), n, m);
rand(w.data(), n * d, m);
quantizeF32toQ40(w.data(), wQ40.data(), n * d, 1, 0);
quantizeF32toQ80(x.data(), xQ80.data(), n, 1, 0);
matmul_F32_F32_F32(o.data(), x.data(), w.data(), n, d, 1, 0);
matmul_Q80_Q40_F32(oTemp.data(), xQ80.data(), wQ40.data(), n, d, 1, 0);
compare_F32("matmul_Q80_Q40_F32", o.data(), oTemp.data(), d, 4.0f);
}
void testLlamafileSgemm() {
const NnUint batchSize = 8;
const NnUint n = 256;
const NnUint d = 128;
std::vector<float> x(n * batchSize);
std::vector<NnBlockQ80> xQ((n * batchSize) / Q80_BLOCK_SIZE);
std::vector<float> w(n * d);
std::vector<NnBlockQ40> wQ((n * d) / Q40_BLOCK_SIZE);
std::vector<float> o(d * batchSize);
std::vector<float> oTemp(d * batchSize);
rand(x.data(), n * batchSize, 12345);
rand(w.data(), n * d, 23456);
quantizeF32toQ80(x.data(), xQ.data(), n * batchSize, 1, 0);
quantizeF32toQ40(w.data(), wQ.data(), n * d, 1, 0);
// f32
for (NnUint i = 0; i < batchSize; i++) {
matmul_F32_F32_F32(o.data() + i * d, x.data() + i * n, w.data(), n, d, 1, 0);
}
assert(llamafile_sgemm(
d, batchSize, n,
w.data(), n,
x.data(), n,
oTemp.data(), d,
0, 1, 0,
F_32, F_32, F_32
));
compare_F32("llamafileSgemm_F32", o.data(), oTemp.data(), d * batchSize, 0.01f);
#if __ARM_FEATURE_DOTPROD
// q40ᵀ * q80
assert(llamafile_sgemm(
d, batchSize, n / Q80_BLOCK_SIZE,
wQ.data(), n / Q80_BLOCK_SIZE,
xQ.data(), n / Q80_BLOCK_SIZE,
oTemp.data(), d,
0, 1, 0,
F_Q40, F_Q80, F_32
));
compare_F32("llamafileSgemm_Q80_Q40", o.data(), oTemp.data(), d * batchSize, 1.5f);
#endif
}
void testScale() {
float i[] = {1.0f, 2.0f, 3.0f, 4.0f};
float o[4];
scale_F32(i, o, 0.5f, 4u, 1u, 0u);
float expectedOutput[] = {0.5f, 1.0f, 1.5f, 2.0f};
compare_F32("scale_F32", o, expectedOutput, 4u, 0.00001f);
}
void testTopk() {
float x[] = {1.0f, 4.0f, 2.0f, 3.0f};
std::vector<NnUint> topk(2);
topk_F32(x, topk.data(), 4u, 2u);
assert(topk[0] == 1u);
assert(topk[1] == 3u);
printPassed("testTopk");
}
int main() {
initQuants();
printCpuInstructionSet();
testSplitThreads();
testConvertF32toF16();
testQuantization(32);
testQuantization(2);
testQuantization(1);
testInvRms();
testRmsNorm(128);
testMul(32);
testMul(2);
testMul(1);
testAdd(32);
testAdd(2);
testAdd(1);
testMergeSum();
testSoftmax();
testSilu();
testMatmul_F32_Q40_F32(32);
testMatmul_F32_Q40_F32(2);
testMatmul_F32_Q40_F32(1);
testLlamafileSgemm();
testScale();
testTopk();
return 0;
}

1600
src/nn/nn-cpu-ops.cpp Normal file

File diff suppressed because it is too large Load Diff

43
src/nn/nn-cpu-ops.hpp Normal file
View File

@@ -0,0 +1,43 @@
#ifndef NN_CPU_OPS_H
#define NN_CPU_OPS_H
#include "nn-core.hpp"
#define ASSERT_EQ(a, b) \
if (a != b) { \
printf("Assertion failed: %d != %d (%s:%d)\n", a, b, __FILE__, __LINE__); \
exit(-1); \
}
typedef struct {
const char *name;
NnByte nBatches;
NnByte *bufferFlags;
NnByte **buffers;
NnBufferConfig *bufferConfigs;
NnByte **pipes;
NnPipeConfig *pipeConfigs;
void *opConfig;
NnByte **input;
NnSize3D inputSize;
bool hasInputContinuousMemory;
NnByte **output;
NnSize3D outputSize;
bool hasOutputContinuousMemory;
NnByte *weight;
NnSize3D weightSize;
} NnCpuOpContext;
typedef void (*NnCpuOpForwardInit)(NnCpuOpContext *context);
typedef void (*NnCpuOpForward)(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context);
void printCpuInstructionSet();
NnCpuOpForwardInit getCpuOpForwardInit(NnOpCode code, NnOpQuantType quantType);
NnCpuOpForward getCpuOpForward(NnOpCode code, NnOpQuantType quantType);
void softmax_F32(float *x, const NnUint size);
#endif

85
src/nn/nn-cpu-test.cpp Normal file
View File

@@ -0,0 +1,85 @@
#include "nn-core.hpp"
#include "nn-config-builder.hpp"
#include "nn-cpu.hpp"
#include <cstdio>
#define DIM 32
#define N_BATCHES 2
void buildConfig(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
NnUint nNodes = 1;
NnNetConfigBuilder netBuilder(nNodes, N_BATCHES);
NnUint xPipeIndex = netBuilder.addPipe("X", size2D(F_32, N_BATCHES, DIM));
NnNodeConfigBuilder nodeBuilder(0);
NnUint invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, N_BATCHES, 1));
NnSegmentConfigBuilder segmentBuilder;
segmentBuilder.addSync(xPipeIndex, SYNC_NODE_SLICES_EXCEPT_ROOT);
segmentBuilder.addOp(OP_INV_RMS, "inv_rms", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
size0(),
NnInvRmsOpConfig{1e-5f, 1});
segmentBuilder.addOp(OP_RMS_NORM, "rms_norm", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size1D(F_32, DIM),
NnRmsNormOpConfig{invRmsBufferIndex, 1});
nodeBuilder.addSegment(segmentBuilder.build());
*netConfig = netBuilder.build();
*nodeConfig = nodeBuilder.build();
}
void print2D(const char *name, NnUint x, NnUint y, float *w) {
for (NnUint i = 0; i < y; i++) {
printf("%s[%d] = ", name, i);
for (NnUint j = 0; j < x; j++)
printf("%f ", w[i * x + j]);
printf("\n");
}
}
int main() {
initQuants();
NnUint nThreads = 2;
NnNetConfig netConfig;
NnNodeConfig nodeConfig;
buildConfig(&netConfig, &nodeConfig);
NnNetExecution execution(nThreads, &netConfig);
float *x = (float *)execution.pipes[0];
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint i = 0; i < DIM; i++)
x[b * DIM + i] = i / (float)DIM + (float)b;
}
print2D("x", DIM, N_BATCHES, x);
float rmsNormWeight[DIM];
for (NnUint i = 0; i < DIM; i++)
rmsNormWeight[i] = 0.5 + i / (float)DIM;
NnCpuDevice *device = new NnCpuDevice(&netConfig, &nodeConfig, &execution);
std::vector<NnExecutorDevice> devices;
devices.push_back(NnExecutorDevice(device, -1, -1));
NnFakeNodeSynchronizer synchronizer;
float *rms = (float *)device->buffers[0];
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
executor.loadWeight("rms_norm", 0u, 0u, sizeof(rmsNormWeight), (NnByte *)rmsNormWeight);
execution.setBatchSize(2);
executor.forward();
print2D("rms", N_BATCHES, 1, rms);
print2D("x", DIM, N_BATCHES, x);
releaseNetConfig(&netConfig);
releaseNodeConfig(&nodeConfig);
return 0;
}

232
src/nn/nn-cpu.cpp Normal file
View File

@@ -0,0 +1,232 @@
#include "nn-cpu.hpp"
#include "nn-cpu-ops.hpp"
#include <cassert>
#include <cstring>
#include <stdexcept>
#include <thread>
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
#endif
#define DEBUG_CPU_OP_QUANTS false
#define BUFFER_ALIGNMENT 64
static NnByte *allocAlignedBuffer(NnSize size) {
NnByte *buffer;
#ifdef _WIN32
buffer = (NnByte *)_aligned_malloc(size, BUFFER_ALIGNMENT);
if (buffer == NULL)
throw std::runtime_error("_aligned_malloc failed");
#else
if (posix_memalign((void **)&buffer, BUFFER_ALIGNMENT, size) != 0)
throw std::runtime_error("posix_memalign failed");
mlock(buffer, size);
#endif
return buffer;
}
static void releaseAlignedBuffer(NnByte *buffer) {
#ifdef _WIN32
_aligned_free(buffer);
#else
free(buffer);
#endif
}
NnCpuDevice::NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
this->netConfig = netConfig;
this->nodeConfig = nodeConfig;
this->netExecution = netExecution;
printCpuInstructionSet();
nBuffers = nodeConfig->nBuffers;
buffers = new NnByte *[nBuffers];
for (NnUint bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++) {
NnBufferConfig *config = &nodeConfig->buffers[bufferIndex];
NnByte *buffer = allocAlignedBuffer(config->size.nBytes);
buffers[bufferIndex] = buffer;
}
bufferFlags = new NnByte[nBuffers];
std::memset(bufferFlags, 0, nBuffers * sizeof(NnByte));
}
NnCpuDevice::~NnCpuDevice() {
for (NnUint bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++)
releaseAlignedBuffer(buffers[bufferIndex]);
delete[] buffers;
delete[] bufferFlags;
}
NnUint NnCpuDevice::maxNThreads() {
return std::thread::hardware_concurrency();
}
NnDeviceSegment *NnCpuDevice::createSegment(NnUint segmentIndex) {
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
assert(segmentConfig->nOps > 0);
std::vector<NnOpQuantType> opQuants(segmentConfig->nOps);
std::vector<NnCpuOpForward> opForwardLocal(segmentConfig->nOps);
std::vector<NnSize3D> inputSizes(segmentConfig->nOps);
std::vector<NnSize3D> outputSizes(segmentConfig->nOps);
std::vector<std::vector<NnByte *>> inputsPtr(segmentConfig->nOps);
std::vector<std::vector<NnByte *>> outputsPtr(segmentConfig->nOps);
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
NnSize3D inputSize;
NnSize3D outputSize;
inputsPtr[opIndex] = resolvePointer(&inputSize, &opConfig->input);
outputsPtr[opIndex] = resolvePointer(&outputSize, &opConfig->output);
NnOpQuantType opQuant = getOpQuantType(
inputSize.floatType,
opConfig->weightSize.floatType,
outputSize.floatType);
#if DEBUG_CPU_OP_QUANTS
printf("%20s %2d: %s\n", opConfig->name, opConfig->index, opQuantTypeToString(opQuant));
#endif
NnCpuOpForward forward = getCpuOpForward(opConfig->code, opQuant);
if (forward == nullptr) {
throw std::invalid_argument(
std::string("Unsupported CPU op code: ") + opCodeToString(opConfig->code) +
", quant: " + opQuantTypeToString(opQuant) +
", op name: " + opConfig->name);
}
inputSizes[opIndex] = inputSize;
outputSizes[opIndex] = outputSize;
opQuants[opIndex] = opQuant;
opForwardLocal[opIndex] = forward;
}
NnCpuOpForward *opForward = new NnCpuOpForward[segmentConfig->nOps];
NnCpuOpContext *opContexts = new NnCpuOpContext[segmentConfig->nOps];
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
NnCpuOpContext *opContext = &opContexts[opIndex];
NnCpuOpForwardInit opInit = getCpuOpForwardInit(opConfig->code, opQuants[opIndex]);
opContext->name = opConfig->name;
opContext->opConfig = opConfig->config;
opContext->weightSize = opConfig->weightSize;
opContext->nBatches = netConfig->nBatches;
opContext->pipes = netExecution->pipes;
opContext->pipeConfigs = netConfig->pipes;
opContext->buffers = buffers;
opContext->bufferConfigs = nodeConfig->buffers;
opContext->bufferFlags = bufferFlags;
opContext->input = new NnByte *[inputsPtr[opIndex].size()];
opContext->inputSize = inputSizes[opIndex];
opContext->hasInputContinuousMemory = hasPointerContinuousMemory(&opConfig->input);
std::memcpy(opContext->input, inputsPtr[opIndex].data(), inputsPtr[opIndex].size() * sizeof(NnByte *));
opContext->output = new NnByte *[outputsPtr[opIndex].size()];
opContext->outputSize = outputSizes[opIndex];
opContext->hasOutputContinuousMemory = hasPointerContinuousMemory(&opConfig->output);
std::memcpy(opContext->output, outputsPtr[opIndex].data(), outputsPtr[opIndex].size() * sizeof(NnByte *));
#if not(DEBUG_USE_MMAP_FOR_WEIGHTS)
if (opContext->weightSize.nBytes > 0)
opContext->weight = allocAlignedBuffer(opContext->weightSize.nBytes);
else
opContext->weight = nullptr;
#endif
if (opInit != nullptr)
opInit(opContext);
opForward[opIndex] = opForwardLocal[opIndex];
}
return new NnCpuDeviceSegment(opForward, opContexts, segmentConfig->nOps);
}
NnCpuDeviceSegment::~NnCpuDeviceSegment() {
for (NnUint opIndex = 0; opIndex < nOps; opIndex++) {
NnCpuOpContext *context = &opContexts[opIndex];
delete[] context->input;
delete[] context->output;
#if not(DEBUG_USE_MMAP_FOR_WEIGHTS)
if (context->weightSize.nBytes > 0)
releaseAlignedBuffer(context->weight);
#endif
}
delete[] opForward;
delete[] opContexts;
}
std::vector<NnByte *> NnCpuDevice::resolvePointer(NnSize3D *pntrSize, NnPointerConfig *pointerConfig) {
NnByte *source;
NnSize3D *sourceSize;
switch (pointerConfig->source) {
case SRC_BUFFER:
source = buffers[pointerConfig->pointerIndex];
sourceSize = &nodeConfig->buffers[pointerConfig->pointerIndex].size;
break;
case SRC_PIPE:
source = netExecution->pipes[pointerConfig->pointerIndex];
sourceSize = &netConfig->pipes[pointerConfig->pointerIndex].size;
break;
default:
throw std::invalid_argument("Unsupported pointer type");
}
switch (pointerConfig->type) {
case PNTR_RAW: {
*pntrSize = size1D(sourceSize->floatType, sourceSize->length);
return std::vector<NnByte *>{source};
}
case PNTR_BATCH:
case PNTR_BATCHED_SLICE: {
ASSERT_EQ(sourceSize->y, netConfig->nBatches);
std::vector<NnByte *> pntr(sourceSize->z * sourceSize->y);
NnSize batchBytes = getBytes(sourceSize->floatType, sourceSize->x);
for (NnUint z = 0u; z < sourceSize->z; z++) {
for (NnUint y = 0u; y < sourceSize->y; y++)
pntr[z * sourceSize->y + y] = &source[(z * sourceSize->y + y) * batchBytes];
}
*pntrSize = *sourceSize;
if (pointerConfig->type == PNTR_BATCHED_SLICE) {
assert(sourceSize->x % netConfig->nNodes == 0);
NnUint xSlice = sourceSize->x / netConfig->nNodes;
NnSize xSliceBytes = getBytes(sourceSize->floatType, xSlice);
for (NnUint z = 0; z < sourceSize->z; z++) {
for (NnUint y = 0; y < sourceSize->y; y++)
pntr[z * sourceSize->y + y] = &pntr[z * sourceSize->y + y][xSliceBytes * nodeConfig->nodeIndex];
}
*pntrSize = size3D(sourceSize->floatType, sourceSize->z, sourceSize->y, xSlice);
}
return pntr;
}
default:
throw std::invalid_argument("Unsupported pointer config");
}
}
void NnCpuDeviceSegment::loadWeight(NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) {
assert(opIndex >= 0u);
assert(opIndex < nOps);
NnCpuOpContext *context = &opContexts[opIndex];
assert(offset + nBytes <= context->weightSize.nBytes);
#if DEBUG_USE_MMAP_FOR_WEIGHTS
assert(offset == 0u);
context->weight = weight;
#else
std::memcpy(&context->weight[offset], weight, nBytes);
#endif
}
void NnCpuDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) {
NnCpuOpContext *context = &opContexts[opIndex];
// printf("forward: %d %s (%d/%d)\n", opIndex, context->name, threadIndex + 1, nThreads); fflush(stdout);
opForward[opIndex](nThreads, threadIndex, batchSize, context);
}

39
src/nn/nn-cpu.hpp Normal file
View File

@@ -0,0 +1,39 @@
#ifndef NN_CPU_H
#define NN_CPU_H
#include <vector>
#include "nn-executor.hpp"
#include "nn-cpu-ops.hpp"
#define DEBUG_USE_MMAP_FOR_WEIGHTS false
class NnCpuDevice : public NnDevice {
public:
NnByte **buffers;
private:
NnNetConfig *netConfig;
NnNodeConfig *nodeConfig;
NnNetExecution *netExecution;
NnUint nBuffers;
NnByte *bufferFlags;
public:
NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution);
~NnCpuDevice() override;
NnUint maxNThreads() override;
NnDeviceSegment *createSegment(NnUint segmentIndex) override;
std::vector<NnByte *> resolvePointer(NnSize3D *pntrSize, NnPointerConfig *pointerConfig);
};
class NnCpuDeviceSegment : public NnDeviceSegment {
public:
NnUint nOps;
NnCpuOpForward *opForward;
NnCpuOpContext *opContexts;
NnCpuDeviceSegment(NnCpuOpForward *opForward, NnCpuOpContext *opContexts, NnUint nOps)
: opForward(opForward), opContexts(opContexts), nOps(nOps) {}
~NnCpuDeviceSegment() override;
void loadWeight(NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) override;
void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override;
};
#endif

192
src/nn/nn-executor.cpp Normal file
View File

@@ -0,0 +1,192 @@
#include <cassert>
#include <cstring>
#include <stdexcept>
#include "nn-executor.hpp"
void NnFakeNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) {
// Nothing
}
NnNetExecution::NnNetExecution(NnUint nThreads, NnNetConfig *netConfig) {
this->nThreads = nThreads;
this->nBatches = netConfig->nBatches;
this->nPipes = netConfig->nPipes;
this->batchSize = 0; // This value must be overwritten before calling forward
pipes = new NnByte *[netConfig->nPipes];
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) {
NnPipeConfig *pipeConfig = &netConfig->pipes[pipeIndex];
NnByte *pipe = new NnByte[pipeConfig->size.nBytes];
std::memset(pipe, 0, pipeConfig->size.nBytes);
pipes[pipeIndex] = pipe;
}
}
NnNetExecution::~NnNetExecution() {
for (NnUint pipeIndex = 0; pipeIndex < nPipes; pipeIndex++)
delete[] pipes[pipeIndex];
delete[] pipes;
}
void NnNetExecution::setBatchSize(NnUint batchSize) {
assert(batchSize <= nBatches);
this->batchSize = batchSize;
}
NnExecutorDevice::NnExecutorDevice(NnDevice *device, int segmentFrom, int segmentTo) {
this->device = std::unique_ptr<NnDevice>(device);
this->segmentFrom = segmentFrom;
this->segmentTo = segmentTo;
}
NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector<NnExecutorDevice> *devices, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark)
: segments(nodeConfig->nSegments), steps()
{
NnUint maxNThreads = 0;
for (NnExecutorDevice &d : *devices) {
if (d.device->maxNThreads() > maxNThreads)
maxNThreads = d.device->maxNThreads();
}
if (netExecution->nThreads > maxNThreads)
throw std::invalid_argument("This configuration supports max " + std::to_string(maxNThreads) + " threads");
this->netExecution = netExecution;
this->nodeConfig = nodeConfig;
bool useSynchronizer = netConfig->nNodes > 1;
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
NnDevice *device = nullptr;
for (NnExecutorDevice &d : *devices) {
if (
(d.segmentFrom == -1 && d.segmentTo == -1) ||
(segmentIndex >= d.segmentFrom && segmentIndex <= d.segmentTo)
) {
device = d.device.get();
break;
}
}
if (device == nullptr)
throw std::invalid_argument("Cannot locate device for segment " + std::to_string(segmentIndex));
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
if (segmentConfig->nOps > 0) {
NnDeviceSegment *segment = device->createSegment(segmentIndex);
segments[segmentIndex] = std::unique_ptr<NnDeviceSegment>(segment);
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++)
steps.push_back(NnExecutorStep{ STEP_EXECUTE_OP, segment, opIndex, &segmentConfig->ops[opIndex] });
}
if (useSynchronizer && segmentConfig->nSyncs > 0)
steps.push_back(NnExecutorStep{ STEP_SYNC_NODES, nullptr, segmentIndex, nullptr });
}
steps.shrink_to_fit();
context.nThreads = netExecution->nThreads;
context.synchronizer = synchronizer;
context.nSteps = (NnUint)steps.size();
context.steps = steps.data();
if (benchmark)
context.timer = new Timer();
else
context.timer = nullptr;
threads = new NnExecutorThread[netExecution->nThreads];
for (NnUint threadIndex = 0; threadIndex < netExecution->nThreads; threadIndex++) {
NnExecutorThread *thread = &threads[threadIndex];
thread->threadIndex = threadIndex;
thread->context = &context;
}
}
NnExecutor::~NnExecutor() {
if (context.timer != nullptr)
delete context.timer;
delete[] threads;
}
void NnExecutor::loadWeight(const char *name, NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) {
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
for (NnUint i = 0; i < segmentConfig->nOps; i++) {
NnOpConfig *opConfig = &segmentConfig->ops[i];
if (opConfig->index == opIndex && std::strcmp(opConfig->name, name) == 0) {
NnDeviceSegment *segment = segments[segmentIndex].get();
assert(segment != nullptr);
segment->loadWeight(i, offset, nBytes, weight);
return;
}
}
}
throw std::invalid_argument("Cannot locate op by name: " + std::string(name));
}
inline void executeStep(NnExecutorStep *step, NnUint nThreads, NnExecutorThread *thread, NnExecutorContext *context) {
if (step->type == STEP_EXECUTE_OP) {
step->segment->forward(step->arg0, nThreads, thread->threadIndex, context->batchSize);
} else if (step->type == STEP_SYNC_NODES) {
context->synchronizer->sync(step->arg0, nThreads, thread->threadIndex);
} else {
throw std::invalid_argument("Unsupported step type");
}
}
static inline void *executorThreadHandler(void *arg) {
NnExecutorThread *thread = (NnExecutorThread *)arg;
NnExecutorContext *context = thread->context;
NnUint nThreads = context->nThreads;
NnUint doneCount = nThreads - 1;
while (true) {
const unsigned int currentStepIndex = context->currentStepIndex.load();
if (currentStepIndex == context->nSteps)
break;
NnExecutorStep *step = &context->steps[currentStepIndex];
executeStep(step, nThreads, thread, context);
NnUint currentCount = context->doneThreadCount.fetch_add(1);
if (currentCount == doneCount) {
if (context->timer != nullptr) {
NnUint time = context->timer->elapsedMicroseconds();
context->totalTime[step->type] += time;
context->timer->reset();
}
context->doneThreadCount.store(0);
context->currentStepIndex.fetch_add(1);
} else {
while (context->currentStepIndex.load() == currentStepIndex);
}
}
return nullptr;
}
void NnExecutor::forward() {
assert(netExecution->batchSize > 0);
NnUint nThreads = netExecution->nThreads;
context.currentStepIndex.exchange(0);
context.doneThreadCount.exchange(0);
context.batchSize = netExecution->batchSize;
if (context.timer != nullptr) {
std::memset(context.totalTime, 0, sizeof(context.totalTime));
context.timer->reset();
}
NnUint threadIndex;
for (threadIndex = 1; threadIndex < nThreads; threadIndex++) {
int result = pthread_create(&threads[threadIndex].handler, NULL, (PthreadFunc)executorThreadHandler, (void *)&threads[threadIndex]);
if (result != 0)
throw std::runtime_error("Failed to create thread");
}
executorThreadHandler((void *)&threads[0]);
for (threadIndex = 1; threadIndex < nThreads; threadIndex++)
pthread_join(threads[threadIndex].handler, NULL);
}
NnUint NnExecutor::getTotalTime(NnExecutorStepType type) {
assert((NnUint)type < N_STEP_TYPES);
return context.totalTime[type];
}

103
src/nn/nn-executor.hpp Normal file
View File

@@ -0,0 +1,103 @@
#ifndef NN_EXECUTOR_H
#define NN_EXECUTOR_H
#include "nn-core.hpp"
#include <atomic>
#include <vector>
#include "pthread.h"
class NnDeviceSegment {
public:
virtual ~NnDeviceSegment() {};
virtual void loadWeight(NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) = 0;
virtual void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) = 0;
};
class NnDevice {
public:
virtual NnUint maxNThreads() = 0;
virtual ~NnDevice() {}
virtual NnDeviceSegment *createSegment(NnUint segmentIndex) = 0;
};
class NnNodeSynchronizer {
public:
virtual ~NnNodeSynchronizer() {};
virtual void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) = 0;
};
class NnFakeNodeSynchronizer : public NnNodeSynchronizer {
public:
~NnFakeNodeSynchronizer() override {};
void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) override;
};
class NnNetExecution {
public:
NnUint nThreads;
NnUint nPipes;
NnByte **pipes;
NnUint batchSize;
NnUint nBatches;
NnNetExecution(NnUint nThreads, NnNetConfig *netConfig);
~NnNetExecution();
void setBatchSize(NnUint batchSize);
};
enum NnExecutorStepType {
STEP_EXECUTE_OP,
STEP_SYNC_NODES,
};
#define N_STEP_TYPES STEP_SYNC_NODES + 1
class NnExecutorDevice {
public:
std::unique_ptr<NnDevice> device;
int segmentFrom;
int segmentTo;
NnExecutorDevice(NnDevice *device, int segmentFrom, int segmentTo);
};
typedef struct {
NnExecutorStepType type;
NnDeviceSegment *segment;
NnUint arg0;
NnOpConfig *opConfig;
} NnExecutorStep;
typedef struct {
NnUint nThreads;
NnUint nSteps;
NnExecutorStep *steps;
NnNodeSynchronizer *synchronizer;
std::atomic_uint currentStepIndex;
std::atomic_uint doneThreadCount;
NnUint batchSize;
Timer *timer;
NnUint totalTime[N_STEP_TYPES];
} NnExecutorContext;
typedef struct {
NnUint threadIndex;
NnExecutorContext *context;
PthreadHandler handler;
} NnExecutorThread;
class NnExecutor {
private:
NnNetExecution *netExecution;
NnNodeConfig *nodeConfig;
std::vector<std::unique_ptr<NnDeviceSegment>> segments;
std::vector<NnExecutorStep> steps;
NnExecutorThread *threads;
NnExecutorContext context;
public:
NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector<NnExecutorDevice> *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark);
~NnExecutor();
void loadWeight(const char *name, NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight);
void forward();
NnUint getTotalTime(NnExecutorStepType type);
};
#endif

907
src/nn/nn-network.cpp Normal file
View File

@@ -0,0 +1,907 @@
#ifdef _WIN32
#include <winsock2.h>
#include <ws2tcpip.h> // For inet_addr and other functions
#include <windows.h> // For SSIZE_T
typedef SSIZE_T ssize_t;
#else
#include <sys/socket.h>
#include <netinet/tcp.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <netdb.h> // for getaddrinfo
#endif
#include "nn-network.hpp"
#include <cassert>
#include <cstring>
#include <stdexcept>
#include <vector>
#include <fcntl.h>
#define SOCKET_LAST_ERRCODE errno
#define SOCKET_LAST_ERROR strerror(errno)
#define ACK 23571114
#define MAX_CHUNK_SIZE 4096
static inline bool isEagainError() {
#ifdef _WIN32
return WSAGetLastError() == WSAEWOULDBLOCK;
#else
return SOCKET_LAST_ERRCODE == EAGAIN;
#endif
}
static inline void setNonBlocking(int socket, bool enabled) {
#ifdef _WIN32
u_long mode = enabled ? 1 : 0;
if (ioctlsocket(socket, FIONBIO, &mode) != 0) {
throw std::runtime_error("Error setting socket to non-blocking");
}
#else
int flags = fcntl(socket, F_GETFL, 0);
if (enabled) {
flags |= O_NONBLOCK;
} else {
flags = flags & (~O_NONBLOCK);
}
if (fcntl(socket, F_SETFL, flags) < 0)
throw std::runtime_error("Error setting socket to non-blocking");
#endif
}
static inline void setNoDelay(int socket) {
int flag = 1;
if (setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(int)) < 0)
throw std::runtime_error("Error setting socket to no-delay");
}
static inline void setQuickAck(int socket) {
#ifndef _WIN32
#ifdef TCP_QUICKACK
int value = 1;
if (setsockopt(socket, IPPROTO_TCP, TCP_QUICKACK, (char*)&value, sizeof(int)) < 0)
throw std::runtime_error("Error setting quick ack");
#endif
#endif
}
void setReuseAddr(int socket) {
int opt = 1;
#ifdef _WIN32
int iresult = setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char*)&opt, sizeof(opt));
if (iresult == SOCKET_ERROR) {
closesocket(socket);
throw std::runtime_error("setsockopt failed: " + std::to_string(WSAGetLastError()));
}
#else
if (setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
close(socket);
throw std::runtime_error("setsockopt failed: " + std::string(strerror(errno)));
}
#endif
}
void writeSocket(int socket, const void *data, NnSize size) {
while (size > 0) {
ssize_t s = send(socket, (const char*)data, size, 0);
if (s < 0) {
if (isEagainError()) {
continue;
}
throw NnWriteNetworkException(0, "Error writing to socket");
} else if (s == 0) {
throw NnWriteNetworkException(0, "Socket closed");
}
size -= s;
data = (const char*)data + s;
}
}
static inline bool tryReadSocket(int socket, void *data, NnSize size, unsigned long maxAttempts) {
// maxAttempts = 0 means infinite attempts
NnSize s = size;
while (s > 0) {
ssize_t r = recv(socket, (char*)data, s, 0);
if (r < 0) {
if (isEagainError()) {
if (s == size && maxAttempts > 0) {
maxAttempts--;
if (maxAttempts == 0) {
return false;
}
}
continue;
}
throw NnReadNetworkException(0, "Error reading from socket");
} else if (r == 0) {
throw NnReadNetworkException(0, "Socket closed");
}
data = (char*)data + r;
s -= r;
}
return true;
}
void readSocket(int socket, void *data, NnSize size) {
if (!tryReadSocket(socket, data, size, 0)) {
throw std::runtime_error("Error reading from socket");
}
}
static void readAckPacket(int socket) {
NnUint packet;
readSocket(socket, &packet, sizeof(packet));
if (packet != ACK)
throw std::runtime_error("Invalid ack packet");
}
static void writeAckPacket(int socket) {
NnUint packet = ACK;
writeSocket(socket, &packet, sizeof(packet));
}
static inline int connectSocket(char *host, int port) {
struct addrinfo hints;
struct addrinfo *addr = NULL;
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
char portStr[11];
snprintf(portStr, sizeof(portStr), "%d", port);
int addrinfoError = getaddrinfo(host, portStr, &hints, &addr);
if (addrinfoError != 0 || addr == NULL) {
printf("Cannot resolve target %s (%s)\n", host, gai_strerror(addrinfoError));
throw std::runtime_error("Cannot resolve address");
}
int sock = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
if (sock < 0)
throw std::runtime_error("Cannot create socket");
int connectResult = ::connect(sock, addr->ai_addr, addr->ai_addrlen);
if (connectResult != 0) {
printf("Cannot connect to %s:%d (%s)\n", host, port, SOCKET_LAST_ERROR);
throw std::runtime_error("Cannot connect");
}
setNoDelay(sock);
setQuickAck(sock);
return sock;
}
int createServerSocket(int port) {
const char *host = "0.0.0.0";
struct sockaddr_in serverAddr;
int serverSocket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (serverSocket < 0)
throw std::runtime_error("Cannot create socket");
setReuseAddr(serverSocket);
memset(&serverAddr, 0, sizeof(serverAddr));
serverAddr.sin_family = AF_INET;
serverAddr.sin_port = htons(port);
serverAddr.sin_addr.s_addr = inet_addr(host);
int bindResult;
#ifdef _WIN32
bindResult = bind(serverSocket, (SOCKADDR*)&serverAddr, sizeof(serverAddr));
if (bindResult == SOCKET_ERROR) {
int error = WSAGetLastError();
closesocket(serverSocket);
throw std::runtime_error("Cannot bind port: " + std::to_string(error));
}
#else
bindResult = bind(serverSocket, (struct sockaddr*)&serverAddr, sizeof(serverAddr));
if (bindResult < 0) {
close(serverSocket);
throw std::runtime_error("Cannot bind port: " + std::string(strerror(errno)));
}
#endif
int listenResult = listen(serverSocket, SOMAXCONN);
if (listenResult != 0) {
#ifdef _WIN32
closesocket(serverSocket);
throw std::runtime_error("Cannot listen on port: " + std::to_string(WSAGetLastError()));
#else
close(serverSocket);
throw std::runtime_error("Cannot listen on port: " + std::string(strerror(errno)));
#endif
}
printf("Listening on %s:%d...\n", host, port);
setNoDelay(serverSocket);
setQuickAck(serverSocket);
return serverSocket;
}
void destroySocket(int serverSocket) {
shutdown(serverSocket, 2);
#ifdef _WIN32
closesocket(serverSocket);
#else
close(serverSocket);
#endif
}
int acceptSocket(int serverSocket) {
struct sockaddr_in clientAddr;
socklen_t clientAddrSize = sizeof(clientAddr);
int clientSocket = ::accept(serverSocket, (struct sockaddr*)&clientAddr, &clientAddrSize);
if (clientSocket < 0)
throw std::runtime_error("Error accepting connection");
setNoDelay(clientSocket);
setQuickAck(clientSocket);
return clientSocket;
}
void initSockets() {
#ifdef _WIN32
WSADATA wsaData;
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
throw std::runtime_error("WSAStartup failed: " + std::to_string(WSAGetLastError()));
}
#endif
}
void cleanupSockets() {
#ifdef _WIN32
WSACleanup();
#endif
}
NnReadNetworkException::NnReadNetworkException(int code, const char *message) {
this->code = code;
this->message = message;
}
NnWriteNetworkException::NnWriteNetworkException(int code, const char *message) {
this->code = code;
this->message = message;
}
std::unique_ptr<NnNetwork> NnNetwork::serve(int port) {
int serverSocket = createServerSocket(port);
NnUint nSockets;
NnUint nodeIndex;
int rootSocket = acceptSocket(serverSocket);
printf("⭕ The root node has connected\n");
readSocket(rootSocket, &nSockets, sizeof(nSockets));
NnUint nNodes = nSockets - 1; // nSockets - 1 root node
printf("⭕ nNodes: %d\n", nNodes);
readSocket(rootSocket, &nodeIndex, sizeof(nodeIndex));
printf("⭕ NodeIndex: %d\n", nodeIndex);
int *sockets = new int[nSockets];
sockets[0] = rootSocket;
char* *hosts = new char*[nNodes];
int *ports = new int[nNodes];
printf("⭕ Socket[0]: accepted root node\n");
NnUint hostLen;
for (NnUint i = 0; i < nNodes; i++) {
readSocket(rootSocket, &hostLen, sizeof(hostLen));
hosts[i] = new char[hostLen];
readSocket(rootSocket, hosts[i], hostLen);
readSocket(rootSocket, &ports[i], sizeof(ports[i]));
}
writeAckPacket(rootSocket);
// We need to wait here until the root node will send a "root is ready" packet
readAckPacket(rootSocket);
for (NnUint i = 0; i < nNodes; i++) {
NnUint socketIndex = i + 1;
if (i >= nodeIndex) {
printf("⭕ Socket[%d]: connecting to %s:%d worker\n", socketIndex, hosts[i], ports[i]);
sockets[socketIndex] = connectSocket(hosts[i], ports[i]);
printf("⭕ Socket[%d]: connected\n", socketIndex);
} else {
printf("⭕ Socket[%d]: wait for %s:%d worker\n", socketIndex, hosts[i], ports[i]);
sockets[socketIndex] = acceptSocket(serverSocket);
printf("⭕ Socket[%d]: accepted\n", socketIndex);
}
}
for (NnUint i = 0; i < nNodes; i++)
delete[] hosts[i];
delete[] hosts;
delete[] ports;
destroySocket(serverSocket);
printf("⭕ Network is initialized\n");
return std::unique_ptr<NnNetwork>(new NnNetwork(nSockets, sockets));
}
std::unique_ptr<NnNetwork> NnNetwork::connect(NnUint nSockets, char **hosts, NnUint *ports) {
assert(nSockets > 0);
int *sockets = new int[nSockets];
struct sockaddr_in addr;
for (NnUint i = 0; i < nSockets; i++) {
printf("⭕ Socket[%d]: connecting to %s:%d worker\n", i, hosts[i], ports[i]);
int socket = connectSocket(hosts[i], ports[i]);
sockets[i] = socket;
writeSocket(socket, &nSockets, sizeof(nSockets));
writeSocket(socket, &i, sizeof(i)); // send node index
for (NnUint j = 0; j < nSockets; j++) {
if (j == i)
continue;
NnUint hostLen = strlen(hosts[j]) + 1;
writeSocket(socket, &hostLen, sizeof(hostLen));
writeSocket(socket, hosts[j], hostLen);
writeSocket(socket, &ports[j], sizeof(ports[j]));
}
readAckPacket(socket);
printf("⭕ Socket[%d]: connected\n", i);
}
for (NnUint i = 0; i < nSockets; i++) {
writeAckPacket(sockets[i]);
}
printf("⭕ Network is initialized\n");
return std::unique_ptr<NnNetwork>(new NnNetwork(nSockets, sockets));
}
NnNetwork::NnNetwork(NnUint nSockets, int *sockets) {
this->nSockets = nSockets;
this->sockets = sockets;
this->sentBytes = new NnSize[nSockets];
this->recvBytes = new NnSize[nSockets];
}
NnNetwork::~NnNetwork() {
delete[] sentBytes;
delete[] recvBytes;
for (NnUint i = 0; i < nSockets; i++)
destroySocket(sockets[i]);
delete[] sockets;
printf("⭕ Network is closed\n");
}
void NnNetwork::setTurbo(bool enabled) {
for (NnUint i = 0; i < nSockets; i++) {
::setNonBlocking(sockets[i], enabled);
}
}
void NnNetwork::write(const NnUint socketIndex, const void *data, const NnSize size) {
assert(socketIndex < nSockets);
NnByte *current = (NnByte *)data;
int s = sockets[socketIndex];
for (NnSize chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) {
NnSize chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
writeSocket(s, current, chunkSize);
current += chunkSize;
}
sentBytes[socketIndex] += size;
}
void NnNetwork::read(const NnUint socketIndex, void *data, const NnSize size) {
assert(socketIndex < nSockets);
NnByte *current = (NnByte *)data;
int s = sockets[socketIndex];
for (NnSize chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) {
NnSize chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
readSocket(s, current, chunkSize);
current += chunkSize;
}
recvBytes[socketIndex] += size;
}
void NnNetwork::writeAck(const NnUint socketIndex) {
assert(socketIndex >= 0 && socketIndex < nSockets);
writeAckPacket(sockets[socketIndex]);
}
void NnNetwork::readAck(const NnUint socketIndex) {
assert(socketIndex >= 0 && socketIndex < nSockets);
readAckPacket(sockets[socketIndex]);
}
bool NnNetwork::tryReadWithMaxAttempts(NnUint socketIndex, void *data, NnSize size, unsigned long maxAttempts) {
assert(socketIndex >= 0 && socketIndex < nSockets);
if (tryReadSocket(sockets[socketIndex], data, size, maxAttempts)) {
recvBytes[socketIndex] += size;
return true;
}
return false;
}
void NnNetwork::writeMany(NnUint n, NnSocketIo *ios) {
bool isWriting;
NnSize nBytes = 0;
for (NnUint i = 0; i < n; i++) {
NnSocketIo *io = &ios[i];
assert(io->socketIndex < nSockets);
sentBytes[io->socketIndex] += io->size;
}
do {
isWriting = false;
for (NnUint i = 0; i < n; i++) {
NnSocketIo *io = &ios[i];
if (io->size > 0) {
isWriting = true;
int socket = sockets[io->socketIndex];
ssize_t chunkSize = io->size > MAX_CHUNK_SIZE ? MAX_CHUNK_SIZE : io->size;
ssize_t s = send(socket, (const char*)io->data, chunkSize, 0);
if (s < 0) {
if (isEagainError()) {
continue;
}
throw NnWriteNetworkException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR);
} else if (s == 0) {
throw NnWriteNetworkException(0, "Socket closed");
}
io->size -= s;
io->data = (char*)io->data + s;
}
}
} while (isWriting);
}
void NnNetwork::writeAll(void *data, NnSize size) {
std::vector<NnSocketIo> ios(nSockets);
for (NnUint i = 0; i < nSockets; i++) {
NnSocketIo *io = &ios[i];
io->socketIndex = i;
io->data = data;
io->size = size;
}
writeMany(nSockets, &ios[0]);
}
void NnNetwork::readMany(NnUint n, NnSocketIo *ios) {
bool isReading;
NnSize nBytes = 0;
for (NnUint i = 0; i < n; i++) {
NnSocketIo *io = &ios[i];
assert(io->socketIndex < nSockets);
recvBytes[io->socketIndex] += io->size;
}
do {
isReading = false;
for (NnUint i = 0; i < n; i++) {
NnSocketIo *io = &ios[i];
if (io->size > 0) {
isReading = true;
int socket = sockets[io->socketIndex];
ssize_t r = recv(socket, (char*)io->data, io->size, 0);
if (r < 0) {
if (isEagainError()) {
continue;
}
throw NnReadNetworkException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR);
} else if (r == 0) {
throw NnReadNetworkException(0, "Socket closed");
}
io->size -= r;
io->data = (char*)io->data + r;
}
}
} while (isReading);
}
void NnNetwork::getStats(NnSize *sentBytes, NnSize *recvBytes) {
*sentBytes = 0;
*recvBytes = 0;
for (NnUint i = 0; i < nSockets; i++) {
*sentBytes += this->sentBytes[i];
*recvBytes += this->recvBytes[i];
}
resetStats();
}
void NnNetwork::resetStats() {
for (NnUint i = 0; i < nSockets; i++) {
sentBytes[i] = 0;
recvBytes[i] = 0;
}
}
static void syncWithRoot(NnNetwork *network, NnByte nodeIndex, NnByte *buffer, NnSize nBytes, NnUint nThreads, NnUint threadIndex) {
if (nodeIndex == 0) {
// root
NnUint nSocketsPerThread = network->nSockets / nThreads + (network->nSockets % nThreads > threadIndex ? 1 : 0);
if (nSocketsPerThread == 0) return;
std::vector<NnSocketIo> ios(nSocketsPerThread);
for (NnUint i = 0; i < nSocketsPerThread; i++) {
ios[i].socketIndex = threadIndex + i * nThreads;
ios[i].data = buffer;
ios[i].size = nBytes;
}
network->writeMany(nSocketsPerThread, &ios[0]);
} else {
// worker
if (threadIndex != 0) return;
NnSocketIo ios;
ios.data = buffer;
ios.size = nBytes;
ios.socketIndex = 0; // root
network->readMany(1, &ios);
}
}
static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnUint nodeIndex, NnUint nNodes, NnByte *buffer, NnSize nBytes, NnUint nThreads, NnUint threadIndex) {
bool isWorker = nodeIndex != 0;
NnUint nSockets = onlyFromWorkerToRoot && isWorker ? 1 : network->nSockets;
NnUint nSocketsPerThread = nSockets / nThreads + (nSockets % nThreads > threadIndex ? 1 : 0);
if (nSocketsPerThread == 0) return;
NnSize sliceBytes = nBytes / nNodes;
std::vector<NnSocketIo> ios(nSocketsPerThread);
if (!onlyFromWorkerToRoot || isWorker) {
NnByte *mySliceData = &buffer[sliceBytes * nodeIndex];
for (NnUint i = 0; i < nSocketsPerThread; i++) {
NnUint socketIndex = threadIndex + i * nThreads;
ios[i].socketIndex = socketIndex;
ios[i].data = mySliceData;
ios[i].size = sliceBytes;
}
network->writeMany(nSocketsPerThread, &ios[0]);
}
if (!onlyFromWorkerToRoot || !isWorker) {
for (NnUint i = 0; i < nSocketsPerThread; i++) {
NnUint socketIndex = threadIndex + i * nThreads;
NnUint sliceIndex = socketIndex >= nodeIndex ? socketIndex + 1 : socketIndex;
NnByte *sliceData = &buffer[sliceBytes * sliceIndex];
ios[i].socketIndex = socketIndex;
ios[i].data = sliceData;
ios[i].size = sliceBytes;
}
network->readMany(nSocketsPerThread, &ios[0]);
}
}
NnNetworkNodeSynchronizer::NnNetworkNodeSynchronizer(NnNetwork *network, NnNetExecution *execution, NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
this->network = network;
this->execution = execution;
this->netConfig = netConfig;
this->nodeConfig = nodeConfig;
}
void NnNetworkNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) {
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) {
NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex];
NnByte *pipe = execution->pipes[syncConfig->pipeIndex];
NnPipeConfig *pipeConfig = &netConfig->pipes[syncConfig->pipeIndex];
NnSize batchBytes = getBytes(pipeConfig->size.floatType, pipeConfig->size.x);
for (NnUint batchIndex = 0; batchIndex < execution->batchSize; batchIndex++) {
NnByte *pipeBatch = &pipe[batchIndex * batchBytes];
if (syncConfig->syncType == SYNC_WITH_ROOT) {
syncWithRoot(network, nodeConfig->nodeIndex, pipeBatch, batchBytes, nThreads, threadIndex);
} else if (syncConfig->syncType == SYNC_NODE_SLICES) {
syncNodeSlices(false, network, nodeConfig->nodeIndex, netConfig->nNodes, pipeBatch, batchBytes, nThreads, threadIndex);
} else if (syncConfig->syncType == SYNC_NODE_SLICES_EXCEPT_ROOT) {
syncNodeSlices(true, network, nodeConfig->nodeIndex, netConfig->nNodes, pipeBatch, batchBytes, nThreads, threadIndex);
} else {
throw std::invalid_argument("Unknown sync type");
}
}
}
}
static void writeString(NnNetwork *network, NnUint socketIndex, char *str) {
NnUint bytes = std::strlen(str) + 1;
network->write(socketIndex, &bytes, sizeof(NnUint));
network->write(socketIndex, str, bytes);
}
static char *readString(NnNetwork *network, NnUint socketIndex) {
NnUint bytes;
network->read(socketIndex, &bytes, sizeof(NnUint));
char *str = new char[bytes];
network->read(socketIndex, str, bytes);
return str;
}
NnRootConfigWriter::NnRootConfigWriter(NnNetwork *network) {
this->network = network;
}
void NnRootConfigWriter::writeNet(NnUint socketIndex, NnNetConfig *config) {
network->writeAck(socketIndex);
network->write(socketIndex, &config->nBatches, sizeof(config->nBatches));
network->write(socketIndex, &config->nNodes, sizeof(config->nNodes));
network->write(socketIndex, &config->nPipes, sizeof(config->nPipes));
for (NnUint pipeIndex = 0; pipeIndex < config->nPipes; pipeIndex++) {
NnPipeConfig *pipeConfig = &config->pipes[pipeIndex];
network->write(socketIndex, &pipeConfig->size, sizeof(pipeConfig->size));
writeString(network, socketIndex, pipeConfig->name);
}
network->write(socketIndex, &config->nPreSyncs, sizeof(config->nPreSyncs));
for (NnUint preSyncIndex = 0; preSyncIndex < config->nPreSyncs; preSyncIndex++) {
NnPreSyncConfig *preSyncConfig = &config->preSyncs[preSyncIndex];
network->write(socketIndex, &preSyncConfig->pipeIndex, sizeof(preSyncConfig->pipeIndex));
}
network->readAck(socketIndex);
}
void NnRootConfigWriter::writeNode(NnUint socketIndex, NnNodeConfig *config) {
network->writeAck(socketIndex);
network->write(socketIndex, &config->nodeIndex, sizeof(config->nodeIndex));
network->write(socketIndex, &config->nBuffers, sizeof(config->nBuffers));
network->write(socketIndex, &config->nSegments, sizeof(config->nSegments));
for (NnUint bufferIndex = 0; bufferIndex < config->nBuffers; bufferIndex++) {
NnBufferConfig *bufferConfig = &config->buffers[bufferIndex];
network->write(socketIndex, &bufferConfig->size, sizeof(bufferConfig->size));
writeString(network, socketIndex, bufferConfig->name);
}
for (NnUint segmentIndex = 0; segmentIndex < config->nSegments; segmentIndex++) {
NnSegmentConfig *segmentConfig = &config->segments[segmentIndex];
network->write(socketIndex, &segmentConfig->nSyncs, sizeof(segmentConfig->nSyncs));
network->write(socketIndex, &segmentConfig->nOps, sizeof(segmentConfig->nOps));
for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) {
NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex];
network->write(socketIndex, &syncConfig->pipeIndex, sizeof(syncConfig->pipeIndex));
network->write(socketIndex, &syncConfig->syncType, sizeof(syncConfig->syncType));
}
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
network->write(socketIndex, &opConfig->code, sizeof(opConfig->code));
network->write(socketIndex, &opConfig->index, sizeof(opConfig->index));
network->write(socketIndex, &opConfig->weightSize, sizeof(opConfig->weightSize));
network->write(socketIndex, &opConfig->configSize, sizeof(opConfig->configSize));
writeString(network, socketIndex, opConfig->name);
network->write(socketIndex, &opConfig->input, sizeof(opConfig->input));
network->write(socketIndex, &opConfig->output, sizeof(opConfig->output));
if (opConfig->configSize > 0)
network->write(socketIndex, opConfig->config, opConfig->configSize);
}
}
network->readAck(socketIndex);
}
void NnRootConfigWriter::writeToWorkers(NnNetConfig *netConfig, NnNodeConfig *nodeConfigs) {
for (NnUint nodeIndex = 1; nodeIndex < netConfig->nNodes; nodeIndex++) {
NnUint socketIndex = nodeIndex - 1;
writeNet(socketIndex, netConfig);
writeNode(socketIndex, &nodeConfigs[nodeIndex]);
}
}
NnWorkerConfigReader::NnWorkerConfigReader(NnNetwork *network) {
this->network = network;
}
NnNetConfig NnWorkerConfigReader::readNet() {
network->readAck(ROOT_SOCKET_INDEX);
NnNetConfig config;
network->read(ROOT_SOCKET_INDEX, &config.nBatches, sizeof(config.nBatches));
network->read(ROOT_SOCKET_INDEX, &config.nNodes, sizeof(config.nNodes));
network->read(ROOT_SOCKET_INDEX, &config.nPipes, sizeof(config.nPipes));
config.pipes = new NnPipeConfig[config.nPipes];
for (NnUint pipeIndex = 0; pipeIndex < config.nPipes; pipeIndex++) {
NnPipeConfig *pipeConfig = &config.pipes[pipeIndex];
network->read(ROOT_SOCKET_INDEX, &pipeConfig->size, sizeof(pipeConfig->size));
pipeConfig->name = readString(network, ROOT_SOCKET_INDEX);
}
network->read(ROOT_SOCKET_INDEX, &config.nPreSyncs, sizeof(config.nPreSyncs));
config.preSyncs = new NnPreSyncConfig[config.nPreSyncs];
for (NnUint preSyncIndex = 0; preSyncIndex < config.nPreSyncs; preSyncIndex++) {
NnPreSyncConfig *preSyncConfig = &config.preSyncs[preSyncIndex];
network->read(ROOT_SOCKET_INDEX, &preSyncConfig->pipeIndex, sizeof(preSyncConfig->pipeIndex));
}
network->writeAck(ROOT_SOCKET_INDEX);
return config;
}
NnNodeConfig NnWorkerConfigReader::readNode() {
network->readAck(ROOT_SOCKET_INDEX);
NnNodeConfig config;
network->read(ROOT_SOCKET_INDEX, &config.nodeIndex, sizeof(config.nodeIndex));
network->read(ROOT_SOCKET_INDEX, &config.nBuffers, sizeof(config.nBuffers));
network->read(ROOT_SOCKET_INDEX, &config.nSegments, sizeof(config.nSegments));
config.buffers = new NnBufferConfig[config.nBuffers];
config.segments = new NnSegmentConfig[config.nSegments];
for (NnUint bufferIndex = 0; bufferIndex < config.nBuffers; bufferIndex++) {
NnBufferConfig *bufferConfig = &config.buffers[bufferIndex];
network->read(ROOT_SOCKET_INDEX, &bufferConfig->size, sizeof(bufferConfig->size));
bufferConfig->name = readString(network, ROOT_SOCKET_INDEX);
}
for (NnUint segmentIndex = 0; segmentIndex < config.nSegments; segmentIndex++) {
NnSegmentConfig *segmentConfig = &config.segments[segmentIndex];
network->read(ROOT_SOCKET_INDEX, &segmentConfig->nSyncs, sizeof(segmentConfig->nSyncs));
network->read(ROOT_SOCKET_INDEX, &segmentConfig->nOps, sizeof(segmentConfig->nOps));
if (segmentConfig->nSyncs > 0) {
segmentConfig->syncs = new NnSyncConfig[segmentConfig->nSyncs];
for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) {
NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex];
network->read(ROOT_SOCKET_INDEX, &syncConfig->pipeIndex, sizeof(syncConfig->pipeIndex));
network->read(ROOT_SOCKET_INDEX, &syncConfig->syncType, sizeof(syncConfig->syncType));
}
}
if (segmentConfig->nOps > 0) {
segmentConfig->ops = new NnOpConfig[segmentConfig->nOps];
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
network->read(ROOT_SOCKET_INDEX, &opConfig->code, sizeof(opConfig->code));
network->read(ROOT_SOCKET_INDEX, &opConfig->index, sizeof(opConfig->index));
network->read(ROOT_SOCKET_INDEX, &opConfig->weightSize, sizeof(opConfig->weightSize));
network->read(ROOT_SOCKET_INDEX, &opConfig->configSize, sizeof(opConfig->configSize));
opConfig->name = readString(network, ROOT_SOCKET_INDEX);
network->read(ROOT_SOCKET_INDEX, &opConfig->input, sizeof(opConfig->input));
network->read(ROOT_SOCKET_INDEX, &opConfig->output, sizeof(opConfig->output));
if (opConfig->configSize > 0) {
opConfig->config = new NnByte[opConfig->configSize];
network->read(ROOT_SOCKET_INDEX, opConfig->config, opConfig->configSize);
}
}
}
}
network->writeAck(ROOT_SOCKET_INDEX);
return config;
}
NnRootWeightLoader::NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnUint nNodes) {
this->executor = executor;
this->network = network;
this->nNodes = nNodes;
this->tempSize = 0;
}
NnRootWeightLoader::~NnRootWeightLoader() {
if (tempSize > 0)
delete[] temp;
}
void NnRootWeightLoader::finish() {
NnUint zeroSize = 0;
for (NnUint socketIndex = 0; socketIndex < nNodes - 1; socketIndex++) {
network->write(socketIndex, &zeroSize, sizeof(zeroSize));
network->readAck(socketIndex);
}
if (tempSize > 0) {
delete[] temp;
tempSize = 0;
}
}
void NnRootWeightLoader::allocate(NnSize size) {
if (tempSize < size) {
if (tempSize > 0)
delete[] temp;
tempSize = size;
temp = new NnByte[size];
}
}
void NnRootWeightLoader::writeWeight(NnUint nodeIndex, const char *opName, NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) {
NnUint nameSize = std::strlen(opName) + 1;
NnUint socketIndex = nodeIndex - 1;
network->write(socketIndex, &nameSize, sizeof(nameSize));
network->write(socketIndex, opName, nameSize);
network->write(socketIndex, &opIndex, sizeof(opIndex));
network->write(socketIndex, &offset, sizeof(offset));
network->write(socketIndex, &nBytes, sizeof(nBytes));
network->write(socketIndex, weight, nBytes);
}
NnSize NnRootWeightLoader::loadRoot(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight) {
executor->loadWeight(opName, opIndex, 0u, nBytes, weight);
return nBytes;
}
NnSize NnRootWeightLoader::loadAll(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight) {
executor->loadWeight(opName, opIndex, 0u, nBytes, weight);
if (nNodes > 1u) {
for (NnUint nodeIndex = 1u; nodeIndex < nNodes; nodeIndex++)
writeWeight(nodeIndex, opName, opIndex, 0u, nBytes, weight);
}
return nBytes;
}
NnSize NnRootWeightLoader::loadRowMatmulSlices(const char *opName, const NnUint opIndex, const NnUint expertIndex, NnRowMatmulSlice *slice, NnByte *weight) {
const NnUint offset = expertIndex * slice->sliceSize.nBytes;
if (nNodes == 1u) {
executor->loadWeight(opName, opIndex, offset, slice->sliceSize.nBytes, weight);
} else {
allocate(slice->sliceSize.nBytes);
for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) {
splitRowMatmulWeight(slice, nodeIndex, weight, temp);
if (nodeIndex == 0u)
executor->loadWeight(opName, opIndex, offset, slice->sliceSize.nBytes, temp);
else
writeWeight(nodeIndex, opName, opIndex, offset, slice->sliceSize.nBytes, temp);
}
}
return slice->size.nBytes;
}
NnSize NnRootWeightLoader::loadColMatmulSlices(const char *opName, const NnUint opIndex, const NnUint expertIndex, NnColMatmulSlice *slice, NnByte *weight) {
const NnUint offset = expertIndex * slice->sliceSize.nBytes;
if (nNodes == 1) {
executor->loadWeight(opName, opIndex, offset, slice->sliceSize.nBytes, weight);
} else {
allocate(slice->sliceSize.nBytes);
for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) {
splitColMatmulWeight(slice, nodeIndex, weight, temp);
if (nodeIndex == 0)
executor->loadWeight(opName, opIndex, offset, slice->sliceSize.nBytes, temp);
else
writeWeight(nodeIndex, opName, opIndex, offset, slice->sliceSize.nBytes, temp);
}
}
return slice->size.nBytes;
}
NnWorkerWeightReader::NnWorkerWeightReader(NnExecutor *executor, NnNetwork *network) {
this->executor = executor;
this->network = network;
this->tempSize = 0;
}
NnWorkerWeightReader::~NnWorkerWeightReader() {
if (tempSize > 0)
delete[] temp;
}
void NnWorkerWeightReader::allocate(NnUint size) {
if (tempSize < size) {
if (tempSize > 0)
delete[] temp;
tempSize = size;
temp = new NnByte[size];
}
}
void NnWorkerWeightReader::read() {
NnUint nameSize;
NnUint opIndex;
NnSize offset;
NnSize nBytes;
while (true) {
network->read(0, &nameSize, sizeof(nameSize));
if (nameSize == 0) {
network->writeAck(ROOT_SOCKET_INDEX);
if (tempSize > 0) {
delete temp;
tempSize = 0;
}
break;
}
std::unique_ptr<char[]> opNamePtr(new char[nameSize]);
char *opName = opNamePtr.get();
network->read(ROOT_SOCKET_INDEX, opName, nameSize);
network->read(ROOT_SOCKET_INDEX, &opIndex, sizeof(opIndex));
network->read(ROOT_SOCKET_INDEX, &offset, sizeof(offset));
network->read(ROOT_SOCKET_INDEX, &nBytes, sizeof(nBytes));
allocate(nBytes);
network->read(0, temp, nBytes);
executor->loadWeight(opName, opIndex, offset, nBytes, temp);
printf("💿 Loaded %22s %3d, %12zu kB\n", opName, opIndex, nBytes / 1024);
}
printf("💿 Weights loaded\n");
}

129
src/nn/nn-network.hpp Normal file
View File

@@ -0,0 +1,129 @@
#ifndef NN_NETWORK_H
#define NN_NETWORK_H
#include "nn-executor.hpp"
#define ROOT_SOCKET_INDEX 0
void initSockets();
void cleanupSockets();
int acceptSocket(int serverSocket);
void setReuseAddr(int socket);
void writeSocket(int socket, const void* data, NnSize size);
void readSocket(int socket, void* data, NnSize size);
int createServerSocket(int port);
void destroySocket(int serverSocket);
class NnReadNetworkException : public std::exception {
public:
int code;
const char *message;
NnReadNetworkException(int code, const char *message);
};
class NnWriteNetworkException : public std::exception {
public:
int code;
const char *message;
NnWriteNetworkException(int code, const char *message);
};
struct NnSocketIo {
NnUint socketIndex;
const void *data;
NnSize size;
};
class NnNetwork {
private:
int *sockets;
NnSize *sentBytes;
NnSize *recvBytes;
public:
static std::unique_ptr<NnNetwork> serve(int port);
static std::unique_ptr<NnNetwork> connect(NnUint nSockets, char **hosts, NnUint *ports);
NnUint nSockets;
NnNetwork(NnUint nSockets, int *sockets);
~NnNetwork();
void setTurbo(bool enabled);
void write(const NnUint socketIndex, const void *data, const NnSize size);
void read(const NnUint socketIndex, void *data, const NnSize size);
void writeAck(const NnUint socketIndex);
void readAck(const NnUint socketIndex);
bool tryReadWithMaxAttempts(NnUint socketIndex, void *data, NnSize size, unsigned long maxAttempts);
void writeMany(NnUint n, NnSocketIo *ios);
void writeAll(void *data, NnSize size);
void readMany(NnUint n, NnSocketIo *ios);
void getStats(NnSize *sentBytes, NnSize *recvBytes);
void resetStats();
};
class NnNetworkNodeSynchronizer : public NnNodeSynchronizer {
private:
NnNetwork *network;
NnNetExecution *execution;
NnNetConfig *netConfig;
NnNodeConfig *nodeConfig;
public:
NnNetworkNodeSynchronizer(NnNetwork *network, NnNetExecution *execution, NnNetConfig *netConfig, NnNodeConfig *nodeConfig);
~NnNetworkNodeSynchronizer() override {};
void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) override;
};
class NnRootConfigWriter {
private:
NnNetwork *network;
public:
NnRootConfigWriter(NnNetwork *network);
void writeNet(NnUint socketIndex, NnNetConfig *config);
void writeNode(NnUint socketIndex, NnNodeConfig *config);
void writeToWorkers(NnNetConfig *netConfig, NnNodeConfig *nodeConfigs);
};
class NnWorkerConfigReader {
private:
NnNetwork *network;
public:
NnWorkerConfigReader(NnNetwork *network);
NnNetConfig readNet();
NnNodeConfig readNode();
};
class NnRootWeightLoader {
private:
NnExecutor *executor;
NnNetwork *network;
NnUint nNodes;
NnByte *temp;
NnSize tempSize;
public:
NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnUint nNodes);
~NnRootWeightLoader();
void writeWeight(NnUint nodeIndex, const char *opName, NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight);
NnSize loadRoot(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight);
NnSize loadAll(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight);
NnSize loadRowMatmulSlices(const char *opName, const NnUint opIndex, const NnUint expertIndex, NnRowMatmulSlice *slice, NnByte *weight);
NnSize loadColMatmulSlices(const char *opName, const NnUint opIndex, const NnUint expertIndex, NnColMatmulSlice *slice, NnByte *weight);
void finish();
private:
void allocate(NnSize size);};
class NnWorkerWeightReader {
private:
NnExecutor *executor;
NnNetwork *network;
NnByte *temp;
NnUint tempSize;
public:
NnWorkerWeightReader(NnExecutor *executor, NnNetwork *network);
~NnWorkerWeightReader();
void read();
private:
void allocate(NnUint size);
};
#endif

255
src/nn/nn-quants.cpp Normal file
View File

@@ -0,0 +1,255 @@
#include "nn-quants.hpp"
#include <cassert>
#include <cstring>
#include <cmath>
#include <stdexcept>
#include <cstdio>
#if defined(CONVERT_F16_TO_F32_LOOKUP)
float f16ToF32Lookup[65536];
#endif
void initQuants() {
#if defined(CONVERT_F16_TO_F32_LOOKUP)
for (NnUint i = 0; i < 65536; i++)
f16ToF32Lookup[i] = convertF16toF32Impl((NnFp16)i);
#endif
}
float convertF16toF32Impl(const NnFp16 value) {
union Fl32 {
uint32_t u;
float f;
};
const Fl32 magic = { (254U - 15U) << 23 };
const Fl32 infNan = { (127U + 16U) << 23 };
Fl32 result;
result.u = (value & 0x7FFFU) << 13;
result.f *= magic.f;
if (result.f >= infNan.f)
result.u |= 255U << 23;
result.u |= (value & 0x8000U) << 16;
return result.f;
}
NnFp16 convertF32ToF16Impl(const float x) {
int i = *(int *)&x;
int s = (i >> 16) & 0x00008000;
int e = ((i >> 23) & 0x000000ff) - (127 - 15);
int m = i & 0x007fffff;
if (e <= 0) {
if (e < -10) {
return s;
}
m = m | 0x00800000;
int t = 14 - e;
int a = (1 << (t - 1)) - 1;
int b = (m >> t) & 1;
m = (m + a + b) >> t;
return s | m;
}
if (e == 0xff - (127 - 15)) {
if (m == 0) {
return s | 0x7c00;
}
m >>= 13;
return s | 0x7c00 | m | (m == 0);
}
m = m + 0x00000fff + ((m >> 13) & 1);
if (m & 0x00800000) {
m = 0;
e += 1;
}
assert(e <= 30);
return s | (e << 10) | (m >> 13);
}
void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) {
assert(n % Q80_BLOCK_SIZE == 0);
const NnUint nBlocks = n / Q80_BLOCK_SIZE;
SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex);
#if defined(__ARM_NEON)
for (NnUint i = start; i < end; i++) {
const float *x = &input[i * Q80_BLOCK_SIZE];
NnBlockQ80 *y = &output[i];
float32x4_t amaxVec = vdupq_n_f32(0.0f);
for (NnUint j = 0; j < Q80_BLOCK_SIZE; j += 4) {
const float32x4_t vec = vld1q_f32(&x[j]);
const float32x4_t abs_vec = vabsq_f32(vec);
amaxVec = vmaxq_f32(amaxVec, abs_vec);
}
float amax = vmaxvq_f32(amaxVec);
const float d = amax / 127.0f;
const float id = d != 0.0f ? 1.0f / d : 0.0f;
y->d = CONVERT_F32_TO_F16(d);
const float32x4_t vid_vec = vdupq_n_f32(id);
for (NnUint j = 0; j < Q80_BLOCK_SIZE; j += 4) {
float32x4_t vec = vld1q_f32(&x[j]);
vec = vmulq_f32(vec, vid_vec);
const uint32x4_t sign_mask = vcgeq_f32(vec, vdupq_n_f32(0.0f));
const float32x4_t half = vbslq_f32(sign_mask, vdupq_n_f32(0.5f), vdupq_n_f32(-0.5f));
vec = vaddq_f32(vec, half);
const int32x4_t vec_i32 = vcvtq_s32_f32(vec);
const int16x4_t vec_i16 = vqmovn_s32(vec_i32);
const int8x8_t vec_i8 = vqmovn_s16(vcombine_s16(vec_i16, vec_i16));
vst1_lane_s32((int32_t *)(y->qs + j), vreinterpret_s32_s8(vec_i8), 0);
}
}
#elif defined(__AVX2__)
for (NnUint i = start; i < end; ++i) {
const float *x = input + i * Q80_BLOCK_SIZE;
NnBlockQ80 *y = output + i;
__m256 max_abs = _mm256_setzero_ps();
for (int j = 0; j < Q80_BLOCK_SIZE; j += 8) {
__m256 vec = _mm256_loadu_ps(x + j);
__m256 abs_vec = _mm256_and_ps(vec, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
max_abs = _mm256_max_ps(max_abs, abs_vec);
}
__m128 max_hi = _mm256_extractf128_ps(max_abs, 1);
__m128 max_lo = _mm256_castps256_ps128(max_abs);
__m128 max_128 = _mm_max_ps(max_hi, max_lo);
max_128 = _mm_max_ps(max_128, _mm_movehl_ps(max_128, max_128));
max_128 = _mm_max_ss(max_128, _mm_shuffle_ps(max_128, max_128, _MM_SHUFFLE(1, 1, 1, 1)));
float amax = _mm_cvtss_f32(max_128);
const float d = amax / 127.0f;
const float id = (d != 0.0f) ? 1.0f / d : 0.0f;
y->d = CONVERT_F32_TO_F16(d);
const __m256 id_vec = _mm256_set1_ps(id);
const __m128i shuffle_mask = _mm_set_epi8(
-1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, 12, 8, 4, 0
);
for (int j = 0; j < Q80_BLOCK_SIZE; j += 8) {
__m256 vec = _mm256_loadu_ps(x + j);
__m256 scaled = _mm256_mul_ps(vec, id_vec);
__m256 rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
__m256i integers = _mm256_cvtps_epi32(rounded);
__m128i low = _mm256_extracti128_si256(integers, 0);
__m128i high = _mm256_extracti128_si256(integers, 1);
__m128i low_bytes = _mm_shuffle_epi8(low, shuffle_mask);
__m128i high_bytes = _mm_shuffle_epi8(high, shuffle_mask);
uint32_t low_part = _mm_extract_epi32(low_bytes, 0);
uint32_t high_part = _mm_extract_epi32(high_bytes, 0);
uint64_t packed = (static_cast<uint64_t>(high_part) << 32) | low_part;
std::memcpy(y->qs + j, &packed, sizeof(packed));
}
}
#else
for (NnUint i = start; i < end; i++) {
const float *x = &input[i * Q80_BLOCK_SIZE];
NnBlockQ80 *y = &output[i];
float amax = 0.0f;
for (NnUint j = 0; j < Q80_BLOCK_SIZE; j++) {
const float v = fabsf(x[j]);
amax = amax > v ? amax : v;
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f / d : 0.0f;
y->d = CONVERT_F32_TO_F16(d);
for (NnUint j = 0; j < Q80_BLOCK_SIZE; ++j) {
y->qs[j] = roundf(x[j] * id);
}
}
#endif
}
void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnUint k, const NnUint nThreads, const NnUint threadIndex) {
assert(k % Q80_BLOCK_SIZE == 0);
const int nBlocks = k / Q80_BLOCK_SIZE;
const int blocksPerThread = nBlocks / nThreads;
const int sk = blocksPerThread * Q80_BLOCK_SIZE;
const int currentThreadBlocks = blocksPerThread + (threadIndex == nThreads - 1 ? nBlocks % nThreads : 0);
const NnBlockQ80 *x = &input[blocksPerThread * threadIndex];
float* y = &output[sk * threadIndex];
for (int i = 0; i < currentThreadBlocks; i++) {
const float d = CONVERT_F16_TO_F32(x[i].d);
for (int j = 0; j < Q80_BLOCK_SIZE; j++) {
y[i * Q80_BLOCK_SIZE + j] = x[i].qs[j] * d;
}
}
}
void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) {
assert(n % Q40_BLOCK_SIZE == 0);
const NnUint nBlocks = n / Q40_BLOCK_SIZE;
const NnUint halfSize = Q40_BLOCK_SIZE / 2;
SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex);
for (NnUint i = start; i < end; i++) {
float amax = 0.0f;
float max = 0.0f;
for (NnUint j = 0; j < Q40_BLOCK_SIZE; j++) {
float v = x[i * Q40_BLOCK_SIZE + j];
if (amax < fabsf(v)) {
amax = fabsf(v);
max = v;
}
}
const float d = max / -8.0f;
const float id = d ? 1.0f / d : 0.0f;
NnBlockQ40 *o = &output[i];
o->d = CONVERT_F32_TO_F16(d);
for (NnUint j = 0; j < halfSize; j++) {
const float x0 = x[i * Q40_BLOCK_SIZE + j] * id;
const float x1 = x[i * Q40_BLOCK_SIZE + halfSize + j] * id;
uint8_t xi0 = (int8_t)(x0 + 8.5f);
uint8_t xi1 = (int8_t)(x1 + 8.5f);
if (xi0 > 15) xi0 = 15;
if (xi1 > 15) xi1 = 15;
o->qs[j] = xi0 | (xi1 << 4);
}
}
}
void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) {
assert(n % Q40_BLOCK_SIZE == 0);
const NnUint nBlocks = n / Q40_BLOCK_SIZE;
SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex);
for (NnUint i = start; i < end; i++) {
const NnBlockQ40 *b = &x[i];
const float d = CONVERT_F16_TO_F32(b->d);
for (int j = 0; j < Q40_BLOCK_SIZE / 2; ++j) {
const int x0 = (b->qs[j] & 0x0F) - 8;
const int x1 = (b->qs[j] >> 4) - 8;
output[i * Q40_BLOCK_SIZE + j] = x0 * d;
output[i * Q40_BLOCK_SIZE + j + Q40_BLOCK_SIZE / 2] = x1 * d;
}
}
}
const char *floatTypeToString(NnFloatType type) {
if (type == F_UNK) return "F_UNK";
if (type == F_32) return "F_32";
if (type == F_16) return "F_16";
if (type == F_Q40) return "F_Q40";
if (type == F_Q80) return "F_Q80";
throw std::invalid_argument("Unknown float type");
}

88
src/nn/nn-quants.hpp Normal file
View File

@@ -0,0 +1,88 @@
#ifndef NN_QUANTS_H
#define NN_QUANTS_H
#include <cstdint>
#include <cstring>
#if defined(__ARM_NEON)
#include <arm_neon.h>
#elif defined(__AVX2__) || defined(__F16C__)
#include <immintrin.h>
#endif
typedef std::uint8_t NnByte;
typedef std::uint32_t NnUint;
typedef std::size_t NnSize;
typedef std::uint16_t NnFp16;
float convertF16toF32Impl(const NnFp16 value);
NnFp16 convertF32ToF16Impl(const float x);
#if defined(__ARM_NEON) && defined(__ARM_FP16_FORMAT_IEEE)
inline float convertF16ToF32Neon(const NnFp16 value) {
__fp16 fp;
std::memcpy(&fp, &value, sizeof(fp));
return (float)fp;
}
inline NnFp16 convertF32ToF16Neon(const float x) {
__fp16 h = x;
return *(NnFp16 *)&h;
}
#define CONVERT_F16_TO_F32(value) convertF16ToF32Neon(value)
#define CONVERT_F32_TO_F16(value) convertF32ToF16Neon(value)
#elif defined(__F16C__)
#define CONVERT_F32_TO_F16(v) _cvtss_sh((v), _MM_FROUND_TO_NEAREST_INT)
#endif
#if !defined(CONVERT_F16_TO_F32)
extern float f16ToF32Lookup[65536];
inline static float convertF16ToF32Lookup(const NnFp16 value) {
return f16ToF32Lookup[value];
}
#define CONVERT_F16_TO_F32_LOOKUP
#define CONVERT_F16_TO_F32(value) convertF16ToF32Lookup(value)
#endif
#if !defined(CONVERT_F32_TO_F16)
#define CONVERT_F32_TO_F16(value) convertF32ToF16Impl(value)
#endif
#define Q40_BLOCK_SIZE 32
#define Q80_BLOCK_SIZE 32
enum NnFloatType {
F_UNK = -1,
F_32 = 0,
F_16 = 1,
F_Q40 = 2,
F_Q80 = 3,
};
typedef struct {
std::uint16_t d;
std::uint8_t qs[Q40_BLOCK_SIZE / 2];
} NnBlockQ40;
typedef struct {
std::uint16_t d;
std::int8_t qs[Q80_BLOCK_SIZE];
} NnBlockQ80;
void initQuants();
void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnUint k, const NnUint nThreads, const NnUint threadIndex);
void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnUint k, const NnUint nThreads, const NnUint threadIndex);
void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex);
void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex);
const char *floatTypeToString(NnFloatType type);
#define SPLIT_THREADS(varStart, varEnd, rangeLen, nThreads, threadIndex) \
const NnUint rangeSlice = rangeLen / nThreads; \
const NnUint rangeRest = rangeLen % nThreads; \
const NnUint varStart = threadIndex * rangeSlice + (threadIndex < rangeRest ? threadIndex : rangeRest); \
const NnUint varEnd = varStart + rangeSlice + (threadIndex < rangeRest ? 1 : 0);
#endif

989
src/nn/nn-vulkan-test.cpp Normal file
View File

@@ -0,0 +1,989 @@
#include <cstdio>
#include <cmath>
#include "nn-config-builder.hpp"
#include "nn-quants.hpp"
#include "nn-vulkan.hpp"
#define N_BATCHES 4
void printOk(const char *name) {
printf("✅ %24s passed\n", name);
}
void assertFloat(NnUint position, const float value, const float expectedValue, const float tolerance) {
float diff = fabs(expectedValue - value);
if (diff > tolerance) {
printf("❌ [%d] failed: value=%f, expectedValue=%f, diff=%f\n", position, value, expectedValue, diff);
exit(1);
}
}
void execute(
void (*build)(NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder),
void (*execute)(NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device)
) {
NnUint nNodes = 1;
NnNetConfigBuilder netBuilder(nNodes, N_BATCHES);
NnNodeConfigBuilder nodeBuilder(0);
NnSegmentConfigBuilder segmentBuilder;
build(&netBuilder, &nodeBuilder, &segmentBuilder);
nodeBuilder.addSegment(segmentBuilder.build());
NnNetConfig netConfig = netBuilder.build();
NnNodeConfig nodeConfig = nodeBuilder.build();
std::unique_ptr<NnNetConfig, void(*)(NnNetConfig *)> netConfigPtr(&netConfig, releaseNetConfig);
std::unique_ptr<NnNodeConfig, void(*)(NnNodeConfig *)> nodeConfigPtr(&nodeConfig, releaseNodeConfig);
NnNetExecution execution(1, &netConfig);
NnUint gpuIndex = 0;
std::vector<NnExecutorDevice> devices;
NnVulkanDevice *device = new NnVulkanDevice(gpuIndex, &netConfig, &nodeConfig, &execution);
devices.push_back(NnExecutorDevice(device, -1, -1));
NnFakeNodeSynchronizer synchronizer;
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
execute(&executor, &execution, device);
}
template <NnUint dim>
void testRmsNorm_F32_F32_F32() {
#define TEST_RMS_NORM_EPS 1e-5f
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, dim));
NnUint invRmsBufferIndex = nodeBuilder->addBuffer("inv_rms", size2D(F_32, N_BATCHES, 1));
segmentBuilder->addOp(OP_INV_RMS, "inv_rms", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
size0(),
NnInvRmsOpConfig{TEST_RMS_NORM_EPS, 1});
segmentBuilder->addOp(OP_RMS_NORM, "rms_norm", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size1D(F_32, dim),
NnRmsNormOpConfig{invRmsBufferIndex, 1});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
const NnUint batchSize = 2;
execution->setBatchSize(batchSize);
std::vector<float> normWeight(dim);
for (NnUint i = 0; i < dim; i++)
normWeight[i] = (0.25f + (float)i) / (float)dim;
executor->loadWeight("rms_norm", 0u, 0u, normWeight.size() * sizeof(float), (NnByte *)normWeight.data());
float *xPipe = (float *)execution->pipes[0];
float expectedS[batchSize];
for (NnUint b = 0; b < batchSize; b++) {
float *xBatchPipe = &xPipe[b * dim];
float s = 0.0f;
for (NnUint i = 0; i < dim; i++) {
float u = (float)(dim - i + b) / (float)(dim / 2);
xBatchPipe[i] = u;
s += u * u;
}
s /= (float)dim;
expectedS[b] = 1.0f / sqrtf(s + TEST_RMS_NORM_EPS);
}
// act
executor->forward();
// assert
float invRmsBuffer[N_BATCHES];
device->data.buffers[0].get()->read((NnByte *)invRmsBuffer);
for (NnUint b = 0; b < batchSize; b++) {
float *xBatchPipe = &xPipe[b * dim];
const float t = 0.0000019f;
assertFloat(b, invRmsBuffer[b], expectedS[b], t);
const float s = invRmsBuffer[b];
for (NnUint i = 0; i < dim; i++) {
float u = (float)(dim - i + b) / (float)(dim / 2);
assertFloat(b * dim + i, xBatchPipe[i], (u * s) * normWeight[i], t);
}
}
printOk("testRmsNorm_F32_F32_F32");
});
}
template <NnUint dim>
void testSilu_F32_F32() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, dim));
segmentBuilder->addOp(OP_SILU, "silu", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size0(),
NnSiluOpCodeConfig{});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float expectedOutput[dim * N_BATCHES];
float *xPipe = (float *)execution->pipes[0];
for (NnUint b = 0; b < N_BATCHES; b++) {
const NnUint offset = b * dim;
for (NnUint i = 0; i < dim; i++) {
const float v = i / (float)dim + (float)(b + 1);
xPipe[offset + i] = v;
expectedOutput[offset + i] = v / (1.0 + expf(-v));
}
}
// act
executor->forward();
// assert
float t = 0.00001f;
for (NnUint b = 0; b < N_BATCHES; b++) {
const NnUint offset = b * dim;
for (NnUint i = 0; i < dim; i++)
assertFloat(offset + i, xPipe[offset + i], expectedOutput[offset + i], t);
}
printOk("testSilu_F32_F32");
});
}
template <NnUint dim, NnUint nZ>
void testMul_F32_F32() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, nZ, N_BATCHES, dim));
NnUint sBufferIndex = nodeBuilder->addBuffer("s", size3D(F_32, nZ, N_BATCHES, dim));
segmentBuilder->addOp(OP_MUL, "mul", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size0(),
NnMulOpCodeConfig{sBufferIndex});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[0];
float sBuffer[nZ * N_BATCHES * dim];
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++) {
xPipe[i] = (float)i;
sBuffer[i] = (i % 8) / 10.0f;
}
device->data.buffers[0].get()->write((NnByte *)sBuffer);
// act
executor->forward();
// assert
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++)
assertFloat(i, xPipe[i], i * ((i % 8) / 10.0f), 0.000001f);
printOk("testMul_F32_F32");
});
}
void testMergeAdd_F32_F32() {
#define MERGE_ADD_F32_NODES 2
#define MERGE_ADD_F32_DIM 64
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint zPipeIndex = netBuilder->addPipe("Z", size2D(F_32, N_BATCHES, MERGE_ADD_F32_DIM * MERGE_ADD_F32_NODES));
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MERGE_ADD_F32_DIM));
segmentBuilder->addOp(OP_MERGE_ADD, "mergeAdd", 0,
pointerBatchConfig(SRC_PIPE, zPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size0(),
NnMergeAddOpCodeConfig{});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *zPipe = (float *)execution->pipes[0];
float *xPipe = (float *)execution->pipes[1];
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint n = 0; n < MERGE_ADD_F32_NODES; n++) {
for (NnUint i = 0; i < MERGE_ADD_F32_DIM; i++)
zPipe[b * MERGE_ADD_F32_NODES * MERGE_ADD_F32_DIM + n * MERGE_ADD_F32_DIM + i] = (float)(b + 1);
}
}
// act
executor->forward();
// assert
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint i = 0; i < MERGE_ADD_F32_DIM; i++) {
NnUint j = b * MERGE_ADD_F32_DIM + i;
assertFloat(j, xPipe[j], (float)(2 * b + 2), 0.00001f);
}
}
printOk("testMergeAdd_F32_F32");
});
}
void testMergeSum_F32_F32() {
#define MERGE_SUM_F32_N_Z 2
#define MERGE_SUM_F32_DIM 4
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, MERGE_SUM_F32_N_Z, N_BATCHES, MERGE_SUM_F32_DIM));
NnUint yPipeIndex = netBuilder->addPipe("Y", size3D(F_32, 1u, N_BATCHES, MERGE_SUM_F32_DIM));
segmentBuilder->addOp(OP_MERGE_SUM, "merge_sum", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, yPipeIndex),
size0(),
NnMergeSumOpCodeConfig{});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(2);
float *xPipe = (float *)execution->pipes[0];
float *xPipeZ0 = &xPipe[0];
float *xPipeZ1 = &xPipe[N_BATCHES * MERGE_SUM_F32_DIM];
float *yPipe = (float *)execution->pipes[1];
xPipeZ0[0] = 1.0f;
xPipeZ0[1] = 1.1f;
xPipeZ0[2] = 1.2f;
xPipeZ0[3] = 1.3f;
xPipeZ0[4] = 2.0f;
xPipeZ0[5] = 2.1f;
xPipeZ0[6] = 2.2f;
xPipeZ0[7] = 2.3f;
xPipeZ1[0] = 0.5f;
xPipeZ1[1] = 0.1f;
xPipeZ1[2] = 0.2f;
xPipeZ1[3] = 0.3f;
xPipeZ1[4] = 0.4f;
xPipeZ1[5] = 0.3f;
xPipeZ1[6] = 0.2f;
xPipeZ1[7] = 0.1f;
// act
executor->forward();
const float t = 0.00001f;
assertFloat(0, yPipe[0], 1.5f, t);
assertFloat(1, yPipe[1], 1.2f, t);
assertFloat(2, yPipe[2], 1.4f, t);
assertFloat(3, yPipe[3], 1.6f, t);
assertFloat(4, yPipe[4], 2.4f, t);
assertFloat(5, yPipe[5], 2.4f, t);
assertFloat(6, yPipe[6], 2.4f, t);
assertFloat(7, yPipe[7], 2.4f, t);
printOk("testMergeSum_F32_F32");
});
}
template <NnUint nNodes, NnUint dim>
static void testMergeAdd_Q80_F32() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
const NnUint zPipeIndex = netBuilder->addPipe("Z", size2D(F_Q80, N_BATCHES, dim * nNodes));
const NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, dim));
segmentBuilder->addOp(OP_MERGE_ADD, "mergeAdd", 0,
pointerBatchConfig(SRC_PIPE, zPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size0(),
NnMergeAddOpCodeConfig{});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float z[N_BATCHES * dim * nNodes];
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint n = 0; n < nNodes; n++) {
for (NnUint i = 0; i < dim; i++)
z[b * nNodes * dim + n * dim + i] = (float)(b + 1);
}
}
NnBlockQ80 *zPipe = (NnBlockQ80 *)execution->pipes[0];
const float *xPipe = (float *)execution->pipes[1];
quantizeF32toQ80(z, zPipe, N_BATCHES * dim * nNodes, 1, 0);
// act
executor->forward();
// assert
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint i = 0; i < dim; i++) {
float expectedValue = (float)((b + 1) * nNodes);
NnUint j = b * dim + i;
assertFloat(j, xPipe[j], expectedValue, 0.001f);
}
}
printOk("testMergeAdd_Q80_F32");
});
}
void testEmbedding_F32_F32() {
#define EMBEDDING_DIM 16
#define EMBEDDING_LEN 8
assert(EMBEDDING_LEN > N_BATCHES);
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, EMBEDDING_DIM));
segmentBuilder->addOp(OP_EMBEDDING, "embedding", 0,
pointerBatchConfig(SRC_PIPE, posPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size2D(F_32, EMBEDDING_LEN, EMBEDDING_DIM),
NnEmbeddingOpConfig{});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float embedding[EMBEDDING_DIM * EMBEDDING_LEN];
for (NnUint l = 0; l < EMBEDDING_LEN; l++) {
for (NnUint i = 0; i < EMBEDDING_DIM; i++)
embedding[l * EMBEDDING_DIM + i] = (float)(l + 4);
}
float *posPipe = (float *)execution->pipes[0];
for (NnUint b = 0; b < N_BATCHES; b++)
posPipe[b] = (float)b;
executor->loadWeight("embedding", 0u, 0u, EMBEDDING_DIM * EMBEDDING_LEN * sizeof(float), (NnByte *)embedding);
// act
executor->forward();
// assert
float *xPipe = (float *)execution->pipes[1];
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint i = 0; i < EMBEDDING_DIM; i++) {
NnUint j = b * EMBEDDING_DIM + i;
assertFloat(j, xPipe[j], (float)(b + 4), 0.00001f);
}
}
printOk("testEmbedding_F32_F32");
});
}
template <NnUint dim>
void testShift_F32_F32() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, dim));
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, 1, N_BATCHES * dim));
segmentBuilder->addOp(
OP_SHIFT, "shift", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerRawConfig(SRC_PIPE, yPipeIndex),
size0(),
NnShiftOpCodeConfig{posPipeIndex});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[1];
float *yPipe = (float *)execution->pipes[2];
float pos[N_BATCHES];
for (NnUint b = 0; b < N_BATCHES; b++) {
pos[b] = (float)b;
for (NnUint i = 0; i < dim; i++)
xPipe[b * dim + i] = (float)(b * 100 + i);
}
device->data.pipes[0].get()->write((NnByte *)pos);
// act
executor->forward();
// assert
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint i = 0; i < dim; i++) {
NnUint j = b * dim + i;
assertFloat(j, yPipe[j], (float)(b * 100 + i), 0.00001f);
}
}
printOk("testShift_F32_F32");
});
}
template <NnUint dim, NnUint nZ>
void testCast_F32_F32() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, nZ, N_BATCHES, dim));
NnUint yPipeIndex = netBuilder->addPipe("Y", size3D(F_32, nZ, N_BATCHES, dim));
segmentBuilder->addOp(
OP_CAST, "cast", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, yPipeIndex),
size0(),
NnCastOpCodeConfig{});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[0];
float *yPipe = (float *)execution->pipes[1];
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++)
xPipe[i] = (float)(i + 1);
// act
executor->forward();
// assert
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++)
assertFloat(i, yPipe[i], (float)(i + 1), 0.00001f);
printOk("testCast_F32_F32");
});
}
template <NnUint dim, NnUint nZ>
void testCast_F32_Q80() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, nZ, N_BATCHES, dim));
NnUint yPipeIndex = netBuilder->addPipe("Y", size3D(F_Q80, nZ, N_BATCHES, dim));
segmentBuilder->addOp(
OP_CAST, "cast", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, yPipeIndex),
size0(),
NnCastOpCodeConfig{});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[0];
NnBlockQ80 *yPipe = (NnBlockQ80 *)execution->pipes[1];
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++)
xPipe[i] = (float)(i + 1);
// act
executor->forward();
float yF32[nZ * N_BATCHES * dim];
dequantizeQ80toF32(yPipe, yF32, nZ * N_BATCHES * dim, 1, 0);
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++) {
const float expectedV = (float)(i + 1);
const float change = (yF32[i] - expectedV) / expectedV;
assertFloat(i, change, 0.0, 0.009f);
}
printOk("testCast_F32_Q80");
});
}
template <NnRopeType ropeType, void (*assertOutput)(float *x0, float *x1)>
void testRope_F32_F32() {
#define ROPE_DIM 2048
#define ROPE_KV_DIM 512
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
const NnUint nHeads = 32;
const NnUint seqLen = 4096;
const NnRopeSlice slice = sliceRope(ropeType, ROPE_DIM, ROPE_KV_DIM, 8, 1, seqLen, ROPE_DIM / nHeads, 500000.0f, 0);
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, ROPE_DIM));
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
NnUint ropeCacheBufferIndex = nodeBuilder->addBuffer("ropeCache", slice.cacheSize);
NnUint isQ = 1;
segmentBuilder->addOp(
OP_ROPE, "rope_llama", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size0(),
NnRopeOpConfig{ropeType, isQ, posPipeIndex, ropeCacheBufferIndex, 32.0f, 1.0f, 4.0f, 8192, slice});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(2);
float *xPipe = (float *)execution->pipes[0];
float pos[N_BATCHES];
pos[0] = (float)6;
pos[1] = (float)31;
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint i = 0; i < ROPE_DIM; i++)
xPipe[b * ROPE_DIM + i] = 1.0f;
}
device->data.pipes[1].get()->write((NnByte *)pos);
// act
executor->forward();
// assert
float *x0 = &xPipe[0 * ROPE_DIM];
float *x1 = &xPipe[1 * ROPE_DIM];
assertOutput(x0, x1);
});
}
void assertRopeLlama_F32_F32(float *x0, float *x1) {
const float t = 0.000001f;
assertFloat(0, x0[0], 1.239586f, t);
assertFloat(1, x0[1], 0.680755f, t);
assertFloat(2, x0[2], 0.077202f, t);
assertFloat(3, x0[3], -1.412105f, t);
assertFloat(1988, x0[1988], -1.356766f, t);
assertFloat(2022, x0[2022], 0.999923f, t);
assertFloat(2023, x0[2023], 1.000077f, t);
assertFloat(0, x1[0], 1.318780f, t);
assertFloat(1, x1[1], 0.510705f, t);
assertFloat(1078, x1[1078], 0.999985f, t);
assertFloat(1078, x1[1079], 1.000015f, t);
}
void assertRopeFalcon_F32_F32(float *x0, float *x1) {
const float t = 0.000001f;
assertFloat(0, x0[0], 1.239586f, t);
assertFloat(1, x0[1], 0.077202f, t);
assertFloat(2, x0[2], -1.356766f, t);
assertFloat(3, x0[3], -1.164938f, t);
assertFloat(1988, x0[1988], -0.522115f, t);
assertFloat(1988, x0[1989], 0.018772f, t);
assertFloat(2022, x0[2022], 1.361834f, t);
assertFloat(2023, x0[2023], 1.276253f, t);
assertFloat(0, x1[0], 1.318780f, t);
assertFloat(1, x1[1], -1.139289f, t);
assertFloat(1, x1[2], -0.417384f, t);
assertFloat(1, x1[3], -1.291486f, t);
assertFloat(1078, x1[1078], 1.003737f, t);
assertFloat(1078, x1[1079], 1.002481f, t);
}
void testRopeLlama_F32_F32() {
testRope_F32_F32<NnRopeType::ROPE_LLAMA, assertRopeLlama_F32_F32>();
printOk("testRopeLlama_F32_F32");
}
void testRopeFalcon_F32_F32() {
testRope_F32_F32<NnRopeType::ROPE_FALCON, assertRopeFalcon_F32_F32>();
printOk("testRopeFalcon_F32_F32");
}
template <NnUint N, NnUint D>
void testMatmul_F32_F32_F32() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, N));
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, D));
NnUint nullBufferIndex = nodeBuilder->addBuffer("null", size1D(F_32, 1u));
segmentBuilder->addOp(
OP_MATMUL, "matmul", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, yPipeIndex),
size2D(F_32, N, D),
NnMatmulOpConfig{0u, 0u, nullBufferIndex});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[0];
float *yPipe = (float *)execution->pipes[1];
float weight[N * D];
for (NnUint i = 0; i < N_BATCHES * N; i++)
xPipe[i] = i * 0.0001f;
for (NnUint i = 0; i < N * D; i++)
weight[i] = i * 0.000001f;
executor->loadWeight("matmul", 0u, 0u, N * D * sizeof(float), (NnByte *)weight);
// act
executor->forward();
// assert
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint d = 0; d < D; d++) {
float sum = 0.0f;
for (NnUint n = 0; n < N; n++)
sum += xPipe[b * N + n] * weight[d * N + n];
const NnUint p = b * D + d;
assertFloat(p, yPipe[p], sum, 0.0002f);
}
}
printOk("testMatmul_F32_F32_F32");
});
}
void testMatmul_F32_F32_F32_expert() {
#define MATMUL_F32_N 4
#define MATMUL_F32_D 1
#define MATMUL_F32_E 4
#define MATMUL_F32_A 2
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, MATMUL_F32_A, N_BATCHES, MATMUL_F32_N));
NnUint yPipeIndex = netBuilder->addPipe("Y", size3D(F_32, MATMUL_F32_A, N_BATCHES, MATMUL_F32_D));
NnUint activeExpertIndexesIndex = nodeBuilder->addBuffer("indexes", size2D(F_32, N_BATCHES, MATMUL_F32_A));
segmentBuilder->addOp(
OP_MATMUL, "matmul", 0u,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, yPipeIndex),
size3D(F_32, MATMUL_F32_E, MATMUL_F32_N, MATMUL_F32_D),
NnMatmulOpConfig{MATMUL_F32_E, MATMUL_F32_A, activeExpertIndexesIndex});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[0]; // (A, N_BATCHES, N)
float *yPipe = (float *)execution->pipes[1]; // (A, N_BATCHES, D)
constexpr NnUint wSize = MATMUL_F32_N * MATMUL_F32_D;
constexpr NnUint wSizeBytes = wSize * sizeof(float);
float weight[wSize];
float indexes[N_BATCHES * MATMUL_F32_A];
for (NnUint e = 0u; e < MATMUL_F32_E; e++) {
for (NnUint i = 0u; i < wSize; i++)
weight[i] = 0.1f * (float)(e + 1);
executor->loadWeight("matmul", 0u, wSizeBytes * e, wSizeBytes, (NnByte *)weight);
}
for (NnUint i = 0u; i < MATMUL_F32_A * N_BATCHES; i++)
indexes[i] = (float)(i % MATMUL_F32_E); // 0, 1, 2, 3, 0, 1, 2, 3, ...
for (NnUint i = 0; i < MATMUL_F32_A * N_BATCHES * MATMUL_F32_N; i++)
xPipe[i] = (float)(i / MATMUL_F32_N + 1); // 1.0, 1.0, ... 2.0, 2.0, ...
device->data.buffers[0].get()->write((NnByte *)indexes);
executor->forward();
float t = 0.0001f;
assertFloat(0, yPipe[0], 0.1f /* index=0, e=0 */ * (4 * 1.0f), t);
assertFloat(1, yPipe[1], 0.3f /* index=2, e=2 */ * (4 * 2.0f), t);
assertFloat(2, yPipe[2], 0.1f /* index=4, e=0 */ * (4 * 3.0f), t);
assertFloat(3, yPipe[3], 0.3f /* index=6, e=2 */ * (4 * 4.0f), t);
assertFloat(4, yPipe[4], 0.2f /* index=1, e=1 */ * (4 * 5.0f), t);
assertFloat(5, yPipe[5], 0.4f /* index=3, e=3 */ * (4 * 6.0f), t);
assertFloat(6, yPipe[6], 0.2f /* index=5, e=1 */ * (4 * 7.0f), t);
assertFloat(7, yPipe[7], 0.4f /* index=7, e=3 */ * (4 * 8.0f), t);
printOk("testMatmul_F32_F32_F32_expert");
});
}
template <NnUint N, NnUint D>
void testMatmul_Q80_Q40_F32() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_Q80, N_BATCHES, N));
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, D));
NnUint nullBufferIndex = nodeBuilder->addBuffer("null", size1D(F_32, 1u));
segmentBuilder->addOp(
OP_MATMUL, "matmul", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, yPipeIndex),
size2D(F_Q40, N, D),
NnMatmulOpConfig{0u, 0u, nullBufferIndex});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
NnBlockQ80 *xPipe = (NnBlockQ80 *)execution->pipes[0];
float *yPipe = (float *)execution->pipes[1];
constexpr NnUint xSize = N_BATCHES * N;
constexpr NnUint weightSize = N * D;
constexpr NnUint weightBlocks = weightSize / Q40_BLOCK_SIZE;
std::unique_ptr<float[]> x(new float[xSize]);
std::unique_ptr<float[]> weight(new float[weightSize]);
std::unique_ptr<NnBlockQ40[]> weightQ40(new NnBlockQ40[weightBlocks]);
for (NnUint i = 0; i < xSize; i++)
x[i] = 0.1f + (i / (float)N - 0.5f) * 0.0005f;
for (NnUint i = 0; i < weightSize; i++)
weight[i] = 0.1f + (i / (float)D - 0.5f) * 0.0005f;
quantizeF32toQ80(x.get(), xPipe, xSize, 1, 0);
quantizeF32toQ40(weight.get(), weightQ40.get(), weightSize, 1, 0);
executor->loadWeight("matmul", 0u, 0u, weightBlocks * sizeof(NnBlockQ40), (NnByte *)weightQ40.get());
// act
executor->forward();
// assert
for (NnUint b = 0; b < N_BATCHES; b++) {
for (NnUint d = 0; d < D; d++) {
float sum = 0.0f;
for (NnUint n = 0; n < N; n++)
sum += x[b * N + n] * weight[d * N + n];
const NnUint p = b * D + d;
const float err = sum == 0.0 ? (yPipe[p] - sum) : (yPipe[p] - sum) / sum;
// printf("[%d] %f %f (%f)\n", b, yPipe[p], sum, err);
assertFloat(p, err, 0.0f, 0.009f);
}
}
printOk("testMatmul_Q80_Q40_F32");
});
}
void testMultiheadAtt_F32_F32() {
#define MULTIHEAD_ATT_DIM 128
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
const NnUint nHeads = 32;
const NnUint nKvHeads = 8;
const NnUint headDim = MULTIHEAD_ATT_DIM / nHeads;
const NnUint seqLen = 4096;
const NnUint qSliceD0 = 2048;
const NnUint kvDim0 = 512;
const NnKvCacheSlice kvCacheSlice = sliceKvCache(kvDim0, seqLen, 1);
const NnMultiHeadAttSlice multiHeadAttSlice = sliceMultiHeadAtt(nHeads, seqLen, 1, N_BATCHES);
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MULTIHEAD_ATT_DIM));
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
NnUint qBufferIndex = nodeBuilder->addBuffer("POS", size2D(F_32, N_BATCHES, qSliceD0));
NnUint kCacheBufferIndex = nodeBuilder->addBuffer("kCache", kvCacheSlice.keySize);
NnUint vCacheBufferIndex = nodeBuilder->addBuffer("vCache", kvCacheSlice.valueSize);
NnUint attCacheBufferIndex = nodeBuilder->addBuffer("vCache", multiHeadAttSlice.attSize);
segmentBuilder->addOp(
OP_MULTIHEAD_ATT, "multihead_att", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size0(),
NnMultiHeadAttOpConfig{nHeads, nHeads, nKvHeads, headDim, seqLen, qSliceD0, kvDim0,
posPipeIndex, qBufferIndex, kCacheBufferIndex, vCacheBufferIndex, attCacheBufferIndex});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// TODO: for now this is a smoke test
execution->setBatchSize(N_BATCHES);
executor->forward();
printOk("testMultiheadAtt_F32_F32");
});
}
template <NnUint dim, NnUint nZ>
void testSoftmax_F32_F32() {
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, nZ, N_BATCHES, dim));
segmentBuilder->addOp(OP_SOFTMAX, "softmax", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size0(),
NnSoftmaxOpCodeConfig{});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[0];
for (NnUint z = 0u; z < nZ; z++) {
float *xPipeZ = &xPipe[z * N_BATCHES * dim];
for (NnUint b = 0; b < N_BATCHES; b++) {
const NnUint offset = b * dim;
for (NnUint i = 0u; i < dim; i++)
xPipeZ[offset + i] = i / (float)dim + (float)b;
}
}
// act
executor->forward();
// assert
float t = 0.00001f;
for (NnUint z = 0u; z < nZ; z++) {
float *xPipeZ = &xPipe[z * N_BATCHES * dim];
for (NnUint b = 0; b < N_BATCHES; b++) {
const NnUint offset = b * dim;
float max = ((dim - 1) / (float)dim) + (float)b;
for (NnUint i = 0u; i < dim; i++) {
const float v = i / (float)dim + (float)b;
const float expectedOutput = expf(v - max);
assertFloat(offset + i, xPipeZ[offset + i], expectedOutput, t);
}
}
}
printOk("testSoftmax_F32_F32");
});
}
void testScale_F32_F32() {
#define SCALE_F32_N_Z 4
#define SCALE_F32_DIM 64
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, SCALE_F32_N_Z, N_BATCHES, SCALE_F32_DIM));
NnUint scaleBufferIndex = nodeBuilder->addBuffer("scale", size3D(F_32, SCALE_F32_N_Z, SCALE_F32_DIM, 1u));
segmentBuilder->addOp(OP_SCALE, "scale", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, xPipeIndex),
size0(),
NnScaleOpCodeConfig{scaleBufferIndex});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[0];
float scale[SCALE_F32_N_Z * SCALE_F32_DIM];
for (NnUint z = 0u; z < SCALE_F32_N_Z; z++) {
const NnUint zOffset = z * N_BATCHES * SCALE_F32_DIM;
for (NnUint y = 0u; y < N_BATCHES; y++) {
scale[z * N_BATCHES + y] = 0.5f * (float)(z * 100 + y);
for (NnUint i = 0; i < SCALE_F32_DIM; i++)
xPipe[zOffset + y * SCALE_F32_DIM + i] = (float)(z * 10000 + y * 100 + i);
}
}
device->data.buffers[0].get()->write((NnByte *)scale);
// act
executor->forward();
// assert
for (NnUint z = 0u; z < SCALE_F32_N_Z; z++) {
const NnUint zOffset = z * N_BATCHES * SCALE_F32_DIM;
for (NnUint y = 0u; y < N_BATCHES; y++) {
for (NnUint i = 0; i < SCALE_F32_DIM; i++) {
const NnUint p = zOffset + y * SCALE_F32_DIM + i;
const float v = xPipe[p];
const float expectedOutput = (float)(z * 10000 + y * 100 + i) * 0.5f * (float)(z * 100 + y);
assertFloat(p, v, expectedOutput, 0.00001f);
}
}
}
printOk("testScale_F32_F32");
});
}
void testMoeGate_F32_F32() {
#define MOE_GATE_F32_DIM 8
#define MOE_GATE_F32_K 4
execute(
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MOE_GATE_F32_DIM));
NnUint gPipeIndex = netBuilder->addPipe("g", size3D(F_32, MOE_GATE_F32_K, N_BATCHES, 1u));
NnUint indexesBufferIndex = nodeBuilder->addBuffer("i", size2D(F_32, N_BATCHES, MOE_GATE_F32_K));
segmentBuilder->addOp(OP_MOE_GATE, "moe_gate", 0,
pointerBatchConfig(SRC_PIPE, xPipeIndex),
pointerBatchConfig(SRC_PIPE, gPipeIndex),
size0(),
NnMoeGateOpCodeConfig{MOE_GATE_F32_K, 0u, indexesBufferIndex});
},
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
// arrange
execution->setBatchSize(N_BATCHES);
float *xPipe = (float *)execution->pipes[0];
float *gPipe = (float *)execution->pipes[1];
xPipe[0] = 3.0f;
xPipe[1] = 1.0f;
xPipe[2] = 6.0f;
xPipe[3] = 5.0f;
xPipe[4] = 8.0f;
xPipe[5] = 4.0f;
xPipe[6] = 2.0f;
xPipe[7] = 7.0f;
// act
executor->forward();
float pos[N_BATCHES * MOE_GATE_F32_K];
device->data.buffers[0].get()->read((NnByte *)pos);
// assert
const float t = 0.00001f;
assertFloat(0, gPipe[0 * N_BATCHES], 8.0f, t);
assertFloat(1, gPipe[1 * N_BATCHES], 7.0f, t);
assertFloat(2, gPipe[2 * N_BATCHES], 6.0f, t);
assertFloat(3, gPipe[3 * N_BATCHES], 5.0f, t);
assertFloat(100, pos[0], 4.0f, t);
assertFloat(101, pos[1], 7.0f, t);
assertFloat(102, pos[2], 2.0f, t);
assertFloat(103, pos[3], 3.0f, t);
printOk("testMoeGate_F32_F32");
});
}
int main() {
initQuants();
testRmsNorm_F32_F32_F32<4>();
testRmsNorm_F32_F32_F32<1024>();
testRmsNorm_F32_F32_F32<3196>();
testSilu_F32_F32<4>();
testSilu_F32_F32<32>();
testSilu_F32_F32<104>();
testMul_F32_F32<32, 1>();
testMul_F32_F32<48, 4>();
testMergeAdd_F32_F32();
testMergeSum_F32_F32();
testMergeAdd_Q80_F32<2, 64>();
testMergeAdd_Q80_F32<4, 128>();
testMergeAdd_Q80_F32<4, 160>();
testEmbedding_F32_F32();
testShift_F32_F32<32>();
testShift_F32_F32<9>();
testCast_F32_F32<128, 1>();
testCast_F32_F32<32, 2>();
testCast_F32_F32<8, 4>();
testCast_F32_Q80<256, 1>();
testCast_F32_Q80<64, 4>();
testRopeLlama_F32_F32();
testRopeFalcon_F32_F32();
testMatmul_F32_F32_F32<64, 96>();
testMatmul_F32_F32_F32<3191, 109>();
testMatmul_F32_F32_F32_expert();
testMatmul_Q80_Q40_F32<14336, 4096>();
testMatmul_Q80_Q40_F32<4096, 14336>();
testMatmul_Q80_Q40_F32<4096, 4096>();
testMatmul_Q80_Q40_F32<64, 48>();
testMatmul_Q80_Q40_F32<64, 64>();
testMatmul_Q80_Q40_F32<192, 16>();
testMultiheadAtt_F32_F32();
testSoftmax_F32_F32<256, 1>();
testSoftmax_F32_F32<512, 4>();
testScale_F32_F32();
testMoeGate_F32_F32();
return 0;
}

1163
src/nn/nn-vulkan.cpp Normal file

File diff suppressed because it is too large Load Diff

173
src/nn/nn-vulkan.hpp Normal file
View File

@@ -0,0 +1,173 @@
#ifndef NN_VULKAN_HPP
#define NN_VULKAN_HPP
#include <vulkan/vulkan.hpp>
#include <vector>
#include "nn-executor.hpp"
#include "nn-cpu-ops.hpp"
class NnVulkanContext {
public:
vk::Instance instance;
vk::PhysicalDevice physicalDevice;
vk::Device device;
uint32_t queueFamilyIndex;
vk::CommandPool commandPool;
vk::Queue queue;
NnSize nonCoherentAtomSize;
NnVulkanContext(const NnUint gpuIndex);
~NnVulkanContext();
std::pair<vk::Buffer, vk::DeviceMemory> createRawBuffer(const uint32_t memoryTypeIndex, const vk::DeviceSize bufferSize, const vk::BufferUsageFlags usageFlags);
};
enum NnVulkanStagingCopierDirection {
COPY_TO_DEVICE,
COPY_FROM_DEVICE
};
class NnVulkanStagingCopier {
private:
NnVulkanContext *context;
uint32_t memoryTypeIndex;
vk::DeviceSize allocatedSize;
vk::Buffer hostBuffer;
vk::DeviceMemory hostMemory;
void *hostPointer;
public:
NnVulkanStagingCopier(NnVulkanContext *context);
~NnVulkanStagingCopier();
void allocate(const NnSize size);
void copy(NnByte *data, const NnSize size, const NnVulkanStagingCopierDirection direction);
void executeCopyCommand(vk::Buffer& target, const NnSize offset, const NnSize size, const NnVulkanStagingCopierDirection direction);
void addCopyCommand(vk::CommandBuffer& commandBuffer, vk::Buffer& target, const NnSize offset, const NnSize size, const NnVulkanStagingCopierDirection direction);
void tryRelease();
};
class NnVulkanBuffer {
private:
bool isHostVisible;
NnVulkanContext *context;
NnVulkanStagingCopier *copier;
vk::DeviceMemory deviceMemory;
NnByte *hostPointer;
public:
const char *name;
NnSize bufferSize;
bool isSliceable;
vk::Buffer deviceBuffer;
vk::BufferUsageFlags usageFlags;
NnVulkanBuffer(NnVulkanContext *context, NnVulkanStagingCopier *copier, const char *name, const NnSize bufferSize, const bool isSliceable, vk::BufferUsageFlags usageFlags, bool fastAccess);
~NnVulkanBuffer();
void write(const NnByte *data);
void write(const NnByte *data, const NnSize offset, const NnSize size);
void read(NnByte *data);
void read(NnByte *data, const NnSize offset, const NnSize size);
NnSize calcSliceSize(const NnSize nominator, const NnSize denominator);
};
class NnVulkanBufferFactory {
private:
NnVulkanContext *context;
NnVulkanStagingCopier *copier;
public:
NnVulkanBufferFactory(NnVulkanContext *context, NnVulkanStagingCopier *copier);
std::unique_ptr<NnVulkanBuffer> create(const char *name, const NnSize bufferSize, const bool isSliceable, vk::BufferUsageFlags usageFlags, bool fastAccess);
};
typedef struct {
NnUint inputOffset;
NnUint inputSizeX;
NnUint outputOffset;
NnUint outputSizeX;
} NnVulkanBatchInfo;
class NnVulkanDeviceData {
private:
NnNetConfig *netConfig;
NnNodeConfig *nodeConfig;
public:
std::vector<std::unique_ptr<NnVulkanBuffer>> pipes;
std::vector<std::unique_ptr<NnVulkanBuffer>> buffers;
std::vector<std::unique_ptr<NnVulkanBuffer>> internalBuffers;
NnVulkanDeviceData(NnVulkanBufferFactory *bufferFactory, NnNetConfig *netConfig, NnNodeConfig *nodeConfig);
~NnVulkanDeviceData();
NnSize3D resolveBufferSize(NnPointerConfig *config);
NnVulkanBuffer *resolvePointerVulkanBuffer(NnPointerConfig *config);
NnUint resolveBufferBatchOffset(NnPointerConfig *config, NnUint batchIndex, NnUint zIndex);
NnUint resolveBufferBatchWidth(NnPointerConfig *config);
NnVulkanBuffer *resolvePipeByIndex(NnUint pipeIndex);
NnVulkanBuffer *resolveBufferByIndex(NnUint bufferIndex);
};
class NnVulkanDevice : public NnDevice {
private:
NnVulkanContext context;
NnVulkanStagingCopier copier;
NnVulkanBufferFactory bufferFactory;
NnNetConfig *netConfig;
NnNodeConfig *nodeConfig;
NnNetExecution *netExecution;
public:
NnVulkanDeviceData data;
NnVulkanDevice(NnUint gpuIndex, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution);
~NnVulkanDevice() override;
NnUint maxNThreads() override;
NnDeviceSegment *createSegment(NnUint segmentIndex) override;
};
class NnVulkanDeviceSegmentData {
private:
NnVulkanDeviceData *data;
std::vector<NnUint> batchInfoBufferIndex;
std::vector<NnUint> weightBufferIndex;
std::vector<NnUint> configBufferIndex;
public:
NnVulkanDeviceSegmentData(NnVulkanBufferFactory *bufferFactory, NnVulkanDeviceData *data, NnSegmentConfig *segmentConfig, NnUint nBatches);
NnVulkanBuffer *resolveOpBatchInfoVulkanBuffer(NnUint opIndex);
NnVulkanBuffer *resolveOpWeightVulkanBuffer(NnUint opIndex);
NnVulkanBuffer *resolveOpConfigVulkanBuffer(NnUint opIndex);
};
enum NnOpBufferAccessType {
ACCESS_IMMUTABLE,
ACCESS_READONLY,
ACCESS_READ_WRITE,
};
typedef struct {
NnOpBufferAccessType type;
NnVulkanBuffer *buffer;
} NnOpBufferAccess;
class NnVulkanDeviceSegment : public NnDeviceSegment {
private:
NnVulkanContext *context;
NnVulkanStagingCopier *copier;
NnVulkanDeviceData *data;
NnNetConfig *netConfig;
NnUint segmentIndex;
NnSegmentConfig *segmentConfig;
NnNetExecution *netExecution;
std::unique_ptr<NnVulkanDeviceSegmentData> segmentData;
std::vector<vk::ShaderModule> shaderModules;
std::vector<vk::DescriptorSetLayout> descriptorSetLayouts;
vk::DescriptorPool descriptorPool;
std::vector<vk::DescriptorSet> descriptorSets;
vk::Fence fence;
std::vector<vk::PipelineLayout> pipelineLayouts;
std::vector<vk::Pipeline> pipelines;
vk::PipelineCache pipelineCache;
vk::CommandBuffer commandBuffer;
std::vector<std::vector<NnVulkanBuffer *>> buffersToSync;
NnUint lastBatchSize;
public:
NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanStagingCopier *copier, NnVulkanBufferFactory *bufferFactory, NnVulkanDeviceData *data, NnNetConfig *netConfig, NnUint segmentIndex, NnSegmentConfig *segmentConfig, NnNetExecution *netExecution);
~NnVulkanDeviceSegment() override;
void loadWeight(NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) override;
void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override;
};
#endif

39
src/nn/pthread.h Normal file
View File

@@ -0,0 +1,39 @@
#ifndef PTHREAD_WRAPPER
#define PTHREAD_WRAPPER
#ifdef _WIN32
#include <windows.h>
typedef HANDLE PthreadHandler;
typedef DWORD PthreadResult;
typedef DWORD (WINAPI *PthreadFunc)(void *);
static int pthread_create(PthreadHandler *out, void *unused, PthreadFunc func, void *arg) {
(void) unused;
PthreadHandler handle = CreateThread(NULL, 0, func, arg, 0, NULL);
if (handle == NULL) {
return EAGAIN;
}
*out = handle;
return 0;
}
static int pthread_join(PthreadHandler thread, void *unused) {
(void) unused;
DWORD ret = WaitForSingleObject(thread, INFINITE);
if (ret == WAIT_FAILED) {
return -1;
}
CloseHandle(thread);
return 0;
}
#else
#include <pthread.h>
typedef pthread_t PthreadHandler;
typedef void* PthreadResult;
typedef void* (*PthreadFunc)(void *);
#endif
#endif

View File

@@ -0,0 +1,37 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define CHUNK_SIZE 4
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint zIndex = gl_WorkGroupID.z;
const uint b = zIndex * N_BATCHES + batchIndex;
const uint chunkIndex = gl_WorkGroupID.x;
const BatchInfo info = infos[b];
const uint offset = chunkIndex * CHUNK_SIZE;
const uint xOffset = info.inputOffset + offset;
const uint yOffset = info.outputOffset + offset;
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
y[yOffset + i] = x[xOffset + i];
}
}

View File

@@ -0,0 +1,56 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_8bit_storage : enable
#extension GL_EXT_shader_16bit_storage : enable
#extension GL_EXT_shader_explicit_arithmetic_types : enable
#define Q80_BLOCK_SIZE 32
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset; // number of Q80 blocks
uint outputSizeX; // number of Q80 blocks
};
struct BlockQ80 {
float16_t d;
int8_t qs[Q80_BLOCK_SIZE];
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { BlockQ80 y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint zIndex = gl_WorkGroupID.z;
const uint b = zIndex * N_BATCHES + batchIndex;
const uint d = gl_WorkGroupID.x;
const BatchInfo info = infos[b];
const uint xiOffset = info.inputOffset + d * Q80_BLOCK_SIZE;
const uint yiOffset = info.outputOffset + d;
float amax = 0.0;
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
amax = max(amax, abs(x[xiOffset + j]));
}
const float dd = amax / 127.0f;
const float id = dd != 0.0f ? 1.0f / dd : 0.0f;
y[yiOffset].d = float16_t(dd);
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
const float v = x[xiOffset + j];
y[yiOffset].qs[j] = int8_t(clamp(round(v * id), -127.0f, 127.0f));
}
}

View File

@@ -0,0 +1,38 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define CHUNK_SIZE 4
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
shared uint sharedPosition;
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint chunkIndex = gl_WorkGroupID.x;
const uint position = uint(x[batchIndex]);
const BatchInfo info = infos[batchIndex];
const uint offset = chunkIndex * CHUNK_SIZE;
const uint yOffset = info.outputOffset + offset;
const uint wOffset = position * info.outputSizeX + offset;
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
y[yOffset + i] = weight[wOffset + i];
}
}

View File

@@ -0,0 +1,55 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define N_THREADS 256
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
layout(binding = 3) readonly uniform opConfigBuffer {
float epsilon;
uint nColumns;
};
shared float sums[N_THREADS];
void main() {
const uint threadIndex = gl_LocalInvocationID.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint colIndex = gl_WorkGroupID.x;
const BatchInfo info = infos[batchIndex];
const uint dim = info.inputSizeX / nColumns;
const uint offset = info.inputOffset + dim * colIndex;
float sum = 0.0f;
for (uint i = threadIndex; i < dim; i += N_THREADS) {
float v = x[offset + i];
sum += v * v;
}
sums[threadIndex] = sum;
barrier();
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
if (threadIndex < i)
sums[threadIndex] += sums[threadIndex + i];
barrier();
}
if (threadIndex == 0) {
y[batchIndex * nColumns + colIndex] = inversesqrt((sums[0] / float(dim)) + epsilon);
}
}

View File

@@ -0,0 +1,60 @@
#version 450
#define N_THREADS 128
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
layout(binding = 4) readonly uniform opConfigBuffer {
uint nExperts;
uint nActiveExperts;
uint activeExpertIndexesBufferIndex;
};
layout(binding = 5) readonly buffer activeExpertIndexesBuffer { float activeExpertIndexes[]; };
void main() {
const uint threadIndex = gl_LocalInvocationID.x;
const uint nWorkGroups = gl_NumWorkGroups.x;
const uint workGroupIndex = gl_WorkGroupID.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint zIndex = gl_WorkGroupID.z;
BatchInfo info = infos[zIndex * N_BATCHES + batchIndex];
const uint slice = info.outputSizeX / nWorkGroups;
const uint rest = info.outputSizeX % nWorkGroups;
const uint dim = slice + (workGroupIndex < rest ? 1 : 0);
const uint offset = workGroupIndex * slice + min(workGroupIndex, rest);
const uint expertIndex = nExperts == 0
? 0
: uint(activeExpertIndexes[batchIndex * nActiveExperts + zIndex]);
const uint expertOffset = expertIndex * info.inputSizeX * info.outputSizeX;
const uint inputSizeX = info.inputSizeX;
const uint xOffset = info.inputOffset;
const uint yOffset = info.outputOffset;
for (uint i = threadIndex; i < dim; i += N_THREADS) {
const uint d = offset + i;
const uint wOffset = expertOffset + d * inputSizeX;
float sum = 0.0f;
for (uint j = 0; j < inputSizeX; j++) {
sum += x[xOffset + j] * weight[wOffset + j];
}
y[yOffset + d] = sum;
}
}

View File

@@ -0,0 +1,125 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_8bit_storage : enable
#extension GL_EXT_shader_16bit_storage : enable
#extension GL_EXT_shader_explicit_arithmetic_types : enable
#define TILE_SIZE_X 2
#define TILE_SIZE_D 8
#define Q80_Q40_BLOCK_SIZE 32
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
layout(constant_id = 2) const uint N_THREADS = 32;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
struct BlockQ80 {
float16_t d;
int8_t qs[Q80_Q40_BLOCK_SIZE];
};
struct BlockQ40 {
float16_t d;
uint8_t qs[Q80_Q40_BLOCK_SIZE / 2];
};
layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
layout(binding = 4) readonly uniform opConfigBuffer {
uint nExperts;
uint nActiveExperts;
uint activeExpertIndexesBufferIndex;
};
layout(binding = 5) readonly buffer activeExpertIndexesBuffer { float activeExpertIndexes[]; };
shared float sums[N_THREADS * TILE_SIZE_D];
void main() {
const uint threadIndex = gl_LocalInvocationID.x;
const uint workGroupIndex = gl_WorkGroupID.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint zIndex = gl_WorkGroupID.z;
const uint b = zIndex * N_BATCHES + batchIndex;
const BatchInfo info = infos[b];
const uint expertIndex = nExperts == 0
? 0
: uint(activeExpertIndexes[batchIndex * nActiveExperts + zIndex]);
const uint expertOffset = expertIndex * info.inputSizeX * info.outputSizeX;
const uint inputOffset = info.inputOffset;
const uint inputSizeX = info.inputSizeX;
const uint d = TILE_SIZE_D * workGroupIndex;
vec4 xTemp[Q80_Q40_BLOCK_SIZE / 4];
for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
sums[threadIndex * TILE_SIZE_D + dt] = 0.0f;
}
[[unroll]] for (uint it = 0; it < TILE_SIZE_X; it++) {
const uint xi = inputOffset + threadIndex + it * N_THREADS;
const float xScale = float(x[xi].d);
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
xTemp[j] = vec4(
x[xi].qs[j * 2],
x[xi].qs[j * 2 + Q80_Q40_BLOCK_SIZE / 2],
x[xi].qs[j * 2 + 1],
x[xi].qs[j * 2 + 1 + Q80_Q40_BLOCK_SIZE / 2]
);
}
[[unroll]] for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
const uint wi = expertOffset + (d + dt) * inputSizeX + threadIndex + it * N_THREADS;
const BlockQ40 wBlock = weight[wi];
float s = 0.0f;
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
uint w0 = wBlock.qs[j * 2];
uint w1 = wBlock.qs[j * 2 + 1];
s += dot(xTemp[j], vec4(
int(w0 & 0xFu) - 8,
int(w0 >> 4) - 8,
int(w1 & 0xFu) - 8,
int(w1 >> 4) - 8
));
}
sums[threadIndex * TILE_SIZE_D + dt] += s * xScale * wBlock.d;
}
}
barrier();
const uint outputOffset = infos[b].outputOffset; // Hoisting fix for Raspberry PI
uint i = N_THREADS;
while (i % 2 == 0) {
i >>= 1;
for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
if (threadIndex < i) {
sums[threadIndex * TILE_SIZE_D + dt] += sums[(threadIndex + i) * TILE_SIZE_D + dt];
}
}
barrier();
}
for (uint dt = threadIndex; dt < TILE_SIZE_D; dt += N_THREADS) {
float s = 0.0;
for (uint j = 1; j <= i; j++) {
s += sums[(j - 1) * TILE_SIZE_D + dt];
}
y[outputOffset + d + dt] = float(s);
}
}

View File

@@ -0,0 +1,46 @@
#version 450
#define N_THREADS 256
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
void main() {
const uint threadIndex = gl_LocalInvocationID.x;
const uint nWorkGroups = gl_NumWorkGroups.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint workGroupIndex = gl_WorkGroupID.x;
const BatchInfo info = infos[batchIndex];
const uint slice = info.outputSizeX / nWorkGroups;
const uint rest = info.outputSizeX % nWorkGroups;
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);
const uint dim = slice + (workGroupIndex < rest ? 1 : 0);
const uint outputSizeX = info.outputSizeX;
const uint parts = info.inputSizeX / info.outputSizeX;
const uint xOffset = info.inputOffset + offset;
const uint yOffset = info.outputOffset + offset;
for (uint i = threadIndex; i < dim; i += N_THREADS) {
float sum = 0.0;
const uint iOffset = xOffset + i;
const uint oOffset = yOffset + i;
for (uint n = 0; n < parts; n++) {
sum += x[n * outputSizeX + iOffset];
}
y[oOffset] += sum;
}
}

View File

@@ -0,0 +1,57 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_8bit_storage : enable
#extension GL_EXT_shader_16bit_storage : enable
#extension GL_EXT_shader_explicit_arithmetic_types : enable
#define Q80_BLOCK_SIZE 32
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
struct BatchInfo {
uint inputOffset; // number of Q80 blocks
uint inputSizeX; // number of Q80 blocks
uint outputOffset;
uint outputSizeX;
};
struct BlockQ80 {
float16_t d;
int8_t qs[Q80_BLOCK_SIZE];
};
layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; };
layout(binding = 1) buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint i = gl_WorkGroupID.x;
const BatchInfo info = infos[batchIndex];
const uint nBlocks = info.outputSizeX / Q80_BLOCK_SIZE; // 128
const uint nSlices = info.inputSizeX / nBlocks;
float16_t sums[Q80_BLOCK_SIZE];
const uint xiOffset = info.inputOffset + i;
const uint yiOffset = info.outputOffset + i * Q80_BLOCK_SIZE;
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
sums[k] = float16_t(0.0f);
}
for (uint n = 0; n < nSlices; n++) {
const BlockQ80 b = x[xiOffset + n * nBlocks];
const float16_t d = b.d;
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
sums[k] += float16_t(b.qs[k]) * d;
}
}
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
y[yiOffset + k] += float(sums[k]);
}
}

View File

@@ -0,0 +1,41 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define CHUNK_SIZE 4
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
void main() {
const uint workGroupIndex = gl_WorkGroupID.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint chunkOffset = workGroupIndex * CHUNK_SIZE;
const BatchInfo outputInfo = infos[batchIndex];
const uint yOffset = outputInfo.outputOffset + chunkOffset;
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
float sum = 0.0f;
for (uint z = 0; z < N_Z; z++) {
const BatchInfo inputInfo = infos[batchIndex + z * N_BATCHES];
const uint xOffset = inputInfo.inputOffset + chunkOffset;
sum += x[xOffset + i];
}
y[yOffset + i] = sum;
}
}

View File

@@ -0,0 +1,69 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N = 16;
layout(constant_id = 2) const uint K = 2;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[K * N_BATCHES]; };
layout(binding = 3) readonly uniform opConfigBuffer {
uint k;
uint normTopk;
uint indexesBufferIndex;
};
layout(binding = 4) buffer indexesBuffer { float indexes[]; };
shared float topVals[K];
shared uint topIdx[K];
void main() {
// TODO: this impl is not optimal
const uint batchIndex = gl_WorkGroupID.y;
BatchInfo info = infos[batchIndex];
for (uint i = 0; i < K; i++) {
topVals[i] = -1e10f;
topIdx[i] = 0;
}
for (uint i = 0; i < info.inputSizeX; i++) {
float v = x[info.inputOffset + i];
for (uint k = 0; k < K; k++) {
if (v > topVals[k]) {
for (uint s = K - 1; s > k; s--) {
topVals[s] = topVals[s - 1];
topIdx[s] = topIdx[s - 1];
}
topVals[k] = v;
topIdx[k] = i;
break;
}
}
}
float sum = 1.0f;
if (normTopk == 1) {
sum = 0.0f;
for (uint k = 0; k < K; k++) {
sum += topVals[k];
}
}
for (uint k = 0; k < K; k++) {
indexes[batchIndex * K + k] = float(topIdx[k]);
y[infos[k * N_BATCHES + batchIndex].outputOffset] = topVals[k] / sum;
}
}

View File

@@ -0,0 +1,41 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define CHUNK_SIZE 4
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
layout(binding = 3) readonly uniform configBuffer {
uint multiplierBufferIndex;
};
layout(binding = 4) readonly buffer multiplierBuffer { float m[]; };
void main() {
const uint chunkIndex = gl_WorkGroupID.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint zIndex = gl_WorkGroupID.z;
const uint b = zIndex * N_BATCHES + batchIndex;
const BatchInfo info = infos[b];
const uint chunkOffset = chunkIndex * CHUNK_SIZE;
const uint xyOffset = info.inputOffset + chunkOffset;
const uint mOffset = info.inputSizeX * b + chunkOffset;
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
y[xyOffset + i] = x[xyOffset + i] * m[mOffset + i];
}
}

View File

@@ -0,0 +1,127 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define N_THREADS 256
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
layout(binding = 3) readonly uniform configBuffer {
uint nHeads;
uint nHeads0;
uint nKvHeads;
uint headDim;
uint seqLen;
uint qSliceD0;
uint kvDim0;
uint positionPipeIndex;
uint queryBufferIndex;
uint keyCacheBufferIndex;
uint valueCacheBufferIndex;
uint attBufferIndex;
};
layout(binding = 4) readonly buffer positionsBuffer { float positions[]; };
layout(binding = 5) readonly buffer queryBuffer { float query[]; };
layout(binding = 6) readonly buffer keyCacheBuffer { float keyCache[]; };
layout(binding = 7) readonly buffer valueCacheBuffer { float valueCache[]; };
layout(binding = 8) buffer attBufferBuffer { float att[]; };
shared uint sharedPosition;
shared float sharedMaxScore;
shared float temp[N_THREADS];
void main() {
const uint threadIndex = gl_LocalInvocationID.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint h = gl_WorkGroupID.x;
if (threadIndex == 0) {
sharedPosition = uint(positions[batchIndex]);
}
barrier();
const uint kvMul = nHeads / nKvHeads;
const uint headIndex = h / kvMul;
const float invHeadDimRoot = 1.0f / sqrt(float(headDim));
const BatchInfo info = infos[batchIndex];
const uint position = sharedPosition;
const uint attOffset = batchIndex * nHeads0 * seqLen + h * seqLen;
const uint qOffset = batchIndex * qSliceD0 + h * headDim;
const uint kvOffset = headIndex * headDim;
const uint yOffset = info.outputOffset + h * headDim;
float ms = -1e10f;
for (uint p = threadIndex; p <= position; p += N_THREADS) {
const uint kOffset = kvOffset + p * kvDim0;
float score = 0.0f;
for (uint i = 0; i < headDim; i++) {
score += query[qOffset + i] * keyCache[kOffset + i];
}
score *= invHeadDimRoot;
ms = max(ms, score);
att[attOffset + p] = score;
}
temp[threadIndex] = ms;
barrier();
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
if (threadIndex < i)
temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
barrier();
}
if (threadIndex == 0) {
sharedMaxScore = temp[0];
}
barrier();
const float maxScore = sharedMaxScore;
float s = 0.0f;
for (uint p = threadIndex; p <= position; p += N_THREADS) {
float v = exp(att[attOffset + p] - maxScore);
att[attOffset + p] = v;
s += v;
}
temp[threadIndex] = s;
barrier();
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
if (threadIndex < i)
temp[threadIndex] += temp[threadIndex + i];
barrier();
}
const float yScale = 1.0f / temp[0];
for (uint i = threadIndex; i < headDim; i += N_THREADS) {
float sum = 0.0f;
const uint vOffset = kvOffset + i;
for (uint p = 0; p <= position; p += 1) {
const float a = att[attOffset + p];
const float v = valueCache[vOffset + p * kvDim0];
sum += v * a;
}
y[yOffset + i] = sum * yScale;
}
}

View File

@@ -0,0 +1,60 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_8bit_storage : enable
#extension GL_EXT_shader_16bit_storage : enable
#extension GL_EXT_shader_explicit_arithmetic_types : enable
#define Q80_BLOCK_SIZE 32
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
struct BlockQ80 {
float16_t d;
int8_t qs[Q80_BLOCK_SIZE];
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { BlockQ80 y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint d = gl_WorkGroupID.x;
const BatchInfo info = infos[batchIndex];
const uint xiOffset = info.inputOffset + d * Q80_BLOCK_SIZE;
float amax = 0.0;
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
amax = max(amax, abs(x[xiOffset + j]));
}
const float dd = amax / 127.0f;
const float id = dd != 0.0f ? 1.0f / dd : 0.0f;
const float16_t dd16 = float16_t(dd);
for (uint z = 0; z < N_Z; z++) {
const uint yiOffset = infos[batchIndex + z * N_BATCHES].outputOffset + d;
y[yiOffset].d = dd16;
}
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
const float v = x[xiOffset + j];
const int8_t v8 = int8_t(clamp(round(v * id), -127.0f, 127.0f));
for (uint z = 0; z < N_Z; z++) {
const uint yiOffset = infos[batchIndex + z * N_BATCHES].outputOffset + d;
y[yiOffset].qs[j] = v8;
}
}
}

View File

@@ -0,0 +1,44 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define CHUNK_SIZE 4
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
layout(binding = 4) readonly uniform configBuffer {
uint invRmsBufferIndex; // not used
uint nColumns;
};
layout(binding = 5) readonly buffer invRmsBuffer { float invRms[]; };
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint chunkIndex = gl_WorkGroupID.x;
const BatchInfo info = infos[batchIndex];
const uint dim = info.inputSizeX / nColumns;
const uint offset = chunkIndex * CHUNK_SIZE;
const uint colIndex = offset / dim;
const float s = invRms[batchIndex * nColumns + colIndex];
const uint xOffset = info.inputOffset + offset;
const uint yOffset = info.outputOffset + offset;
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
y[yOffset + i] = (x[xOffset + i] * s) * weight[(offset + i) % dim];
}
}

View File

@@ -0,0 +1,115 @@
#version 450
#define N_THREADS 256
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
struct RopeSlice {
uint qDim0;
uint qDimStart;
uint qDimEnd;
uint qShift;
uint kvDim;
uint kvDim0;
uint kvDimStart;
uint sliceDim;
uint seqLen;
uint headDim;
uint nKvHeads;
float ropeTheta;
// NnSize2D cacheSize;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
layout(binding = 3) readonly uniform configBuffer {
uint ropeType;
uint isQ;
uint positionPipeIndex;
uint ropeCacheBufferIndex;
float ropeScalingFactor;
float ropeScalingLowFreqFactor;
float ropeScalingHighFreqFactor;
uint ropeScalingOrigMaxSeqLen;
RopeSlice slice;
};
layout(binding = 4) readonly buffer positionsBuffer { float positions[]; };
layout(binding = 5) readonly buffer ropeCacheBuffer { float ropeCache[]; };
shared uint sharedPosition;
void main() {
const uint threadIndex = gl_LocalInvocationID.x;
const uint batchIndex = gl_WorkGroupID.y;
if (threadIndex == 0) {
sharedPosition = uint(positions[batchIndex]);
}
barrier();
const uint position = sharedPosition;
const BatchInfo info = infos[batchIndex];
const uint xOffset = info.inputOffset;
const uint yOffset = info.outputOffset;
const uint dim0 = isQ == 1 ? slice.qDim0 : slice.kvDim0;
if (ropeType == 0 || ropeType == 2 /* Llama */) {
uint posOffset = position * slice.sliceDim;
if (isQ == 1) {
posOffset += slice.qShift;
}
const uint dim0Half = dim0 / 2;
for (uint i = threadIndex; i < dim0Half; i += N_THREADS) {
const uint j = i * 2;
const uint c = posOffset + j;
float fcr = ropeCache[c];
float fci = ropeCache[c + 1];
float v0 = x[xOffset + j];
float v1 = x[xOffset + j + 1];
const float x0 = fma(-v1, fci, v0 * fcr);
const float x1 = fma( v0, fci, v1 * fcr);
y[yOffset + j] = x0;
y[yOffset + j + 1] = x1;
}
} else if (ropeType == 1 /* Falcon */) {
const uint posOffset = position * slice.headDim;
const uint headDim = slice.headDim;
const uint headDimHalf = headDim / 2;
const uint nHeads0 = dim0 / headDim;
for (uint h = 0; h < nHeads0; h++) {
const uint o = h * headDim;
for (uint i = threadIndex; i < headDimHalf; i += N_THREADS) {
const uint c = posOffset + i;
float fcr = ropeCache[c];
float fci = ropeCache[c + headDimHalf];
float v0 = x[xOffset + o + i];
float v1 = x[xOffset + o + i + headDimHalf];
float x0 = v0 * fcr - v1 * fci;
float x1 = v0 * fci + v1 * fcr;
y[yOffset + o + i] = x0;
y[yOffset + o + i + headDimHalf] = x1;
}
}
}
}

View File

@@ -0,0 +1,43 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define CHUNK_SIZE 4
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
layout(binding = 3) readonly buffer configBuffer {
uint scaleBufferIndex;
};
layout(binding = 4) readonly buffer scaleBuffer { float scale[]; };
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint workGroupIndex = gl_WorkGroupID.x;
const uint zIndex = gl_WorkGroupID.z;
const uint b = zIndex * N_BATCHES + batchIndex;
const float s = scale[b];
const BatchInfo info = infos[b];
const uint offset = workGroupIndex * CHUNK_SIZE;
const uint xOffset = info.inputOffset + offset;
const uint yOffset = info.outputOffset + offset;
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
y[yOffset + i] = x[xOffset + i] * s;
}
}

View File

@@ -0,0 +1,40 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define CHUNK_SIZE 4
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
layout(binding = 3) readonly uniform configBuffer {
uint indexPipeIndex;
};
layout(binding = 4) readonly buffer indexBuffer { float indexes[]; };
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint chunkIndex = gl_WorkGroupID.x;
const uint index = uint(indexes[batchIndex]);
const BatchInfo info = infos[batchIndex];
const uint offset = chunkIndex * CHUNK_SIZE;
const uint xOffset = info.inputOffset + offset;;
const uint yOffset = index * info.inputSizeX + offset;
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
y[yOffset + i] = x[xOffset + i];
}
}

View File

@@ -0,0 +1,38 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define CHUNK_SIZE 4
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
void main() {
const uint batchIndex = gl_WorkGroupID.y;
const uint chunkIndex = gl_WorkGroupID.x;
const uint zIndex = gl_WorkGroupID.z;
const uint b = zIndex * N_BATCHES + batchIndex;
const BatchInfo info = infos[b];
const uint offset = chunkIndex * CHUNK_SIZE;
const uint xOffset = info.inputOffset + offset;
const uint yOffset = info.outputOffset + offset;
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
float v = x[xOffset + i];
y[yOffset + i] = v / (1.0f + exp(-v));
}
}

View File

@@ -0,0 +1,56 @@
#version 450
#extension GL_EXT_control_flow_attributes : enable
#define N_THREADS 256
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
layout(constant_id = 0) const uint N_BATCHES = 32;
layout(constant_id = 1) const uint N_Z = 1;
struct BatchInfo {
uint inputOffset;
uint inputSizeX;
uint outputOffset;
uint outputSizeX;
};
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
shared float temp[N_THREADS];
void main() {
const uint threadIndex = gl_LocalInvocationID.x;
const uint batchIndex = gl_WorkGroupID.y;
const uint zIndex = gl_WorkGroupID.z;
const uint b = zIndex * N_BATCHES + batchIndex;
const BatchInfo info = infos[b];
const uint size = info.inputSizeX;
const uint xOffset = info.inputOffset;
const uint yOffset = info.outputOffset;
float m = -1e10f;
for (uint i = threadIndex; i < size; i += N_THREADS) {
m = max(m, x[xOffset + i]);
}
temp[threadIndex] = m;
barrier();
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
if (threadIndex < i)
temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
barrier();
}
barrier();
const float maxVal = temp[0];
for (uint i = threadIndex; i < size; i += N_THREADS) {
y[yOffset + i] = exp(x[xOffset + i] - maxVal);
}
}

319
src/tokenizer-test.cpp Normal file
View File

@@ -0,0 +1,319 @@
#include <cassert>
#include <cstring>
#include "tokenizer.hpp"
#define DEV_TESTS false
#define ASSERT_EQ(a, b) \
if (a != b) { \
printf("Assertion failed: %d != %d (%s:%d)\n", a, b, __FILE__, __LINE__); \
exit(-1); \
}
#define TEST_EOS_ID 10000
void printOk(const char *name) {
printf("✅ %24s passed\n", name);
}
void compare(const char *name, const int *a, const int *b, const unsigned int aN, const int bN) {
bool passed = true;
if (aN != bN) {
passed = false;
} else {
for (unsigned int i = 0; i < bN; i++) {
if (a[i] != b[i]) {
passed = false;
break;
}
}
}
if (!passed) {
printf("❌ %24s failed\na: ", name);
for (unsigned int j = 0; j < aN; j++)
printf("%5d ", a[j]);
printf("\nb: ");
for (unsigned int j = 0; j < bN; j++)
printf("%5d ", b[j]);
printf("\n");
exit(1);
}
printOk(name);
}
void dev_testEncode(Tokenizer *tokenizer) {
int tokens[1024];
int nTokens;
{
const char *text = "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
const int expectedTokens[] = {128000, 128006, 882, 128007, 271, 15339, 128009, 128006, 78191, 128007, 271};
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
compare("case0", expectedTokens, tokens, 11, nTokens);
}
{
const char *text = "!!&&@(*x)^^!";
const int expectedTokens[] = {128000, 3001, 7827, 31, 4163, 87, 8, 22634, 0};
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
compare("case1", expectedTokens, tokens, 9, nTokens);
}
{
const char *text = "😃!😇x";
const int expectedTokens[] = {128000, 76460, 225, 0, 76460, 229, 87};
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
compare("case2", expectedTokens, tokens, 7, nTokens);
}
}
void dev_testDecoderEmojiStreamRecover(Tokenizer *tokenizer) {
char *x0 = tokenizer->decode(128000);
assert(x0 == nullptr);
char *x1 = tokenizer->decode(76460);
assert(x1 == nullptr);
char *x2 = tokenizer->decode(76460);
assert(x2 == nullptr);
char *x3 = tokenizer->decode(225);
assert(strcmp(x3, "<EFBFBD>😃") == 0);
printOk("testDecoderEmojiStreamRecover");
}
void dev_testDecoderEmoji(Tokenizer *tokenizer) {
char *x0 = tokenizer->decode(128000);
assert(x0 == nullptr);
char *x1 = tokenizer->decode(76460);
assert(x1 == nullptr);
char *x2 = tokenizer->decode(225);
assert(strcmp(x2, "😃") == 0);
char *x3 = tokenizer->decode(0);
assert(strcmp(x3, "!") == 0);
char *x4 = tokenizer->decode(56);
assert(strcmp(x4, "Y") == 0);
printOk("testDecoderEmoji");
}
void dev_testDecoderEmojiWithEos(Tokenizer *tokenizer) {
char *x0 = tokenizer->decode(128000);
assert(x0 == nullptr);
char *x1 = tokenizer->decode(76460);
assert(x1 == nullptr);
char *x2 = tokenizer->decode(225);
assert(strcmp(x2, "😃") == 0);
char *x3 = tokenizer->decode(128001);
assert(x3 == nullptr); // piece should not contain <|end_of_text|>
printOk("decoderEmojiWithEos");
}
void testChatTemplateDetection() {
ChatTemplateGenerator t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "<eos>");
assert(t0.type == TEMPLATE_LLAMA3);
printOk("chatTemplateDetection");
}
void testEosDetectorWithPadding() {
const int tokens[2] = {TEST_EOS_ID, TEST_EOS_ID + 1};
const char *pieces[2] = { "<eos>", "<stop>" };
EosDetector detector(2, tokens, pieces, 1, 1);
// "<eos>"
{
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "eo"), MAYBE_EOS);
ASSERT_EQ(detector.append(3, "s>"), EOS);
assert(detector.getDelta() == nullptr);
}
// "<stop> "
detector.reset();
{
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "stop"), MAYBE_EOS);
ASSERT_EQ(detector.append(3, "> "), EOS);
assert(detector.getDelta() == nullptr);
}
// " "
detector.reset();
{
ASSERT_EQ(detector.append(1, " "), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, " ") == 0);
}
// "!<eos> "
detector.reset();
{
ASSERT_EQ(detector.append(1, "!<"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "eos"), MAYBE_EOS);
ASSERT_EQ(detector.append(3, "> "), EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "!") == 0);
}
// "!<eos> "
detector.reset();
{
ASSERT_EQ(detector.append(1, "<eo"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "s>XY"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "<eos>XY") == 0);
}
// "<eo" + EOS
detector.reset();
{
ASSERT_EQ(detector.append(1, "<eo"), MAYBE_EOS);
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "<eo") == 0);
}
// EOS
detector.reset();
{
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
assert(detector.getDelta() == nullptr);
}
// after reset it's expected to return nullptr delta if to the append() passed nullptr piece
detector.reset();
{
ASSERT_EQ(detector.append(1, "x"), NOT_EOS);
char *delta0 = detector.getDelta();
assert(std::strcmp(delta0, "x") == 0);
detector.reset();
ASSERT_EQ(detector.append(2, nullptr), NOT_EOS);
char *delta1 = detector.getDelta();
assert(delta1 == nullptr);
}
printOk("eosDetectorWithPadding");
}
void testEosDetectorWithLongPadding() {
const int tokens[1] = {TEST_EOS_ID};
const char *pieces[1] = { "|end|" };
EosDetector detector(1, tokens, pieces, 5, 5);
// "lipsum"
{
ASSERT_EQ(detector.append(1, "lipsum"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "lipsum") == 0);
}
// "lorem"
detector.reset();
{
ASSERT_EQ(detector.append(1, "lorem"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "lorem") == 0);
}
// "lorem|enQ"
detector.reset();
{
ASSERT_EQ(detector.append(1, "lorem|"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "enQ"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "lorem|enQ") == 0);
}
printOk("eosDetectorWithLongPadding");
}
void testEosDetectorWithoutPadding() {
const int tokens[1] = {TEST_EOS_ID};
const char *pieces[1] = { "<eos>" };
EosDetector detector(1, tokens, pieces, 0, 0);
// "<eos>"
{
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "eo"), MAYBE_EOS);
ASSERT_EQ(detector.append(3, "s>"), EOS);
assert(detector.getDelta() == nullptr);
}
// " <"
detector.reset();
{
ASSERT_EQ(detector.append(1, " <"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, " <") == 0);
}
// "<eos> "
detector.reset();
{
ASSERT_EQ(detector.append(1, "<eos"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "> "), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "<eos> ") == 0);
}
// EOS
detector.reset();
{
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
assert(detector.getDelta() == nullptr);
}
// emoji
detector.reset();
{
ASSERT_EQ(detector.append(TEST_EOS_ID, "😃"), EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "😃") == 0);
}
printOk("eosDetectorWithLongPadding");
}
int main() {
#if DEV_TESTS
Tokenizer tokenizer("models/llama3_2_1b_instruct_q40/dllama_tokenizer_llama3_2_1b_instruct_q40.t");
dev_testEncode(&tokenizer);
dev_testDecoderEmoji(&tokenizer);
dev_testDecoderEmojiWithEos(&tokenizer);
dev_testDecoderEmojiStreamRecover(&tokenizer);
#endif
testChatTemplateDetection();
testEosDetectorWithPadding();
testEosDetectorWithLongPadding();
testEosDetectorWithoutPadding();
return 0;
}

722
src/tokenizer.cpp Normal file
View File

@@ -0,0 +1,722 @@
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cstring>
#include <cstdint>
#include <fcntl.h>
#include <ctype.h>
#include <ctime>
#include <cassert>
#include <stdexcept>
#include <sstream>
#include <vector>
#include "nn/nn-core.hpp"
#include "nn/nn-cpu-ops.hpp"
#include "tokenizer.hpp"
#if defined(__ARM_NEON)
#include <arm_neon.h>
#endif
#define DEBUG_TOKENIZER_ENCODER false
#define DEBUG_TOKENIZER_BENCHMARK false
#define DEBUG_TEMPLATE_GENERATOR false
#define DEBUG_SAMPLER_BENCHMARK false
unsigned int randomU32(unsigned long long *state) {
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
*state ^= *state >> 12;
*state ^= *state << 25;
*state ^= *state >> 27;
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
}
float randomF32(unsigned long long *state) {
// random float32 in <0,1)
return (randomU32(state) >> 8) / 16777216.0f;
}
int compareTokens(const void *a, const void *b) {
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
}
Tokenizer::Tokenizer(const char* tokenizerPath)
: eosTokenIds() {
bosId = -1;
chatTemplate = nullptr;
maxTokenLength = 0;
// read in the file
FILE *file = fopen(tokenizerPath, "rb");
if (!file)
throw std::runtime_error("Failed to open tokenizer file");
int magic;
if (fread(&magic, sizeof(int), 1, file) != 1)
throw std::runtime_error("Cannot read tokenizer magic number");
if (magic == 0x567123) {
TokenizerOldHeader header;
if (fread(&header, sizeof(TokenizerOldHeader), 1, file) != 1)
throw std::runtime_error("Cannot read tokenizer header");
maxTokenLength = header.maxTokenLength;
vocabSize = header.vocabSize;
bosId = header.bosId;
eosTokenIds.push_back(header.eosId);
} else if (magic == 0x567124) {
int headerSize;
if (fread(&headerSize, sizeof(int), 1, file) != 1)
throw std::runtime_error("Cannot read tokenizer header size");
int nKv = (headerSize - 2 * sizeof(int)) / sizeof(int);
std::vector<int> buffer(nKv);
if (fread(&buffer[0], nKv * sizeof(int), 1, file) != 1) {
throw std::runtime_error("Cannot read header values");
}
int version = -1;
int chatTemplateLength = -1;
int nEosTokens = 0;
for (int i = 0; i < nKv; i += 2) {
int key = buffer[i];
int value = buffer[i + 1];
if (key == TOK_VERSION) version = value;
else if (key == TOK_VOCAB_SIZE) vocabSize = value;
else if (key == MAX_TOKEN_LENGTH) maxTokenLength = (unsigned int)value;
else if (key == BOS_ID) bosId = value;
else if (key == EOS_ID) eosTokenIds.push_back(value); // Backward compatibility
else if (key == CHAT_EOS_ID) eosTokenIds.push_back(value); // Backward compatibility
else if (key == CHAT_TEMPLATE) chatTemplateLength = value;
else if (key == CHAT_STOP) fseek(file, value, SEEK_CUR); // Ignored
else if (key == PAD_ID) {} // Ignored
else if (key == N_EOS_TOKENS) nEosTokens = value;
else if (key == ADD_BOS) addBos = value == 1;
else {
throw std::runtime_error("Invalid tokenizer header key:" + std::to_string(key));
}
}
if (version != 1)
throw std::runtime_error("Old tokenizer version, please regenerate your tokenizer");
if (chatTemplateLength > 0) {
chatTemplate = new char[chatTemplateLength + 1];
if (fread(chatTemplate, chatTemplateLength, 1, file) != 1)
throw std::runtime_error("Cannot read chat template from tokenizer file");
chatTemplate[chatTemplateLength] = '\0';
}
if (nEosTokens > 0) {
int eosTokenId;
for (int i = 0; i < nEosTokens; i++) {
if (fread(&eosTokenId, sizeof(int), 1, file) != 1)
throw std::runtime_error("Cannot read eos token id from tokenizer file");
eosTokenIds.push_back(eosTokenId);
}
}
} else {
throw std::runtime_error("Invalid tokenizer file");
}
if (maxTokenLength < 1)
throw std::runtime_error("Invalid tokenizer max token length");
// malloc space to hold the scores and the strings
vocab = new char*[vocabSize];
vocabLength = new unsigned int[vocabSize];
vocabScores = new float[vocabSize];
int length;
for (int i = 0; i < vocabSize; i++) {
if (fread(vocabScores + i, sizeof(float), 1, file) != 1)
throw std::runtime_error("Cannot read size from tokenizer file");
if (fread(&length, sizeof(int), 1, file) != 1)
throw std::runtime_error("Cannot read length from tokenizer file");
vocab[i] = new char[length + 1];
if (fread(vocab[i], length, 1, file) != 1)
throw std::runtime_error("Cannot read word from tokenizer file");
vocab[i][length] = '\0'; // add the string terminating token
vocabLength[i] = length;
}
// TODO: this is very unstable assumption that bosId splits regular and special vocab
regularVocabSize = bosId;
specialVocabSize = vocabSize - regularVocabSize;
regularVocab = new TokenIndex[regularVocabSize];
for (int i = 0; i < regularVocabSize; i++) {
regularVocab[i].str = vocab[i];
regularVocab[i].id = i;
}
qsort(regularVocab, regularVocabSize, sizeof(TokenIndex), compareTokens);
specialVocab = new TokenIndex[specialVocabSize];
for (int i = 0; i < specialVocabSize; i++) {
specialVocab[i].str = vocab[i + regularVocabSize];
specialVocab[i].id = i + regularVocabSize;
}
strBufferSize = maxTokenLength * 2;
if (strBufferSize < (4 * 2)) { // ensure place for 2 utf-8 multi-byte sequence
strBufferSize = 4 * 2;
}
strBufferSize += 1 + 2;
strBuffer = new char[strBufferSize];
utf8Buffer = new char[strBufferSize];
if (bosId >= 0) {
printf("📄 AddBos: %d\n", addBos ? 1 : 0);
printf("📄 BosId: %d (%s)\n", bosId, vocab[bosId]);
}
if (eosTokenIds.size() > 0) {
printf("📄 EosId: ");
for (unsigned int i = 0; i < eosTokenIds.size(); i++) {
printf("%d (%s) ", eosTokenIds[i], vocab[eosTokenIds[i]]);
}
printf("\n");
}
printf("📄 RegularVocabSize: %d\n", regularVocabSize);
printf("📄 SpecialVocabSize: %d\n", specialVocabSize);
fclose(file);
}
Tokenizer::~Tokenizer() {
if (chatTemplate != NULL) delete[] chatTemplate;
for (int i = 0; i < vocabSize; i++)
delete[] vocab[i];
delete[] vocab;
delete[] vocabLength;
delete[] vocabScores;
delete[] regularVocab;
delete[] specialVocab;
delete[] strBuffer;
delete[] utf8Buffer;
}
int Tokenizer::findSpecialTokenStartWith(char *piece) {
for (unsigned int i = 0; i < specialVocabSize; i++) {
unsigned int tokenId = specialVocab[i].id;
unsigned int length = vocabLength[tokenId];
if (std::strncmp(vocab[tokenId], piece, length) == 0)
return tokenId;
}
return -1;
}
int Tokenizer::findRegularToken(char *piece) {
TokenIndex tok = { .str = piece };
TokenIndex *res = (TokenIndex*)bsearch(&tok, regularVocab, regularVocabSize, sizeof(TokenIndex), compareTokens);
return res != NULL ? res->id : -1;
}
bool Tokenizer::isEos(int token) {
for (unsigned int i = 0; i < eosTokenIds.size(); i++) {
if (token == eosTokenIds[i])
return true;
}
return false;
}
void Tokenizer::resetDecoder() {
strBufferPos = 0;
}
char *Tokenizer::detokUtf8() {
char* src = strBuffer;
char* dst = utf8Buffer;
char* checkpoint_src = src;
char* checkpoint = dst;
unsigned expect_continuation = 0;
while (unsigned char c = *src) {
bool need_recovery = false;
if (expect_continuation) {
if ((c & 0xc0) == 0x80) {
*dst++ = *src++;
expect_continuation--;
} else {
need_recovery = true;
}
} else if (c <= 0x7f) {
*dst++ = *src++;
} else if (c >= 0xc0 && c <= 0xdf) {
*dst++ = *src++;
expect_continuation = 1;
} else if (c >= 0xe0 && c <= 0xef) {
*dst++ = *src++;
expect_continuation = 2;
} else if (c >= 0xf0 && c <= 0xf7) {
*dst++ = *src++;
expect_continuation = 3;
} else {
need_recovery = true;
}
if (!need_recovery) {
if (!expect_continuation) {
checkpoint = dst;
checkpoint_src = src;
}
} else {
// perform stream recover
if (expect_continuation) {
expect_continuation = 0;
} else {
++src;
}
dst = checkpoint;
// emit 0xfffd
*dst++ = 0xef;
*dst++ = 0xbf;
*dst++ = 0xbd;
fprintf(stderr, "Tokenizer: decoded invalid utf8 -- attempting stream recover\n");
}
}
if (src > checkpoint_src) {
memmove(strBuffer, checkpoint_src, src - checkpoint_src + 1);
strBufferPos = src - checkpoint_src;
} else {
strBufferPos = 0;
}
*checkpoint = '\0';
if (checkpoint > utf8Buffer) {
return utf8Buffer;
} else {
return nullptr;
}
}
char *Tokenizer::decode(int token) {
if (token == bosId)
return nullptr;
if (isEos(token)) {
if (strBufferPos > 0)
return strBuffer;
return nullptr;
}
char *piece = vocab[token];
int pieceLen = vocabLength[token];
assert(strBufferPos + pieceLen + 1 < strBufferSize);
std::memcpy(&strBuffer[strBufferPos], piece, pieceLen * sizeof(char));
strBufferPos += pieceLen;
strBuffer[strBufferPos] = '\0';
return detokUtf8();
}
void Tokenizer::encode(char *text, int *tokens, int *nTokens, bool isStart, bool addSpecialTokens) {
#if DEBUG_TOKENIZER_BENCHMARK
Timer startTime;
#endif
if (text == nullptr)
throw std::runtime_error("Input text is null");
size_t strLen = 0;
*nTokens = 0;
if (isStart && addBos && bosId >= 0)
tokens[(*nTokens)++] = bosId;
for (char *c = text; *c != '\0'; c++) {
if (addSpecialTokens) {
int specialTokenId = findSpecialTokenStartWith(c);
if (specialTokenId >= 0) {
tokens[(*nTokens)++] = specialTokenId;
c += vocabLength[specialTokenId] - 1;
continue;
}
}
strBuffer[strLen] = *c;
strLen++;
assert(strLen < strBufferSize);
strBuffer[strLen] = '\0';
int id = findRegularToken(strBuffer);
if (id != -1) {
tokens[(*nTokens)++] = id;
strLen = 0;
}
}
assert(strLen == 0);
// merge the best consecutive pair each iteration, according the scores in vocab_scores
while (1) {
float best_score = -1e10;
int best_id = -1;
int best_idx = -1;
for (int i=0; i < (*nTokens-1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
snprintf(strBuffer, strBufferSize, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
int id = findRegularToken(strBuffer);
if (id != -1 && vocabScores[id] > best_score) {
// this merge pair exists in vocab! record its score and position
best_score = vocabScores[id];
best_id = id;
best_idx = i;
}
}
if (best_idx == -1) {
break; // we couldn't find any more pairs to merge, so we're done
}
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
tokens[best_idx] = best_id;
// delete token at position best_idx+1, shift the entire sequence back 1
for (int i = best_idx+1; i < (*nTokens-1); i++) {
tokens[i] = tokens[i+1];
}
(*nTokens)--; // token length decreased
}
#if DEBUG_TOKENIZER_BENCHMARK
NnUint duration = startTime.elapsedMicroseconds();
printf("🕒 [%22s] %u μs\n", "ENCODER", duration);
#endif
#if DEBUG_TOKENIZER_ENCODER
printf("\033[1;33m[");
for (unsigned int i = 0; i < *nTokens; i++)
printf("%d,", tokens[i]);
printf("]\033[0m");
#endif
}
int sample_argmax(float* probabilities, int n) {
// return the index that has the highest probability
int max_i = 0;
float max_p = probabilities[0];
for (int i = 1; i < n; i++) {
if (probabilities[i] > max_p) {
max_i = i;
max_p = probabilities[i];
}
}
return max_i;
}
int sample_mult(float* probabilities, int n, float coin) {
// sample index from probabilities (they must sum to 1!)
// coin is a random number in [0, 1), usually from random_f32()
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
cdf += probabilities[i];
if (coin < cdf) {
return i;
}
}
return n - 1; // in case of rounding errors
}
int compare(const void* a, const void* b) {
ProbIndex* a_ = (ProbIndex*) a;
ProbIndex* b_ = (ProbIndex*) b;
if (a_->prob > b_->prob) return -1;
if (a_->prob < b_->prob) return 1;
return 0;
}
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
// top-p sampling (or "nucleus sampling") samples from the smallest set of
// tokens that exceed probability topp. This way we never sample tokens that
// have very low probabilities and are less likely to go "off the rails".
// coin is a random number in [0, 1), usually from random_f32()
int n0 = 0;
// quicksort indices in descending order of probabilities
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
// so for efficiency we crop these out as candidates before sorting
const float cutoff = (1.0f - topp) / (n - 1);
for (int i = 0; i < n; i++) {
if (probabilities[i] >= cutoff) {
probindex[n0].index = i;
probindex[n0].prob = probabilities[i];
n0++;
}
}
qsort(probindex, n0, sizeof(ProbIndex), compare);
// truncate the list where cumulative probability exceeds topp
float cumulative_prob = 0.0f;
int last_idx = n0 - 1; // in case of rounding errors consider all elements
for (int i = 0; i < n0; i++) {
cumulative_prob += probindex[i].prob;
if (cumulative_prob > topp) {
last_idx = i;
break; // we've exceeded topp by including last_idx
}
}
// sample from the truncated list
float r = coin * cumulative_prob;
float cdf = 0.0f;
for (int i = 0; i <= last_idx; i++) {
cdf += probindex[i].prob;
if (r < cdf) {
return probindex[i].index;
}
}
return probindex[last_idx].index; // in case of rounding errors
}
Sampler::Sampler(int vocab_size, float temperature, float topp, unsigned long long rngSeed) {
this->vocab_size = vocab_size;
this->temperature = temperature;
this->topp = topp;
this->rngState = rngSeed;
// buffer only used with nucleus sampling; may not need but it's ~small
probindex = new ProbIndex[vocab_size];
}
Sampler::~Sampler() {
delete[] probindex;
}
int Sampler::sample(float* logits) {
#if DEBUG_SAMPLER_BENCHMARK
Timer startTime;
#endif
// sample the token given the logits and some hyperparameters
int next;
if (temperature == 0.0f) {
// greedy argmax sampling: take the token with the highest probability
next = sample_argmax(logits, vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q < vocab_size; q++) { logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax_F32(logits, vocab_size);
// flip a (float) coin (this is our source of entropy for sampling)
float coin = randomF32(&rngState);
// we sample from this distribution to get the next token
if (topp <= 0 || topp >= 1) {
// simply sample from the predicted probability distribution
next = sample_mult(logits, vocab_size, coin);
} else {
// top-p (nucleus) sampling, clamping the least likely tokens to zero
next = sample_topp(logits, vocab_size, topp, probindex, coin);
}
}
#if DEBUG_SAMPLER_BENCHMARK
NnUint duration = startTime.elapsedMicroseconds();
printf("🕒 [%22s] %u μs\n", "SAMPLER", duration);
#endif
return next;
}
void Sampler::setTemp(float temp) {
this->temperature = temp;
}
void Sampler::setSeed(unsigned long long seed) {
this->rngState = seed;
}
TokenizerChatStops::TokenizerChatStops(Tokenizer* tokenizer) {
nStops = tokenizer->eosTokenIds.size();
char** s = new char*[nStops];
for (unsigned int i = 0; i < nStops; i++) {
s[i] = tokenizer->vocab[tokenizer->eosTokenIds[i]];
}
maxStopLength = 0;
for (size_t i = 0; i < nStops; i++) {
size_t len = strlen(s[i]);
if (len > maxStopLength) maxStopLength = len;
}
stops = (const char**)s;
}
TokenizerChatStops::~TokenizerChatStops() {
delete[] stops;
}
static const char *chatTemplateTypeToString(const ChatTemplateType type) {
if (type == TEMPLATE_LLAMA2) return "llama2";
if (type == TEMPLATE_LLAMA3) return "llama3";
if (type == TEMPLATE_DEEP_SEEK3) return "deepSeek3";
if (type == TEMPLATE_CHATML) return "chatml";
return "unknown";
}
ChatTemplateGenerator::ChatTemplateGenerator(const ChatTemplateType type, const char* chatTemplate, const char* eos)
: buffer()
{
if (type == TEMPLATE_UNKNOWN) {
if (chatTemplate == NULL)
throw std::runtime_error("The tokenizer does not include chat template");
if (strstr(chatTemplate, "[INST]") != NULL) {
this->type = TEMPLATE_LLAMA2;
} else if (strstr(chatTemplate, "<|start_header_id|>") != NULL) {
this->type = TEMPLATE_LLAMA3;
} else if (strstr(chatTemplate, "<Assistant>") != NULL) {
this->type = TEMPLATE_DEEP_SEEK3;
} else if (strstr(chatTemplate, "<|im_start|>") != NULL) {
this->type = TEMPLATE_CHATML;
} else {
throw std::runtime_error("Not supported chat template");
}
} else {
this->type = type;
}
this->eos = eos;
printf("⭐ Chat template: %s\n", chatTemplateTypeToString(this->type));
}
GeneratedChat ChatTemplateGenerator::generate(unsigned int nItems, ChatItem* items, bool appendGenerationPrompt) {
buffer.clear();
size_t publicPromptSize = 0;
if (type == TEMPLATE_LLAMA2) {
unsigned int i = 0;
if (nItems >= 2 && items[0].role == "system" && items[1].role == "user") {
buffer += "[INST] <<SYS>>\n" + items[0].message + "\n<</SYS>>\n\n" + items[1].message + " [/INST]" + eos;
i += 2;
}
for (; i < nItems; i++) {
if (items[i].role == "assistant") {
buffer += items[i].message + eos;
} else if (items[i].role == "user") {
buffer += "[INST] " + items[i].message + " [/INST]" + eos;
}
}
} else if (type == TEMPLATE_LLAMA3) {
for (unsigned int i = 0; i < nItems; i++)
buffer += "<|start_header_id|>" + items[i].role + "<|end_header_id|>\n\n" + items[i].message + eos;
if (appendGenerationPrompt)
buffer += "<|start_header_id|>assistant<|end_header_id|>\n\n";
} else if (type == TEMPLATE_DEEP_SEEK3) {
unsigned int i = 0;
if (nItems > 0 && items[0].role == "system") {
buffer += items[0].message;
i++;
}
for (; i < nItems; i++) {
if (items[i].role == "user") {
buffer += "<User>" + items[i].message;
} else if (items[i].role == "assistant") {
buffer += "<Assistant>" + items[i].message;
}
}
if (appendGenerationPrompt) {
buffer += "<Assistant><think>\n";
publicPromptSize = 8;
}
} else if (type == TEMPLATE_CHATML) {
for (unsigned int i = 0; i < nItems; i++) {
if (items[i].role == "system") {
buffer += "<|im_start|>system\n" + items[i].message + "<|im_end|>\n";
} else if (items[i].role == "user") {
buffer += "<|im_start|>user\n" + items[i].message + "<|im_end|>\n";
} else if (items[i].role == "assistant") {
buffer += "<|im_start|>assistant\n" + items[i].message + "<|im_end|>\n";
}
if (appendGenerationPrompt)
buffer += "<|im_start|>assistant\n";
}
}
const char *content = buffer.c_str();
size_t length = buffer.size();
const char *publicPrompt = publicPromptSize > 0
? &content[length - publicPromptSize]
: nullptr;
#if DEBUG_TEMPLATE_GENERATOR
printf("\033[1;31m[%s]\033[0m", content);
#endif
return {content, length, publicPrompt};
}
EosDetector::EosDetector(size_t nTokens, const int *tokens, const char** pieces, int paddingLeft, int paddingRight) {
this->nTokens = nTokens;
this->tokens = tokens;
this->pieces = pieces;
this->pieceSizes = new size_t[nTokens];
for (size_t s = 0; s < nTokens; s++) {
pieceSizes[s] = strlen(pieces[s]);
printf("🛑 Stop: %s\n", pieces[s]);
}
this->bufferPos = 0;
this->bufferSize = 0;
this->paddingLeft = paddingLeft;
this->paddingRight = paddingRight;
}
EosDetector::~EosDetector() {
if (bufferSize > 0)
delete[] buffer;
delete[] pieceSizes;
}
bool EosDetector::isEos(int tokenId) {
for (size_t i = 0; i < nTokens; i++) {
if (tokenId == tokens[i])
return true;
}
return false;
}
EosDetectorType EosDetector::append(int tokenId, const char *piece) {
if (piece != nullptr) {
int pieceLength = std::strlen(piece);
int newSize = bufferPos + pieceLength + 1;
if (newSize > bufferSize) {
char* newBuffer = new char[newSize];
if (bufferPos > 0)
std::memcpy(newBuffer, buffer, bufferPos);
if (bufferSize > 0)
delete[] buffer;
buffer = newBuffer;
bufferSize = newSize;
}
std::memcpy(&buffer[bufferPos], piece, pieceLength);
bufferPos += pieceLength;
buffer[bufferPos] = '\0';
}
// detection
if (isEos(tokenId)) {
eosPos = bufferPos;
return EOS;
}
eosPos = -1;
for (size_t s = 0; s < nTokens; s++) {
size_t pieceSize = pieceSizes[s];
if (bufferPos > pieceSize + paddingLeft + paddingRight) continue;
for (int lo = 0; lo <= paddingLeft; lo++) {
int n = bufferPos - lo;
if (n == 0 || n > pieceSize + paddingRight) continue;
if (n > pieceSize) n = pieceSize;
if (strncmp(buffer + lo, pieces[s], n) == 0) {
if (n == pieceSize) {
eosPos = lo;
buffer[eosPos] = '\0';
return EOS;
}
return MAYBE_EOS;
}
}
}
return NOT_EOS;
}
char* EosDetector::getDelta() {
if (bufferPos == 0) return nullptr;
if (eosPos == -1) return buffer;
if (eosPos == 0) return nullptr;
return buffer;
}
void EosDetector::reset() {
bufferPos = 0;
}

157
src/tokenizer.hpp Normal file
View File

@@ -0,0 +1,157 @@
#ifndef TOKENIZER_HPP
#define TOKENIZER_HPP
#include <cstdio>
#include <string>
#include <vector>
typedef struct {
char *str;
unsigned int id;
} TokenIndex;
struct TokenizerOldHeader {
unsigned int vocabSize;
unsigned int maxTokenLength;
int bosId;
int eosId;
int padId;
};
enum TokenizerHeaderKey {
TOK_VERSION = 0,
TOK_VOCAB_SIZE = 1,
MAX_TOKEN_LENGTH = 2,
BOS_ID = 3,
EOS_ID = 4, // Backward compatibility
PAD_ID = 5, // Ignored
CHAT_EOS_ID = 6, // Backward compatibility
CHAT_TEMPLATE = 7,
CHAT_STOP = 8, // Ignored
N_EOS_TOKENS = 9,
ADD_BOS = 10,
};
class Tokenizer {
private:
unsigned int maxTokenLength;
unsigned int regularVocabSize;
unsigned int specialVocabSize;
float *vocabScores;
unsigned int *vocabLength;
TokenIndex *regularVocab;
TokenIndex *specialVocab;
size_t strBufferSize;
char *strBuffer;
char *utf8Buffer;
size_t strBufferPos;
public:
std::vector<int> eosTokenIds;
unsigned int vocabSize;
char **vocab;
int bosId;
bool addBos;
char *chatTemplate;
Tokenizer(const char *tokenizer_path);
~Tokenizer();
int findSpecialTokenStartWith(char *piece);
int findRegularToken(char *piece);
void encode(char *text, int *tokens, int *nTokens, bool isStart, bool addSpecialTokens);
bool isEos(int token);
char *decode(int token);
void resetDecoder();
private:
char *detokUtf8();
};
typedef struct {
float prob;
int index;
} ProbIndex;
class Sampler {
private:
int vocab_size;
ProbIndex *probindex;
float temperature;
float topp;
unsigned long long rngState;
public:
Sampler(int vocab_size, float temperature, float topp, unsigned long long rngSeed);
~Sampler();
int sample(float *logits);
void setTemp(float temp);
void setSeed(unsigned long long rngSeed);
};
class TokenizerChatStops {
public:
const char **stops;
size_t nStops;
size_t maxStopLength;
TokenizerChatStops(Tokenizer *tokenizer);
~TokenizerChatStops();
};
enum ChatTemplateType {
TEMPLATE_UNKNOWN = 0,
TEMPLATE_LLAMA2 = 1,
TEMPLATE_LLAMA3 = 2,
TEMPLATE_DEEP_SEEK3 = 3,
TEMPLATE_CHATML = 4,
};
struct ChatItem {
std::string role;
std::string message;
};
struct GeneratedChat {
const char *content;
size_t length;
const char *publicPrompt;
};
class ChatTemplateGenerator {
public:
const char *eos;
ChatTemplateType type;
std::string buffer;
ChatTemplateGenerator(const ChatTemplateType type, const char *chatTemplate, const char *eos);
GeneratedChat generate(unsigned int nItems, ChatItem *items, bool appendGenerationPrompt);
};
enum EosDetectorType {
MAYBE_EOS = 0,
EOS = 1,
NOT_EOS = 2,
};
class EosDetector {
private:
size_t nTokens;
const int *tokens;
const char **pieces;
size_t *pieceSizes;
size_t bufferPos;
size_t bufferSize;
int eosPos;
int paddingLeft;
int paddingRight;
public:
char *buffer;
EosDetector(size_t nTokens, const int *tokens, const char* *pieces, int paddingLeft, int paddingRight);
~EosDetector();
EosDetectorType append(int tokenId, const char *piece);
bool isEos(int tokenId);
char *getDelta();
void reset();
};
#endif