init
This commit is contained in:
30
.env.example
Normal file
30
.env.example
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Distributed Llama Docker Environment Configuration
|
||||||
|
# Copy this file to .env and customize as needed
|
||||||
|
|
||||||
|
# Model configuration
|
||||||
|
MODEL_NAME=llama3_2_3b_instruct_q40
|
||||||
|
MAX_SEQ_LEN=4096
|
||||||
|
BUFFER_FLOAT_TYPE=q80
|
||||||
|
|
||||||
|
# Thread configuration
|
||||||
|
CONTROLLER_NTHREADS=4
|
||||||
|
WORKER_NTHREADS=4
|
||||||
|
|
||||||
|
# To use a different model, change MODEL_NAME to one of:
|
||||||
|
# - llama3_1_8b_instruct_q40
|
||||||
|
# - llama3_1_405b_instruct_q40
|
||||||
|
# - llama3_2_1b_instruct_q40
|
||||||
|
# - llama3_2_3b_instruct_q40
|
||||||
|
# - llama3_3_70b_instruct_q40
|
||||||
|
# - deepseek_r1_distill_llama_8b_q40
|
||||||
|
# - qwen3_0.6b_q40
|
||||||
|
# - qwen3_1.7b_q40
|
||||||
|
# - qwen3_8b_q40
|
||||||
|
# - qwen3_14b_q40
|
||||||
|
# - qwen3_30b_a3b_q40
|
||||||
|
|
||||||
|
# Performance tuning:
|
||||||
|
# - Adjust CONTROLLER_NTHREADS and WORKER_NTHREADS based on your Pi's CPU cores
|
||||||
|
# - For Pi 4 (4 cores): use 4 threads
|
||||||
|
# - For Pi 3 (4 cores): use 2-4 threads
|
||||||
|
# - For Pi Zero 2 (4 cores): use 2 threads
|
||||||
BIN
.github/8raspi.jpg
vendored
Normal file
BIN
.github/8raspi.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 542 KiB |
BIN
.github/8raspi2.jpg
vendored
Normal file
BIN
.github/8raspi2.jpg
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 298 KiB |
BIN
.github/cover.png
vendored
Normal file
BIN
.github/cover.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 61 KiB |
63
.github/workflows/main.yml
vendored
Normal file
63
.github/workflows/main.yml
vendored
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
name: main
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- feat/nn
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
- feat/nn
|
||||||
|
jobs:
|
||||||
|
build-linux:
|
||||||
|
name: Linux
|
||||||
|
runs-on: ${{matrix.runner}}
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
include:
|
||||||
|
- runner: ubuntu-22.04
|
||||||
|
arch: amd64
|
||||||
|
- runner: ubuntu-24.04-arm
|
||||||
|
arch: arm64
|
||||||
|
steps:
|
||||||
|
- name: Checkout Repo
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Dependencies
|
||||||
|
id: dependencies
|
||||||
|
run: sudo apt-get update && sudo apt-get install build-essential
|
||||||
|
- name: Build
|
||||||
|
id: build
|
||||||
|
run: |
|
||||||
|
make dllama
|
||||||
|
make nn-cpu-test
|
||||||
|
make nn-cpu-ops-test
|
||||||
|
make tokenizer-test
|
||||||
|
- name: nn-cpu-test
|
||||||
|
run: ./nn-cpu-test
|
||||||
|
- name: nn-cpu-ops-test
|
||||||
|
run: ./nn-cpu-ops-test
|
||||||
|
- name: tokenizer-test
|
||||||
|
run: ./tokenizer-test
|
||||||
|
|
||||||
|
build-windows:
|
||||||
|
name: Windows
|
||||||
|
runs-on: windows-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout Repo
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
- name: Dependencies
|
||||||
|
id: dependencies
|
||||||
|
run: choco install make
|
||||||
|
- name: Build
|
||||||
|
id: build
|
||||||
|
run: |
|
||||||
|
make dllama
|
||||||
|
make nn-cpu-test
|
||||||
|
make nn-cpu-ops-test
|
||||||
|
make tokenizer-test
|
||||||
|
- name: nn-cpu-test
|
||||||
|
run: ./nn-cpu-test
|
||||||
|
- name: nn-cpu-ops-test
|
||||||
|
run: ./nn-cpu-ops-test
|
||||||
|
- name: tokenizer-test
|
||||||
|
run: ./tokenizer-test
|
||||||
19
.gitignore
vendored
Normal file
19
.gitignore
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
.vscode/settings.json
|
||||||
|
|
||||||
|
*.o
|
||||||
|
*.0
|
||||||
|
*.dSYM
|
||||||
|
*.data
|
||||||
|
*.temp
|
||||||
|
*.tmp
|
||||||
|
__pycache__
|
||||||
|
|
||||||
|
*-test
|
||||||
|
/models
|
||||||
|
main
|
||||||
|
run*.sh
|
||||||
|
server
|
||||||
|
/dllama
|
||||||
|
/dllama-*
|
||||||
|
*.exe
|
||||||
|
*.spv
|
||||||
17
.vscode/launch.json
vendored
Normal file
17
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "main",
|
||||||
|
"type": "cppdbg",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/main",
|
||||||
|
"args": [],
|
||||||
|
"stopAtEntry": false,
|
||||||
|
"cwd": "${workspaceFolder}",
|
||||||
|
"environment": [],
|
||||||
|
"externalConsole": false,
|
||||||
|
"MIMode": "lldb"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
202
DOCKER_README.md
Normal file
202
DOCKER_README.md
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
# Distributed Llama Docker Setup for Raspberry Pi
|
||||||
|
|
||||||
|
This directory contains Docker configurations to run Distributed Llama on Raspberry Pi devices using containers. There are two variants:
|
||||||
|
|
||||||
|
1. **Controller** (`Dockerfile.controller`) - Downloads models and runs the API server
|
||||||
|
2. **Worker** (`Dockerfile.worker`) - Runs worker nodes that connect to the controller
|
||||||
|
|
||||||
|
## Quick Start with Docker Compose
|
||||||
|
|
||||||
|
### 1. Download a Model
|
||||||
|
|
||||||
|
First, download a model using the controller container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Create a models directory
|
||||||
|
mkdir -p models
|
||||||
|
|
||||||
|
# Download a model (this will take some time)
|
||||||
|
docker-compose run --rm controller --download llama3_2_3b_instruct_q40
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Start the Distributed Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start all services (1 controller + 3 workers)
|
||||||
|
docker-compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
The API will be available at `http://localhost:9999`
|
||||||
|
|
||||||
|
### 3. Test the API
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# List available models
|
||||||
|
curl http://localhost:9999/v1/models
|
||||||
|
|
||||||
|
# Send a chat completion request
|
||||||
|
curl -X POST http://localhost:9999/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "llama",
|
||||||
|
"messages": [{"role": "user", "content": "Hello, how are you?"}],
|
||||||
|
"max_tokens": 100
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Manual Docker Usage
|
||||||
|
|
||||||
|
### Building the Images
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build controller image
|
||||||
|
docker build -f Dockerfile.controller -t distributed-llama-controller .
|
||||||
|
|
||||||
|
# Build worker image
|
||||||
|
docker build -f Dockerfile.worker -t distributed-llama-worker .
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running the Controller
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Download a model first
|
||||||
|
docker run -v ./models:/app/models distributed-llama-controller --download llama3_2_3b_instruct_q40
|
||||||
|
|
||||||
|
# Run API server (standalone mode, no workers)
|
||||||
|
docker run -p 9999:9999 -v ./models:/app/models distributed-llama-controller \
|
||||||
|
--model llama3_2_3b_instruct_q40
|
||||||
|
|
||||||
|
# Run API server with workers
|
||||||
|
docker run -p 9999:9999 -v ./models:/app/models distributed-llama-controller \
|
||||||
|
--model llama3_2_3b_instruct_q40 \
|
||||||
|
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running Workers
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run a worker on default port 9999
|
||||||
|
docker run -p 9999:9999 distributed-llama-worker
|
||||||
|
|
||||||
|
# Run a worker with custom settings
|
||||||
|
docker run -p 9998:9998 distributed-llama-worker --port 9998 --nthreads 2
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Models
|
||||||
|
|
||||||
|
You can download any of these models:
|
||||||
|
|
||||||
|
- `llama3_1_8b_instruct_q40`
|
||||||
|
- `llama3_1_405b_instruct_q40` (very large, 56 parts)
|
||||||
|
- `llama3_2_1b_instruct_q40`
|
||||||
|
- `llama3_2_3b_instruct_q40`
|
||||||
|
- `llama3_3_70b_instruct_q40`
|
||||||
|
- `deepseek_r1_distill_llama_8b_q40`
|
||||||
|
- `qwen3_0.6b_q40`
|
||||||
|
- `qwen3_1.7b_q40`
|
||||||
|
- `qwen3_8b_q40`
|
||||||
|
- `qwen3_14b_q40`
|
||||||
|
- `qwen3_30b_a3b_q40`
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
### Controller Options
|
||||||
|
|
||||||
|
- `--model <name>`: Model name to use (required)
|
||||||
|
- `--port <port>`: API server port (default: 9999)
|
||||||
|
- `--nthreads <n>`: Number of threads (default: 4)
|
||||||
|
- `--max-seq-len <n>`: Maximum sequence length (default: 4096)
|
||||||
|
- `--buffer-float-type <type>`: Buffer float type (default: q80)
|
||||||
|
- `--workers <addresses>`: Space-separated worker addresses
|
||||||
|
- `--download <model>`: Download a model and exit
|
||||||
|
|
||||||
|
### Worker Options
|
||||||
|
|
||||||
|
- `--port <port>`: Worker port (default: 9999)
|
||||||
|
- `--nthreads <n>`: Number of threads (default: 4)
|
||||||
|
|
||||||
|
## Environment Variables (Docker Compose)
|
||||||
|
|
||||||
|
You can customize the setup using environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set model and thread counts
|
||||||
|
MODEL_NAME=llama3_2_1b_instruct_q40 \
|
||||||
|
CONTROLLER_NTHREADS=2 \
|
||||||
|
WORKER_NTHREADS=2 \
|
||||||
|
docker-compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
Available variables:
|
||||||
|
- `MODEL_NAME`: Model to use (default: llama3_2_3b_instruct_q40)
|
||||||
|
- `CONTROLLER_NTHREADS`: Controller threads (default: 4)
|
||||||
|
- `WORKER_NTHREADS`: Worker threads (default: 4)
|
||||||
|
- `MAX_SEQ_LEN`: Maximum sequence length (default: 4096)
|
||||||
|
- `BUFFER_FLOAT_TYPE`: Buffer float type (default: q80)
|
||||||
|
|
||||||
|
## Multi-Device Setup
|
||||||
|
|
||||||
|
To run across multiple Raspberry Pi devices:
|
||||||
|
|
||||||
|
### Device 1 (Controller)
|
||||||
|
```bash
|
||||||
|
# Run controller
|
||||||
|
docker run -p 9999:9999 -v ./models:/app/models distributed-llama-controller \
|
||||||
|
--model llama3_2_3b_instruct_q40 \
|
||||||
|
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
|
||||||
|
```
|
||||||
|
|
||||||
|
### Devices 2-4 (Workers)
|
||||||
|
```bash
|
||||||
|
# Run worker on each device
|
||||||
|
docker run -p 9999:9999 distributed-llama-worker --nthreads 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## Performance Tips
|
||||||
|
|
||||||
|
1. **Thread Count**: Set `--nthreads` to the number of CPU cores on each device
|
||||||
|
2. **Memory**: Larger models require more RAM. Monitor usage with `docker stats`
|
||||||
|
3. **Network**: Use wired Ethernet connections for better performance between devices
|
||||||
|
4. **Storage**: Use fast SD cards (Class 10 or better) or USB 3.0 storage for model files
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Model Download Issues
|
||||||
|
```bash
|
||||||
|
# Check if model files exist
|
||||||
|
ls -la models/llama3_2_3b_instruct_q40/
|
||||||
|
|
||||||
|
# Re-download if corrupted
|
||||||
|
docker-compose run --rm controller --download llama3_2_3b_instruct_q40
|
||||||
|
```
|
||||||
|
|
||||||
|
### Worker Connection Issues
|
||||||
|
```bash
|
||||||
|
# Check worker logs
|
||||||
|
docker-compose logs worker1
|
||||||
|
|
||||||
|
# Test network connectivity
|
||||||
|
docker exec -it <controller_container> ping 172.20.0.11
|
||||||
|
```
|
||||||
|
|
||||||
|
### Resource Issues
|
||||||
|
```bash
|
||||||
|
# Monitor resource usage
|
||||||
|
docker stats
|
||||||
|
|
||||||
|
# Reduce thread count if CPU usage is too high
|
||||||
|
CONTROLLER_NTHREADS=2 WORKER_NTHREADS=2 docker-compose up
|
||||||
|
```
|
||||||
|
|
||||||
|
## Web Interface
|
||||||
|
|
||||||
|
You can use the web chat interface at [llama-ui.js.org](https://llama-ui.js.org/):
|
||||||
|
|
||||||
|
1. Open the website
|
||||||
|
2. Go to settings
|
||||||
|
3. Set base URL to: `http://your-pi-ip:9999`
|
||||||
|
4. Save and start chatting
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
This Docker setup follows the same license as the main Distributed Llama project.
|
||||||
160
Dockerfile.controller
Normal file
160
Dockerfile.controller
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# Dockerfile for Distributed Llama Controller (Raspberry Pi)
|
||||||
|
# This variant can download models and start the API server
|
||||||
|
FROM arm64v8/debian:bookworm-slim
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
g++ \
|
||||||
|
make \
|
||||||
|
git \
|
||||||
|
python3 \
|
||||||
|
python3-pip \
|
||||||
|
curl \
|
||||||
|
wget \
|
||||||
|
ca-certificates \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy source code
|
||||||
|
COPY src/ ./src/
|
||||||
|
COPY Makefile ./
|
||||||
|
COPY launch.py ./
|
||||||
|
|
||||||
|
# Build the applications
|
||||||
|
RUN make dllama && make dllama-api
|
||||||
|
|
||||||
|
# Create models directory for volume mount
|
||||||
|
RUN mkdir -p /app/models
|
||||||
|
|
||||||
|
# Create a script to download models
|
||||||
|
COPY <<EOF /app/download-model.sh
|
||||||
|
#!/bin/bash
|
||||||
|
if [ -z "\$1" ]; then
|
||||||
|
echo "Usage: download-model.sh <model_name>"
|
||||||
|
echo "Available models:"
|
||||||
|
python3 launch.py
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 launch.py "\$1" -skip-run -skip-script -y
|
||||||
|
EOF
|
||||||
|
|
||||||
|
RUN chmod +x /app/download-model.sh
|
||||||
|
|
||||||
|
# Create entrypoint script
|
||||||
|
COPY <<EOF /app/entrypoint.sh
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Default values
|
||||||
|
MODEL_NAME=""
|
||||||
|
API_PORT=9999
|
||||||
|
NTHREADS=4
|
||||||
|
MAX_SEQ_LEN=4096
|
||||||
|
WORKERS=""
|
||||||
|
BUFFER_FLOAT_TYPE="q80"
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
while [[ \$# -gt 0 ]]; do
|
||||||
|
case \$1 in
|
||||||
|
--model)
|
||||||
|
MODEL_NAME="\$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--port)
|
||||||
|
API_PORT="\$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--nthreads)
|
||||||
|
NTHREADS="\$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--max-seq-len)
|
||||||
|
MAX_SEQ_LEN="\$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--workers)
|
||||||
|
shift
|
||||||
|
WORKERS="\$@"
|
||||||
|
break
|
||||||
|
;;
|
||||||
|
--buffer-float-type)
|
||||||
|
BUFFER_FLOAT_TYPE="\$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--download)
|
||||||
|
MODEL_NAME="\$2"
|
||||||
|
echo "Downloading model: \$MODEL_NAME"
|
||||||
|
/app/download-model.sh "\$MODEL_NAME"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
--help)
|
||||||
|
echo "Usage: docker run distributed-llama-controller [OPTIONS]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --download <model> Download a model and exit"
|
||||||
|
echo " --model <model> Model name to use"
|
||||||
|
echo " --port <port> API server port (default: 9999)"
|
||||||
|
echo " --nthreads <n> Number of threads (default: 4)"
|
||||||
|
echo " --max-seq-len <n> Maximum sequence length (default: 4096)"
|
||||||
|
echo " --buffer-float-type <type> Buffer float type (default: q80)"
|
||||||
|
echo " --workers <workers> Space-separated list of worker addresses (e.g., 10.0.0.2:9999 10.0.0.3:9999)"
|
||||||
|
echo ""
|
||||||
|
echo "Examples:"
|
||||||
|
echo " # Download a model"
|
||||||
|
echo " docker run -v ./models:/app/models distributed-llama-controller --download llama3_2_3b_instruct_q40"
|
||||||
|
echo ""
|
||||||
|
echo " # Run API server with workers"
|
||||||
|
echo " docker run -p 9999:9999 -v ./models:/app/models distributed-llama-controller \\"
|
||||||
|
echo " --model llama3_2_3b_instruct_q40 --workers 10.0.0.2:9999 10.0.0.3:9999"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: \$1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [ -z "\$MODEL_NAME" ]; then
|
||||||
|
echo "Error: --model is required"
|
||||||
|
echo "Use --help for usage information"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
MODEL_PATH="/app/models/\$MODEL_NAME/dllama_model_\$MODEL_NAME.m"
|
||||||
|
TOKENIZER_PATH="/app/models/\$MODEL_NAME/dllama_tokenizer_\$MODEL_NAME.t"
|
||||||
|
|
||||||
|
if [ ! -f "\$MODEL_PATH" ] || [ ! -f "\$TOKENIZER_PATH" ]; then
|
||||||
|
echo "Error: Model files not found for \$MODEL_NAME"
|
||||||
|
echo "Model path: \$MODEL_PATH"
|
||||||
|
echo "Tokenizer path: \$TOKENIZER_PATH"
|
||||||
|
echo ""
|
||||||
|
echo "Please download the model first:"
|
||||||
|
echo "docker run -v ./models:/app/models distributed-llama-controller --download \$MODEL_NAME"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build the command
|
||||||
|
CMD="./dllama-api --port \$API_PORT --model \$MODEL_PATH --tokenizer \$TOKENIZER_PATH --buffer-float-type \$BUFFER_FLOAT_TYPE --nthreads \$NTHREADS --max-seq-len \$MAX_SEQ_LEN"
|
||||||
|
|
||||||
|
if [ ! -z "\$WORKERS" ]; then
|
||||||
|
CMD="\$CMD --workers \$WORKERS"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Starting API server with command:"
|
||||||
|
echo "\$CMD"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
exec \$CMD
|
||||||
|
EOF
|
||||||
|
|
||||||
|
RUN chmod +x /app/entrypoint.sh
|
||||||
|
|
||||||
|
# Expose the default API port
|
||||||
|
EXPOSE 9999
|
||||||
|
|
||||||
|
# Use the entrypoint script
|
||||||
|
ENTRYPOINT ["/app/entrypoint.sh"]
|
||||||
75
Dockerfile.worker
Normal file
75
Dockerfile.worker
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# Dockerfile for Distributed Llama Worker (Raspberry Pi)
|
||||||
|
# This variant runs as a worker node and connects to a controller
|
||||||
|
FROM arm64v8/debian:bookworm-slim
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
build-essential \
|
||||||
|
g++ \
|
||||||
|
make \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Set working directory
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy source code
|
||||||
|
COPY src/ ./src/
|
||||||
|
COPY Makefile ./
|
||||||
|
|
||||||
|
# Build only the worker application
|
||||||
|
RUN make dllama
|
||||||
|
|
||||||
|
# Create entrypoint script
|
||||||
|
COPY <<EOF /app/entrypoint.sh
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Default values
|
||||||
|
PORT=9999
|
||||||
|
NTHREADS=4
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
while [[ \$# -gt 0 ]]; do
|
||||||
|
case \$1 in
|
||||||
|
--port)
|
||||||
|
PORT="\$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--nthreads)
|
||||||
|
NTHREADS="\$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--help)
|
||||||
|
echo "Usage: docker run distributed-llama-worker [OPTIONS]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --port <port> Worker port (default: 9999)"
|
||||||
|
echo " --nthreads <n> Number of threads (default: 4)"
|
||||||
|
echo ""
|
||||||
|
echo "Example:"
|
||||||
|
echo " docker run -p 9999:9999 distributed-llama-worker --port 9999 --nthreads 4"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: \$1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Build the command
|
||||||
|
CMD="./dllama worker --port \$PORT --nthreads \$NTHREADS"
|
||||||
|
|
||||||
|
echo "Starting worker with command:"
|
||||||
|
echo "\$CMD"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
exec \$CMD
|
||||||
|
EOF
|
||||||
|
|
||||||
|
RUN chmod +x /app/entrypoint.sh
|
||||||
|
|
||||||
|
# Expose the default worker port
|
||||||
|
EXPOSE 9999
|
||||||
|
|
||||||
|
# Use the entrypoint script
|
||||||
|
ENTRYPOINT ["/app/entrypoint.sh"]
|
||||||
9
LICENSE
Normal file
9
LICENSE
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2024 Bartłomiej Tadych (b4rtaz)
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
90
Makefile
Normal file
90
Makefile
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
CXX = g++
|
||||||
|
CXXFLAGS = -std=c++11 -Werror -Wformat -Werror=format-security
|
||||||
|
|
||||||
|
ifndef TERMUX_VERSION
|
||||||
|
CXXFLAGS += -march=native -mtune=native
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifdef DEBUG
|
||||||
|
CXXFLAGS += -g -fsanitize=address
|
||||||
|
else
|
||||||
|
CXXFLAGS += -O3
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifdef WVLA
|
||||||
|
CXXFLAGS += -Wvla-extension
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifdef DLLAMA_VULKAN
|
||||||
|
CGLSLC = glslc
|
||||||
|
|
||||||
|
ifeq ($(OS),Windows_NT)
|
||||||
|
LIBS += -L$(VK_SDK_PATH)\lib -lvulkan-1
|
||||||
|
CXXFLAGS += -DDLLAMA_VULKAN -I$(VK_SDK_PATH)\include
|
||||||
|
else
|
||||||
|
LIBS += -lvulkan
|
||||||
|
CXXFLAGS += -DDLLAMA_VULKAN
|
||||||
|
endif
|
||||||
|
|
||||||
|
DEPS += nn-vulkan.o
|
||||||
|
endif
|
||||||
|
|
||||||
|
ifeq ($(OS),Windows_NT)
|
||||||
|
LIBS += -lws2_32
|
||||||
|
DELETE_CMD = del /f
|
||||||
|
else
|
||||||
|
LIBS += -lpthread
|
||||||
|
DELETE_CMD = rm -fv
|
||||||
|
endif
|
||||||
|
|
||||||
|
.PHONY: clean dllama
|
||||||
|
|
||||||
|
clean:
|
||||||
|
$(DELETE_CMD) *.o dllama dllama-* socket-benchmark mmap-buffer-* *-test *.exe
|
||||||
|
|
||||||
|
# nn
|
||||||
|
nn-quants.o: src/nn/nn-quants.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
nn-core.o: src/nn/nn-core.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
nn-executor.o: src/nn/nn-executor.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
nn-network.o: src/nn/nn-network.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
llamafile-sgemm.o: src/nn/llamafile/sgemm.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
nn-cpu-ops.o: src/nn/nn-cpu-ops.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
nn-cpu.o: src/nn/nn-cpu.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
nn-cpu-test: src/nn/nn-cpu-test.cpp nn-quants.o nn-core.o nn-executor.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o
|
||||||
|
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
|
||||||
|
nn-cpu-ops-test: src/nn/nn-cpu-ops-test.cpp nn-quants.o nn-core.o nn-executor.o llamafile-sgemm.o nn-cpu.o
|
||||||
|
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
|
||||||
|
nn-vulkan.o: src/nn/nn-vulkan.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
|
||||||
|
ifdef DLLAMA_VULKAN
|
||||||
|
VULKAN_SHADER_SRCS := $(wildcard src/nn/vulkan/*.comp)
|
||||||
|
VULKAN_SHADER_BINS := $(VULKAN_SHADER_SRCS:.comp=.spv)
|
||||||
|
DEPS += $(VULKAN_SHADER_BINS)
|
||||||
|
|
||||||
|
%.spv: %.comp
|
||||||
|
$(CGLSLC) -c $< -o $@ --target-env=vulkan1.2
|
||||||
|
nn-vulkan-test: src/nn/nn-vulkan-test.cpp nn-quants.o nn-core.o nn-executor.o nn-vulkan.o ${DEPS}
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)
|
||||||
|
endif
|
||||||
|
|
||||||
|
# llm
|
||||||
|
tokenizer.o: src/tokenizer.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
llm.o: src/llm.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
app.o: src/app.cpp
|
||||||
|
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
||||||
|
tokenizer-test: src/tokenizer-test.cpp nn-quants.o nn-core.o llamafile-sgemm.o nn-cpu-ops.o tokenizer.o
|
||||||
|
$(CXX) $(CXXFLAGS) $^ -o $@ $(LIBS)
|
||||||
|
dllama: src/dllama.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o ${DEPS}
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)
|
||||||
|
dllama-api: src/dllama-api.cpp nn-quants.o nn-core.o nn-executor.o nn-network.o llamafile-sgemm.o nn-cpu-ops.o nn-cpu.o tokenizer.o llm.o app.o ${DEPS}
|
||||||
|
$(CXX) $(CXXFLAGS) $(filter-out %.spv, $^) -o $@ $(LIBS)
|
||||||
142
README.md
Normal file
142
README.md
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|

|
||||||
|
|
||||||
|
# Distributed Llama
|
||||||
|
|
||||||
|
[](https://github.com/b4rtaz/distributed-llama/actions) [](/LICENSE) [](https://n4no.com/projects/distributedLlama/discord.php)
|
||||||
|
|
||||||
|
Connect home devices into a powerful cluster to accelerate LLM inference. More devices mean faster performance, leveraging tensor parallelism and high-speed synchronization over Ethernet.
|
||||||
|
|
||||||
|
Supports Linux, macOS, and Windows. Optimized for ARM and x86_64 AVX2 CPUs.
|
||||||
|
|
||||||
|
**How to Run**
|
||||||
|
- [💻 How to Run on Linux, MacOS or Windows](./docs/HOW_TO_RUN_LINUX_MACOS_WIN.md)
|
||||||
|
- [🍓 How to Run on Raspberry Pi](./docs/HOW_TO_RUN_RASPBERRYPI.md)
|
||||||
|
- [🧠 How to Run on GPU](./docs/HOW_TO_RUN_GPU.md)
|
||||||
|
|
||||||
|
**News**
|
||||||
|
- 16 Sep 2025 - Qwen 3 MoE models are now supported on Vulkan.
|
||||||
|
- 5 Sep 2025 - Qwen 3 MoE models are now supported on CPU.
|
||||||
|
- 3 Aug 2025 - Qwen 3 0.6B, 1.7B, 8B and 14B models are now supported.
|
||||||
|
- 23 Mar 2025 - [🌋 Experimental Vulkan support](https://github.com/b4rtaz/distributed-llama/releases/tag/v0.13.0)
|
||||||
|
- 12 Feb 2025 - 🚧 Merged the [fundamental codebase refactor](https://github.com/b4rtaz/distributed-llama/releases/tag/v0.12.0)
|
||||||
|
- 9 Jan 2025 - [🍎 Llama 3.3 70B on 4 x Mac Mini M4 Pro 24GB RAM](https://github.com/b4rtaz/distributed-llama/discussions/147)
|
||||||
|
|
||||||
|
### 🔥 Setup Root Node by Single Command
|
||||||
|
|
||||||
|
Python 3 and C++ compiler required. The command will download the model and the tokenizer.
|
||||||
|
|
||||||
|
| Model | Size | Command |
|
||||||
|
| --------------------------------- | -------- | ---------------------------------------------------- |
|
||||||
|
| Llama 3.1 8B Instruct Q40 | 6.32 GB | `python launch.py llama3_1_8b_instruct_q40` |
|
||||||
|
| Llama 3.1 405B Instruct Q40 | 238 GB | `python launch.py llama3_1_405b_instruct_q40`. |
|
||||||
|
| Llama 3.2 1B Instruct Q40 | 1.7 GB | `python launch.py llama3_2_1b_instruct_q40` |
|
||||||
|
| Llama 3.2 3B Instruct Q40 | 3.4 GB | `python launch.py llama3_2_3b_instruct_q40` |
|
||||||
|
| Llama 3.3 70B Instruct Q40 | 40 GB | `python launch.py llama3_3_70b_instruct_q40` |
|
||||||
|
| DeepSeek R1 Distill Llama 8B Q40 | 6.32 GB | `python launch.py deepseek_r1_distill_llama_8b_q40` |
|
||||||
|
| Qwen 3 0.6B Q40 | 0.9 GB | `python launch.py qwen3_0.6b_q40` |
|
||||||
|
| Qwen 3 1.7B Q40 | 2.2 GB | `python launch.py qwen3_1.7b_q40` |
|
||||||
|
| Qwen 3 8B Q40 | 6.7 GB | `python launch.py qwen3_8b_q40` |
|
||||||
|
| Qwen 3 14B Q40 | 10.9 GB | `python launch.py qwen3_14b_q40` |
|
||||||
|
| Qwen 3 30B A3B Q40 | 17.0 GB | `python launch.py qwen3_30b_a3b_q40` |
|
||||||
|
|
||||||
|
### 🛠️ Convert Model Manually
|
||||||
|
|
||||||
|
* [🤗 How to Convert Hugging Face Model](./docs/HOW_TO_CONVERT_HF_MODEL.md)
|
||||||
|
|
||||||
|
### 🚧 Known Limitations
|
||||||
|
|
||||||
|
* You can run Distributed Llama only on 1, 2, 4... 2^n nodes.
|
||||||
|
* The maximum number of nodes is equal to the number of KV heads in the model [#70](https://github.com/b4rtaz/distributed-llama/issues/70).
|
||||||
|
* Only the following quantizations are supported [#183](https://github.com/b4rtaz/distributed-llama/issues/183):
|
||||||
|
* `q40` model with `q80` `buffer-float-type`
|
||||||
|
* `f32` model with `f32` `buffer-float-type`
|
||||||
|
|
||||||
|
### 👷 Architecture
|
||||||
|
|
||||||
|
````
|
||||||
|
[🔀 SWITCH OR ROUTER]
|
||||||
|
| | | |
|
||||||
|
| | | |_______ 🔸 device1 (ROOT) 10.0.0.1
|
||||||
|
| | |_________ 🔹 device2 (WORKER 1) 10.0.0.2:9999
|
||||||
|
| |___________ 🔹 device3 (WORKER 2) 10.0.0.3:9999
|
||||||
|
|_____________ 🔹 device4 (WORKER 3) 10.0.0.4:9999
|
||||||
|
...
|
||||||
|
````
|
||||||
|
|
||||||
|
The project is split up into two parts:
|
||||||
|
* **🔸 Root node** - it's responsible for loading the model and weights and forward them to workers. Also, it synchronizes the state of the neural network. The root node is also a worker, it processes own slice of the neural network.
|
||||||
|
* **🔹 Worker node** - it processes own slice of the neural network. It doesn't require any configuration related to the model.
|
||||||
|
|
||||||
|
You always need the root node and you can add 2^n - 1 worker nodes to speed up the inference. The RAM usage of the neural network is split up across all nodes. The root node requires a bit more RAM than worker nodes.
|
||||||
|
|
||||||
|
### 🎹 Commands
|
||||||
|
|
||||||
|
* `dllama inference` - run the inference with a simple benchmark,
|
||||||
|
* `dllama chat` - run the CLI chat,
|
||||||
|
* `dllama worker` - run the worker node,
|
||||||
|
* `dllama-api` - run the API server.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>🎹 Supported Arguments</summary>
|
||||||
|
|
||||||
|
<br />Inference, Chat, API
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| ---------------------------- | ---------------------------------------------------------------- | -------------------------------------- |
|
||||||
|
| `--model <path>` | Path to model. | `dllama_model_meta-llama-3-8b_q40.m` |
|
||||||
|
| `--tokenizer <path>` | Tokenizer to model. | `dllama_tokenizer_llama3.t` |
|
||||||
|
| `--buffer-float-type <type>` | Float precision of synchronization. | `q80` |
|
||||||
|
| `--workers <workers>` | Addresses of workers (ip:port), separated by space. | `10.0.0.1:9999 10.0.0.2:9999` |
|
||||||
|
| `--max-seq-len <n>` | The maximum sequence length, it helps to reduce the RAM usage. | `4096` |
|
||||||
|
|
||||||
|
Inference, Chat, Worker, API
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| ---------------------------- | --------------------------------------------------------------------- | ----------------------------------- |
|
||||||
|
| `--nthreads <n>` | Amount of threads. Don't set a higher value than number of CPU cores. | `4` |
|
||||||
|
|
||||||
|
Worker, API
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| ---------------------------- | --------------------------------- | ----------------- |
|
||||||
|
| `--port <port>` | Binding port. | `9999` |
|
||||||
|
|
||||||
|
Inference
|
||||||
|
|
||||||
|
| Argument | Description | Example |
|
||||||
|
| ---------------------------- | ------------------------------ | ------------------ |
|
||||||
|
| `--prompt <prompt>` | Initial prompt. | `"Hello World"` |
|
||||||
|
| `--steps <steps>` | Number of tokens to generate. | `256` |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 📊 Measurements
|
||||||
|
|
||||||
|
Please check the [discussions](https://github.com/b4rtaz/distributed-llama/discussions) section, where many measurements were published on different configurations.
|
||||||
|
|
||||||
|
## ✋ Contribution
|
||||||
|
|
||||||
|
Feel free to contribute to this project. For small changes, simply create a new merge request. For larger changes, please create an issue to discuss your plans. Please follow these guidelines when contributing:
|
||||||
|
|
||||||
|
* Make only minimal changes and avoid modifying files that are not necessary.
|
||||||
|
* Ensure the code is compatible across all supported systems and CPUs.
|
||||||
|
* This repository is maintained in English.
|
||||||
|
|
||||||
|
## 💡 License
|
||||||
|
|
||||||
|
This project is released under the MIT license.
|
||||||
|
|
||||||
|
## 📖 Citation
|
||||||
|
|
||||||
|
```
|
||||||
|
@misc{dllama,
|
||||||
|
author = {Bartłomiej Tadych},
|
||||||
|
title = {Distributed Llama},
|
||||||
|
year = {2024},
|
||||||
|
publisher = {GitHub},
|
||||||
|
journal = {GitHub repository},
|
||||||
|
howpublished = {\url{https://github.com/b4rtaz/distributed-llama}},
|
||||||
|
commit = {7eb77ca93ec0d502e28d36b6fb20039b449cbea4}
|
||||||
|
}
|
||||||
|
```
|
||||||
4
converter/.gitignore
vendored
Normal file
4
converter/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
*.t
|
||||||
|
*.m
|
||||||
|
*.bin
|
||||||
|
*/
|
||||||
265
converter/convert-hf.py
Normal file
265
converter/convert-hf.py
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from writer import parseFloatType, writeTensor, writeHeader, FloatType
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
class ArchType:
|
||||||
|
LLAMA = 0xABCD00
|
||||||
|
QWEN3 = 0xABCD01
|
||||||
|
QWEN3_MOE = 0xABCD02
|
||||||
|
|
||||||
|
def permute(tensor, nHeads: int, nKvHeads: int):
|
||||||
|
if nHeads != nKvHeads:
|
||||||
|
nHeads = nKvHeads
|
||||||
|
return (tensor.reshape(nHeads, 2, tensor.shape[0] // nHeads // 2, *tensor.shape[1:]).swapaxes(1, 2).reshape(tensor.shape))
|
||||||
|
|
||||||
|
class Processor:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
self.archType = config['arch_type']
|
||||||
|
self.currentModelIndex = None
|
||||||
|
self.currentModel = None
|
||||||
|
self.currentModelKeys = None
|
||||||
|
self.layerMap = {}
|
||||||
|
self.plan = []
|
||||||
|
|
||||||
|
def __unloadModel(self):
|
||||||
|
if self.currentModel:
|
||||||
|
del self.currentModel
|
||||||
|
self.currentModel = None
|
||||||
|
gc.collect()
|
||||||
|
self.currentModelIndex = None
|
||||||
|
|
||||||
|
def __loadModel(self, index: int):
|
||||||
|
if (self.currentModelIndex == index):
|
||||||
|
return
|
||||||
|
self.__unloadModel()
|
||||||
|
filePath = self.config['files'][index]
|
||||||
|
fileName = os.path.basename(filePath)
|
||||||
|
print(f'💿 Loading file {fileName}...')
|
||||||
|
self.currentModel = safe_open(filePath, framework='pt', device='cpu')
|
||||||
|
self.currentModelKeys = list(self.currentModel.keys())
|
||||||
|
for key in self.currentModelKeys:
|
||||||
|
self.layerMap[key] = index
|
||||||
|
print(f'Found {len(self.currentModelKeys)} layers')
|
||||||
|
self.currentModelIndex = index
|
||||||
|
|
||||||
|
def __transformQ(self, tensor):
|
||||||
|
if self.archType == ArchType.LLAMA:
|
||||||
|
return permute(tensor, self.config['n_heads'], self.config['n_heads'])
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def __transformK(self, tensor):
|
||||||
|
if self.archType == ArchType.LLAMA:
|
||||||
|
return permute(tensor, self.config['n_heads'], self.config['n_kv_heads'])
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def __preparePlan(self):
|
||||||
|
wt = self.config['weights_float_type']
|
||||||
|
p = self.plan
|
||||||
|
p.append([FloatType.F32,
|
||||||
|
'model.embed_tokens.weight'])
|
||||||
|
for l in range(0, self.config['n_layers']):
|
||||||
|
p.append([wt, self.__transformQ,
|
||||||
|
f'model.layers.{l}.self_attn.q_proj.weight'])
|
||||||
|
p.append([wt, self.__transformK,
|
||||||
|
f'model.layers.{l}.self_attn.k_proj.weight'])
|
||||||
|
p.append([wt,
|
||||||
|
f'model.layers.{l}.self_attn.v_proj.weight'])
|
||||||
|
p.append([wt,
|
||||||
|
f'model.layers.{l}.self_attn.o_proj.weight'])
|
||||||
|
|
||||||
|
if (self.config['n_experts'] > 0):
|
||||||
|
p.append([FloatType.F32, f'model.layers.{l}.mlp.gate.weight'])
|
||||||
|
for e in range(self.config['n_experts']):
|
||||||
|
p.append([wt,
|
||||||
|
f'model.layers.{l}.mlp.experts.{e}.gate_proj.weight'])
|
||||||
|
p.append([wt,
|
||||||
|
f'model.layers.{l}.mlp.experts.{e}.down_proj.weight'])
|
||||||
|
p.append([wt,
|
||||||
|
f'model.layers.{l}.mlp.experts.{e}.up_proj.weight'])
|
||||||
|
else:
|
||||||
|
p.append([wt,
|
||||||
|
f'model.layers.{l}.mlp.gate_proj.weight'])
|
||||||
|
p.append([wt,
|
||||||
|
f'model.layers.{l}.mlp.down_proj.weight'])
|
||||||
|
p.append([wt,
|
||||||
|
f'model.layers.{l}.mlp.up_proj.weight'])
|
||||||
|
|
||||||
|
if (self.archType == ArchType.QWEN3 or self.archType == ArchType.QWEN3_MOE):
|
||||||
|
p.append([FloatType.F32,
|
||||||
|
f'model.layers.{l}.self_attn.q_norm.weight'])
|
||||||
|
p.append([FloatType.F32,
|
||||||
|
f'model.layers.{l}.self_attn.k_norm.weight'])
|
||||||
|
|
||||||
|
p.append([FloatType.F32,
|
||||||
|
f'model.layers.{l}.input_layernorm.weight'])
|
||||||
|
p.append([FloatType.F32,
|
||||||
|
f'model.layers.{l}.post_attention_layernorm.weight'])
|
||||||
|
p.append([FloatType.F32,
|
||||||
|
'model.norm.weight'])
|
||||||
|
p.append([wt,
|
||||||
|
'lm_head.weight', 'model.embed_tokens.weight'])
|
||||||
|
|
||||||
|
def write(self, outputFile: str):
|
||||||
|
self.__preparePlan()
|
||||||
|
|
||||||
|
# Loading the last model file to get the layer names
|
||||||
|
self.__loadModel(len(self.config['files']) - 1)
|
||||||
|
self.__unloadModel()
|
||||||
|
|
||||||
|
for planItem in self.plan:
|
||||||
|
lookup = planItem[1:]
|
||||||
|
transform = None
|
||||||
|
if (callable(lookup[0])):
|
||||||
|
transform = lookup[0]
|
||||||
|
lookup = lookup[1:]
|
||||||
|
|
||||||
|
if (self.currentModelIndex == None):
|
||||||
|
modelIndex = 0
|
||||||
|
else:
|
||||||
|
modelIndex = None
|
||||||
|
for layerName in lookup:
|
||||||
|
if (layerName in self.layerMap):
|
||||||
|
modelIndex = self.layerMap[layerName]
|
||||||
|
break
|
||||||
|
if (modelIndex is None):
|
||||||
|
modelIndex = self.currentModelIndex + 1
|
||||||
|
self.__loadModel(modelIndex)
|
||||||
|
|
||||||
|
tensor = None
|
||||||
|
for layerName in lookup:
|
||||||
|
if (layerName in self.currentModelKeys):
|
||||||
|
tensor = self.currentModel.get_tensor(layerName)
|
||||||
|
break
|
||||||
|
if tensor is None:
|
||||||
|
raise Exception(f'Layer {lookup[0]} not found')
|
||||||
|
print(f'🔶 Writing tensor {layerName} {tensor.shape}...')
|
||||||
|
|
||||||
|
floatType = planItem[0]
|
||||||
|
if (transform):
|
||||||
|
tensor = transform(tensor)
|
||||||
|
writeTensor(outputFile, tensor, floatType)
|
||||||
|
|
||||||
|
def parseArchType(type: str):
|
||||||
|
archType = {
|
||||||
|
'llama': ArchType.LLAMA,
|
||||||
|
'mistral': ArchType.LLAMA,
|
||||||
|
'qwen3': ArchType.QWEN3,
|
||||||
|
'qwen3_moe': ArchType.QWEN3_MOE,
|
||||||
|
}.get(type)
|
||||||
|
if (archType is None):
|
||||||
|
raise Exception(f'Unsupported arch type: {type}')
|
||||||
|
return archType
|
||||||
|
|
||||||
|
def parseHiddenAct(act: str):
|
||||||
|
hiddenAct = {
|
||||||
|
'gelu': 0,
|
||||||
|
'silu': 1
|
||||||
|
}.get(act)
|
||||||
|
if (hiddenAct is None):
|
||||||
|
raise Exception(f'Unsupported hidden act: {act}')
|
||||||
|
return hiddenAct
|
||||||
|
|
||||||
|
def parseRopeType(rt: str):
|
||||||
|
ropeType = {
|
||||||
|
'llama3': 2, # LLAMA3_1
|
||||||
|
}.get(rt)
|
||||||
|
if (ropeType is None):
|
||||||
|
raise Exception(f'Unsupported rope type: {ropeType}')
|
||||||
|
return ropeType
|
||||||
|
|
||||||
|
def parseRmsNormEpsilon(epsilon: float):
|
||||||
|
if (epsilon == 1e-05):
|
||||||
|
return 5
|
||||||
|
elif (epsilon == 1e-06):
|
||||||
|
return 6
|
||||||
|
raise Exception(f'Unsupported epsilon: {epsilon}')
|
||||||
|
|
||||||
|
def loadConfig(folderPath: str, weightsFloatType: int):
|
||||||
|
allFiles = os.listdir(folderPath)
|
||||||
|
allFiles.sort()
|
||||||
|
with open(os.path.join(folderPath, 'config.json')) as fc:
|
||||||
|
config = json.load(fc)
|
||||||
|
files = []
|
||||||
|
for fileName in allFiles:
|
||||||
|
if fileName.endswith('.safetensors') and not fileName.startswith('.'):
|
||||||
|
files.append(os.path.join(folderPath, fileName))
|
||||||
|
if (len(files) == 0):
|
||||||
|
raise Exception('Not found any model file')
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'version': 0,
|
||||||
|
'arch_type': parseArchType(config['model_type']),
|
||||||
|
'hidden_act': parseHiddenAct(config['hidden_act']),
|
||||||
|
'dim': config['hidden_size'],
|
||||||
|
'hidden_dim': config['intermediate_size'],
|
||||||
|
'n_layers': config['num_hidden_layers'],
|
||||||
|
'n_heads': config['num_attention_heads'],
|
||||||
|
'n_kv_heads': config['num_key_value_heads'],
|
||||||
|
'weights_float_type': weightsFloatType,
|
||||||
|
'max_seq_len': config['max_position_embeddings'],
|
||||||
|
'vocab_size': config['vocab_size'],
|
||||||
|
'files': files,
|
||||||
|
}
|
||||||
|
|
||||||
|
nExperts = config.get('num_experts')
|
||||||
|
nActiveExperts = config.get('num_experts_per_tok')
|
||||||
|
result['n_experts'] = int(nExperts) if nExperts is not None else 0
|
||||||
|
result['n_active_experts'] = int(nActiveExperts) if nActiveExperts is not None else 0
|
||||||
|
|
||||||
|
ropeTheta = config.get('rope_theta')
|
||||||
|
if (ropeTheta is not None):
|
||||||
|
result['rope_theta'] = int(ropeTheta)
|
||||||
|
|
||||||
|
ropeScaling = config.get('rope_scaling')
|
||||||
|
if (ropeScaling is not None):
|
||||||
|
result['rope_scaling_factor'] = int(ropeScaling['factor'])
|
||||||
|
result['rope_scaling_low_freq_factor'] = int(ropeScaling['low_freq_factor'])
|
||||||
|
result['rope_scaling_high_freq_factory'] = int(ropeScaling['high_freq_factor'])
|
||||||
|
result['rope_scaling_orig_max_seq_len'] = int(ropeScaling['original_max_position_embeddings'])
|
||||||
|
result['rope_type'] = parseRopeType(ropeScaling['rope_type'])
|
||||||
|
|
||||||
|
headDim = config.get('head_dim')
|
||||||
|
if (headDim is not None):
|
||||||
|
result['head_dim'] = headDim
|
||||||
|
|
||||||
|
rmsNormEps = config.get('rms_norm_eps')
|
||||||
|
if (rmsNormEps is not None):
|
||||||
|
result['norm_epsilon'] = parseRmsNormEpsilon(rmsNormEps)
|
||||||
|
|
||||||
|
moeHiddenDim = config.get('moe_intermediate_size')
|
||||||
|
if (moeHiddenDim is not None):
|
||||||
|
result['moe_hidden_dim'] = int(moeHiddenDim)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def printUsage():
|
||||||
|
print('Usage: python convert-hf.py <sourceFolderPath> <weightsFloatType> <name>')
|
||||||
|
print()
|
||||||
|
print('Options:')
|
||||||
|
print(' <sourceFolderPath> The path to the folder containing the model files')
|
||||||
|
print(' <weightsFloatType> The float type of the weights (e.g. "q40")')
|
||||||
|
print(' <name> The name of the model (e.g. "llama3")')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if (len(sys.argv) < 4):
|
||||||
|
printUsage()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
sourceFolderPath = sys.argv[1]
|
||||||
|
weightsFloatType = parseFloatType(sys.argv[2])
|
||||||
|
name = sys.argv[3]
|
||||||
|
outputFileName = f'dllama_model_{name}_{sys.argv[2]}.m'
|
||||||
|
|
||||||
|
print(f'Output file: {outputFileName}')
|
||||||
|
|
||||||
|
config = loadConfig(sourceFolderPath, weightsFloatType)
|
||||||
|
|
||||||
|
with open(outputFileName, 'wb') as outputFile:
|
||||||
|
writeHeader(outputFile, config)
|
||||||
|
processor = Processor(config)
|
||||||
|
processor.write(outputFile)
|
||||||
|
|
||||||
|
print(f'✅ {outputFileName} created successfully')
|
||||||
121
converter/convert-llama.py
Normal file
121
converter/convert-llama.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from writer import writeTensor, writeHeader, parseFloatType, strFloatType, FloatType
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
LAYER_CHUNK_SIZE = 48
|
||||||
|
|
||||||
|
def convert(modelPath, outputPath, targetFloatType):
|
||||||
|
paramsPath = os.path.join(modelPath, 'params.json')
|
||||||
|
with open(paramsPath) as f:
|
||||||
|
params = json.load(f)
|
||||||
|
if (params['vocab_size'] < 1):
|
||||||
|
raise Exception('vocab_size is invalid, please update params.json file')
|
||||||
|
if (params.get('max_seq_len') is None):
|
||||||
|
raise Exception('max_seq_len is required, please update params.json file')
|
||||||
|
params['n_kv_heads'] = params.get('n_kv_heads') or params['n_heads']
|
||||||
|
params['head_size'] = params['dim'] / params['n_heads']
|
||||||
|
params['arch_type'] = 0xABCD00
|
||||||
|
params['n_experts'] = 0
|
||||||
|
params['n_active_experts'] = 0
|
||||||
|
params['weights_float_type'] = targetFloatType
|
||||||
|
if ('rope_theta' in params):
|
||||||
|
params['rope_theta'] = int(params['rope_theta'])
|
||||||
|
|
||||||
|
modelPaths = sorted(list(Path(modelPath).glob('consolidated.*.pth')))
|
||||||
|
nSlices = len(modelPaths)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append('tok_embeddings.weight')
|
||||||
|
for layerIndex in range(0, params['n_layers']):
|
||||||
|
layers.append(f'layers.{layerIndex}.attention.wq.weight')
|
||||||
|
layers.append(f'layers.{layerIndex}.attention.wk.weight')
|
||||||
|
layers.append(f'layers.{layerIndex}.attention.wv.weight')
|
||||||
|
layers.append(f'layers.{layerIndex}.attention.wo.weight')
|
||||||
|
layers.append(f'layers.{layerIndex}.feed_forward.w1.weight')
|
||||||
|
layers.append(f'layers.{layerIndex}.feed_forward.w2.weight')
|
||||||
|
layers.append(f'layers.{layerIndex}.feed_forward.w3.weight')
|
||||||
|
layers.append(f'layers.{layerIndex}.attention_norm.weight')
|
||||||
|
layers.append(f'layers.{layerIndex}.ffn_norm.weight')
|
||||||
|
layers.append('norm.weight')
|
||||||
|
layers.append('output.weight')
|
||||||
|
|
||||||
|
isHeaderWrote = False
|
||||||
|
outFile = open(outputPath, 'wb')
|
||||||
|
|
||||||
|
nChunks = math.ceil(len(layers) / LAYER_CHUNK_SIZE)
|
||||||
|
for chunkIndex in range(0, nChunks):
|
||||||
|
chunkLayerNames = layers[LAYER_CHUNK_SIZE * chunkIndex:LAYER_CHUNK_SIZE * (chunkIndex + 1)]
|
||||||
|
models = {}
|
||||||
|
for layerName in chunkLayerNames:
|
||||||
|
models[layerName] = []
|
||||||
|
|
||||||
|
print(f'💿 Chunking model {chunkIndex + 1}/{nChunks}...')
|
||||||
|
|
||||||
|
for modelPath in modelPaths:
|
||||||
|
model = torch.load(modelPath, map_location='cpu')
|
||||||
|
for modelKey in model:
|
||||||
|
if (modelKey in chunkLayerNames):
|
||||||
|
models[modelKey].append(model[modelKey])
|
||||||
|
if not isHeaderWrote:
|
||||||
|
params['hidden_dim'] = model['layers.0.feed_forward.w1.weight'].shape[0] * nSlices
|
||||||
|
writeHeader(outFile, params)
|
||||||
|
isHeaderWrote = True
|
||||||
|
del model
|
||||||
|
|
||||||
|
for layerName in chunkLayerNames:
|
||||||
|
if layerName == 'rope.freqs':
|
||||||
|
continue
|
||||||
|
|
||||||
|
isAxis1 = (
|
||||||
|
layerName == 'tok_embeddings.weight' or
|
||||||
|
layerName.endswith('.attention.wo.weight') or
|
||||||
|
layerName.endswith('.feed_forward.w2.weight')
|
||||||
|
)
|
||||||
|
isAlwaysF32 = (
|
||||||
|
layerName == 'tok_embeddings.weight' or
|
||||||
|
layerName.endswith('.attention_norm.weight') or
|
||||||
|
layerName.endswith('.ffn_norm.weight') or
|
||||||
|
layerName == 'norm.weight'
|
||||||
|
)
|
||||||
|
floatType = FloatType.F32 if isAlwaysF32 else targetFloatType
|
||||||
|
|
||||||
|
tensors = models[layerName]
|
||||||
|
if len(tensors) == 1 or len(tensors[0].shape) == 1:
|
||||||
|
tensor = tensors[0]
|
||||||
|
else:
|
||||||
|
tensor = torch.cat(tensors, dim=(1 if isAxis1 else 0))
|
||||||
|
|
||||||
|
print(f'🔶 Exporting {layerName} {tensor.shape}...')
|
||||||
|
writeTensor(outFile, tensor, floatType)
|
||||||
|
|
||||||
|
del models
|
||||||
|
|
||||||
|
outFile.close()
|
||||||
|
|
||||||
|
def usage():
|
||||||
|
print('Usage: python convert-llama.py <modelPath> <targetFloatType>')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if (len(sys.argv) < 3):
|
||||||
|
usage()
|
||||||
|
|
||||||
|
modelPath = sys.argv[1]
|
||||||
|
targetFloatType = parseFloatType(sys.argv[2])
|
||||||
|
targetFloatTypeStr = strFloatType(targetFloatType)
|
||||||
|
|
||||||
|
modelName = os.path.basename(modelPath)
|
||||||
|
outputFileName = f'dllama_model_{modelName.lower()}_{targetFloatTypeStr}.m'
|
||||||
|
|
||||||
|
print(f'Model name: {modelName}')
|
||||||
|
print(f'Target float type: {targetFloatTypeStr}')
|
||||||
|
print(f'Target file: {outputFileName}')
|
||||||
|
|
||||||
|
convert(modelPath, outputFileName, targetFloatType)
|
||||||
|
|
||||||
|
print('Done!')
|
||||||
137
converter/convert-tokenizer-hf.py
Normal file
137
converter/convert-tokenizer-hf.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
from transformers import PreTrainedTokenizerFast
|
||||||
|
writer = __import__('tokenizer-writer')
|
||||||
|
|
||||||
|
def openJson(path):
|
||||||
|
with open(path, 'r', encoding='utf-8') as file:
|
||||||
|
return json.load(file)
|
||||||
|
|
||||||
|
def unicodeToBytes():
|
||||||
|
# https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/encoder.py#L9
|
||||||
|
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
||||||
|
cs = bs[:]
|
||||||
|
n = 0
|
||||||
|
for b in range(2 ** 8):
|
||||||
|
if b not in bs:
|
||||||
|
bs.append(b)
|
||||||
|
cs.append(2 ** 8 + n)
|
||||||
|
n += 1
|
||||||
|
cs = [chr(n) for n in cs]
|
||||||
|
return dict(zip(cs, bs))
|
||||||
|
|
||||||
|
class TokensResolver:
|
||||||
|
def __init__(self, dirPath, tokenizerConfig):
|
||||||
|
self.dirPath = dirPath
|
||||||
|
self.tokenizerConfig = tokenizerConfig
|
||||||
|
self.bosId = None
|
||||||
|
self.eosIds = None
|
||||||
|
self.tokens = []
|
||||||
|
self.scores = []
|
||||||
|
|
||||||
|
def resolvePreTrainedTokenizerFast(self):
|
||||||
|
utb = unicodeToBytes()
|
||||||
|
tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(self.dirPath, 'tokenizer.json'))
|
||||||
|
vocabLen = len(tokenizer.get_vocab())
|
||||||
|
for i in range(vocabLen):
|
||||||
|
tokenChars = list(tokenizer.convert_ids_to_tokens([i])[0])
|
||||||
|
tokenBytes = []
|
||||||
|
for chr in tokenChars:
|
||||||
|
if (chr in utb):
|
||||||
|
tokenBytes.append(utb[chr])
|
||||||
|
else:
|
||||||
|
tokenBytes += list(chr.encode('utf-8'))
|
||||||
|
self.tokens.append(bytes(tokenBytes))
|
||||||
|
self.scores.append(-float(i))
|
||||||
|
|
||||||
|
self.bosId = tokenizer.bos_token_id
|
||||||
|
if (tokenizer.eos_token_id):
|
||||||
|
self.eosIds = [tokenizer.eos_token_id]
|
||||||
|
if (self.bosId is None or self.eosId is None):
|
||||||
|
config = openJson(os.path.join(self.dirPath, 'config.json'))
|
||||||
|
if (self.bosId is None):
|
||||||
|
self.bosId = config['bos_token_id']
|
||||||
|
if (self.eosIds is None):
|
||||||
|
self.eosIds = config['eos_token_id']
|
||||||
|
if isinstance(self.eosIds, list):
|
||||||
|
self.eosIds = self.eosIds
|
||||||
|
else:
|
||||||
|
self.eosIds = [self.eosIds]
|
||||||
|
|
||||||
|
def resolveLlamaTokenizer(self):
|
||||||
|
modelPath = os.path.join(self.dirPath, 'tokenizer.model')
|
||||||
|
processor = SentencePieceProcessor(model_file=modelPath)
|
||||||
|
|
||||||
|
assert processor.vocab_size() == processor.get_piece_size()
|
||||||
|
self.bosId = processor.bos_id()
|
||||||
|
self.eosIds = [processor.eos_id()]
|
||||||
|
vocabSize = processor.vocab_size()
|
||||||
|
for i in range(vocabSize):
|
||||||
|
t = processor.id_to_piece(i)
|
||||||
|
s = processor.get_score(i)
|
||||||
|
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
|
||||||
|
# Check for byte characters
|
||||||
|
if len(t) == 6 and t.startswith('<0x') and t.endswith('>'):
|
||||||
|
# For example, "<0x0A>"" is a newline character
|
||||||
|
b = bytearray.fromhex(t[3:-1])
|
||||||
|
else:
|
||||||
|
b = t.encode('utf-8')
|
||||||
|
self.tokens.append(b)
|
||||||
|
self.scores.append(s)
|
||||||
|
|
||||||
|
def resolve(self):
|
||||||
|
cls = self.tokenizerConfig['tokenizer_class']
|
||||||
|
if (cls == 'PreTrainedTokenizerFast' or
|
||||||
|
cls == 'LlamaTokenizerFast' or
|
||||||
|
cls == 'Qwen2Tokenizer'):
|
||||||
|
return self.resolvePreTrainedTokenizerFast()
|
||||||
|
if (cls == 'LlamaTokenizer'):
|
||||||
|
return self.resolveLlamaTokenizer()
|
||||||
|
raise Exception(f'Tokenizer {cls} is not supported')
|
||||||
|
|
||||||
|
def printUsage():
|
||||||
|
print('Usage: python convert-tokenizer-hf.py <tokenizerFolderPath> <name>')
|
||||||
|
print()
|
||||||
|
print('Options:')
|
||||||
|
print(' <tokenizerFolderPath> The path to the folder with tokenizer_config.json')
|
||||||
|
print(' <name> The name of the tokenizer (e.g. "llama3")')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if (len(sys.argv) < 2):
|
||||||
|
printUsage()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
dirPath = sys.argv[1]
|
||||||
|
name = sys.argv[2]
|
||||||
|
tokenizerConfig = openJson(os.path.join(dirPath, 'tokenizer_config.json'))
|
||||||
|
|
||||||
|
resolver = TokensResolver(dirPath, tokenizerConfig)
|
||||||
|
resolver.resolve()
|
||||||
|
|
||||||
|
if (resolver.bosId is None or resolver.eosIds is None):
|
||||||
|
raise Exception('Cannot resolve bosId or eosIds')
|
||||||
|
print(f'bosId: {resolver.bosId} ({resolver.tokens[resolver.bosId]})')
|
||||||
|
for eosId in resolver.eosIds:
|
||||||
|
print(f'eosId: {eosId} ({resolver.tokens[eosId]})')
|
||||||
|
|
||||||
|
chatTemplate = None
|
||||||
|
if ('chat_template' in tokenizerConfig):
|
||||||
|
chatTemplate = tokenizerConfig['chat_template'].encode('utf-8')
|
||||||
|
|
||||||
|
addBos = True
|
||||||
|
if ('add_bos_token' in tokenizerConfig):
|
||||||
|
addBos = tokenizerConfig['add_bos_token']
|
||||||
|
|
||||||
|
outputFileName = f'dllama_tokenizer_{name}.t'
|
||||||
|
with open(outputFileName, 'wb') as outputFile:
|
||||||
|
writer.writeTokenizer(
|
||||||
|
outputFile,
|
||||||
|
resolver.tokens,
|
||||||
|
resolver.scores,
|
||||||
|
chatTemplate,
|
||||||
|
resolver.bosId,
|
||||||
|
addBos,
|
||||||
|
resolver.eosIds)
|
||||||
|
print(f'✅ Created {outputFileName}')
|
||||||
44
converter/convert-tokenizer-llama2.py
Normal file
44
converter/convert-tokenizer-llama2.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
writer = __import__('tokenizer-writer')
|
||||||
|
|
||||||
|
chatTemplate = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
|
||||||
|
|
||||||
|
def printUsage():
|
||||||
|
print('Usage: python convert-tokenizer-llama2.py <llama2FolderPath>')
|
||||||
|
print()
|
||||||
|
print('Options:')
|
||||||
|
print(' <llama2FolderPath> The path to the folder with llama2 folder path')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if (len(sys.argv) < 2):
|
||||||
|
printUsage()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
dirPath = sys.argv[1]
|
||||||
|
modelPath = os.path.join(dirPath, 'tokenizer.model')
|
||||||
|
processor = SentencePieceProcessor(model_file=modelPath)
|
||||||
|
|
||||||
|
vocabSize = processor.vocab_size()
|
||||||
|
tokens = []
|
||||||
|
scores = []
|
||||||
|
for i in range(vocabSize):
|
||||||
|
t = processor.id_to_piece(i)
|
||||||
|
s = processor.get_score(i)
|
||||||
|
t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
|
||||||
|
b = t.encode('utf-8')
|
||||||
|
tokens.append(b)
|
||||||
|
scores.append(s)
|
||||||
|
|
||||||
|
outputFileName = 'dllama_tokenizer_llama2.t'
|
||||||
|
with open(outputFileName, 'wb') as outputFile:
|
||||||
|
writer.writeTokenizer(
|
||||||
|
outputFile,
|
||||||
|
tokens,
|
||||||
|
scores,
|
||||||
|
chatTemplate.encode('utf-8'),
|
||||||
|
processor.bos_id(),
|
||||||
|
[processor.eos_id()])
|
||||||
|
|
||||||
|
print(f'✅ Created {outputFileName}')
|
||||||
78
converter/convert-tokenizer-llama3.py
Normal file
78
converter/convert-tokenizer-llama3.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
import sys
|
||||||
|
import base64
|
||||||
|
writer = __import__('tokenizer-writer')
|
||||||
|
|
||||||
|
# Format of input file:
|
||||||
|
# ```
|
||||||
|
# IQ== 0
|
||||||
|
# Ig== 1
|
||||||
|
# Iw== 2
|
||||||
|
# ...
|
||||||
|
# ```
|
||||||
|
|
||||||
|
nSpecialTokens = 256
|
||||||
|
specialTokens = [
|
||||||
|
'<|begin_of_text|>',
|
||||||
|
'<|end_of_text|>',
|
||||||
|
'<|reserved_special_token_0|>',
|
||||||
|
'<|reserved_special_token_1|>',
|
||||||
|
'<|reserved_special_token_2|>',
|
||||||
|
'<|reserved_special_token_3|>',
|
||||||
|
'<|start_header_id|>',
|
||||||
|
'<|end_header_id|>',
|
||||||
|
'<|reserved_special_token_4|>',
|
||||||
|
'<|eot_id|>',
|
||||||
|
] + [
|
||||||
|
f'<|reserved_special_token_{i}|>'
|
||||||
|
for i in range(5, nSpecialTokens - 5)
|
||||||
|
]
|
||||||
|
bosId = 128000
|
||||||
|
eosId = 128001
|
||||||
|
chatEosId = 128009
|
||||||
|
chatTemplate = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"
|
||||||
|
|
||||||
|
def printUsage():
|
||||||
|
print('Usage: python convert-tokenizer-llama3.py <tokenizerPath>')
|
||||||
|
print()
|
||||||
|
print('Options:')
|
||||||
|
print(' <tokenizerPath> The path to the Llama 3 tokenizer model (tokenizer.model)')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if (len(sys.argv) < 2):
|
||||||
|
printUsage()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
modelPath = sys.argv[1]
|
||||||
|
outputFileName = 'dllama_tokenizer_llama3.t'
|
||||||
|
|
||||||
|
with open(modelPath, 'r') as inputFile:
|
||||||
|
with open(outputFileName, 'wb') as outputFile:
|
||||||
|
inputLines = inputFile.readlines()
|
||||||
|
nLines = len(inputLines)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
scores = []
|
||||||
|
for line in inputLines:
|
||||||
|
s = line.split(' ')
|
||||||
|
bytes = base64.b64decode(s[0])
|
||||||
|
score = -float(s[1])
|
||||||
|
tokens.append(bytes)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
specialTokenIndex = nLines
|
||||||
|
for token in specialTokens:
|
||||||
|
bytes = token.encode('utf-8')
|
||||||
|
score = -float(specialTokenIndex)
|
||||||
|
tokens.append(bytes)
|
||||||
|
scores.append(score)
|
||||||
|
specialTokenIndex += 1
|
||||||
|
|
||||||
|
writer.writeTokenizer(
|
||||||
|
outputFile,
|
||||||
|
tokens,
|
||||||
|
scores,
|
||||||
|
chatTemplate.encode('utf-8'),
|
||||||
|
bosId,
|
||||||
|
[eosId, chatEosId])
|
||||||
|
|
||||||
|
print(f'✅ Created {outputFileName}')
|
||||||
5
converter/requirements.txt
Normal file
5
converter/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
python>=3.9
|
||||||
|
numpy==1.23.5
|
||||||
|
pytorch==2.0.1
|
||||||
|
safetensors==0.4.2
|
||||||
|
sentencepiece==0.1.99
|
||||||
57
converter/tokenizer-writer.py
Normal file
57
converter/tokenizer-writer.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import struct
|
||||||
|
|
||||||
|
def writeTokenizer(file, tokens, scores, chatTemplate, bosId, addBos, eosTokens):
|
||||||
|
headerKeys = {
|
||||||
|
'version': 0,
|
||||||
|
'vocab_size': 1,
|
||||||
|
'max_token_length': 2,
|
||||||
|
'bos_id': 3,
|
||||||
|
'chat_template': 7,
|
||||||
|
'n_eos_tokens': 9,
|
||||||
|
'add_bos': 10,
|
||||||
|
}
|
||||||
|
header = struct.pack('i', 0x567124)
|
||||||
|
|
||||||
|
nTokens = len(tokens)
|
||||||
|
maxTokenLength = max(len(t) for t in tokens)
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
params['bos_id'] = bosId
|
||||||
|
params['version'] = 1
|
||||||
|
params['vocab_size'] = nTokens
|
||||||
|
params['max_token_length'] = maxTokenLength
|
||||||
|
if (chatTemplate):
|
||||||
|
params['chat_template'] = len(chatTemplate)
|
||||||
|
params['n_eos_tokens'] = len(eosTokens)
|
||||||
|
params['add_bos'] = 1 if addBos else 0
|
||||||
|
|
||||||
|
data = b''
|
||||||
|
for key in params:
|
||||||
|
value = params[key]
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if key in headerKeys:
|
||||||
|
data += struct.pack('ii', headerKeys[key], params[key])
|
||||||
|
else:
|
||||||
|
print(f'Unknown header key: {key}')
|
||||||
|
|
||||||
|
print('⭐ Params:')
|
||||||
|
print(params)
|
||||||
|
if (chatTemplate):
|
||||||
|
print('⭐ Chat template:')
|
||||||
|
print(chatTemplate)
|
||||||
|
|
||||||
|
header += struct.pack('i', len(header) * 2 + len(data))
|
||||||
|
file.write(header)
|
||||||
|
file.write(data)
|
||||||
|
if chatTemplate:
|
||||||
|
file.write(chatTemplate)
|
||||||
|
|
||||||
|
for eosToken in eosTokens:
|
||||||
|
file.write(struct.pack('i', eosToken))
|
||||||
|
|
||||||
|
for i in range(0, nTokens):
|
||||||
|
size = len(tokens[i])
|
||||||
|
assert(size > 0)
|
||||||
|
file.write(struct.pack('fI', scores[i], size))
|
||||||
|
file.write(tokens[i])
|
||||||
35
converter/writer-test.py
Normal file
35
converter/writer-test.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
from writer import writeQuantizedQ40Tensor
|
||||||
|
|
||||||
|
TEMP_FILE_NAME = 'writer-test.temp'
|
||||||
|
|
||||||
|
def readBase64FromFile(path):
|
||||||
|
with open(path, 'rb') as file:
|
||||||
|
return file.read().hex()
|
||||||
|
|
||||||
|
def testWriteQuantizedQ40Tensor():
|
||||||
|
EXPECTED_OUTPUT = '7e346345a692b89665b2c5790537876e598aaa366d988876a898b8d788a98868ce660c66f6b3a88cba5ce9a871987ba9cc5bcaaa760c1eb556a4455b747b6b9504968828ef2a8d7c1db5c6be3764799e66db6d8e76463126a30e4333cad7a4f645947c6cf97f9de086d468c8d535a6ba7dc799d3d0c657bab6799468cad8bb349eb7d7635c7c798998696bb38e4085a9eb34444ba96a7f8ba7b2b42d746a96cf9660aeb4499d8708ad5c7b9a7558947645f3bbb6b0346a656887ad9a86059baac5c596ab781c703569bb8a4356a4bd58cb78736ba09759bb0e34a6274e827b957d7a67dfa86846955660d234b6d9d78a378094a8a8708a7a774ae92f8a36b8c999a9b77a7d958a69747c807963941235379886d69a7a8767b3a6a4ac71999760'
|
||||||
|
|
||||||
|
torch.manual_seed(seed=1)
|
||||||
|
tensor = torch.randn(32, 16)
|
||||||
|
|
||||||
|
with open(TEMP_FILE_NAME, 'wb') as file:
|
||||||
|
writeQuantizedQ40Tensor(file, tensor)
|
||||||
|
|
||||||
|
contentBase64 = readBase64FromFile(TEMP_FILE_NAME)
|
||||||
|
assert contentBase64 == EXPECTED_OUTPUT, f'Received: {contentBase64}'
|
||||||
|
print('✅ writeQuantizedQ40Tensor')
|
||||||
|
|
||||||
|
def runWriteQuantizedQ40TensorBenchmark():
|
||||||
|
tensor = torch.randn(8192, 4096)
|
||||||
|
t0 = time.time()
|
||||||
|
with open(TEMP_FILE_NAME, 'wb') as file:
|
||||||
|
writeQuantizedQ40Tensor(file, tensor)
|
||||||
|
t1 = time.time()
|
||||||
|
print(f'🕐 writeQuantizedQ40Tensor: {t1 - t0:.4f}s')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
testWriteQuantizedQ40Tensor()
|
||||||
|
runWriteQuantizedQ40TensorBenchmark()
|
||||||
148
converter/writer.py
Normal file
148
converter/writer.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
import struct
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class FloatType:
|
||||||
|
F32 = 0
|
||||||
|
F16 = 1
|
||||||
|
Q40 = 2
|
||||||
|
Q80 = 3
|
||||||
|
|
||||||
|
floatTypeMap = {
|
||||||
|
'f32': FloatType.F32,
|
||||||
|
'f16': FloatType.F16,
|
||||||
|
'q40': FloatType.Q40,
|
||||||
|
'q80': FloatType.Q80,
|
||||||
|
}
|
||||||
|
floatTypeNames = list(floatTypeMap.keys())
|
||||||
|
|
||||||
|
def parseFloatType(type):
|
||||||
|
floatType = floatTypeMap.get(type)
|
||||||
|
if floatType is not None:
|
||||||
|
return floatType
|
||||||
|
raise Exception(f'{type} is not supported')
|
||||||
|
|
||||||
|
def strFloatType(type):
|
||||||
|
return floatTypeNames[type]
|
||||||
|
|
||||||
|
def writeQuantizedQ40Tensor(file, x):
|
||||||
|
x = x.to(torch.float32).numpy().astype(np.float32)
|
||||||
|
blockSize = 32
|
||||||
|
blockHalfSize = blockSize // 2
|
||||||
|
assert(x.shape[0] % blockSize == 0)
|
||||||
|
groups = x.reshape(-1, blockSize)
|
||||||
|
gmax = np.max(groups, axis=1)
|
||||||
|
gmin = np.min(groups, axis=1)
|
||||||
|
deltas = np.divide(np.where(-gmin > gmax, gmin, gmax), -8)
|
||||||
|
deltas16 = deltas.astype(np.float16)
|
||||||
|
ids = np.where(deltas != 0, 1.0 / deltas, 0)
|
||||||
|
groups = np.add(groups * ids[:, np.newaxis], 8.5)
|
||||||
|
groups = np.clip(groups, 0, 15).astype(int)
|
||||||
|
|
||||||
|
gLow = groups[:, :blockHalfSize] & 0xF
|
||||||
|
gHigh = (groups[:, blockHalfSize:] & 0xF) << 4
|
||||||
|
gCombined = gLow | gHigh
|
||||||
|
|
||||||
|
nBytes = 0
|
||||||
|
for groupIndex in range(0, len(groups)):
|
||||||
|
delta16 = deltas16[groupIndex]
|
||||||
|
buffer = struct.pack(f'e{blockHalfSize}B', delta16, *gCombined[groupIndex])
|
||||||
|
file.write(buffer)
|
||||||
|
nBytes += len(buffer)
|
||||||
|
return nBytes
|
||||||
|
|
||||||
|
def writeQuantizedQ80Tensor(file, x):
|
||||||
|
x = x.to(torch.float32).numpy().astype(np.float32)
|
||||||
|
blockSize = 32
|
||||||
|
assert(x.shape[0] % blockSize == 0)
|
||||||
|
groups = x.reshape(-1, blockSize)
|
||||||
|
gmax = np.max(groups, axis=1)
|
||||||
|
gmin = np.min(groups, axis=1)
|
||||||
|
gabsMax = np.where(-gmin > gmax, -gmin, gmax)
|
||||||
|
deltas = gabsMax / ((1 << 7) - 1)
|
||||||
|
deltas16 = deltas.astype(np.float16)
|
||||||
|
ids = np.where(deltas != 0, 1.0 / deltas, 0)
|
||||||
|
groups = groups * ids[:, np.newaxis]
|
||||||
|
groups8 = np.round(groups).astype(np.int8)
|
||||||
|
|
||||||
|
nBytes = 0
|
||||||
|
for groupIndex in range(0, len(groups)):
|
||||||
|
buffer = struct.pack(f'e{blockSize}b', deltas16[groupIndex], *groups8[groupIndex])
|
||||||
|
file.write(buffer)
|
||||||
|
nBytes += len(buffer)
|
||||||
|
return nBytes
|
||||||
|
|
||||||
|
def writeF32Tensor(file, d):
|
||||||
|
chunkSize = 10000
|
||||||
|
nBytes = 0
|
||||||
|
for i in range(0, len(d), chunkSize):
|
||||||
|
chunk = d[i:i+chunkSize].to(torch.float32).numpy().astype(np.float32)
|
||||||
|
b = struct.pack(f'{len(chunk)}f', *chunk)
|
||||||
|
nBytes += len(b)
|
||||||
|
file.write(b)
|
||||||
|
return nBytes
|
||||||
|
|
||||||
|
def writeF16Tensor(file, d):
|
||||||
|
d = d.to(torch.float16).numpy().astype(np.float16)
|
||||||
|
b = struct.pack(f'{len(d)}e', *d)
|
||||||
|
file.write(b)
|
||||||
|
return len(b)
|
||||||
|
|
||||||
|
def writeTensor(file, tensor, floatType):
|
||||||
|
d = tensor.detach().cpu().view(-1)
|
||||||
|
t0 = time.time()
|
||||||
|
nBytes = 0
|
||||||
|
if (floatType == FloatType.F16):
|
||||||
|
nBytes = writeF16Tensor(file, d)
|
||||||
|
elif (floatType == FloatType.F32):
|
||||||
|
nBytes = writeF32Tensor(file, d)
|
||||||
|
elif (floatType == FloatType.Q40):
|
||||||
|
nBytes = writeQuantizedQ40Tensor(file, d)
|
||||||
|
elif (floatType == FloatType.Q80):
|
||||||
|
nBytes = writeQuantizedQ80Tensor(file, d)
|
||||||
|
else:
|
||||||
|
raise Exception(f'Unknown float type')
|
||||||
|
t1 = time.time()
|
||||||
|
print(f'Saved {strFloatType(floatType)} tensor in {t1 - t0:.2f}s, {nBytes} bytes')
|
||||||
|
|
||||||
|
def writeHeader(file, params):
|
||||||
|
headerKeys = {
|
||||||
|
'version': 0,
|
||||||
|
'arch_type': 1,
|
||||||
|
'dim': 2,
|
||||||
|
'hidden_dim': 3,
|
||||||
|
'n_layers': 4,
|
||||||
|
'n_heads': 5,
|
||||||
|
'n_kv_heads': 6,
|
||||||
|
'n_experts': 7,
|
||||||
|
'n_active_experts': 8,
|
||||||
|
'vocab_size': 9,
|
||||||
|
'max_seq_len': 10,
|
||||||
|
'hidden_act': 11,
|
||||||
|
'rope_theta': 12,
|
||||||
|
'weights_float_type': 13,
|
||||||
|
'rope_scaling_factor': 14,
|
||||||
|
'rope_scaling_low_freq_factor': 15,
|
||||||
|
'rope_scaling_high_freq_factory': 16,
|
||||||
|
'rope_scaling_orig_max_seq_len': 17,
|
||||||
|
'rope_type': 18,
|
||||||
|
'head_dim': 19,
|
||||||
|
'norm_epsilon': 20,
|
||||||
|
'moe_hidden_dim': 21,
|
||||||
|
}
|
||||||
|
header = struct.pack('i', 0xA00ABCD)
|
||||||
|
|
||||||
|
data = b''
|
||||||
|
for key in params:
|
||||||
|
if key in headerKeys:
|
||||||
|
data += struct.pack('ii', headerKeys[key], params[key])
|
||||||
|
else:
|
||||||
|
print(f'Warning: Unknown header key: {key}')
|
||||||
|
|
||||||
|
header += struct.pack('i', len(header) * 2 + len(data))
|
||||||
|
file.write(header)
|
||||||
|
file.write(data)
|
||||||
|
for key in params:
|
||||||
|
print(f'🎓 {key}: {params[key]}')
|
||||||
|
print()
|
||||||
81
docker-compose.yml
Normal file
81
docker-compose.yml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
# Controller service - downloads models and runs API
|
||||||
|
controller:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.controller
|
||||||
|
ports:
|
||||||
|
- "9999:9999"
|
||||||
|
volumes:
|
||||||
|
- ./models:/app/models
|
||||||
|
networks:
|
||||||
|
distributed-llama:
|
||||||
|
ipv4_address: 172.20.0.10
|
||||||
|
environment:
|
||||||
|
- MODEL_NAME=${MODEL_NAME:-llama3_2_3b_instruct_q40}
|
||||||
|
- NTHREADS=${CONTROLLER_NTHREADS:-4}
|
||||||
|
- MAX_SEQ_LEN=${MAX_SEQ_LEN:-4096}
|
||||||
|
- BUFFER_FLOAT_TYPE=${BUFFER_FLOAT_TYPE:-q80}
|
||||||
|
command: >
|
||||||
|
--model ${MODEL_NAME:-llama3_2_3b_instruct_q40}
|
||||||
|
--port 9999
|
||||||
|
--nthreads ${CONTROLLER_NTHREADS:-4}
|
||||||
|
--max-seq-len ${MAX_SEQ_LEN:-4096}
|
||||||
|
--buffer-float-type ${BUFFER_FLOAT_TYPE:-q80}
|
||||||
|
--workers 172.20.0.11:9999 172.20.0.12:9999 172.20.0.13:9999
|
||||||
|
depends_on:
|
||||||
|
- worker1
|
||||||
|
- worker2
|
||||||
|
- worker3
|
||||||
|
|
||||||
|
# Worker services
|
||||||
|
worker1:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.worker
|
||||||
|
networks:
|
||||||
|
distributed-llama:
|
||||||
|
ipv4_address: 172.20.0.11
|
||||||
|
environment:
|
||||||
|
- NTHREADS=${WORKER_NTHREADS:-4}
|
||||||
|
command: >
|
||||||
|
--port 9999
|
||||||
|
--nthreads ${WORKER_NTHREADS:-4}
|
||||||
|
|
||||||
|
worker2:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.worker
|
||||||
|
networks:
|
||||||
|
distributed-llama:
|
||||||
|
ipv4_address: 172.20.0.12
|
||||||
|
environment:
|
||||||
|
- NTHREADS=${WORKER_NTHREADS:-4}
|
||||||
|
command: >
|
||||||
|
--port 9999
|
||||||
|
--nthreads ${WORKER_NTHREADS:-4}
|
||||||
|
|
||||||
|
worker3:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile.worker
|
||||||
|
networks:
|
||||||
|
distributed-llama:
|
||||||
|
ipv4_address: 172.20.0.13
|
||||||
|
environment:
|
||||||
|
- NTHREADS=${WORKER_NTHREADS:-4}
|
||||||
|
command: >
|
||||||
|
--port 9999
|
||||||
|
--nthreads ${WORKER_NTHREADS:-4}
|
||||||
|
|
||||||
|
networks:
|
||||||
|
distributed-llama:
|
||||||
|
driver: bridge
|
||||||
|
ipam:
|
||||||
|
config:
|
||||||
|
- subnet: 172.20.0.0/16
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
models:
|
||||||
32
docs/HOW_TO_CONVERT_HF_MODEL.md
Normal file
32
docs/HOW_TO_CONVERT_HF_MODEL.md
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
# How to Convert 🤗 Hugging Face Model
|
||||||
|
|
||||||
|
Currently, Distributed Llama supports these Hugging Face models: `llama`, `mistral`, `qwen3` and `qwen3_moe`. You can try to convert any compatible Hugging Face model and run it with Distributed Llama.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> All converters are in the early stages of development. After conversion, the model may not work correctly.
|
||||||
|
|
||||||
|
1. Download a model, for example: [Mistral-7B-v0.3](https://huggingface.co/mistralai/Mistral-7B-v0.3/tree/main).
|
||||||
|
2. The downloaded model should contain `config.json`, `tokenizer.json`, `tokenizer_config.json` and `tokenizer.model` and safetensor files.
|
||||||
|
3. Run the converter of the model:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd converter
|
||||||
|
python convert-hf.py path/to/hf/model q40 mistral-7b-0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Run the converter of the tokenizer:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python convert-tokenizer-hf.py path/to/hf/model mistral-7b-0.3
|
||||||
|
```
|
||||||
|
|
||||||
|
5. That's it! Now you can run the Distributed Llama.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./dllama inference \
|
||||||
|
--prompt "Hello world" \
|
||||||
|
--steps 64 \
|
||||||
|
--model dllama_model_mistral-7b-0.3_q40.m \
|
||||||
|
--tokenizer dllama_tokenizer_mistral-7b-0.3.t \
|
||||||
|
--buffer-float-type q80
|
||||||
|
```
|
||||||
34
docs/HOW_TO_RUN_GPU.md
Normal file
34
docs/HOW_TO_RUN_GPU.md
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
# How to Run Distributed Llama on 🧠 GPU
|
||||||
|
|
||||||
|
Distributed Llama can run on GPU devices using Vulkan API. This article describes how to build and run the project on GPU.
|
||||||
|
|
||||||
|
Before you start here, please check how to build and run Distributed Llama on CPU:
|
||||||
|
* [🍓 How to Run on Raspberry Pi](./HOW_TO_RUN_RASPBERRYPI.md)
|
||||||
|
* [💻 How to Run on Linux, MacOS or Windows](./HOW_TO_RUN_LINUX_MACOS_WIN.md)
|
||||||
|
|
||||||
|
To run on GPU, please follow these steps:
|
||||||
|
|
||||||
|
1. Install Vulkan SDK for your platform.
|
||||||
|
* Linux: please check [this article](https://vulkan.lunarg.com/doc/view/latest/linux/getting_started_ubuntu.html).
|
||||||
|
* MacOS: download SDK [here](https://vulkan.lunarg.com/sdk/home#mac).
|
||||||
|
2. Build Distributed Llama with GPU support:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
DLLAMA_VULKAN=1 make dllama
|
||||||
|
DLLAMA_VULKAN=1 make dllama-api
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Now `dllama` and `dllama-api` binaries supports arguments related to GPU usage.
|
||||||
|
|
||||||
|
```
|
||||||
|
--gpu-index <index> Use GPU device with given index (use `0` for first device)
|
||||||
|
```
|
||||||
|
|
||||||
|
4. You can run the root node or worker node on GPU by specifying the `--gpu-index` argument. Vulkan backend requires single thread, so you should also set `--nthreads 1`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./dllama inference ... --nthreads 1 --gpu-index 0
|
||||||
|
./dllama chat ... --nthreads 1 --gpu-index 0
|
||||||
|
./dllama worker ... --nthreads 1 --gpu-index 0
|
||||||
|
./dllama-api ... --nthreads 1 --gpu-index 0
|
||||||
|
```
|
||||||
89
docs/HOW_TO_RUN_LINUX_MACOS_WIN.md
Normal file
89
docs/HOW_TO_RUN_LINUX_MACOS_WIN.md
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# How to Run Distributed Llama on 💻 Linux, MacOS or Windows
|
||||||
|
|
||||||
|
This article describes how to run Distributed Llama on 4 devices, but you can also run it on 1, 2, 4, 8... devices. Please adjust the commands and topology according to your configuration.
|
||||||
|
|
||||||
|
````
|
||||||
|
[🔀 SWITCH OR ROUTER]
|
||||||
|
| | | |
|
||||||
|
| | | |_______ 🔸 device1 (ROOT) 10.0.0.1
|
||||||
|
| | |_________ 🔹 device2 (WORKER 1) 10.0.0.2:9999
|
||||||
|
| |___________ 🔹 device3 (WORKER 2) 10.0.0.3:9999
|
||||||
|
|_____________ 🔹 device4 (WORKER 3) 10.0.0.4:9999
|
||||||
|
````
|
||||||
|
|
||||||
|
1. Install Git and C++ compiler on **🔸🔹 ALL** devices:
|
||||||
|
|
||||||
|
* Linux:
|
||||||
|
```
|
||||||
|
sudo apt install git build-essential
|
||||||
|
```
|
||||||
|
* MacOS
|
||||||
|
```
|
||||||
|
brew install git
|
||||||
|
```
|
||||||
|
* Windows
|
||||||
|
|
||||||
|
Install Git and Mingw (via [Chocolatey](https://chocolatey.org/install)):
|
||||||
|
```powershell
|
||||||
|
choco install git mingw
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Connect **🔸🔹 ALL** devices to your **🔀 SWITCH OR ROUTER** via Ethernet cable. If you're using only two devices, it's better to connect them directly without a switch.
|
||||||
|
|
||||||
|
3. Clone this repository and compile Distributed Llama on **🔸🔹 ALL** devices:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
git clone https://github.com/b4rtaz/distributed-llama.git
|
||||||
|
cd distributed-llama
|
||||||
|
make dllama
|
||||||
|
make dllama-api
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Download the model to the **🔸 ROOT** device using the `launch.py` script. You don't need to download the model on worker devices.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python3 launch.py # Prints a list of available models
|
||||||
|
|
||||||
|
python3 launch.py llama3_2_3b_instruct_q40 # Downloads the model to the root device
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Start workers on all **🔹 WORKER** devices:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./dllama worker --port 9999 --nthreads 4
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Run the inference to test if everything works fine on the **🔸 ROOT** device:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./dllama inference \
|
||||||
|
--prompt "Hello world" \
|
||||||
|
--steps 32 \
|
||||||
|
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
|
||||||
|
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \
|
||||||
|
--buffer-float-type q80 \
|
||||||
|
--nthreads 4 \
|
||||||
|
--max-seq-len 4096 \
|
||||||
|
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
|
||||||
|
```
|
||||||
|
|
||||||
|
7. To run the API server, start it on the **🔸 ROOT** device:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
./dllama-api \
|
||||||
|
--port 9999 \
|
||||||
|
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
|
||||||
|
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \
|
||||||
|
--buffer-float-type q80 \
|
||||||
|
--nthreads 4 \
|
||||||
|
--max-seq-len 4096 \
|
||||||
|
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can connect to the API server:
|
||||||
|
|
||||||
|
```
|
||||||
|
http://10.0.0.1:9999/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
8. When the API server is running, you can open the web chat in your browser, open [llama-ui.js.org](https://llama-ui.js.org/), go to the settings and set the base URL to: `http://10.0.0.1:9999`. Press the "save" button and start chatting!
|
||||||
96
docs/HOW_TO_RUN_RASPBERRYPI.md
Normal file
96
docs/HOW_TO_RUN_RASPBERRYPI.md
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# How to Run Distributed Llama on 🍓 Raspberry Pi
|
||||||
|
|
||||||
|
This article describes how to run Distributed Llama on 4 Raspberry Pi devices, but you can also run it on 1, 2, 4, 8... devices. Please adjust the commands and topology according to your configuration.
|
||||||
|
|
||||||
|
````
|
||||||
|
[🔀 SWITCH OR ROUTER]
|
||||||
|
| | | |
|
||||||
|
| | | |_______ 🔸 raspberrypi1 (ROOT) 10.0.0.1
|
||||||
|
| | |_________ 🔹 raspberrypi2 (WORKER 1) 10.0.0.2:9999
|
||||||
|
| |___________ 🔹 raspberrypi3 (WORKER 2) 10.0.0.3:9999
|
||||||
|
|_____________ 🔹 raspberrypi4 (WORKER 3) 10.0.0.4:9999
|
||||||
|
````
|
||||||
|
|
||||||
|
1. Install `Raspberry Pi OS Lite (64 bit)` on your **🔸🔹 ALL** Raspberry Pi devices. This OS doesn't have desktop environment but you can easily connect via SSH to manage it.
|
||||||
|
2. Connect **🔸🔹 ALL** devices to your **🔀 SWITCH OR ROUTER** via Ethernet cable. If you're using only two devices, it's better to connect them directly without a switch.
|
||||||
|
3. Connect to all devices via SSH from your computer.
|
||||||
|
|
||||||
|
```
|
||||||
|
ssh user@raspberrypi1.local
|
||||||
|
ssh user@raspberrypi2.local
|
||||||
|
ssh user@raspberrypi3.local
|
||||||
|
ssh user@raspberrypi4.local
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Install Git on **🔸🔹 ALL** devices:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
sudo apt install git
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Clone this repository and compile Distributed Llama on **🔸🔹 ALL** devices:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
git clone https://github.com/b4rtaz/distributed-llama.git
|
||||||
|
cd distributed-llama
|
||||||
|
make dllama
|
||||||
|
make dllama-api
|
||||||
|
```
|
||||||
|
|
||||||
|
6. Download the model to the **🔸 ROOT** device using the `launch.py` script. You don't need to download the model on worker devices.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python3 launch.py # Prints a list of available models
|
||||||
|
|
||||||
|
python3 launch.py llama3_2_3b_instruct_q40 # Downloads the model to the root device
|
||||||
|
```
|
||||||
|
|
||||||
|
7. Assign static IP addresses on **🔸🔹 ALL** devices. Each device must have a unique IP address in the same subnet.
|
||||||
|
|
||||||
|
```sh
|
||||||
|
sudo ip addr add 10.0.0.1/24 dev eth0 # 🔸 ROOT
|
||||||
|
sudo ip addr add 10.0.0.2/24 dev eth0 # 🔹 WORKER 1
|
||||||
|
sudo ip addr add 10.0.0.3/24 dev eth0 # 🔹 WORKER 2
|
||||||
|
sudo ip addr add 10.0.0.4/24 dev eth0 # 🔹 WORKER 3
|
||||||
|
```
|
||||||
|
|
||||||
|
8. Start workers on all **🔹 WORKER** devices:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
sudo nice -n -20 ./dllama worker --port 9999 --nthreads 4
|
||||||
|
```
|
||||||
|
|
||||||
|
9. Run the inference to test if everything works fine on the **🔸 ROOT** device:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
sudo nice -n -20 ./dllama inference \
|
||||||
|
--prompt "Hello world" \
|
||||||
|
--steps 32 \
|
||||||
|
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
|
||||||
|
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \
|
||||||
|
--buffer-float-type q80 \
|
||||||
|
--nthreads 4 \
|
||||||
|
--max-seq-len 4096 \
|
||||||
|
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
|
||||||
|
```
|
||||||
|
|
||||||
|
10. To run the API server, start it on the **🔸 ROOT** device:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
sudo nice -n -20 ./dllama-api \
|
||||||
|
--port 9999 \
|
||||||
|
--model models/llama3_2_3b_instruct_q40/dllama_model_llama3_2_3b_instruct_q40.m \
|
||||||
|
--tokenizer models/llama3_2_3b_instruct_q40/dllama_tokenizer_llama3_2_3b_instruct_q40.t \
|
||||||
|
--buffer-float-type q80 \
|
||||||
|
--nthreads 4 \
|
||||||
|
--max-seq-len 4096 \
|
||||||
|
--workers 10.0.0.2:9999 10.0.0.3:9999 10.0.0.4:9999
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can connect to the API server from your computer:
|
||||||
|
|
||||||
|
```
|
||||||
|
http://raspberrypi1.local:9999/v1/models
|
||||||
|
```
|
||||||
|
|
||||||
|
11. When the API server is running, you can open the web chat in your browser, open [llama-ui.js.org](https://llama-ui.js.org/), go to the settings and set the base URL to: `http://raspberrypi1.local:9999`. Press the "save" button and start chatting!
|
||||||
49
examples/chat-api-client.js
Normal file
49
examples/chat-api-client.js
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
// This is a simple client for dllama-api.
|
||||||
|
//
|
||||||
|
// Usage:
|
||||||
|
//
|
||||||
|
// 1. Start the server, how to do it is described in the `src/apps/dllama-api/README.md` file.
|
||||||
|
// 2. Run this script: `node examples/chat-api-client.js`
|
||||||
|
|
||||||
|
const HOST = process.env.HOST ? process.env.HOST : '127.0.0.1';
|
||||||
|
const PORT = process.env.PORT ? Number(process.env.PORT) : 9990;
|
||||||
|
|
||||||
|
async function chat(messages, maxTokens) {
|
||||||
|
const response = await fetch(`http://${HOST}:${PORT}/v1/chat/completions`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
messages,
|
||||||
|
temperature: 0.7,
|
||||||
|
stop: ['<|eot_id|>'],
|
||||||
|
max_tokens: maxTokens
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
return await response.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
async function ask(system, user, maxTokens) {
|
||||||
|
console.log(`> system: ${system}`);
|
||||||
|
console.log(`> user: ${user}`);
|
||||||
|
const response = await chat([
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
content: system
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: user
|
||||||
|
}
|
||||||
|
], maxTokens);
|
||||||
|
console.log(response.usage);
|
||||||
|
console.log(response.choices[0].message.content);
|
||||||
|
}
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
await ask('You are an excellent math teacher.', 'What is 1 + 2?', 128);
|
||||||
|
await ask('You are a romantic.', 'Where is Europe?', 128);
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
200
examples/macbeth.sh
Normal file
200
examples/macbeth.sh
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This is a simple test of generating a sequence that fulfills the KV cache.
|
||||||
|
#
|
||||||
|
# Used model & tokenizer: https://huggingface.co/b4rtaz/llama-3-8b-distributed-llama
|
||||||
|
# Probably, this test will be working correctly only on MacBook Pro M1, due to differences in float multiplication on different CPUs.
|
||||||
|
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
# Source: https://www.opensourceshakespeare.org/views/plays/play_view.php?WorkID=macbeth&Scope=entire
|
||||||
|
PROMPT="Duncan. What bloody man is that? He can report,
|
||||||
|
As seemeth by his plight, of the revolt
|
||||||
|
The newest state. 20
|
||||||
|
|
||||||
|
Malcolm. This is the sergeant
|
||||||
|
Who like a good and hardy soldier fought
|
||||||
|
'Gainst my captivity. Hail, brave friend!
|
||||||
|
Say to the king the knowledge of the broil
|
||||||
|
As thou didst leave it. 25
|
||||||
|
|
||||||
|
Sergeant. Doubtful it stood;
|
||||||
|
As two spent swimmers, that do cling together
|
||||||
|
And choke their art. The merciless Macdonwald—
|
||||||
|
Worthy to be a rebel, for to that
|
||||||
|
The multiplying villanies of nature 30
|
||||||
|
Do swarm upon him—from the western isles
|
||||||
|
Of kerns and gallowglasses is supplied;
|
||||||
|
And fortune, on his damned quarrel smiling,
|
||||||
|
Show'd like a rebel's whore: but all's too weak:
|
||||||
|
For brave Macbeth—well he deserves that name— 35
|
||||||
|
Disdaining fortune, with his brandish'd steel,
|
||||||
|
Which smoked with bloody execution,
|
||||||
|
Like valour's minion carved out his passage
|
||||||
|
Till he faced the slave;
|
||||||
|
Which ne'er shook hands, nor bade farewell to him, 40
|
||||||
|
Till he unseam'd him from the nave to the chaps,
|
||||||
|
And fix'd his head upon our battlements.
|
||||||
|
|
||||||
|
Duncan. O valiant cousin! worthy gentleman!
|
||||||
|
|
||||||
|
Sergeant. As whence the sun 'gins his reflection
|
||||||
|
Shipwrecking storms and direful thunders break, 45
|
||||||
|
So from that spring whence comfort seem'd to come
|
||||||
|
Discomfort swells. Mark, king of Scotland, mark:
|
||||||
|
No sooner justice had with valour arm'd
|
||||||
|
Compell'd these skipping kerns to trust their heels,
|
||||||
|
But the Norweyan lord surveying vantage, 50
|
||||||
|
With furbish'd arms and new supplies of men
|
||||||
|
Began a fresh assault.
|
||||||
|
|
||||||
|
Duncan. Dismay'd not this
|
||||||
|
Our captains, Macbeth and Banquo?
|
||||||
|
|
||||||
|
Sergeant. Yes; 55
|
||||||
|
As sparrows eagles, or the hare the lion.
|
||||||
|
If I say sooth, I must report they were
|
||||||
|
As cannons overcharged with double cracks, so they
|
||||||
|
Doubly redoubled strokes upon the foe:
|
||||||
|
Except they meant to bathe in reeking wounds, 60
|
||||||
|
Or memorise another Golgotha,
|
||||||
|
I cannot tell.
|
||||||
|
But I am faint, my gashes cry for help.
|
||||||
|
|
||||||
|
Duncan. So well thy words become thee as thy wounds;
|
||||||
|
They smack of honour both. Go get him surgeons. 65
|
||||||
|
[Exit Sergeant, attended]
|
||||||
|
Who comes here?"
|
||||||
|
|
||||||
|
GENERATED="Malcolm. The worthy Thane of Ross.
|
||||||
|
Duncan. What a haste looks through a duel's wounds! 70
|
||||||
|
Some must be pac'd.
|
||||||
|
[Exit Ross]
|
||||||
|
See this encounter is like to the poring
|
||||||
|
On of a beggar's story, told by one
|
||||||
|
That means to pluck upon the heart the strings
|
||||||
|
And draw the tears thriftily. 75
|
||||||
|
[Enter Lennox]
|
||||||
|
How goes the night, boy?
|
||||||
|
|
||||||
|
Lennox. The night is long that none should wake.
|
||||||
|
|
||||||
|
Duncan. You do not need to stare. The Moor
|
||||||
|
To know the man. 'Tis the Moors devices. 80
|
||||||
|
[Exit Lennox]
|
||||||
|
By the happy right of mine own hands,
|
||||||
|
Strike all that live in this poor thing of mine.
|
||||||
|
'Tis calld the Eyrie, and I am sick at heart.
|
||||||
|
As hellish-devils do the damned souls
|
||||||
|
O'their bad lives, thus ill-breveted, linger
|
||||||
|
O'er lamps and forks and other instruments
|
||||||
|
That prove the stages of the night. 90
|
||||||
|
Good sir, take note; I bid you farewell:
|
||||||
|
Come sleep, and cut short this nitty romance.
|
||||||
|
[He sleeps.]
|
||||||
|
If cravens, I bear them like the Minion of the moon,
|
||||||
|
With tiptoe foot he sneaks and starts to be a man. 95
|
||||||
|
And when he is found asleep, awake him with this armed' s address:
|
||||||
|
That sleep which th'assassin hallowed,
|
||||||
|
Scotland, awake; your king is murder'd, sleep no more. 100
|
||||||
|
*Furbish'd. Weapons polished for battle.
|
||||||
|
*Thriftily. Fastidiously, thoughtfully.
|
||||||
|
*Eyrie. Fortress; the lair of birds of prey.
|
||||||
|
*Minion. A braggart, a coward.
|
||||||
|
|
||||||
|
1.5
|
||||||
|
|
||||||
|
Macbeth. So foul and fair a day I have not seen. 5
|
||||||
|
Ross. Good morning, noble Macbeth. I come from Inverness,
|
||||||
|
And find our throne void, the arm'd rest you; 10
|
||||||
|
My Lord of Cassil has resigned his life.
|
||||||
|
Macbeth. Whate'er you owe, in time repay, fair friends.
|
||||||
|
Note you the words; I pray you do.
|
||||||
|
Ross. I am your faithful servant, and will keep
|
||||||
|
My sworn reward upon your life; my lord.
|
||||||
|
Macbeth. You shall be well rewarded; stay the press, 20
|
||||||
|
And I'll not fail. How now, good fellow?
|
||||||
|
Servant. Sir, his schoolmaster. 25
|
||||||
|
Macbeth. Well, good, though, old.
|
||||||
|
Tell me, good fellow, how goes the night? 30
|
||||||
|
Servant. There's marrygold and fire in your veins, my lord.
|
||||||
|
Macbeth. He does commend you; the weight of this old night's embargoes 35
|
||||||
|
Did one hour's waste of time lay upon him.
|
||||||
|
I know when we are too safe, 'tis dangerous to be secure;
|
||||||
|
Therefore our fearful parts do brave the danger 40
|
||||||
|
Which knows it not. I see you are a gentleman.
|
||||||
|
And a laudable one too; I am most off obliged.
|
||||||
|
Servant. I should be sorry, my good lord, to have had the labour 45
|
||||||
|
To outlive this damned hour. 50
|
||||||
|
Macbeth. What's done cannot be undone. To bed, to bed, to bed.
|
||||||
|
Servant. Will it please you to lie still? 55
|
||||||
|
Macbeth. Lord, lord, my heart is in my mouth. All's true that ends well.
|
||||||
|
Servant. I thank you, fair, and leave you to the content. 60
|
||||||
|
Macbeth. You see, my lord, it smokes, and shows no cause
|
||||||
|
Why the drone dies. 65
|
||||||
|
Servant. Grief fills the room up of one vast stair,
|
||||||
|
And downs our vaults to the inconstant man above. 70
|
||||||
|
Macbeth. Go bid thy masters and thy mistress say, 75
|
||||||
|
I have power in earth to do so much.
|
||||||
|
There's comfort yet. They are assailable. Then say I,
|
||||||
|
Thus ye may answer.
|
||||||
|
Servant. He cannot be wronged; or being wronged, 80
|
||||||
|
I cannot help him. 85
|
||||||
|
Macbeth. You know but by this; as this, 90
|
||||||
|
The Jew foole is hang'd. 95
|
||||||
|
Servant. No more today, my lord. 100
|
||||||
|
Macbeth. He does shame to tell him he loves him, but not remove him 105
|
||||||
|
From his true place; no.
|
||||||
|
Servant. That's true, and now I remember the story 110
|
||||||
|
Of that sign in Leo four diurnal courses
|
||||||
|
Returning in a constant motion were within 115
|
||||||
|
A boare that had on Taurus' back tetracted; 120
|
||||||
|
Or neuer, or but once in modulated accidence. 125
|
||||||
|
Macbeth. Thou climd'st alone, ty'd to the stag's horn.
|
||||||
|
Servant. I was a bull, for this the goodly year. 130
|
||||||
|
Come, put me in my place.
|
||||||
|
Macbeth. Now go to sleep. 135
|
||||||
|
Servant. The west neuer sett before the equinox 140
|
||||||
|
Till now; and sunnes look'd not theyr frequencie 145
|
||||||
|
Upon our lappe till now, my lord. 150
|
||||||
|
Macbeth. This game of chance you term a gong.
|
||||||
|
Servant. A gong is a scotch word for an egg. 155
|
||||||
|
Macbeth. Peace, be still. 160
|
||||||
|
Servant. I coniecture I smell the blood of an Englishman. 165
|
||||||
|
Macbeth. The faith is murthered.
|
||||||
|
Servant. That murder'd in his sleep. 170
|
||||||
|
Macbeth. And sleeping murdered. 175
|
||||||
|
Servant. In the fair queen heere in his royal court. 180
|
||||||
|
Macbeth. So great a mercy that it may last eternally.
|
||||||
|
Servant. The earth hath bubbles as the water hath, 185
|
||||||
|
And these are of them. Whate'er we will do 190
|
||||||
|
To mend the trespasses of the comming time 195
|
||||||
|
Shall be the seedes of new mischefe, and shall beget 200
|
||||||
|
The formes of the extinctnese, which we are now. 205
|
||||||
|
Macbeth. We have scorch'd the snake, not kill'd it. 210
|
||||||
|
Servant. They hunt it in the morn. Good gally, good lord! 215
|
||||||
|
It weares a gilded snout. 220
|
||||||
|
Macbeth. It is the very painting of your fear. 225
|
||||||
|
Servant. This is the worst. 230
|
||||||
|
Macbeth. A fair quater of a mile is yet to go. 235
|
||||||
|
Servant. A mile and half. 240
|
||||||
|
Macbeth. I have run fifteen miles to-day.
|
||||||
|
Servant. A calender's date.
|
||||||
|
Macbeth. A bigger patch, a bigger patch. 245
|
||||||
|
Servant. Thirteen of more. 250
|
||||||
|
Macbeth. Wast thou with him? 255
|
||||||
|
Servant. No, nor he to night. 260
|
||||||
|
Macbeth. Thou seest the moon"
|
||||||
|
|
||||||
|
echo "Generating, it can take a while..."
|
||||||
|
|
||||||
|
OUTPUT=$(( ./dllama generate --seed 12345 --temperature 0.9 --topp 0.9 --prompt "$PROMPT" --weights-float-type q40 --buffer-float-type f32 --nthreads 2 --steps 2048 --model models/llama3_8b_q40/dllama_model_llama3_8b_q40.m --tokenizer models/llama3_8b_q40/dllama_tokenizer_llama3_8b_q40.t --workers 127.0.0.1:9999 127.0.0.1:9998 127.0.0.1:9997 ) 2>&1)
|
||||||
|
|
||||||
|
echo "$OUTPUT"
|
||||||
|
|
||||||
|
if [[ $OUTPUT == *"$GENERATED"* ]]; then
|
||||||
|
echo "✅ Output is same"
|
||||||
|
else
|
||||||
|
echo "❌ Output is different"
|
||||||
|
fi
|
||||||
52
examples/n-workers.sh
Normal file
52
examples/n-workers.sh
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This script starts N workers from a single command. Mainly useful for testing and debugging.
|
||||||
|
# Usage:
|
||||||
|
#
|
||||||
|
# W=7 T=2 bash n-workers.sh start
|
||||||
|
# W=7 bash n-workers.sh stop
|
||||||
|
#
|
||||||
|
# Env vars:
|
||||||
|
# W - n workers
|
||||||
|
# T - n threads per worker
|
||||||
|
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
|
if [ -z "$W" ]; then
|
||||||
|
W=3
|
||||||
|
fi
|
||||||
|
if [ -z "$T" ]; then
|
||||||
|
T=1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$1" == "start" ]; then
|
||||||
|
for (( w = 0; w < $W ; w += 1 ));
|
||||||
|
do
|
||||||
|
PORT=$(expr 9999 - $w)
|
||||||
|
PROC_ID=$(lsof -ti:$PORT)
|
||||||
|
if [ -n "$PROC_ID" ]; then
|
||||||
|
kill -9 $PROC_ID
|
||||||
|
echo "Killed process $PROC_ID"
|
||||||
|
fi
|
||||||
|
|
||||||
|
mkdir -p dllama_worker_$w # macOs does not support -Logfile argument, so we place logs inside different directories
|
||||||
|
cd dllama_worker_$w
|
||||||
|
screen -d -L -S dllama_worker_$w -m ../../dllama worker --port $PORT --nthreads $T
|
||||||
|
cd ..
|
||||||
|
echo "Started worker $w on port $PORT"
|
||||||
|
done
|
||||||
|
|
||||||
|
sleep 2
|
||||||
|
elif [ "$1" == "stop" ]; then
|
||||||
|
for (( w = 0; w < $W ; w += 1 ));
|
||||||
|
do
|
||||||
|
screen -S dllama_worker_$w -X quit
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Stopped $W workers"
|
||||||
|
else
|
||||||
|
echo "Usage: $0 [start|stop]"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "> screen -ls"
|
||||||
|
screen -ls
|
||||||
195
launch.py
Normal file
195
launch.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
import multiprocessing
|
||||||
|
from urllib.request import urlopen
|
||||||
|
|
||||||
|
def parts(length):
|
||||||
|
result = []
|
||||||
|
for i in range(length):
|
||||||
|
a = chr(97 + (i // 26))
|
||||||
|
b = chr(97 + (i % 26))
|
||||||
|
result.append(a + b)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# [['model-url-0', 'model-url-1', ...], 'tokenizer-url', 'weights-float-type', 'buffer-float-type', 'model-type']
|
||||||
|
MODELS = {
|
||||||
|
'llama3_1_8b_instruct_q40': [
|
||||||
|
['https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.1_instruct_q40.m?download=true'],
|
||||||
|
'https://huggingface.co/b4rtaz/Llama-3_1-8B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'llama3_1_405b_instruct_q40': [
|
||||||
|
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Llama-3_1-405B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama31_405b_q40_{suffix}?download=true', parts(56))),
|
||||||
|
'https://huggingface.co/b4rtaz/Llama-3_1-405B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama_3_1.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'llama3_2_1b_instruct_q40': [
|
||||||
|
['https://huggingface.co/b4rtaz/Llama-3_2-1B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.2-1b-instruct_q40.m?download=true'],
|
||||||
|
'https://huggingface.co/b4rtaz/Llama-3_2-1B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3_2.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'llama3_2_3b_instruct_q40': [
|
||||||
|
['https://huggingface.co/b4rtaz/Llama-3_2-3B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama3.2-3b-instruct_q40.m?download=true'],
|
||||||
|
'https://huggingface.co/b4rtaz/Llama-3_2-3B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama3_2.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'llama3_3_70b_instruct_q40': [
|
||||||
|
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Llama-3_3-70B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_model_llama-3.3-70b_q40{suffix}?download=true', parts(11))),
|
||||||
|
'https://huggingface.co/b4rtaz/Llama-3_3-70B-Q40-Instruct-Distributed-Llama/resolve/main/dllama_tokenizer_llama-3.3-70b.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'deepseek_r1_distill_llama_8b_q40': [
|
||||||
|
['https://huggingface.co/b4rtaz/DeepSeek-R1-Distill-Llama-8B-Distributed-Llama/resolve/main/dllama_model_deepseek-r1-distill-llama-8b_q40.m?download=true'],
|
||||||
|
'https://huggingface.co/b4rtaz/DeepSeek-R1-Distill-Llama-8B-Distributed-Llama/resolve/main/dllama_tokenizer_deepseek-r1-distill-llama-8b.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'qwen3_0.6b_q40': [
|
||||||
|
['https://huggingface.co/b4rtaz/Qwen3-0.6B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_0.6b_q40.m?download=true'],
|
||||||
|
'https://huggingface.co/b4rtaz/Qwen3-0.6B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_0.6b.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'qwen3_1.7b_q40': [
|
||||||
|
['https://huggingface.co/b4rtaz/Qwen3-1.7B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_1.7b_q40.m?download=true'],
|
||||||
|
'https://huggingface.co/b4rtaz/Qwen3-1.7B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_1.7b.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'qwen3_8b_q40': [
|
||||||
|
['https://huggingface.co/b4rtaz/Qwen3-8B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_8b_q40.m?download=true'],
|
||||||
|
'https://huggingface.co/b4rtaz/Qwen3-8B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_8b.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'qwen3_14b_q40': [
|
||||||
|
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Qwen3-14B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_14b_q40_{suffix}?download=true', parts(2))),
|
||||||
|
'https://huggingface.co/b4rtaz/Qwen3-14B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_14b.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
'qwen3_30b_a3b_q40': [
|
||||||
|
list(map(lambda suffix : f'https://huggingface.co/b4rtaz/Qwen3-30B-A3B-Q40-Distributed-Llama/resolve/main/dllama_model_qwen3_30b_a3b_{suffix}?download=true', parts(5))),
|
||||||
|
'https://huggingface.co/b4rtaz/Qwen3-30B-A3B-Q40-Distributed-Llama/resolve/main/dllama_tokenizer_qwen3_30b_a3b.t?download=true',
|
||||||
|
'q40', 'q80', 'chat', '--max-seq-len 4096'
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def confirm(message: str):
|
||||||
|
alwaysYes = sys.argv.count('-y') > 0
|
||||||
|
if alwaysYes:
|
||||||
|
return True
|
||||||
|
result = input(f'❓ {message} ("Y" if yes): ').upper()
|
||||||
|
return result == 'Y' or result == 'YES'
|
||||||
|
|
||||||
|
def downloadFile(urls, path: str):
|
||||||
|
if os.path.isfile(path):
|
||||||
|
fileName = os.path.basename(path)
|
||||||
|
if not confirm(f'{fileName} already exists, do you want to download again?'):
|
||||||
|
return
|
||||||
|
|
||||||
|
socket.setdefaulttimeout(30)
|
||||||
|
lastSizeMb = 0
|
||||||
|
with open(path, 'wb') as file:
|
||||||
|
for url in urls:
|
||||||
|
startPosition = file.tell()
|
||||||
|
success = False
|
||||||
|
for attempt in range(8):
|
||||||
|
print(f'📄 {url} (attempt: {attempt})')
|
||||||
|
try:
|
||||||
|
with urlopen(url) as response:
|
||||||
|
while True:
|
||||||
|
chunk = response.read(4096)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
file.write(chunk)
|
||||||
|
sizeMb = file.tell() // (1024 * 1024)
|
||||||
|
if sizeMb != lastSizeMb:
|
||||||
|
sys.stdout.write("\rDownloaded %i MB" % sizeMb)
|
||||||
|
lastSizeMb = sizeMb
|
||||||
|
sys.stdout.write('\n')
|
||||||
|
success = True
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f'\n❌ Error downloading {url}: {e}')
|
||||||
|
file.seek(startPosition)
|
||||||
|
file.truncate()
|
||||||
|
time.sleep(1 * attempt)
|
||||||
|
if not success:
|
||||||
|
raise Exception(f'Failed to download {url}')
|
||||||
|
sys.stdout.write(' ✅\n')
|
||||||
|
|
||||||
|
def download(modelName: str, model: list):
|
||||||
|
dirPath = os.path.join('models', modelName)
|
||||||
|
print(f'📀 Downloading {modelName} to {dirPath}...')
|
||||||
|
os.makedirs(dirPath, exist_ok=True)
|
||||||
|
modelUrls = model[0]
|
||||||
|
tokenizerUrl = model[1]
|
||||||
|
modelPath = os.path.join(dirPath, f'dllama_model_{modelName}.m')
|
||||||
|
tokenizerPath = os.path.join(dirPath, f'dllama_tokenizer_{modelName}.t')
|
||||||
|
downloadFile(modelUrls, modelPath)
|
||||||
|
downloadFile([tokenizerUrl], tokenizerPath)
|
||||||
|
print('📀 All files are downloaded')
|
||||||
|
return (modelPath, tokenizerPath)
|
||||||
|
|
||||||
|
def writeRunFile(modelName: str, command: str):
|
||||||
|
filePath = f'run_{modelName}.sh'
|
||||||
|
with open(filePath, 'w') as file:
|
||||||
|
file.write('#!/bin/sh\n')
|
||||||
|
file.write('\n')
|
||||||
|
file.write(f'{command}\n')
|
||||||
|
return filePath
|
||||||
|
|
||||||
|
def printUsage():
|
||||||
|
print('Usage: python download-model.py <model>')
|
||||||
|
print()
|
||||||
|
print('Options:')
|
||||||
|
print(' <model> The name of the model to download')
|
||||||
|
print(' -skip-run Do not run the model after download')
|
||||||
|
print(' -skip-script Do not create a script to run the model')
|
||||||
|
print(' -y Skip confirmation prompts')
|
||||||
|
print()
|
||||||
|
print('Available models:')
|
||||||
|
for model in MODELS:
|
||||||
|
print(f' {model}')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
if (len(sys.argv) < 2):
|
||||||
|
printUsage()
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
os.chdir(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
modelName = sys.argv[1].replace('-', '_')
|
||||||
|
if modelName not in MODELS:
|
||||||
|
print(f'Model is not supported: {modelName}')
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
model = MODELS[modelName]
|
||||||
|
(modelPath, tokenizerPath) = download(modelName, model)
|
||||||
|
|
||||||
|
nThreads = multiprocessing.cpu_count()
|
||||||
|
if (model[4] == 'chat'):
|
||||||
|
command = './dllama chat'
|
||||||
|
else:
|
||||||
|
command = './dllama inference --steps 64 --prompt "Hello world"'
|
||||||
|
command += f' --model {modelPath} --tokenizer {tokenizerPath} --buffer-float-type {model[3]} --nthreads {nThreads}'
|
||||||
|
if (len(model) > 5):
|
||||||
|
command += f' {model[5]}'
|
||||||
|
|
||||||
|
print('To run Distributed Llama you need to execute:')
|
||||||
|
print('--- copy start ---')
|
||||||
|
print()
|
||||||
|
print('\033[96m' + command + '\033[0m')
|
||||||
|
print()
|
||||||
|
print('--- copy end -----')
|
||||||
|
|
||||||
|
skipRun = sys.argv.count('-skip-run') > 0
|
||||||
|
skipScript = sys.argv.count('-skip-script') > 0
|
||||||
|
|
||||||
|
if (not skipScript):
|
||||||
|
runFilePath = writeRunFile(modelName, command)
|
||||||
|
print(f'🌻 Created {runFilePath} script to easy run')
|
||||||
|
|
||||||
|
if (not skipRun):
|
||||||
|
if (confirm('Do you want to run Distributed Llama?')):
|
||||||
|
if (not os.path.isfile('dllama')):
|
||||||
|
os.system('make dllama')
|
||||||
|
os.system(command)
|
||||||
BIN
report/report.pdf
Normal file
BIN
report/report.pdf
Normal file
Binary file not shown.
179
src/api-types.hpp
Executable file
179
src/api-types.hpp
Executable file
@@ -0,0 +1,179 @@
|
|||||||
|
#ifndef API_TYPES_HPP
|
||||||
|
#define API_TYPES_HPP
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "json.hpp"
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
struct ChatMessageDelta {
|
||||||
|
std::string role;
|
||||||
|
std::string content;
|
||||||
|
|
||||||
|
ChatMessageDelta() : role(""), content("") {}
|
||||||
|
ChatMessageDelta(const std::string& role_, const std::string& content_) : role(role_), content(content_) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChatMessage {
|
||||||
|
std::string role;
|
||||||
|
std::string content;
|
||||||
|
|
||||||
|
ChatMessage() : role(""), content("") {}
|
||||||
|
ChatMessage(const std::string& role_, const std::string& content_) : role(role_), content(content_) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChunkChoice {
|
||||||
|
int index;
|
||||||
|
ChatMessageDelta delta;
|
||||||
|
std::string finish_reason;
|
||||||
|
|
||||||
|
ChunkChoice() : index(0) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct Choice {
|
||||||
|
int index;
|
||||||
|
ChatMessage message;
|
||||||
|
std::string finish_reason;
|
||||||
|
|
||||||
|
Choice() : finish_reason("") {}
|
||||||
|
Choice(ChatMessage &message_) : message(message_), finish_reason("") {}
|
||||||
|
Choice(const std::string &reason_) : finish_reason(reason_) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChatCompletionChunk {
|
||||||
|
std::string id;
|
||||||
|
std::string object;
|
||||||
|
long long created;
|
||||||
|
std::string model;
|
||||||
|
std::vector<ChunkChoice> choices;
|
||||||
|
|
||||||
|
ChatCompletionChunk(ChunkChoice &choice_)
|
||||||
|
: id("cmpl-c0"), object("chat.completion"), model("Distributed Model") {
|
||||||
|
created = std::time(nullptr); // Set created to current Unix timestamp
|
||||||
|
choices.push_back(choice_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Struct to represent the usage object
|
||||||
|
struct ChatUsage {
|
||||||
|
int prompt_tokens;
|
||||||
|
int completion_tokens;
|
||||||
|
int total_tokens;
|
||||||
|
|
||||||
|
ChatUsage() : prompt_tokens(0), completion_tokens(0), total_tokens(0) {}
|
||||||
|
ChatUsage(int pt, int ct, int tt) : prompt_tokens(pt), completion_tokens(ct), total_tokens(tt) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChatCompletion {
|
||||||
|
std::string id;
|
||||||
|
std::string object;
|
||||||
|
long long created; // Unix timestamp
|
||||||
|
std::string model;
|
||||||
|
std::vector<Choice> choices;
|
||||||
|
ChatUsage usage;
|
||||||
|
|
||||||
|
ChatCompletion() : id(), object(), model() {}
|
||||||
|
ChatCompletion(const Choice &choice_, const ChatUsage& usage_)
|
||||||
|
: id("cmpl-j0"), object("chat.completion"), model("Distributed Model"), usage(usage_) {
|
||||||
|
created = std::time(nullptr); // Set created to current Unix timestamp
|
||||||
|
choices.push_back(choice_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Model {
|
||||||
|
std::string id;
|
||||||
|
std::string object;
|
||||||
|
long long created;
|
||||||
|
std::string owned_by;
|
||||||
|
|
||||||
|
Model() : id(), object(), created(0), owned_by() {}
|
||||||
|
Model(const std::string &id_) : id(id_), object("model"), created(0), owned_by("user") {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ModelList {
|
||||||
|
std::string object;
|
||||||
|
std::vector<Model> data;
|
||||||
|
ModelList(): object("list") {}
|
||||||
|
ModelList(const Model &model_) : object("list") {
|
||||||
|
data.push_back(model_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct InferenceParams {
|
||||||
|
std::vector<ChatMessage> messages;
|
||||||
|
int max_tokens;
|
||||||
|
float temperature;
|
||||||
|
float top_p;
|
||||||
|
std::vector<std::string> stop;
|
||||||
|
bool stream;
|
||||||
|
unsigned long long seed;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Define to_json for Delta struct
|
||||||
|
void to_json(json& j, const ChatMessageDelta& msg) {
|
||||||
|
j = json{{"role", msg.role}, {"content", msg.content}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_json(json& j, const ChatMessage& msg) {
|
||||||
|
j = json{{"role", msg.role}, {"content", msg.content}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_json(json& j, const ChunkChoice& choice) {
|
||||||
|
j = json{{"index", choice.index}, {"delta", choice.delta}, {"finish_reason", choice.finish_reason}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_json(json& j, const Choice& choice) {
|
||||||
|
j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_json(json& j, const ChatCompletionChunk& completion) {
|
||||||
|
j = json{{"id", completion.id},
|
||||||
|
{"object", completion.object},
|
||||||
|
{"created", completion.created},
|
||||||
|
{"model", completion.model},
|
||||||
|
{"choices", completion.choices}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_json(json& j, const ChatUsage& usage) {
|
||||||
|
j = json{{"completion_tokens", usage.completion_tokens},
|
||||||
|
{"prompt_tokens", usage.prompt_tokens},
|
||||||
|
{"total_tokens", usage.total_tokens}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_json(json& j, const ChatCompletion& completion) {
|
||||||
|
j = json{{"id", completion.id},
|
||||||
|
{"object", completion.object},
|
||||||
|
{"created", completion.created},
|
||||||
|
{"model", completion.model},
|
||||||
|
{"usage", completion.usage},
|
||||||
|
{"choices", completion.choices}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_json(json& j, const Model& model) {
|
||||||
|
j = json{{"id", model.id},
|
||||||
|
{"object", model.object},
|
||||||
|
{"created", model.created},
|
||||||
|
{"owned_by", model.owned_by}};
|
||||||
|
}
|
||||||
|
|
||||||
|
void to_json(json& j, const ModelList& models) {
|
||||||
|
j = json{{"object", models.object},
|
||||||
|
{"data", models.data}};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ChatMessage> parseChatMessages(json &json){
|
||||||
|
std::vector<ChatMessage> messages;
|
||||||
|
messages.reserve(json.size());
|
||||||
|
|
||||||
|
for (const auto& item : json) {
|
||||||
|
messages.emplace_back(
|
||||||
|
item["role"].template get<std::string>(),
|
||||||
|
item["content"].template get<std::string>()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return messages;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
358
src/app.cpp
Normal file
358
src/app.cpp
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
#include "app.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <stdexcept>
|
||||||
|
#if defined(DLLAMA_VULKAN)
|
||||||
|
#include "nn/nn-vulkan.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
static NnFloatType parseFloatType(char *val) {
|
||||||
|
if (std::strcmp(val, "f32") == 0) return F_32;
|
||||||
|
if (std::strcmp(val, "f16") == 0) return F_16;
|
||||||
|
if (std::strcmp(val, "q40") == 0) return F_Q40;
|
||||||
|
if (std::strcmp(val, "q80") == 0) return F_Q80;
|
||||||
|
throw std::runtime_error("Invalid float type: " + std::string(val));
|
||||||
|
}
|
||||||
|
|
||||||
|
static ChatTemplateType parseChatTemplateType(char *val) {
|
||||||
|
if (std::strcmp(val, "llama2") == 0) return TEMPLATE_LLAMA2;
|
||||||
|
if (std::strcmp(val, "llama3") == 0) return TEMPLATE_LLAMA3;
|
||||||
|
if (std::strcmp(val, "deepSeek3") == 0) return TEMPLATE_DEEP_SEEK3;
|
||||||
|
throw std::runtime_error("Invalid chat template type: " + std::string(val));
|
||||||
|
}
|
||||||
|
|
||||||
|
AppCliArgs AppCliArgs::parse(int argc, char* *argv, bool requireMode) {
|
||||||
|
AppCliArgs args;
|
||||||
|
args.help = false;
|
||||||
|
args.mode = nullptr;
|
||||||
|
args.nBatches = 32;
|
||||||
|
args.nThreads = 1;
|
||||||
|
args.modelPath = nullptr;
|
||||||
|
args.tokenizerPath = nullptr;
|
||||||
|
args.prompt = nullptr;
|
||||||
|
args.syncType = F_32;
|
||||||
|
args.nWorkers = 0;
|
||||||
|
args.workerHosts = nullptr;
|
||||||
|
args.workerPorts = nullptr;
|
||||||
|
args.port = 9990;
|
||||||
|
args.temperature = 0.8f;
|
||||||
|
args.topp = 0.9f;
|
||||||
|
args.steps = 0;
|
||||||
|
args.seed = (unsigned long long)time(nullptr);
|
||||||
|
args.chatTemplateType = TEMPLATE_UNKNOWN;
|
||||||
|
args.maxSeqLen = 0;
|
||||||
|
args.netTurbo = true;
|
||||||
|
args.gpuIndex = -1;
|
||||||
|
args.gpuSegmentFrom = -1;
|
||||||
|
args.gpuSegmentTo = -1;
|
||||||
|
|
||||||
|
int i = 1;
|
||||||
|
if (requireMode && argc > 1) {
|
||||||
|
args.mode = argv[1];
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
// First see if any of the args are asking for help/usage and fail fast
|
||||||
|
for (int x = 0; x < argc; x++) {
|
||||||
|
if ((std::strcmp(argv[x], "--usage") == 0) ||
|
||||||
|
(std::strcmp(argv[x], "--help") == 0) ||
|
||||||
|
(std::strcmp(argv[x], "-h") == 0)) {
|
||||||
|
args.help = true;
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (; i + 1 < argc; i += 2) {
|
||||||
|
char *name = argv[i];
|
||||||
|
char *value = argv[i + 1];
|
||||||
|
if (std::strcmp(name, "--model") == 0) {
|
||||||
|
args.modelPath = value;
|
||||||
|
} else if (std::strcmp(name, "--tokenizer") == 0) {
|
||||||
|
args.tokenizerPath = value;
|
||||||
|
} else if (std::strcmp(name, "--prompt") == 0) {
|
||||||
|
args.prompt = value;
|
||||||
|
} else if (std::strcmp(name, "--buffer-float-type") == 0) {
|
||||||
|
args.syncType = parseFloatType(value);
|
||||||
|
} else if (std::strcmp(name, "--workers") == 0) {
|
||||||
|
int j = i + 1;
|
||||||
|
for (; j < argc && argv[j][0] != '-'; j++);
|
||||||
|
int count = j - i - 1;
|
||||||
|
|
||||||
|
args.nWorkers = count;
|
||||||
|
args.workerHosts = new char*[count];
|
||||||
|
args.workerPorts = new NnUint[count];
|
||||||
|
|
||||||
|
for (int s = 0; s < count; s++) {
|
||||||
|
char *v = argv[i + 1 + s];
|
||||||
|
char *separator = std::strstr(v, ":");
|
||||||
|
if (separator == NULL) {
|
||||||
|
throw std::runtime_error("Invalid worker address: " + std::string(v));
|
||||||
|
}
|
||||||
|
int hostLen = separator - v;
|
||||||
|
args.workerHosts[s] = new char[hostLen + 1];
|
||||||
|
std::memcpy(args.workerHosts[s], v, hostLen);
|
||||||
|
args.workerHosts[s][hostLen] = '\0';
|
||||||
|
args.workerPorts[s] = std::atoi(separator + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
i += count - 1;
|
||||||
|
} else if (std::strcmp(name, "--port") == 0) {
|
||||||
|
args.port = atoi(value);
|
||||||
|
} else if (std::strcmp(name, "--nthreads") == 0) {
|
||||||
|
args.nThreads = atoi(value);
|
||||||
|
} else if (std::strcmp(name, "--steps") == 0) {
|
||||||
|
args.steps = atoi(value);
|
||||||
|
} else if (std::strcmp(name, "--temperature") == 0) {
|
||||||
|
args.temperature = atof(value);
|
||||||
|
} else if (std::strcmp(name, "--topp") == 0) {
|
||||||
|
args.topp = atof(value);
|
||||||
|
} else if (std::strcmp(name, "--seed") == 0) {
|
||||||
|
args.seed = atoll(value);
|
||||||
|
} else if (std::strcmp(name, "--chat-template") == 0) {
|
||||||
|
args.chatTemplateType = parseChatTemplateType(value);
|
||||||
|
} else if (std::strcmp(name, "--max-seq-len") == 0) {
|
||||||
|
args.maxSeqLen = (unsigned int)atoi(value);
|
||||||
|
} else if (std::strcmp(name, "--gpu-index") == 0) {
|
||||||
|
args.gpuIndex = atoi(value);
|
||||||
|
} else if (std::strcmp(name, "--gpu-segments") == 0) {
|
||||||
|
char *separator = std::strstr(value, ":");
|
||||||
|
if (separator == NULL)
|
||||||
|
throw std::runtime_error("GPU segments expected in the format <from>:<to>");
|
||||||
|
args.gpuSegmentFrom = atoi(value);
|
||||||
|
args.gpuSegmentTo = atoi(separator + 1);
|
||||||
|
} else if (std::strcmp(name, "--net-turbo") == 0) {
|
||||||
|
args.netTurbo = atoi(value) == 1;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Unknown option: " + std::string(name));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args.nThreads < 1)
|
||||||
|
throw std::runtime_error("Number of threads must be at least 1");
|
||||||
|
return args;
|
||||||
|
}
|
||||||
|
|
||||||
|
AppCliArgs::~AppCliArgs() {
|
||||||
|
if (workerHosts != nullptr) {
|
||||||
|
for (NnUint i = 0; i < nWorkers; i++)
|
||||||
|
delete[] workerHosts[i];
|
||||||
|
delete[] workerHosts;
|
||||||
|
}
|
||||||
|
if (workerPorts != nullptr)
|
||||||
|
delete[] workerPorts;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<NnExecutorDevice> resolveDevices(AppCliArgs *args, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
|
||||||
|
std::vector<NnExecutorDevice> devices;
|
||||||
|
|
||||||
|
if (args->gpuIndex >= 0) {
|
||||||
|
#if defined(DLLAMA_VULKAN)
|
||||||
|
devices.push_back(NnExecutorDevice(
|
||||||
|
new NnVulkanDevice(args->gpuIndex, netConfig, nodeConfig, netExecution),
|
||||||
|
args->gpuSegmentFrom,
|
||||||
|
args->gpuSegmentTo
|
||||||
|
));
|
||||||
|
#else
|
||||||
|
throw std::runtime_error("This build does not support GPU");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args->gpuIndex < 0 || (args->gpuSegmentFrom >= 0 && args->gpuSegmentTo >= 0)) {
|
||||||
|
devices.push_back(NnExecutorDevice(new NnCpuDevice(netConfig, nodeConfig, netExecution), -1, -1));
|
||||||
|
}
|
||||||
|
return devices;
|
||||||
|
}
|
||||||
|
|
||||||
|
RootLlmInference::RootLlmInference(LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network) {
|
||||||
|
this->header = net->header;
|
||||||
|
this->tokenPipe = (float *)execution->pipes[net->tokenPipeIndex];
|
||||||
|
this->positionPipe = (float *)execution->pipes[net->positionPipeIndex];
|
||||||
|
this->logitsPipe = (float *)execution->pipes[net->logitsPipeIndex];
|
||||||
|
this->execution = execution;
|
||||||
|
this->executor = executor;
|
||||||
|
this->network = network; // May be nullptr!
|
||||||
|
}
|
||||||
|
|
||||||
|
void RootLlmInference::setBatchSize(NnUint batchSize) {
|
||||||
|
execution->setBatchSize(batchSize);
|
||||||
|
controlPacket.batchSize = batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RootLlmInference::setPosition(NnUint position) {
|
||||||
|
assert(position >= 0);
|
||||||
|
assert(position + execution->batchSize - 1 < header->seqLen);
|
||||||
|
|
||||||
|
controlPacket.position = position;
|
||||||
|
for (NnUint i = 0; i < execution->batchSize; i++)
|
||||||
|
positionPipe[i] = (float)(position + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
void RootLlmInference::setToken(NnUint batchIndex, NnUint token) {
|
||||||
|
assert(batchIndex >= 0 && batchIndex < execution->batchSize);
|
||||||
|
tokenPipe[batchIndex] = (float)token;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RootLlmInference::forward() {
|
||||||
|
if (network != nullptr)
|
||||||
|
network->writeAll(&controlPacket, sizeof(LlmControlPacket));
|
||||||
|
executor->forward();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RootLlmInference::finish() {
|
||||||
|
if (network != nullptr) {
|
||||||
|
controlPacket.batchSize = 0;
|
||||||
|
network->writeAll(&controlPacket, sizeof(LlmControlPacket));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
WorkerLlmInference::WorkerLlmInference(NnNetExecution *execution, NnNetwork *network) {
|
||||||
|
this->isFinished = false;
|
||||||
|
this->execution = execution;
|
||||||
|
this->network = network;
|
||||||
|
this->positionPipe = (float *)execution->pipes[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
bool WorkerLlmInference::tryReadControlPacket() {
|
||||||
|
const unsigned long maxAttempts = 10000;
|
||||||
|
if (!network->tryReadWithMaxAttempts(ROOT_SOCKET_INDEX, &controlPacket, sizeof(LlmControlPacket), maxAttempts))
|
||||||
|
return false;
|
||||||
|
if (controlPacket.batchSize == 0) {
|
||||||
|
printf("🛑 Stop signal\n");
|
||||||
|
isFinished = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
for (NnUint i = 0; i < controlPacket.batchSize; i++)
|
||||||
|
positionPipe[i] = (float)(controlPacket.position + i);
|
||||||
|
execution->setBatchSize(controlPacket.batchSize);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context)) {
|
||||||
|
NnUint nNodes = args->nWorkers + 1;
|
||||||
|
|
||||||
|
LlmHeader header = loadLlmHeader(args->modelPath, args->maxSeqLen, args->syncType);
|
||||||
|
if (nNodes > header.nKvHeads)
|
||||||
|
// TODO: https://github.com/b4rtaz/distributed-llama/issues/70
|
||||||
|
throw std::runtime_error("This version does not support more nodes than the number of KV heads in the model");
|
||||||
|
if (header.weightType == F_Q40 && header.syncType != F_Q80)
|
||||||
|
throw std::runtime_error("This version supports only Q40 weights with Q80 sync type");
|
||||||
|
|
||||||
|
Tokenizer tokenizer(args->tokenizerPath);
|
||||||
|
if (tokenizer.vocabSize != header.vocabSize)
|
||||||
|
printf("Tokenizer vocab size (%d) does not match the model vocab size (%d)\n", tokenizer.vocabSize, header.vocabSize);
|
||||||
|
|
||||||
|
Sampler sampler(tokenizer.vocabSize, args->temperature, args->topp, args->seed);
|
||||||
|
|
||||||
|
LlmNet net = buildLlmNet(&header, nNodes, args->nBatches);
|
||||||
|
std::unique_ptr<LlmNet, void(*)(LlmNet *)> netPtr(&net, releaseLlmNet);
|
||||||
|
|
||||||
|
NnNodeConfig *rootNodeConfig = &net.nodeConfigs[0];
|
||||||
|
|
||||||
|
printLlmHeader(&header);
|
||||||
|
printNodeRequiredMemory(&net.netConfig, rootNodeConfig);
|
||||||
|
|
||||||
|
NnNetExecution execution(args->nThreads, &net.netConfig);
|
||||||
|
|
||||||
|
std::unique_ptr<NnNodeSynchronizer> synchronizer(nullptr);
|
||||||
|
std::unique_ptr<NnNetwork> networkPtr(nullptr);
|
||||||
|
NnNetwork *network = nullptr;
|
||||||
|
|
||||||
|
if (nNodes == 1) {
|
||||||
|
synchronizer.reset(new NnFakeNodeSynchronizer());
|
||||||
|
} else {
|
||||||
|
networkPtr = NnNetwork::connect(args->nWorkers, args->workerHosts, args->workerPorts);
|
||||||
|
network = networkPtr.get();
|
||||||
|
synchronizer.reset(new NnNetworkNodeSynchronizer(network, &execution, &net.netConfig, rootNodeConfig));
|
||||||
|
|
||||||
|
NnRootConfigWriter configWriter(network);
|
||||||
|
configWriter.writeToWorkers(&net.netConfig, net.nodeConfigs);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<NnExecutorDevice> devices = resolveDevices(args, &net.netConfig, rootNodeConfig, &execution);
|
||||||
|
NnExecutor executor(&net.netConfig, rootNodeConfig, &devices, &execution, synchronizer.get(), args->benchmark);
|
||||||
|
|
||||||
|
NnRootWeightLoader weightLoader(&executor, network, nNodes);
|
||||||
|
loadLlmNetWeight(args->modelPath, &net, &weightLoader);
|
||||||
|
|
||||||
|
RootLlmInference inference(&net, &execution, &executor, network);
|
||||||
|
|
||||||
|
if (network != nullptr) {
|
||||||
|
network->resetStats();
|
||||||
|
if (args->netTurbo) {
|
||||||
|
network->setTurbo(true);
|
||||||
|
printf("🚁 Network is in non-blocking mode\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AppInferenceContext context;
|
||||||
|
context.args = args;
|
||||||
|
context.header = &header;
|
||||||
|
context.inference = &inference;
|
||||||
|
context.sampler = &sampler;
|
||||||
|
context.tokenizer = &tokenizer;
|
||||||
|
context.network = network;
|
||||||
|
context.executor = &executor;
|
||||||
|
|
||||||
|
handler(&context);
|
||||||
|
|
||||||
|
inference.finish();
|
||||||
|
}
|
||||||
|
|
||||||
|
void runWorkerApp(AppCliArgs *args) {
|
||||||
|
while (true) {
|
||||||
|
std::unique_ptr<NnNetwork> networkPtr = NnNetwork::serve(args->port);
|
||||||
|
NnNetwork *network = networkPtr.get();
|
||||||
|
|
||||||
|
NnWorkerConfigReader configReader(network);
|
||||||
|
NnNetConfig netConfig = configReader.readNet();
|
||||||
|
NnNodeConfig nodeConfig = configReader.readNode();
|
||||||
|
std::unique_ptr<NnNetConfig, void(*)(NnNetConfig *)> netConfigPtr(&netConfig, releaseNetConfig);
|
||||||
|
std::unique_ptr<NnNodeConfig, void(*)(NnNodeConfig *)> nodeConfigPtr(&nodeConfig, releaseNodeConfig);
|
||||||
|
|
||||||
|
printNodeRequiredMemory(&netConfig, &nodeConfig);
|
||||||
|
|
||||||
|
NnNetExecution execution(args->nThreads, &netConfig);
|
||||||
|
|
||||||
|
std::vector<NnExecutorDevice> devices = resolveDevices(args, &netConfig, &nodeConfig, &execution);
|
||||||
|
NnNetworkNodeSynchronizer synchronizer(network, &execution, &netConfig, &nodeConfig);
|
||||||
|
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
|
||||||
|
|
||||||
|
NnWorkerWeightReader weightReader(&executor, network);
|
||||||
|
weightReader.read();
|
||||||
|
|
||||||
|
WorkerLlmInference inference(&execution, network);
|
||||||
|
bool isFirstAttempt = true;
|
||||||
|
bool isTurboEnabled = false;
|
||||||
|
clock_t startTime;
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
if (isFirstAttempt)
|
||||||
|
startTime = clock();
|
||||||
|
|
||||||
|
if (!inference.tryReadControlPacket()) {
|
||||||
|
if (isTurboEnabled && !isFirstAttempt && clock() - startTime > CLOCKS_PER_SEC) {
|
||||||
|
network->setTurbo(false);
|
||||||
|
isTurboEnabled = false;
|
||||||
|
printf("🚁 Network is in blocking mode\n");
|
||||||
|
}
|
||||||
|
isFirstAttempt = false;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (inference.isFinished)
|
||||||
|
break;
|
||||||
|
|
||||||
|
if (args->netTurbo && !isTurboEnabled) {
|
||||||
|
network->setTurbo(true);
|
||||||
|
isTurboEnabled = true;
|
||||||
|
printf("🚁 Network is in non-blocking mode\n");
|
||||||
|
}
|
||||||
|
executor.forward();
|
||||||
|
isFirstAttempt = true;
|
||||||
|
} catch (const NnReadNetworkException &e) {
|
||||||
|
printf("Read network exception: %s\n", e.message);
|
||||||
|
break;
|
||||||
|
} catch (const NnWriteNetworkException &e) {
|
||||||
|
printf("Write network exception: %s\n", e.message);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
95
src/app.hpp
Normal file
95
src/app.hpp
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
#ifndef APP_HPP
|
||||||
|
#define APP_HPP
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include "nn/nn-core.hpp"
|
||||||
|
#include "nn/nn-cpu.hpp"
|
||||||
|
#include "tokenizer.hpp"
|
||||||
|
#include "llm.hpp"
|
||||||
|
|
||||||
|
class AppCliArgs {
|
||||||
|
public:
|
||||||
|
char *mode;
|
||||||
|
NnUint nThreads;
|
||||||
|
NnUint nBatches;
|
||||||
|
bool help;
|
||||||
|
|
||||||
|
// inference
|
||||||
|
char *modelPath;
|
||||||
|
char *tokenizerPath;
|
||||||
|
char *prompt;
|
||||||
|
NnFloatType syncType;
|
||||||
|
NnUint nWorkers;
|
||||||
|
char **workerHosts;
|
||||||
|
NnUint *workerPorts;
|
||||||
|
float temperature;
|
||||||
|
float topp;
|
||||||
|
NnUint steps;
|
||||||
|
bool benchmark;
|
||||||
|
unsigned long long seed;
|
||||||
|
ChatTemplateType chatTemplateType;
|
||||||
|
NnUint maxSeqLen;
|
||||||
|
bool netTurbo;
|
||||||
|
int gpuIndex;
|
||||||
|
int gpuSegmentFrom;
|
||||||
|
int gpuSegmentTo;
|
||||||
|
|
||||||
|
// worker
|
||||||
|
NnUint port;
|
||||||
|
|
||||||
|
static AppCliArgs parse(int argc, char **argv, bool hasMode);
|
||||||
|
~AppCliArgs();
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint position;
|
||||||
|
NnUint batchSize; // 0 = stop signal
|
||||||
|
} LlmControlPacket;
|
||||||
|
|
||||||
|
class RootLlmInference {
|
||||||
|
public:
|
||||||
|
float *logitsPipe;
|
||||||
|
private:
|
||||||
|
float *tokenPipe;
|
||||||
|
float *positionPipe;
|
||||||
|
LlmHeader *header;
|
||||||
|
NnNetExecution *execution;
|
||||||
|
NnExecutor *executor;
|
||||||
|
NnNetwork *network;
|
||||||
|
LlmControlPacket controlPacket;
|
||||||
|
public:
|
||||||
|
RootLlmInference(LlmNet *net, NnNetExecution *execution, NnExecutor *executor, NnNetwork *network);
|
||||||
|
void setBatchSize(NnUint batchSize);
|
||||||
|
void setPosition(NnUint position);
|
||||||
|
void setToken(NnUint batchIndex, NnUint token);
|
||||||
|
void forward();
|
||||||
|
void finish();
|
||||||
|
};
|
||||||
|
|
||||||
|
class WorkerLlmInference {
|
||||||
|
public:
|
||||||
|
bool isFinished;
|
||||||
|
private:
|
||||||
|
float *positionPipe;
|
||||||
|
NnNetExecution *execution;
|
||||||
|
NnNetwork *network;
|
||||||
|
LlmControlPacket controlPacket;
|
||||||
|
public:
|
||||||
|
WorkerLlmInference(NnNetExecution *execution, NnNetwork *network);
|
||||||
|
bool tryReadControlPacket();
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
AppCliArgs *args;
|
||||||
|
LlmHeader *header;
|
||||||
|
RootLlmInference *inference;
|
||||||
|
Tokenizer *tokenizer;
|
||||||
|
Sampler *sampler;
|
||||||
|
NnNetwork *network;
|
||||||
|
NnExecutor *executor;
|
||||||
|
} AppInferenceContext;
|
||||||
|
|
||||||
|
void runInferenceApp(AppCliArgs *args, void (*handler)(AppInferenceContext *context));
|
||||||
|
void runWorkerApp(AppCliArgs *args);
|
||||||
|
|
||||||
|
#endif
|
||||||
622
src/dllama-api.cpp
Normal file
622
src/dllama-api.cpp
Normal file
@@ -0,0 +1,622 @@
|
|||||||
|
#include <cstring>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <cassert>
|
||||||
|
#include <sstream>
|
||||||
|
#include <iostream>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <csignal>
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <winsock2.h>
|
||||||
|
#include <ws2tcpip.h>
|
||||||
|
#else
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <netinet/in.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "tokenizer.hpp"
|
||||||
|
#include "app.hpp"
|
||||||
|
#include "json.hpp"
|
||||||
|
#include "api-types.hpp"
|
||||||
|
#include "nn/nn-network.hpp"
|
||||||
|
|
||||||
|
typedef unsigned int pos_t;
|
||||||
|
|
||||||
|
using json = nlohmann::json;
|
||||||
|
|
||||||
|
enum class HttpMethod {
|
||||||
|
METHOD_GET = 0,
|
||||||
|
METHOD_POST = 1,
|
||||||
|
METHOD_PUT = 2,
|
||||||
|
METHOD_DELETE = 3,
|
||||||
|
METHOD_OPTIONS = 4,
|
||||||
|
METHOD_UNKNOWN = 5
|
||||||
|
};
|
||||||
|
|
||||||
|
class HttpRequest {
|
||||||
|
public:
|
||||||
|
static HttpRequest read(int serverSocket) {
|
||||||
|
HttpRequest req(serverSocket);
|
||||||
|
|
||||||
|
std::vector<char> httpRequest = req.readHttpRequest();
|
||||||
|
// Parse the HTTP request
|
||||||
|
std::string data = std::string(httpRequest.begin(), httpRequest.end());
|
||||||
|
|
||||||
|
// Split request into lines
|
||||||
|
std::istringstream iss(data);
|
||||||
|
std::string line;
|
||||||
|
std::getline(iss, line);
|
||||||
|
|
||||||
|
// Parse request line
|
||||||
|
std::istringstream lineStream(line);
|
||||||
|
std::string methodStr, path;
|
||||||
|
lineStream >> methodStr >> path;
|
||||||
|
req.method = parseMethod(methodStr);
|
||||||
|
req.path = path;
|
||||||
|
|
||||||
|
// Parse headers
|
||||||
|
while (std::getline(iss, line) && line != "\r") {
|
||||||
|
size_t pos = line.find(':');
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
std::string key = line.substr(0, pos);
|
||||||
|
std::string value = line.substr(pos + 2); // Skip ': ' after key
|
||||||
|
// Trim whitespace and non-printable characters from header value
|
||||||
|
value.erase(std::remove_if(value.begin(), value.end(), [](unsigned char c) {
|
||||||
|
return std::isspace(c) || !std::isprint(c);
|
||||||
|
}), value.end());
|
||||||
|
req.headers[key] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse body
|
||||||
|
std::getline(iss, req.body, '\0');
|
||||||
|
|
||||||
|
if (req.body.size() > 0) {
|
||||||
|
// printf("body: %s\n", req.body.c_str());
|
||||||
|
req.parsedJson = json::parse(req.body);
|
||||||
|
}
|
||||||
|
return req;
|
||||||
|
}
|
||||||
|
|
||||||
|
static HttpMethod parseMethod(const std::string& method) {
|
||||||
|
if (method == "GET") return HttpMethod::METHOD_GET;
|
||||||
|
if (method == "POST") return HttpMethod::METHOD_POST;
|
||||||
|
if (method == "PUT") return HttpMethod::METHOD_PUT;
|
||||||
|
if (method == "DELETE") return HttpMethod::METHOD_DELETE;
|
||||||
|
if (method == "OPTIONS") return HttpMethod::METHOD_OPTIONS;
|
||||||
|
return HttpMethod::METHOD_UNKNOWN;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int serverSocket;
|
||||||
|
public:
|
||||||
|
std::string path;
|
||||||
|
std::unordered_map<std::string, std::string> headers;
|
||||||
|
std::string body;
|
||||||
|
json parsedJson;
|
||||||
|
HttpMethod method;
|
||||||
|
|
||||||
|
HttpRequest(int serverSocket) {
|
||||||
|
this->serverSocket = serverSocket;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<char> readHttpRequest() {
|
||||||
|
std::string httpRequest;
|
||||||
|
char buffer[1024 * 64];
|
||||||
|
ssize_t bytesRead;
|
||||||
|
|
||||||
|
// First, read all headers
|
||||||
|
std::string headerData;
|
||||||
|
size_t headerEnd;
|
||||||
|
bool headerDone = false;
|
||||||
|
std::string extraReadPastHeader;
|
||||||
|
while (!headerDone) {
|
||||||
|
bytesRead = recv(serverSocket, buffer, sizeof(buffer) - 1, 0);
|
||||||
|
if (bytesRead <= 0) {
|
||||||
|
throw std::runtime_error("Error while reading headers from socket");
|
||||||
|
}
|
||||||
|
buffer[bytesRead] = '\0';
|
||||||
|
headerData.append(buffer);
|
||||||
|
|
||||||
|
// Check for end of headers (http header says "\r\n\r\n")
|
||||||
|
headerEnd = headerData.find("\r\n\r\n");
|
||||||
|
if (headerEnd != std::string::npos) {
|
||||||
|
headerDone = true;
|
||||||
|
if (headerEnd < headerData.size()-4) {
|
||||||
|
// We read something past the header
|
||||||
|
extraReadPastHeader = headerData.substr(headerEnd+4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
httpRequest.append(headerData);
|
||||||
|
|
||||||
|
// Next, find Content-Length header for body length
|
||||||
|
std::istringstream headerStream(headerData);
|
||||||
|
std::string line;
|
||||||
|
ssize_t contentLength = 0;
|
||||||
|
while (std::getline(headerStream, line) && line != "\r") {
|
||||||
|
size_t pos = line.find(':');
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
std::string key = line.substr(0, pos);
|
||||||
|
std::string value = line.substr(pos + 2); // Skip ': ' after key
|
||||||
|
if (key == "Content-Length") {
|
||||||
|
try {
|
||||||
|
contentLength = std::stoi(value); // stoi ignores any whitespace
|
||||||
|
} catch (const std::invalid_argument& e) {
|
||||||
|
throw std::runtime_error("Bad Content-Length header - not a number");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now read the full content body
|
||||||
|
if (contentLength > 0) {
|
||||||
|
// If we read any extra past the header before, read that much less now
|
||||||
|
// But first, sanity check to make sure Content-Length isn't lying and there is actually more
|
||||||
|
if (extraReadPastHeader.size() > static_cast<size_t>(contentLength)) {
|
||||||
|
throw std::runtime_error("Received more body data than Content-Length header said");
|
||||||
|
}
|
||||||
|
contentLength -= extraReadPastHeader.size();
|
||||||
|
|
||||||
|
std::vector<char> body(contentLength);
|
||||||
|
ssize_t totalRead = 0;
|
||||||
|
while (totalRead < contentLength) {
|
||||||
|
bytesRead = recv(serverSocket, body.data() + totalRead, contentLength - totalRead, 0);
|
||||||
|
if (bytesRead <= 0) {
|
||||||
|
throw std::runtime_error("Error while reading body from socket");
|
||||||
|
}
|
||||||
|
totalRead += bytesRead;
|
||||||
|
}
|
||||||
|
if (body.size() > 0) {
|
||||||
|
httpRequest.append(body.data(), contentLength);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::vector<char>(httpRequest.begin(), httpRequest.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string getMethod() {
|
||||||
|
if (method == HttpMethod::METHOD_GET) return "GET";
|
||||||
|
if (method == HttpMethod::METHOD_POST) return "POST";
|
||||||
|
if (method == HttpMethod::METHOD_PUT) return "PUT";
|
||||||
|
if (method == HttpMethod::METHOD_DELETE) return "DELETE";
|
||||||
|
if (method == HttpMethod::METHOD_OPTIONS) return "OPTIONS";
|
||||||
|
return "UNKNOWN";
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeCors() {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << "HTTP/1.1 204 No Content\r\n"
|
||||||
|
<< "Access-Control-Allow-Origin: *\r\n"
|
||||||
|
<< "Access-Control-Allow-Methods: GET, POST, PUT, DELETE\r\n"
|
||||||
|
<< "Access-Control-Allow-Headers: Content-Type, Authorization\r\n"
|
||||||
|
<< "Connection: close\r\n"
|
||||||
|
<< "\r\n";
|
||||||
|
std::string data = buffer.str();
|
||||||
|
writeSocket(serverSocket, data.c_str(), data.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeNotFound() {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << "HTTP/1.1 404 Not Found\r\n"
|
||||||
|
<< "Connection: close\r\n"
|
||||||
|
<< "Content-Length: 9\r\n"
|
||||||
|
<< "\r\n"
|
||||||
|
<< "Not Found";
|
||||||
|
std::string data = buffer.str();
|
||||||
|
writeSocket(serverSocket, data.c_str(), data.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeJson(std::string json) {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << "HTTP/1.1 200 OK\r\n"
|
||||||
|
<< "Access-Control-Allow-Origin: *\r\n"
|
||||||
|
<< "Content-Type: application/json; charset=utf-8\r\n"
|
||||||
|
<< "Connection: close\r\n"
|
||||||
|
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
|
||||||
|
std::string data = buffer.str();
|
||||||
|
writeSocket(serverSocket, data.c_str(), data.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeStreamStartChunk() {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << "HTTP/1.1 200 OK\r\n"
|
||||||
|
<< "Access-Control-Allow-Origin: *\r\n"
|
||||||
|
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
|
||||||
|
<< "Connection: close\r\n"
|
||||||
|
<< "Transfer-Encoding: chunked\r\n\r\n";
|
||||||
|
std::string data = buffer.str();
|
||||||
|
writeSocket(serverSocket, data.c_str(), data.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeStreamChunk(const std::string data) {
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << std::hex << data.size() << "\r\n" << data << "\r\n";
|
||||||
|
std::string d = buffer.str();
|
||||||
|
writeSocket(serverSocket, d.c_str(), d.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeStreamEndChunk() {
|
||||||
|
const char *endChunk = "0000\r\n\r\n";
|
||||||
|
writeSocket(serverSocket, endChunk, strlen(endChunk));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Route {
|
||||||
|
std::string path;
|
||||||
|
HttpMethod method;
|
||||||
|
std::function<void(HttpRequest&)> handler;
|
||||||
|
};
|
||||||
|
|
||||||
|
class Router {
|
||||||
|
public:
|
||||||
|
static void resolve(HttpRequest& request, std::vector<Route>& routes) {
|
||||||
|
if (request.method == HttpMethod::METHOD_OPTIONS) {
|
||||||
|
request.writeCors();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const auto& route : routes) {
|
||||||
|
if (request.method == route.method && request.path == route.path) {
|
||||||
|
route.handler(request);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request.writeNotFound();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void writeChatCompletionChunk(HttpRequest &request, const std::string &delta, const bool stop){
|
||||||
|
ChunkChoice choice;
|
||||||
|
if (stop) {
|
||||||
|
choice.finish_reason = "stop";
|
||||||
|
} else {
|
||||||
|
choice.delta = ChatMessageDelta("assistant", delta);
|
||||||
|
}
|
||||||
|
ChatCompletionChunk chunk = ChatCompletionChunk(choice);
|
||||||
|
|
||||||
|
std::ostringstream buffer;
|
||||||
|
buffer << "data: " << ((json)chunk).dump() << "\r\n\r\n";
|
||||||
|
request.writeStreamChunk(buffer.str());
|
||||||
|
|
||||||
|
if (stop) {
|
||||||
|
request.writeStreamChunk("data: [DONE]");
|
||||||
|
request.writeStreamEndChunk();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class NaiveCacheItem {
|
||||||
|
public:
|
||||||
|
pos_t endPos;
|
||||||
|
ChatMessage message;
|
||||||
|
NaiveCacheItem(pos_t endPos, ChatMessage message) {
|
||||||
|
this->endPos = endPos;
|
||||||
|
this->message = message;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class NaiveCache {
|
||||||
|
private:
|
||||||
|
std::vector<NaiveCacheItem> cache;
|
||||||
|
public:
|
||||||
|
void push(NaiveCacheItem item) {
|
||||||
|
cache.push_back(item);
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear() {
|
||||||
|
cache.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool resolveDeltaPrompt(std::vector<ChatMessage>& messages, pos_t& startPos) {
|
||||||
|
size_t cacheSize = cache.size();
|
||||||
|
if (cacheSize == 0)
|
||||||
|
return false;
|
||||||
|
if (messages.size() > cacheSize) {
|
||||||
|
size_t i = 0;
|
||||||
|
while (i < cacheSize) {
|
||||||
|
if (
|
||||||
|
cache[i].message.role != messages[i].role ||
|
||||||
|
cache[i].message.content != messages[i].content
|
||||||
|
) break;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
if (i == cacheSize) {
|
||||||
|
startPos = cache[i - 1].endPos;
|
||||||
|
messages.erase(messages.begin(), messages.begin() + i);
|
||||||
|
printf("🐤 Found naive cache for %zu messages, pos=%d\n", i, startPos);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cache.clear();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ApiServer {
|
||||||
|
private:
|
||||||
|
RootLlmInference *inference;
|
||||||
|
Tokenizer *tokenizer;
|
||||||
|
Sampler *sampler;
|
||||||
|
AppCliArgs *args;
|
||||||
|
LlmHeader *header;
|
||||||
|
EosDetector *eosDetector;
|
||||||
|
ChatTemplateGenerator *templateGenerator;
|
||||||
|
NaiveCache naiveCache;
|
||||||
|
|
||||||
|
public:
|
||||||
|
ApiServer(RootLlmInference *inference, Tokenizer *tokenizer, Sampler *sampler, AppCliArgs *args, LlmHeader *header, EosDetector *eosDetector, ChatTemplateGenerator *templateGenerator) {
|
||||||
|
this->inference = inference;
|
||||||
|
this->tokenizer = tokenizer;
|
||||||
|
this->sampler = sampler;
|
||||||
|
this->args = args;
|
||||||
|
this->header = header;
|
||||||
|
this->eosDetector = eosDetector;
|
||||||
|
this->templateGenerator = templateGenerator;
|
||||||
|
}
|
||||||
|
|
||||||
|
void complete(HttpRequest& request) {
|
||||||
|
InferenceParams params = parseRequest(request);
|
||||||
|
|
||||||
|
pos_t startPos = 0;
|
||||||
|
std::vector<ChatMessage> deltaPrompt = params.messages;
|
||||||
|
naiveCache.resolveDeltaPrompt(deltaPrompt, startPos);
|
||||||
|
|
||||||
|
size_t nInputItems = deltaPrompt.size();
|
||||||
|
std::unique_ptr<ChatItem[]> inputItemsPtr(new ChatItem[nInputItems]);
|
||||||
|
ChatItem *inputItems = inputItemsPtr.get();
|
||||||
|
for (size_t i = 0; i < nInputItems; i++) {
|
||||||
|
inputItems[i].role = deltaPrompt[i].role;
|
||||||
|
inputItems[i].message = deltaPrompt[i].content;
|
||||||
|
}
|
||||||
|
|
||||||
|
GeneratedChat inputPrompt = templateGenerator->generate(nInputItems, inputItems, true);
|
||||||
|
printf("🔹%s🔸", inputPrompt.content);
|
||||||
|
|
||||||
|
int nPromptTokens;
|
||||||
|
std::unique_ptr<int[]> promptTokensPtr(new int[inputPrompt.length + 2]);
|
||||||
|
int *promptTokens = promptTokensPtr.get();
|
||||||
|
bool isStart = startPos == 0;
|
||||||
|
tokenizer->encode((char*)inputPrompt.content, promptTokens, &nPromptTokens, isStart, true);
|
||||||
|
|
||||||
|
pos_t promptEndPos = startPos + nPromptTokens - 1;
|
||||||
|
if (promptEndPos > header->seqLen)
|
||||||
|
promptEndPos = header->seqLen;
|
||||||
|
|
||||||
|
pos_t maxPredPos = params.max_tokens > 0 ? (promptEndPos + params.max_tokens) : header->seqLen;
|
||||||
|
if (maxPredPos > header->seqLen)
|
||||||
|
maxPredPos = header->seqLen;
|
||||||
|
|
||||||
|
for (size_t j = 0; j < deltaPrompt.size(); j++) {
|
||||||
|
naiveCache.push(NaiveCacheItem(promptEndPos, deltaPrompt[j]));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string buffer;
|
||||||
|
|
||||||
|
if (params.stream)
|
||||||
|
request.writeStreamStartChunk();
|
||||||
|
if (inputPrompt.publicPrompt != nullptr) {
|
||||||
|
if (params.stream)
|
||||||
|
writeChatCompletionChunk(request, inputPrompt.publicPrompt, false);
|
||||||
|
buffer += inputPrompt.publicPrompt;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint pos = startPos;
|
||||||
|
int token;
|
||||||
|
for (NnUint i = 0; ;) {
|
||||||
|
long remainingTokens = promptEndPos - pos;
|
||||||
|
if (remainingTokens <= 0)
|
||||||
|
break;
|
||||||
|
|
||||||
|
NnUint batchSize = remainingTokens < args->nBatches
|
||||||
|
? remainingTokens
|
||||||
|
: args->nBatches;
|
||||||
|
|
||||||
|
inference->setBatchSize(batchSize);
|
||||||
|
inference->setPosition(pos);
|
||||||
|
for (NnUint j = 0; j < batchSize; j++)
|
||||||
|
inference->setToken(j, promptTokens[i + j]);
|
||||||
|
|
||||||
|
inference->forward();
|
||||||
|
|
||||||
|
i += batchSize;
|
||||||
|
pos += batchSize;
|
||||||
|
token = promptTokens[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
inference->setBatchSize(1);
|
||||||
|
tokenizer->resetDecoder();
|
||||||
|
eosDetector->reset();
|
||||||
|
|
||||||
|
for (; pos < maxPredPos;) {
|
||||||
|
inference->setPosition(pos);
|
||||||
|
inference->setToken(0, token);
|
||||||
|
inference->forward();
|
||||||
|
|
||||||
|
token = sampler->sample(inference->logitsPipe);
|
||||||
|
|
||||||
|
char *piece = tokenizer->decode(token);
|
||||||
|
EosDetectorType eosType = eosDetector->append(token, piece);
|
||||||
|
|
||||||
|
if (piece != nullptr) {
|
||||||
|
printf("%s", piece);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (eosType == NOT_EOS || eosType == EOS) {
|
||||||
|
char *delta = eosDetector->getDelta();
|
||||||
|
if (delta != nullptr) {
|
||||||
|
std::string deltaStr(delta);
|
||||||
|
if (params.stream)
|
||||||
|
writeChatCompletionChunk(request, deltaStr, false);
|
||||||
|
buffer += deltaStr;
|
||||||
|
}
|
||||||
|
eosDetector->reset();
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
if (eosType == EOS) break;
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatMessage chatMessage("assistant", buffer);
|
||||||
|
if (pos == header->seqLen) {
|
||||||
|
naiveCache.clear();
|
||||||
|
} else {
|
||||||
|
naiveCache.push(NaiveCacheItem(pos, chatMessage));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.stream) {
|
||||||
|
writeChatCompletionChunk(request, "", true);
|
||||||
|
} else {
|
||||||
|
int nCompletionTokens = pos - promptEndPos;
|
||||||
|
ChatUsage usage(nPromptTokens, nCompletionTokens, nPromptTokens + nCompletionTokens);
|
||||||
|
Choice choice(chatMessage);
|
||||||
|
ChatCompletion completion(choice, usage);
|
||||||
|
std::string chatJson = ((json)completion).dump();
|
||||||
|
request.writeJson(chatJson);
|
||||||
|
}
|
||||||
|
printf("🔶\n");
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
InferenceParams parseRequest(HttpRequest& request) {
|
||||||
|
InferenceParams params;
|
||||||
|
params.temperature = args->temperature;
|
||||||
|
params.top_p = args->topp;
|
||||||
|
params.seed = args->seed;
|
||||||
|
params.stream = false;
|
||||||
|
params.messages = parseChatMessages(request.parsedJson["messages"]);
|
||||||
|
params.max_tokens = -1;
|
||||||
|
|
||||||
|
if (request.parsedJson.contains("stream")) {
|
||||||
|
params.stream = request.parsedJson["stream"].get<bool>();
|
||||||
|
}
|
||||||
|
if (request.parsedJson.contains("temperature")) {
|
||||||
|
params.temperature = request.parsedJson["temperature"].template get<float>();
|
||||||
|
}
|
||||||
|
if (request.parsedJson.contains("seed")) {
|
||||||
|
params.seed = request.parsedJson["seed"].template get<unsigned long long>();
|
||||||
|
sampler->setSeed(params.seed);
|
||||||
|
}
|
||||||
|
if (request.parsedJson.contains("max_tokens")) {
|
||||||
|
params.max_tokens = request.parsedJson["max_tokens"].template get<int>();
|
||||||
|
}
|
||||||
|
if (request.parsedJson.contains("stop")) {
|
||||||
|
params.stop = request.parsedJson["stop"].template get<std::vector<std::string>>();
|
||||||
|
} else {
|
||||||
|
const std::string defaultStop = "<|eot_id|>";
|
||||||
|
params.stop = std::vector<std::string>{defaultStop};
|
||||||
|
}
|
||||||
|
return params;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void handleCompletionsRequest(HttpRequest& request, ApiServer *api) {
|
||||||
|
api->complete(request);
|
||||||
|
}
|
||||||
|
|
||||||
|
void handleModelsRequest(HttpRequest& request, const char* modelPath) {
|
||||||
|
std::string path(modelPath);
|
||||||
|
size_t pos = path.find_last_of("/\\");
|
||||||
|
std::string modelName = (pos == std::string::npos) ? path : path.substr(pos + 1);
|
||||||
|
|
||||||
|
Model model(modelName);
|
||||||
|
ModelList list(model);
|
||||||
|
std::string response = ((json)list).dump();
|
||||||
|
request.writeJson(response);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void server(AppInferenceContext *context) {
|
||||||
|
int serverSocket = createServerSocket(context->args->port);
|
||||||
|
|
||||||
|
TokenizerChatStops stops(context->tokenizer);
|
||||||
|
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
|
||||||
|
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
|
||||||
|
ApiServer api(context->inference, context->tokenizer, context->sampler, context->args, context->header, &eosDetector, &templateGenerator);
|
||||||
|
|
||||||
|
printf("Server URL: http://127.0.0.1:%d/v1/\n", context->args->port);
|
||||||
|
|
||||||
|
std::vector<Route> routes = {
|
||||||
|
{
|
||||||
|
"/v1/chat/completions",
|
||||||
|
HttpMethod::METHOD_POST,
|
||||||
|
std::bind(&handleCompletionsRequest, std::placeholders::_1, &api)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"/v1/models",
|
||||||
|
HttpMethod::METHOD_GET,
|
||||||
|
std::bind(&handleModelsRequest, std::placeholders::_1, context->args->modelPath)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
try {
|
||||||
|
int clientSocket = acceptSocket(serverSocket);
|
||||||
|
HttpRequest request = HttpRequest::read(clientSocket);
|
||||||
|
printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str());
|
||||||
|
Router::resolve(request, routes);
|
||||||
|
destroySocket(clientSocket);
|
||||||
|
} catch (NnReadNetworkException& ex) {
|
||||||
|
printf("Read socket error: %d %s\n", ex.code, ex.message);
|
||||||
|
} catch (NnWriteNetworkException& ex) {
|
||||||
|
printf("Write socket error: %d %s\n", ex.code, ex.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
destroySocket(serverSocket);
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define EXECUTABLE_NAME "dllama-api.exe"
|
||||||
|
#else
|
||||||
|
#define EXECUTABLE_NAME "dllama-api"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void usage() {
|
||||||
|
fprintf(stderr, "Usage: %s {--model <path>} {--tokenizer <path>} [--port <p>]\n", EXECUTABLE_NAME);
|
||||||
|
fprintf(stderr, " [--buffer-float-type {f32|f16|q40|q80}]\n");
|
||||||
|
fprintf(stderr, " [--weights-float-type {f32|f16|q40|q80}]\n");
|
||||||
|
fprintf(stderr, " [--max-seq-len <max>]\n");
|
||||||
|
fprintf(stderr, " [--nthreads <n>]\n");
|
||||||
|
fprintf(stderr, " [--workers <ip:port> ...]\n");
|
||||||
|
fprintf(stderr, " [--temperature <temp>]\n");
|
||||||
|
fprintf(stderr, " [--topp <t>]\n");
|
||||||
|
fprintf(stderr, " [--seed <s>]\n");
|
||||||
|
fprintf(stderr, "Example:\n");
|
||||||
|
fprintf(stderr, " sudo nice -n -20 ./dllama-api --port 9990 --nthreads 4 \\\n");
|
||||||
|
fprintf(stderr, " --model dllama_model_llama3_2_3b_instruct_q40.m \\\n");
|
||||||
|
fprintf(stderr, " --tokenizer dllama_tokenizer_llama3_2_3b_instruct_q40.t \\\n");
|
||||||
|
fprintf(stderr, " --buffer-float-type q80 --max-seq-len 8192 \\\n");
|
||||||
|
fprintf(stderr, " --workers 10.0.0.2:9998 10.0.0.3:9998 10.0.0.4:9998\n");
|
||||||
|
fflush(stderr);
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char *argv[]) {
|
||||||
|
#ifdef SIGPIPE
|
||||||
|
std::signal(SIGPIPE, SIG_IGN);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
initQuants();
|
||||||
|
initSockets();
|
||||||
|
|
||||||
|
int returnCode = EXIT_SUCCESS;
|
||||||
|
try {
|
||||||
|
AppCliArgs args = AppCliArgs::parse(argc, argv, false);
|
||||||
|
if (args.help) {
|
||||||
|
usage();
|
||||||
|
} else {
|
||||||
|
runInferenceApp(&args, server);
|
||||||
|
}
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
printf("🚨 Critical error: %s\n", e.what());
|
||||||
|
returnCode = EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupSockets();
|
||||||
|
return returnCode;
|
||||||
|
}
|
||||||
285
src/dllama.cpp
Normal file
285
src/dllama.cpp
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
#include "nn/nn-core.hpp"
|
||||||
|
#include "nn/nn-config-builder.hpp"
|
||||||
|
#include "nn/nn-cpu.hpp"
|
||||||
|
#include "nn/nn-cpu-ops.hpp"
|
||||||
|
#include "nn/nn-network.hpp"
|
||||||
|
#include "nn/nn-executor.hpp"
|
||||||
|
#include "llm.hpp"
|
||||||
|
#include "tokenizer.hpp"
|
||||||
|
#include "app.hpp"
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
static void inference(AppInferenceContext *context) {
|
||||||
|
if (context->args->prompt == nullptr)
|
||||||
|
throw std::runtime_error("Prompt is required");
|
||||||
|
if (context->args->steps == 0)
|
||||||
|
throw std::runtime_error("Number of steps is required");
|
||||||
|
|
||||||
|
std::vector<int> inputTokensVec(std::strlen(context->args->prompt) + 3);
|
||||||
|
int *inputTokens = inputTokensVec.data();
|
||||||
|
|
||||||
|
NnUint pos = 0;
|
||||||
|
int nInputTokens;
|
||||||
|
context->tokenizer->encode(context->args->prompt, inputTokens, &nInputTokens, true, true);
|
||||||
|
|
||||||
|
if (nInputTokens > context->header->seqLen)
|
||||||
|
throw std::runtime_error("The number of prompt tokens is greater than the sequence length");
|
||||||
|
if (nInputTokens > context->args->steps)
|
||||||
|
throw std::runtime_error("The number of prompt tokens is greater than the number of steps");
|
||||||
|
|
||||||
|
NnSize sentBytes = 0;
|
||||||
|
NnSize recvBytes = 0;
|
||||||
|
NnUint evalTotalTime = 0;
|
||||||
|
NnUint predTotalTime = 0;
|
||||||
|
|
||||||
|
int token = inputTokens[pos];
|
||||||
|
printf("%s\n", context->args->prompt);
|
||||||
|
for (;;) {
|
||||||
|
long remainingTokens = nInputTokens - 1 - (long)pos;
|
||||||
|
if (remainingTokens <= 0)
|
||||||
|
break;
|
||||||
|
NnUint batchSize = remainingTokens < context->args->nBatches
|
||||||
|
? remainingTokens
|
||||||
|
: context->args->nBatches;
|
||||||
|
|
||||||
|
context->inference->setBatchSize(batchSize);
|
||||||
|
context->inference->setPosition(pos);
|
||||||
|
for (NnUint i = 0; i < batchSize; i++)
|
||||||
|
context->inference->setToken(i, inputTokens[pos + i]);
|
||||||
|
|
||||||
|
context->inference->forward();
|
||||||
|
|
||||||
|
pos += batchSize;
|
||||||
|
token = inputTokens[pos + 1];
|
||||||
|
|
||||||
|
if (context->network != nullptr)
|
||||||
|
context->network->getStats(&sentBytes, &recvBytes);
|
||||||
|
|
||||||
|
NnUint evalTime = context->executor->getTotalTime(STEP_EXECUTE_OP);
|
||||||
|
NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES);
|
||||||
|
printf("🔷️ Eval%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | (%d tokens)\n",
|
||||||
|
evalTime / 1000,
|
||||||
|
syncTime / 1000,
|
||||||
|
sentBytes / 1024,
|
||||||
|
recvBytes / 1024,
|
||||||
|
batchSize);
|
||||||
|
evalTotalTime += evalTime + syncTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
|
context->inference->setBatchSize(1);
|
||||||
|
context->tokenizer->resetDecoder();
|
||||||
|
|
||||||
|
const NnUint maxPos = std::min(context->header->seqLen, context->args->steps);
|
||||||
|
for (; pos < maxPos; pos++) {
|
||||||
|
context->inference->setPosition(pos);
|
||||||
|
context->inference->setToken(0, token);
|
||||||
|
context->inference->forward();
|
||||||
|
|
||||||
|
token = context->sampler->sample(context->inference->logitsPipe);
|
||||||
|
|
||||||
|
char *piece = context->tokenizer->decode(token);
|
||||||
|
|
||||||
|
if (context->network != nullptr)
|
||||||
|
context->network->getStats(&sentBytes, &recvBytes);
|
||||||
|
|
||||||
|
NnUint predTime = context->executor->getTotalTime(STEP_EXECUTE_OP);
|
||||||
|
NnUint syncTime = context->executor->getTotalTime(STEP_SYNC_NODES);
|
||||||
|
printf("🔶 Pred%5u ms Sync%5u ms | Sent%6zu kB Recv%6zu kB | %s\n",
|
||||||
|
predTime / 1000,
|
||||||
|
syncTime / 1000,
|
||||||
|
sentBytes / 1024,
|
||||||
|
recvBytes / 1024,
|
||||||
|
piece == nullptr ? "~" : piece);
|
||||||
|
fflush(stdout);
|
||||||
|
predTotalTime += predTime + syncTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint nEvalTokens = nInputTokens - 1;
|
||||||
|
NnUint nPredTokens = pos - nEvalTokens;
|
||||||
|
float evalTotalTimeMs = evalTotalTime / 1000.0;
|
||||||
|
float predTotalTimeMs = predTotalTime / 1000.0;
|
||||||
|
printf("\n");
|
||||||
|
printf("Evaluation\n");
|
||||||
|
printf(" nBatches: %d\n", context->args->nBatches);
|
||||||
|
printf(" nTokens: %d\n", nEvalTokens);
|
||||||
|
printf(" tokens/s: %3.2f (%3.2f ms/tok)\n",
|
||||||
|
(nEvalTokens * 1000) / evalTotalTimeMs,
|
||||||
|
evalTotalTimeMs / ((float) nEvalTokens));
|
||||||
|
printf("Prediction\n");
|
||||||
|
printf(" nTokens: %d\n", nPredTokens);
|
||||||
|
printf(" tokens/s: %3.2f (%3.2f ms/tok)\n",
|
||||||
|
(nPredTokens * 1000) / predTotalTimeMs,
|
||||||
|
predTotalTimeMs / ((float) nPredTokens));
|
||||||
|
}
|
||||||
|
|
||||||
|
static NnUint readStdin(const char *guide, char *buffer, NnUint size) {
|
||||||
|
std::fflush(stdin);
|
||||||
|
std::printf("%s", guide);
|
||||||
|
if (std::fgets(buffer, size, stdin) != NULL) {
|
||||||
|
NnUint length = std::strlen(buffer);
|
||||||
|
if (length > 0 && buffer[length - 1] == '\n') {
|
||||||
|
buffer[length - 1] = '\0';
|
||||||
|
length--;
|
||||||
|
}
|
||||||
|
return length;
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void perplexity(AppInferenceContext *context) {
|
||||||
|
if (context->args->prompt == nullptr)
|
||||||
|
throw std::runtime_error("Prompt is required");
|
||||||
|
|
||||||
|
std::vector<int> inputTokensVec(std::strlen(context->args->prompt) + 3);
|
||||||
|
int *inputTokens = inputTokensVec.data();
|
||||||
|
|
||||||
|
int nInputTokens;
|
||||||
|
context->tokenizer->encode(context->args->prompt, inputTokens, &nInputTokens, true, true);
|
||||||
|
|
||||||
|
printf("Evaluating %d tokens...\n", nInputTokens);
|
||||||
|
|
||||||
|
float totalLogProb = 0.0f;
|
||||||
|
NnUint pos = 0;
|
||||||
|
|
||||||
|
context->inference->setBatchSize(1);
|
||||||
|
|
||||||
|
for (pos = 0; pos < nInputTokens - 1; pos++) {
|
||||||
|
context->inference->setPosition(pos);
|
||||||
|
context->inference->setToken(0, inputTokens[pos]);
|
||||||
|
context->inference->forward();
|
||||||
|
|
||||||
|
float *logits = context->inference->logitsPipe;
|
||||||
|
softmax_F32(logits, context->header->vocabSize);
|
||||||
|
|
||||||
|
int targetToken = inputTokens[pos + 1];
|
||||||
|
float prob = logits[targetToken];
|
||||||
|
|
||||||
|
totalLogProb += std::log(std::max(prob, 1e-30f));
|
||||||
|
printf("%5d / %d, prob=%f\n", pos + 1, nInputTokens - 1, prob);
|
||||||
|
}
|
||||||
|
|
||||||
|
float avgLogProb = totalLogProb / (float)(nInputTokens - 1);
|
||||||
|
float perplexity = expf(-avgLogProb);
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
printf("Results\n");
|
||||||
|
printf(" perplexity: %f (lower = better)\n", perplexity);
|
||||||
|
printf(" avgLogProb: %f\n", avgLogProb);
|
||||||
|
printf(" bitPerToken: %f\n", -avgLogProb / std::log(2.0));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void chat(AppInferenceContext *context) {
|
||||||
|
const NnUint seqLen = context->header->seqLen;
|
||||||
|
char prompt[2048];
|
||||||
|
|
||||||
|
TokenizerChatStops stops(context->tokenizer);
|
||||||
|
ChatTemplateGenerator templateGenerator(context->args->chatTemplateType, context->tokenizer->chatTemplate, stops.stops[0]);
|
||||||
|
EosDetector eosDetector(stops.nStops, context->tokenizer->eosTokenIds.data(), stops.stops, stops.maxStopLength, stops.maxStopLength);
|
||||||
|
|
||||||
|
const NnUint sysPromptLength = readStdin("💻 System prompt (optional): ", prompt, sizeof(prompt));
|
||||||
|
std::vector<ChatItem> deltaItems;
|
||||||
|
if (sysPromptLength > 0)
|
||||||
|
deltaItems.push_back(ChatItem{"system", prompt});
|
||||||
|
|
||||||
|
NnUint pos = 0;
|
||||||
|
NnUint userPromptLength;
|
||||||
|
int token;
|
||||||
|
int nInputTokens;
|
||||||
|
do {
|
||||||
|
do {
|
||||||
|
userPromptLength = readStdin("\n👱 User\n> ", prompt, sizeof(prompt));
|
||||||
|
} while (userPromptLength == 0);
|
||||||
|
|
||||||
|
deltaItems.push_back(ChatItem{"user", prompt});
|
||||||
|
|
||||||
|
GeneratedChat inputPrompt = templateGenerator.generate(deltaItems.size(), deltaItems.data(), true);
|
||||||
|
std::unique_ptr<int[]> inputTokensPtr(new int[inputPrompt.length + 2]);
|
||||||
|
int *inputTokens = inputTokensPtr.get();
|
||||||
|
|
||||||
|
bool isStart = pos == 0;
|
||||||
|
context->tokenizer->encode((char*)inputPrompt.content, inputTokens, &nInputTokens, isStart, true);
|
||||||
|
|
||||||
|
NnUint userPromptEndPos = (NnUint)std::min<unsigned int>(seqLen, pos + nInputTokens - 1);
|
||||||
|
for (NnUint i = 0; ;) {
|
||||||
|
int remainingTokens = userPromptEndPos - pos;
|
||||||
|
if (remainingTokens <= 0)
|
||||||
|
break;
|
||||||
|
NnUint batchSize = remainingTokens < context->args->nBatches
|
||||||
|
? remainingTokens
|
||||||
|
: context->args->nBatches;
|
||||||
|
|
||||||
|
context->inference->setBatchSize(batchSize);
|
||||||
|
context->inference->setPosition(pos);
|
||||||
|
for (NnUint j = 0; j < batchSize; j++)
|
||||||
|
context->inference->setToken(j, inputTokens[i + j]);
|
||||||
|
|
||||||
|
context->inference->forward();
|
||||||
|
|
||||||
|
i += batchSize;
|
||||||
|
pos += batchSize;
|
||||||
|
token = inputTokens[i + 1];
|
||||||
|
}
|
||||||
|
|
||||||
|
context->inference->setBatchSize(1);
|
||||||
|
context->tokenizer->resetDecoder();
|
||||||
|
|
||||||
|
printf("\n🤖 Assistant\n");
|
||||||
|
if (inputPrompt.publicPrompt != nullptr)
|
||||||
|
printf("%s", inputPrompt.publicPrompt);
|
||||||
|
|
||||||
|
while (pos < seqLen) {
|
||||||
|
context->inference->setPosition(pos);
|
||||||
|
context->inference->setToken(0, token);
|
||||||
|
context->inference->forward();
|
||||||
|
|
||||||
|
token = context->sampler->sample(context->inference->logitsPipe);
|
||||||
|
|
||||||
|
char *piece = context->tokenizer->decode(token);
|
||||||
|
EosDetectorType eosType = eosDetector.append(token, piece);
|
||||||
|
if (eosType == NOT_EOS || eosType == EOS) {
|
||||||
|
char *delta = eosDetector.getDelta();
|
||||||
|
if (delta != nullptr) {
|
||||||
|
printf("%s", delta);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
eosDetector.reset();
|
||||||
|
}
|
||||||
|
pos++;
|
||||||
|
if (eosType == EOS) break;
|
||||||
|
}
|
||||||
|
|
||||||
|
deltaItems.clear();
|
||||||
|
} while (pos < seqLen);
|
||||||
|
|
||||||
|
printf("(end of context)\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv) {
|
||||||
|
initQuants();
|
||||||
|
initSockets();
|
||||||
|
|
||||||
|
int returnCode = EXIT_SUCCESS;
|
||||||
|
try {
|
||||||
|
AppCliArgs args = AppCliArgs::parse(argc, argv, true);
|
||||||
|
if (std::strcmp(args.mode, "inference") == 0) {
|
||||||
|
args.benchmark = true;
|
||||||
|
runInferenceApp(&args, &inference);
|
||||||
|
} else if (std::strcmp(args.mode, "perplexity") == 0)
|
||||||
|
runInferenceApp(&args, &perplexity);
|
||||||
|
else if (std::strcmp(args.mode, "chat") == 0)
|
||||||
|
runInferenceApp(&args, &chat);
|
||||||
|
else if (std::strcmp(args.mode, "worker") == 0)
|
||||||
|
runWorkerApp(&args);
|
||||||
|
else
|
||||||
|
throw std::runtime_error("Unsupported mode");
|
||||||
|
} catch (std::exception &e) {
|
||||||
|
printf("🚨 Critical error: %s\n", e.what());
|
||||||
|
returnCode = EXIT_FAILURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
cleanupSockets();
|
||||||
|
return returnCode;
|
||||||
|
}
|
||||||
24765
src/json.hpp
Normal file
24765
src/json.hpp
Normal file
File diff suppressed because it is too large
Load Diff
669
src/llm.cpp
Normal file
669
src/llm.cpp
Normal file
@@ -0,0 +1,669 @@
|
|||||||
|
#include "nn/nn-core.hpp"
|
||||||
|
#include "nn/nn-config-builder.hpp"
|
||||||
|
#include "nn/nn-cpu.hpp"
|
||||||
|
#include "nn/nn-network.hpp"
|
||||||
|
#include "mmap.hpp"
|
||||||
|
#include "llm.hpp"
|
||||||
|
#include <cerrno>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
static const char *hiddenActToString(LlmHiddenAct act) {
|
||||||
|
if (act == HIDDEN_ACT_GELU) return "Gelu";
|
||||||
|
if (act == HIDDEN_ACT_SILU) return "Silu";
|
||||||
|
throw std::runtime_error("Unsupported hidden act");
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char *ropeTypeToString(NnRopeType type) {
|
||||||
|
if (type == ROPE_LLAMA) return "Llama";
|
||||||
|
if (type == ROPE_LLAMA3_1) return "Llama3.1";
|
||||||
|
if (type == ROPE_FALCON) return "Falcon";
|
||||||
|
throw std::runtime_error("Unsupported rope type");
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char *archTypeToString(LlmArchType type) {
|
||||||
|
if (type == LLAMA) return "Llama";
|
||||||
|
if (type == QWEN3) return "Qwen3";
|
||||||
|
if (type == QWEN3_MOE) return "Qwen3 MoE";
|
||||||
|
throw std::runtime_error("Unsupported architecture");
|
||||||
|
}
|
||||||
|
|
||||||
|
static float convertNormEpsilon(int value) {
|
||||||
|
if (value == 5) return 1e-05f;
|
||||||
|
if (value == 6) return 1e-06f;
|
||||||
|
throw std::runtime_error("Unsupported norm epsilon");
|
||||||
|
}
|
||||||
|
|
||||||
|
LlmHeader loadLlmHeader(const char *path, const NnUint maxSeqLen, NnFloatType syncType) {
|
||||||
|
LlmHeader header;
|
||||||
|
std::memset(&header, 0, sizeof(LlmHeader));
|
||||||
|
header.weightType = F_UNK;
|
||||||
|
header.hiddenAct = HIDDEN_ACT_SILU;
|
||||||
|
header.ropeType = ROPE_LLAMA;
|
||||||
|
header.ropeTheta = 10000.0f;
|
||||||
|
header.ropeScalingFactor = 1.0f;
|
||||||
|
header.normEpsilon = 1e-5f;
|
||||||
|
header.moeHiddenDim = 0u;
|
||||||
|
|
||||||
|
std::unique_ptr<FILE, int(*)(FILE *)> fdPtr(fopen(path, "rb"), fclose);
|
||||||
|
FILE *fd = fdPtr.get();
|
||||||
|
if (fd == NULL)
|
||||||
|
throw std::runtime_error(std::string("Cannot open model file (") + path + std::string("): ") + std::strerror(errno));
|
||||||
|
|
||||||
|
int magic;
|
||||||
|
if (fread(&magic, sizeof(int), 1, fd) != 1)
|
||||||
|
throw std::runtime_error("Cannot read magic value");
|
||||||
|
|
||||||
|
if (magic == 0xABCD00 || magic == 0xABCD01)
|
||||||
|
throw std::runtime_error("Old model format is not supported");
|
||||||
|
if (magic != 0xA00ABCD)
|
||||||
|
throw std::runtime_error("Unsupported magic number");
|
||||||
|
|
||||||
|
if (fread(&header.headerSize, sizeof(int), 1, fd) != 1)
|
||||||
|
throw std::runtime_error("Cannot read header size");
|
||||||
|
|
||||||
|
std::vector<int> bufferPtr(header.headerSize);
|
||||||
|
int *buffer = &bufferPtr[0];
|
||||||
|
if (fread(buffer, header.headerSize, 1, fd) != 1)
|
||||||
|
throw std::runtime_error("Cannot read header values");
|
||||||
|
|
||||||
|
int nKv = (header.headerSize - 2 * sizeof(int)) / sizeof(int);
|
||||||
|
|
||||||
|
for (int i = 0; i < nKv; i += 2) {
|
||||||
|
int key = buffer[i];
|
||||||
|
int value = buffer[i + 1];
|
||||||
|
if (key == VERSION) header.version = value;
|
||||||
|
else if (key == ARCH_TYPE) header.archType = (LlmArchType)value;
|
||||||
|
else if (key == DIM) header.dim = value;
|
||||||
|
else if (key == HIDDEN_DIM) header.hiddenDim = value;
|
||||||
|
else if (key == N_LAYERS) header.nLayers = value;
|
||||||
|
else if (key == N_HEADS) header.nHeads = value;
|
||||||
|
else if (key == N_KV_HEADS) header.nKvHeads = value;
|
||||||
|
else if (key == N_EXPERTS) header.nExperts = value;
|
||||||
|
else if (key == N_ACTIVE_EXPERTS) header.nActiveExperts = value;
|
||||||
|
else if (key == VOCAB_SIZE) header.vocabSize = value;
|
||||||
|
else if (key == SEQ_LEN) header.seqLen = value;
|
||||||
|
else if (key == HIDDEN_ACT) header.hiddenAct = (LlmHiddenAct)value;
|
||||||
|
else if (key == ROPE_THETA) header.ropeTheta = (float)value;
|
||||||
|
else if (key == WEIGHT_FLOAT_TYPE) header.weightType = (NnFloatType)value;
|
||||||
|
else if (key == ROPE_SCALING_FACTOR) header.ropeScalingFactor = (float)value;
|
||||||
|
else if (key == ROPE_SCALING_LOW_FREQ_FACTOR) header.ropeScalingLowFreqFactor = (float)value;
|
||||||
|
else if (key == ROPE_SCALING_HIGH_FREQ_FACTORY) header.ropeScalingHighFreqFactory = (float)value;
|
||||||
|
else if (key == ROPE_SCALING_ORIG_MAX_SEQ_LEN) header.ropeScalingOrigMaxSeqLen = value;
|
||||||
|
else if (key == ROPE_TYPE) header.ropeType = (NnRopeType)value;
|
||||||
|
else if (key == HEAD_DIM) header.headDim = value;
|
||||||
|
else if (key == NORM_EPSILON) header.normEpsilon = convertNormEpsilon(value);
|
||||||
|
else if (key == MOE_HIDDEN_DIM) header.moeHiddenDim = value;
|
||||||
|
else throw std::runtime_error("Unsupported header key");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (header.weightType == F_UNK)
|
||||||
|
throw std::runtime_error("Model does not specify weight type");
|
||||||
|
|
||||||
|
header.origSeqLen = header.seqLen;
|
||||||
|
if (maxSeqLen > 0 && header.seqLen > maxSeqLen)
|
||||||
|
header.seqLen = maxSeqLen;
|
||||||
|
|
||||||
|
if (header.headDim == 0)
|
||||||
|
header.headDim = header.dim / header.nHeads;
|
||||||
|
header.qDim = header.headDim * header.nHeads;
|
||||||
|
header.kvDim = header.headDim * header.nKvHeads;
|
||||||
|
header.syncType = syncType;
|
||||||
|
header.fileSize = (NnSize)seekToEnd(fd);
|
||||||
|
|
||||||
|
if (header.archType == QWEN3 || header.archType == QWEN3_MOE)
|
||||||
|
header.ropeType = ROPE_FALCON;
|
||||||
|
return header;
|
||||||
|
}
|
||||||
|
|
||||||
|
void printLlmHeader(LlmHeader *header) {
|
||||||
|
printf("💡 Arch: %s\n", archTypeToString(header->archType));
|
||||||
|
printf("💡 HiddenAct: %s\n", hiddenActToString(header->hiddenAct));
|
||||||
|
printf("💡 Dim: %u\n", header->dim);
|
||||||
|
printf("💡 HeadDim: %u\n", header->headDim);
|
||||||
|
printf("💡 QDim: %u\n", header->qDim);
|
||||||
|
printf("💡 KvDim: %u\n", header->kvDim);
|
||||||
|
printf("💡 HiddenDim: %u\n", header->hiddenDim);
|
||||||
|
printf("💡 VocabSize: %u\n", header->vocabSize);
|
||||||
|
printf("💡 nLayers: %u\n", header->nLayers);
|
||||||
|
printf("💡 nHeads: %u\n", header->nHeads);
|
||||||
|
printf("💡 nKvHeads: %u\n", header->nKvHeads);
|
||||||
|
if (header->seqLen != header->origSeqLen) {
|
||||||
|
printf("💡 OrigSeqLen: %u\n", header->origSeqLen);
|
||||||
|
}
|
||||||
|
if (header->nExperts > 0) {
|
||||||
|
printf("💡 nExperts: %u\n", header->nExperts);
|
||||||
|
printf("💡 nActiveExperts: %u\n", header->nActiveExperts);
|
||||||
|
printf("💡 MoeHiddenDim: %u\n", header->moeHiddenDim);
|
||||||
|
}
|
||||||
|
printf("💡 SeqLen: %u\n", header->seqLen);
|
||||||
|
printf("💡 NormEpsilon: %f\n", header->normEpsilon);
|
||||||
|
printf("💡 RopeType: %s\n", ropeTypeToString(header->ropeType));
|
||||||
|
printf("💡 RopeTheta: %.0f\n", header->ropeTheta);
|
||||||
|
if (header->ropeType == ROPE_LLAMA3_1) {
|
||||||
|
printf("💡 RopeScaling: f=%.1f, l=%.1f, h=%.1f, o=%d\n",
|
||||||
|
header->ropeScalingFactor,
|
||||||
|
header->ropeScalingLowFreqFactor,
|
||||||
|
header->ropeScalingHighFreqFactory,
|
||||||
|
header->ropeScalingOrigMaxSeqLen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches) {
|
||||||
|
NnUint nExpertsOr1 = std::max(h->nExperts, 1u);
|
||||||
|
NnUint nActiveExpertsOr1 = std::max(h->nActiveExperts, 1u);
|
||||||
|
NnUint ffDim = h->hiddenDim;
|
||||||
|
|
||||||
|
if (h->archType == QWEN3_MOE)
|
||||||
|
ffDim = h->moeHiddenDim;
|
||||||
|
|
||||||
|
LlmNet n;
|
||||||
|
n.tokenEmbeddingSize = size2D(F_32, h->vocabSize, h->dim);
|
||||||
|
n.rmsNormSize = size1D(F_32, h->dim);
|
||||||
|
n.qkRmsNormSize = size1D(F_32, h->headDim);
|
||||||
|
n.moeGateSize = size2D(F_32, h->dim, h->nExperts);
|
||||||
|
|
||||||
|
NnKvCacheSlice kvCacheSlice = sliceKvCache(h->kvDim, h->seqLen, nNodes);
|
||||||
|
NnMultiHeadAttSlice multiHeadAttSlice = sliceMultiHeadAtt(h->nHeads, h->seqLen, nNodes, nBatches);
|
||||||
|
|
||||||
|
n.qSlice = sliceRowMatmul(h->weightType, nNodes, h->dim, h->qDim);
|
||||||
|
n.kSlice = sliceRowMatmul(h->weightType, nNodes, h->dim, h->kvDim);
|
||||||
|
n.vSlice = sliceRowMatmul(h->weightType, nNodes, h->dim, h->kvDim);
|
||||||
|
n.woSlice = sliceColMatmul(h->weightType, nNodes, h->qDim, h->dim);
|
||||||
|
|
||||||
|
n.w1Slice = sliceRowMatmul(h->weightType, nNodes, h->dim, ffDim);
|
||||||
|
n.w2Slice = sliceColMatmul(h->weightType, nNodes, ffDim, h->dim);
|
||||||
|
n.w3Slice = sliceRowMatmul(h->weightType, nNodes, h->dim, ffDim);
|
||||||
|
n.wclsSlice = sliceRowMatmul(h->weightType, nNodes, h->dim, h->vocabSize);
|
||||||
|
|
||||||
|
NnUint nQNormColumns = 1;
|
||||||
|
NnUint nKNormColumns = 1;
|
||||||
|
NnUint nInvBufferColumns = 1;
|
||||||
|
if (h->archType == QWEN3 || h->archType == QWEN3_MOE) {
|
||||||
|
ASSERT_EQ(n.qSlice.d0 % h->headDim, 0);
|
||||||
|
ASSERT_EQ(n.kSlice.d0 % h->headDim, 0);
|
||||||
|
nQNormColumns = n.qSlice.d0 / h->headDim;
|
||||||
|
nKNormColumns = n.kSlice.d0 / h->headDim;
|
||||||
|
nInvBufferColumns = std::max(nQNormColumns, nKNormColumns);
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNetConfigBuilder netBuilder(nNodes, nBatches);
|
||||||
|
|
||||||
|
n.positionPipeIndex = netBuilder.addPipe("POS", size2D(F_32, nBatches, 1));
|
||||||
|
n.tokenPipeIndex = netBuilder.addPipe("TOK", size2D(F_32, nBatches, 1));
|
||||||
|
n.xPipeIndex = netBuilder.addPipe("X", size2D(F_32, nBatches, h->dim));
|
||||||
|
n.logitsPipeIndex = netBuilder.addPipe("LG", size2D(F_32, nBatches, h->vocabSize));
|
||||||
|
const NnUint zqPipeIndex = netBuilder.addPipe("ZQ", size2D(h->syncType, nBatches, h->dim * nNodes));
|
||||||
|
|
||||||
|
netBuilder.addPreSync(n.positionPipeIndex);
|
||||||
|
|
||||||
|
n.header = h;
|
||||||
|
n.netConfig = netBuilder.build();
|
||||||
|
n.nodeConfigs = new NnNodeConfig[nNodes];
|
||||||
|
|
||||||
|
for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) {
|
||||||
|
NnRopeSlice ropeSlice = sliceRope(h->ropeType, h->qDim, h->kvDim, h->nKvHeads, nNodes, h->seqLen, h->headDim, h->ropeTheta, nodeIndex);
|
||||||
|
NnNodeConfigBuilder nodeBuilder(nodeIndex);
|
||||||
|
|
||||||
|
const NnUint xBufferIndex = nodeBuilder.addBuffer("x", size2D(F_32, nBatches, h->dim));
|
||||||
|
const NnUint yBufferIndex = nodeBuilder.addBuffer("y", size2D(F_32, nBatches, h->dim));
|
||||||
|
const NnUint yqBufferIndex = h->syncType == F_32
|
||||||
|
? yBufferIndex
|
||||||
|
: nodeBuilder.addBuffer("q_y", size2D(h->syncType, nBatches, h->dim));
|
||||||
|
|
||||||
|
const NnUint zBufferIndex = nodeBuilder.addBuffer("z", size2D(F_32, nBatches, h->qDim));
|
||||||
|
const NnUint zqSliceBufferIndex = nodeBuilder.addBuffer("q_z_slice", size2D(h->syncType, nBatches, h->qDim / nNodes));
|
||||||
|
|
||||||
|
const NnUint qBufferIndex = nodeBuilder.addBuffer("q", size2D(F_32, nBatches, n.qSlice.d0));
|
||||||
|
const NnUint kTempBufferIndex = nodeBuilder.addBuffer("k_temp", size2D(F_32, nBatches, n.kSlice.d0));
|
||||||
|
const NnUint vTempBufferIndex = nodeBuilder.addBuffer("v_temp", size2D(F_32, nBatches, n.vSlice.d0));
|
||||||
|
|
||||||
|
const NnUint invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, nBatches, nInvBufferColumns));
|
||||||
|
|
||||||
|
const NnUint ropeCacheBufferIndex = nodeBuilder.addBuffer("rope_cache", ropeSlice.cacheSize);
|
||||||
|
const NnUint attBufferIndex = nodeBuilder.addBuffer("att", multiHeadAttSlice.attSize);
|
||||||
|
const NnUint logitsSliceBufferIndex = nodeBuilder.addBuffer("lg", size2D(F_32, nBatches, h->vocabSize / nNodes));
|
||||||
|
|
||||||
|
// not moe
|
||||||
|
const NnUint dBufferIndex = nodeBuilder.addBuffer("d", size2D(F_32, nBatches, n.w1Slice.d0));
|
||||||
|
const NnUint dqBufferIndex = h->syncType == F_32
|
||||||
|
? dBufferIndex
|
||||||
|
: nodeBuilder.addBuffer("q_d", size2D(h->syncType, nBatches, n.w1Slice.d0));
|
||||||
|
const NnUint lBufferIndex = nodeBuilder.addBuffer("l", size2D(F_32, nBatches, n.w3Slice.d0));
|
||||||
|
|
||||||
|
// moe
|
||||||
|
const NnUint moeGtBufferIndex = nodeBuilder.addBuffer("gt", size2D(F_32, nBatches, nExpertsOr1));
|
||||||
|
const NnUint moeExpertIndexesBufferIndex = nodeBuilder.addBuffer("act_exp_ix", size2D(F_32, nBatches, nActiveExpertsOr1));
|
||||||
|
const NnUint moeYBufferIndex = nodeBuilder.addBuffer("moe_y", size3D(F_32, nActiveExpertsOr1, nBatches, h->dim));
|
||||||
|
const NnUint moeYqBufferIndex = h->syncType == F_32
|
||||||
|
? moeYBufferIndex
|
||||||
|
: nodeBuilder.addBuffer("q_moe_y", size3D(h->syncType, nActiveExpertsOr1, nBatches, h->dim));
|
||||||
|
const NnUint moeDBufferIndex = nodeBuilder.addBuffer("moe_d", size3D(F_32, nActiveExpertsOr1, nBatches, n.w1Slice.d0));
|
||||||
|
const NnUint moeDQBufferIndex = h->syncType == F_32
|
||||||
|
? moeDBufferIndex
|
||||||
|
: nodeBuilder.addBuffer("q_moe_d", size3D(h->syncType, nActiveExpertsOr1, nBatches, n.w1Slice.d0));
|
||||||
|
const NnUint moeLBufferIndex = nodeBuilder.addBuffer("moe_l", size3D(F_32, nActiveExpertsOr1, nBatches, n.w3Slice.d0));
|
||||||
|
const NnUint moeSBufferIndex = nodeBuilder.addBuffer("moe_s", size3D(F_32, nActiveExpertsOr1, nBatches, 1));
|
||||||
|
|
||||||
|
NnSegmentConfigBuilder start;
|
||||||
|
if (nodeIndex == 0) {
|
||||||
|
start.addOp(
|
||||||
|
OP_EMBEDDING, "embedding", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, n.tokenPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, n.xPipeIndex),
|
||||||
|
n.tokenEmbeddingSize,
|
||||||
|
NnEmbeddingOpConfig{});
|
||||||
|
}
|
||||||
|
start.addSync(n.xPipeIndex, SYNC_WITH_ROOT);
|
||||||
|
nodeBuilder.addSegment(start.build());
|
||||||
|
|
||||||
|
for (NnUint layerIndex = 0; layerIndex < h->nLayers; layerIndex++) {
|
||||||
|
const NnUint kBufferIndex = nodeBuilder.addBuffer("k", kvCacheSlice.keySize);
|
||||||
|
const NnUint vBufferIndex = nodeBuilder.addBuffer("v", kvCacheSlice.valueSize);
|
||||||
|
|
||||||
|
NnSegmentConfigBuilder att;
|
||||||
|
NnSegmentConfigBuilder ff;
|
||||||
|
|
||||||
|
// att
|
||||||
|
if (layerIndex == 0) {
|
||||||
|
att.addOp(
|
||||||
|
OP_CAST, "block_cast_x", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_PIPE, n.xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
} else {
|
||||||
|
att.addOp(
|
||||||
|
OP_MERGE_ADD, "block_merge_add", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_PIPE, zqPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnMergeAddOpCodeConfig{});
|
||||||
|
}
|
||||||
|
|
||||||
|
att.addOp(
|
||||||
|
OP_INV_RMS, "block_norm_pre_0", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnInvRmsOpConfig{h->normEpsilon, 1});
|
||||||
|
att.addOp(
|
||||||
|
OP_RMS_NORM, "block_norm_0", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
n.rmsNormSize,
|
||||||
|
NnRmsNormOpConfig{invRmsBufferIndex, 1});
|
||||||
|
if (yBufferIndex != yqBufferIndex) {
|
||||||
|
att.addOp(
|
||||||
|
OP_CAST, "block_cast_y", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
}
|
||||||
|
att.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_q", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
|
||||||
|
size2D(h->weightType, n.qSlice.n, n.qSlice.d0),
|
||||||
|
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
|
||||||
|
att.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_k", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
|
||||||
|
size2D(h->weightType, n.kSlice.n, n.kSlice.d0),
|
||||||
|
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
|
||||||
|
att.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_v", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, vTempBufferIndex),
|
||||||
|
size2D(h->weightType, n.vSlice.n, n.vSlice.d0),
|
||||||
|
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
|
||||||
|
|
||||||
|
if (h->archType == QWEN3 || h->archType == QWEN3_MOE) {
|
||||||
|
att.addOp(OP_INV_RMS, "block_norm_pre_q", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnInvRmsOpConfig{h->normEpsilon, nQNormColumns});
|
||||||
|
att.addOp(
|
||||||
|
OP_RMS_NORM, "block_norm_q", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
|
||||||
|
size2D(F_32, 1, n.header->headDim),
|
||||||
|
NnRmsNormOpConfig{invRmsBufferIndex, nQNormColumns});
|
||||||
|
|
||||||
|
att.addOp(OP_INV_RMS, "block_norm_pre_k", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnInvRmsOpConfig{h->normEpsilon, nKNormColumns});
|
||||||
|
att.addOp(
|
||||||
|
OP_RMS_NORM, "block_norm_k", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
|
||||||
|
size2D(F_32, 1, n.header->headDim),
|
||||||
|
NnRmsNormOpConfig{invRmsBufferIndex, nKNormColumns});
|
||||||
|
}
|
||||||
|
|
||||||
|
att.addOp(
|
||||||
|
OP_ROPE, "block_rope_q", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnRopeOpConfig{n.header->ropeType, 1, n.positionPipeIndex, ropeCacheBufferIndex,
|
||||||
|
h->ropeScalingFactor, h->ropeScalingLowFreqFactor, h->ropeScalingHighFreqFactory, h->ropeScalingOrigMaxSeqLen,
|
||||||
|
ropeSlice});
|
||||||
|
att.addOp(
|
||||||
|
OP_ROPE, "block_rope_k", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnRopeOpConfig{n.header->ropeType, 0, n.positionPipeIndex, ropeCacheBufferIndex,
|
||||||
|
h->ropeScalingFactor, h->ropeScalingLowFreqFactor, h->ropeScalingHighFreqFactory, h->ropeScalingOrigMaxSeqLen,
|
||||||
|
ropeSlice});
|
||||||
|
att.addOp(
|
||||||
|
OP_SHIFT, "block_shift_k", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
|
||||||
|
pointerRawConfig(SRC_BUFFER, kBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnShiftOpCodeConfig{n.positionPipeIndex});
|
||||||
|
att.addOp(
|
||||||
|
OP_SHIFT, "block_shift_v", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, vTempBufferIndex),
|
||||||
|
pointerRawConfig(SRC_BUFFER, vBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnShiftOpCodeConfig{n.positionPipeIndex});
|
||||||
|
att.addOp(
|
||||||
|
OP_MULTIHEAD_ATT, "block_multihead_att", layerIndex,
|
||||||
|
pointerBatchedSliceConfig(SRC_BUFFER, zBufferIndex),
|
||||||
|
pointerBatchedSliceConfig(SRC_BUFFER, zBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnMultiHeadAttOpConfig{
|
||||||
|
multiHeadAttSlice.nHeads, multiHeadAttSlice.nHeads0,
|
||||||
|
h->nKvHeads, h->headDim, h->seqLen, n.qSlice.d0, kvCacheSlice.kvDim0,
|
||||||
|
n.positionPipeIndex, qBufferIndex, kBufferIndex, vBufferIndex, attBufferIndex});
|
||||||
|
att.addOp(
|
||||||
|
OP_CAST, "block_cast_y2", layerIndex,
|
||||||
|
pointerBatchedSliceConfig(SRC_BUFFER, zBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, zqSliceBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
att.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_wo", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, zqSliceBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
size2D(h->weightType, n.woSlice.n0, n.woSlice.d),
|
||||||
|
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
|
||||||
|
att.addOp(
|
||||||
|
OP_CAST, "block_cast_d", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
pointerBatchedSliceConfig(SRC_PIPE, zqPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
att.addSync(zqPipeIndex, SYNC_NODE_SLICES);
|
||||||
|
|
||||||
|
// ff
|
||||||
|
ff.addOp(
|
||||||
|
OP_MERGE_ADD, "block_merge_add2", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_PIPE, zqPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnMergeAddOpCodeConfig{});
|
||||||
|
ff.addOp(
|
||||||
|
OP_INV_RMS, "block_norm_pre_1", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnInvRmsOpConfig{h->normEpsilon, 1});
|
||||||
|
ff.addOp(
|
||||||
|
OP_RMS_NORM, "block_norm_1", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
n.rmsNormSize,
|
||||||
|
NnRmsNormOpConfig{invRmsBufferIndex, 1});
|
||||||
|
|
||||||
|
if (h->archType == QWEN3_MOE) {
|
||||||
|
ff.addOp(
|
||||||
|
OP_REPEAT_Z, "block_moe_y_repeat", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeYqBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnRepeatZOpCodeConfig{});
|
||||||
|
ff.addOp(
|
||||||
|
OP_MATMUL, "block_moe_gate", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeGtBufferIndex),
|
||||||
|
n.moeGateSize,
|
||||||
|
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
|
||||||
|
ff.addOp(
|
||||||
|
OP_SOFTMAX, "block_moe_softmax", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeGtBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeGtBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnSoftmaxOpCodeConfig{});
|
||||||
|
ff.addOp(
|
||||||
|
OP_MOE_GATE, "block_moe_gate2", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeGtBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeSBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnMoeGateOpCodeConfig{h->nActiveExperts, 1u, moeExpertIndexesBufferIndex});
|
||||||
|
ff.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_w1", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeYqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
|
||||||
|
size3D(h->weightType, h->nExperts, n.w1Slice.n, n.w1Slice.d0),
|
||||||
|
NnMatmulOpConfig{h->nExperts, h->nActiveExperts, moeExpertIndexesBufferIndex});
|
||||||
|
ff.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_w3", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeYqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeLBufferIndex),
|
||||||
|
size3D(h->weightType, h->nExperts, n.w3Slice.n, n.w3Slice.d0),
|
||||||
|
NnMatmulOpConfig{h->nExperts, h->nActiveExperts, moeExpertIndexesBufferIndex});
|
||||||
|
ff.addOp(
|
||||||
|
OP_SILU, "block_act", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnSiluOpCodeConfig{});
|
||||||
|
ff.addOp(
|
||||||
|
OP_MUL, "block_mul", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnMulOpCodeConfig{moeLBufferIndex});
|
||||||
|
if (moeDBufferIndex != moeDQBufferIndex) {
|
||||||
|
ff.addOp(
|
||||||
|
OP_CAST, "block_cast_d2", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeDBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeDQBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
}
|
||||||
|
ff.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_w2", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeDQBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeYBufferIndex),
|
||||||
|
size3D(h->weightType, h->nExperts, n.w2Slice.n0, n.w2Slice.d),
|
||||||
|
NnMatmulOpConfig{h->nExperts, h->nActiveExperts, moeExpertIndexesBufferIndex});
|
||||||
|
ff.addOp(
|
||||||
|
OP_SCALE, "block_moe_scale", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeYBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeYBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnScaleOpCodeConfig{moeSBufferIndex});
|
||||||
|
ff.addOp(
|
||||||
|
OP_MERGE_SUM, "block_moe_merge_sum", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, moeYBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnMergeSumOpCodeConfig{});
|
||||||
|
} else {
|
||||||
|
if (yBufferIndex != yqBufferIndex) {
|
||||||
|
ff.addOp(
|
||||||
|
OP_CAST, "block_cast_y3", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
}
|
||||||
|
ff.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_w1", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
|
||||||
|
size2D(h->weightType, n.w1Slice.n, n.w1Slice.d0),
|
||||||
|
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
|
||||||
|
ff.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_w3", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, lBufferIndex),
|
||||||
|
size2D(h->weightType, n.w3Slice.n, n.w3Slice.d0),
|
||||||
|
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
|
||||||
|
ff.addOp(
|
||||||
|
OP_SILU, "block_act", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnSiluOpCodeConfig{});
|
||||||
|
ff.addOp(
|
||||||
|
OP_MUL, "block_mul", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnMulOpCodeConfig{lBufferIndex});
|
||||||
|
if (dBufferIndex != dqBufferIndex) {
|
||||||
|
ff.addOp(
|
||||||
|
OP_CAST, "block_cast_d2", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, dBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, dqBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
}
|
||||||
|
ff.addOp(
|
||||||
|
OP_MATMUL, "block_matmul_w2", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, dqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
size2D(h->weightType, n.w2Slice.n0, n.w2Slice.d),
|
||||||
|
NnMatmulOpConfig{0, 0, moeExpertIndexesBufferIndex});
|
||||||
|
}
|
||||||
|
ff.addOp(
|
||||||
|
OP_CAST, "block_cast_d3", layerIndex,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
pointerBatchedSliceConfig(SRC_PIPE, zqPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
ff.addSync(zqPipeIndex, SYNC_NODE_SLICES);
|
||||||
|
|
||||||
|
nodeBuilder.addSegment(att.build());
|
||||||
|
nodeBuilder.addSegment(ff.build());
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSegmentConfigBuilder end;
|
||||||
|
end.addOp(
|
||||||
|
OP_MERGE_ADD, "final_merge_add", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, zqPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnMergeAddOpCodeConfig{});
|
||||||
|
end.addOp(
|
||||||
|
OP_INV_RMS, "final_norm_pre", 0,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnInvRmsOpConfig{h->normEpsilon, 1});
|
||||||
|
end.addOp(
|
||||||
|
OP_RMS_NORM, "final_norm", 0,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, xBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
n.rmsNormSize,
|
||||||
|
NnRmsNormOpConfig{invRmsBufferIndex, 1});
|
||||||
|
if (yBufferIndex != yqBufferIndex) {
|
||||||
|
end.addOp(
|
||||||
|
OP_CAST, "final_cast_y", 0,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
}
|
||||||
|
end.addOp(
|
||||||
|
OP_MATMUL, "final_matmul_logits", 0,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, yqBufferIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, logitsSliceBufferIndex),
|
||||||
|
size2D(h->weightType, n.wclsSlice.n, n.wclsSlice.d0),
|
||||||
|
NnMatmulOpConfig{});
|
||||||
|
end.addOp(
|
||||||
|
OP_CAST, "final_cast_logits", 0,
|
||||||
|
pointerBatchConfig(SRC_BUFFER, logitsSliceBufferIndex),
|
||||||
|
pointerBatchedSliceConfig(SRC_PIPE, n.logitsPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
end.addSync(n.logitsPipeIndex, SYNC_NODE_SLICES_EXCEPT_ROOT);
|
||||||
|
|
||||||
|
nodeBuilder.addSegment(end.build());
|
||||||
|
n.nodeConfigs[nodeIndex] = nodeBuilder.build();
|
||||||
|
}
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
|
void releaseLlmNet(LlmNet *net) {
|
||||||
|
for (NnUint nodeIndex = 0u; nodeIndex < net->netConfig.nNodes; nodeIndex++)
|
||||||
|
releaseNodeConfig(&net->nodeConfigs[nodeIndex]);
|
||||||
|
releaseNetConfig(&net->netConfig);
|
||||||
|
delete[] net->nodeConfigs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void loadLlmNetWeight(const char *path, LlmNet *net, NnRootWeightLoader *loader) {
|
||||||
|
MmapFile file;
|
||||||
|
openMmapFile(&file, path, net->header->fileSize);
|
||||||
|
#if DEBUG_USE_MMAP_FOR_WEIGHTS
|
||||||
|
assert(net->netConfig.nNodes == 1u);
|
||||||
|
#else
|
||||||
|
std::unique_ptr<MmapFile, void(*)(MmapFile *)> fdPtr(&file, closeMmapFile);
|
||||||
|
printf("💿 Loading weights...\n");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Timer timer;
|
||||||
|
NnByte *data = (NnByte *)file.data;
|
||||||
|
NnByte *b = &data[net->header->headerSize];
|
||||||
|
b += loader->loadRoot("embedding", 0, net->tokenEmbeddingSize.nBytes, b);
|
||||||
|
|
||||||
|
for (NnUint layerIndex = 0u; layerIndex < net->header->nLayers; layerIndex++) {
|
||||||
|
b += loader->loadRowMatmulSlices("block_matmul_q", layerIndex, 0u, &net->qSlice, b);
|
||||||
|
b += loader->loadRowMatmulSlices("block_matmul_k", layerIndex, 0u, &net->kSlice, b);
|
||||||
|
b += loader->loadRowMatmulSlices("block_matmul_v", layerIndex, 0u, &net->vSlice, b);
|
||||||
|
b += loader->loadColMatmulSlices("block_matmul_wo", layerIndex, 0u, &net->woSlice, b);
|
||||||
|
|
||||||
|
if (net->header->nExperts > 0u) {
|
||||||
|
b += loader->loadAll("block_moe_gate", layerIndex, net->moeGateSize.nBytes, b);
|
||||||
|
for (NnUint expertIndex = 0u; expertIndex < net->header->nExperts; expertIndex++) {
|
||||||
|
b += loader->loadRowMatmulSlices("block_matmul_w1", layerIndex, expertIndex, &net->w1Slice, b);
|
||||||
|
b += loader->loadColMatmulSlices("block_matmul_w2", layerIndex, expertIndex, &net->w2Slice, b);
|
||||||
|
b += loader->loadRowMatmulSlices("block_matmul_w3", layerIndex, expertIndex, &net->w3Slice, b);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
b += loader->loadRowMatmulSlices("block_matmul_w1", layerIndex, 0u, &net->w1Slice, b);
|
||||||
|
b += loader->loadColMatmulSlices("block_matmul_w2", layerIndex, 0u, &net->w2Slice, b);
|
||||||
|
b += loader->loadRowMatmulSlices("block_matmul_w3", layerIndex, 0u, &net->w3Slice, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (net->header->archType == QWEN3 || net->header->archType == QWEN3_MOE) {
|
||||||
|
b += loader->loadAll("block_norm_q", layerIndex, net->qkRmsNormSize.nBytes, b);
|
||||||
|
b += loader->loadAll("block_norm_k", layerIndex, net->qkRmsNormSize.nBytes, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
b += loader->loadAll("block_norm_0", layerIndex, net->rmsNormSize.nBytes, b);
|
||||||
|
b += loader->loadAll("block_norm_1", layerIndex, net->rmsNormSize.nBytes, b);
|
||||||
|
|
||||||
|
if (timer.elapsedMiliseconds() > 10000)
|
||||||
|
printf("💿 Loaded %u/%u\n", layerIndex + 1, net->header->nLayers);
|
||||||
|
}
|
||||||
|
|
||||||
|
b += loader->loadAll("final_norm", 0u, net->rmsNormSize.nBytes, b);
|
||||||
|
b += loader->loadRowMatmulSlices("final_matmul_logits", 0u, 0u, &net->wclsSlice, b);
|
||||||
|
|
||||||
|
long long missingBytes = (long long)(b - data) - net->header->fileSize;
|
||||||
|
if (missingBytes != 0u)
|
||||||
|
throw std::runtime_error("Missing bytes in weight file: " + std::to_string(missingBytes));
|
||||||
|
printf("💿 Weights loaded\n");
|
||||||
|
|
||||||
|
loader->finish();
|
||||||
|
}
|
||||||
104
src/llm.hpp
Normal file
104
src/llm.hpp
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
#ifndef LLM_HPP
|
||||||
|
#define LLM_HPP
|
||||||
|
|
||||||
|
#include "nn/nn-core.hpp"
|
||||||
|
#include "nn/nn-executor.hpp"
|
||||||
|
#include "nn/nn-network.hpp"
|
||||||
|
|
||||||
|
enum LlmHeaderKey {
|
||||||
|
VERSION = 0,
|
||||||
|
ARCH_TYPE = 1,
|
||||||
|
DIM = 2,
|
||||||
|
HIDDEN_DIM = 3,
|
||||||
|
N_LAYERS = 4,
|
||||||
|
N_HEADS = 5,
|
||||||
|
N_KV_HEADS = 6,
|
||||||
|
N_EXPERTS = 7,
|
||||||
|
N_ACTIVE_EXPERTS = 8,
|
||||||
|
VOCAB_SIZE = 9,
|
||||||
|
SEQ_LEN = 10,
|
||||||
|
HIDDEN_ACT = 11,
|
||||||
|
ROPE_THETA = 12,
|
||||||
|
WEIGHT_FLOAT_TYPE = 13,
|
||||||
|
ROPE_SCALING_FACTOR = 14,
|
||||||
|
ROPE_SCALING_LOW_FREQ_FACTOR = 15,
|
||||||
|
ROPE_SCALING_HIGH_FREQ_FACTORY = 16,
|
||||||
|
ROPE_SCALING_ORIG_MAX_SEQ_LEN = 17,
|
||||||
|
ROPE_TYPE = 18,
|
||||||
|
HEAD_DIM = 19,
|
||||||
|
NORM_EPSILON = 20,
|
||||||
|
MOE_HIDDEN_DIM = 21,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum LlmHiddenAct {
|
||||||
|
HIDDEN_ACT_GELU,
|
||||||
|
HIDDEN_ACT_SILU,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum LlmArchType {
|
||||||
|
LLAMA = 0xABCD00,
|
||||||
|
QWEN3 = 0xABCD01,
|
||||||
|
QWEN3_MOE = 0xABCD02,
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnSize headerSize;
|
||||||
|
NnSize fileSize;
|
||||||
|
int version;
|
||||||
|
LlmArchType archType;
|
||||||
|
NnUint dim;
|
||||||
|
NnUint nLayers;
|
||||||
|
NnUint nHeads;
|
||||||
|
NnUint headDim;
|
||||||
|
NnUint nKvHeads;
|
||||||
|
NnUint nExperts;
|
||||||
|
NnUint nActiveExperts;
|
||||||
|
NnUint origSeqLen; // Original model context length
|
||||||
|
NnUint seqLen; // Limited context length by the `--max-seq-len` argument
|
||||||
|
NnUint hiddenDim;
|
||||||
|
NnUint moeHiddenDim;
|
||||||
|
LlmHiddenAct hiddenAct;
|
||||||
|
NnUint qDim;
|
||||||
|
NnUint kvDim;
|
||||||
|
NnUint vocabSize;
|
||||||
|
float ropeTheta;
|
||||||
|
NnRopeType ropeType;
|
||||||
|
float ropeScalingFactor;
|
||||||
|
float ropeScalingLowFreqFactor;
|
||||||
|
float ropeScalingHighFreqFactory;
|
||||||
|
NnUint ropeScalingOrigMaxSeqLen;
|
||||||
|
float normEpsilon;
|
||||||
|
|
||||||
|
NnFloatType weightType;
|
||||||
|
NnFloatType syncType;
|
||||||
|
} LlmHeader;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
LlmHeader *header;
|
||||||
|
NnNetConfig netConfig;
|
||||||
|
NnNodeConfig *nodeConfigs;
|
||||||
|
NnRowMatmulSlice qSlice;
|
||||||
|
NnRowMatmulSlice kSlice;
|
||||||
|
NnRowMatmulSlice vSlice;
|
||||||
|
NnColMatmulSlice woSlice;
|
||||||
|
NnRowMatmulSlice w1Slice;
|
||||||
|
NnColMatmulSlice w2Slice;
|
||||||
|
NnRowMatmulSlice w3Slice;
|
||||||
|
NnRowMatmulSlice wclsSlice;
|
||||||
|
NnUint positionPipeIndex;
|
||||||
|
NnUint tokenPipeIndex;
|
||||||
|
NnUint xPipeIndex;
|
||||||
|
NnUint logitsPipeIndex;
|
||||||
|
NnSize3D tokenEmbeddingSize;
|
||||||
|
NnSize3D rmsNormSize;
|
||||||
|
NnSize3D qkRmsNormSize;
|
||||||
|
NnSize3D moeGateSize;
|
||||||
|
} LlmNet;
|
||||||
|
|
||||||
|
LlmHeader loadLlmHeader(const char* path, const unsigned int maxSeqLen, NnFloatType syncType);
|
||||||
|
void printLlmHeader(LlmHeader *header);
|
||||||
|
LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches);
|
||||||
|
void releaseLlmNet(LlmNet *net);
|
||||||
|
void loadLlmNetWeight(const char* path, LlmNet *net, NnRootWeightLoader *loader);
|
||||||
|
|
||||||
|
#endif
|
||||||
83
src/mmap.hpp
Normal file
83
src/mmap.hpp
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
#ifndef MMAP_HPP
|
||||||
|
#define MMAP_HPP
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <stdexcept>
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <windows.h>
|
||||||
|
#else
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
struct MmapFile {
|
||||||
|
void* data;
|
||||||
|
size_t size;
|
||||||
|
#ifdef _WIN32
|
||||||
|
HANDLE hFile;
|
||||||
|
HANDLE hMapping;
|
||||||
|
#else
|
||||||
|
int fd;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
long seekToEnd(FILE* file) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
_fseeki64(file, 0, SEEK_END);
|
||||||
|
return _ftelli64(file);
|
||||||
|
#else
|
||||||
|
fseek(file, 0, SEEK_END);
|
||||||
|
return ftell(file);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void openMmapFile(MmapFile *file, const char *path, size_t size) {
|
||||||
|
file->size = size;
|
||||||
|
#ifdef _WIN32
|
||||||
|
file->hFile = CreateFileA(path, GENERIC_READ, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, NULL);
|
||||||
|
if (file->hFile == INVALID_HANDLE_VALUE) {
|
||||||
|
printf("Cannot open file %s\n", path);
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
file->hMapping = CreateFileMappingA(file->hFile, NULL, PAGE_READONLY, 0, 0, NULL);
|
||||||
|
if (file->hMapping == NULL) {
|
||||||
|
printf("CreateFileMappingA failed, error: %lu\n", GetLastError());
|
||||||
|
CloseHandle(file->hFile);
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
|
||||||
|
file->data = (void *)MapViewOfFile(file->hMapping, FILE_MAP_READ, 0, 0, 0);
|
||||||
|
if (file->data == NULL) {
|
||||||
|
printf("MapViewOfFile failed!\n");
|
||||||
|
CloseHandle(file->hMapping);
|
||||||
|
CloseHandle(file->hFile);
|
||||||
|
exit(EXIT_FAILURE);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
file->fd = open(path, O_RDONLY);
|
||||||
|
if (file->fd == -1) {
|
||||||
|
throw std::runtime_error("Cannot open file");
|
||||||
|
}
|
||||||
|
|
||||||
|
file->data = mmap(NULL, size, PROT_READ, MAP_PRIVATE, file->fd, 0);
|
||||||
|
if (file->data == MAP_FAILED) {
|
||||||
|
close(file->fd);
|
||||||
|
throw std::runtime_error("Mmap failed");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void closeMmapFile(MmapFile *file) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
UnmapViewOfFile(file->data);
|
||||||
|
CloseHandle(file->hMapping);
|
||||||
|
CloseHandle(file->hFile);
|
||||||
|
#else
|
||||||
|
munmap(file->data, file->size);
|
||||||
|
close(file->fd);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
1010
src/nn/llamafile/sgemm.cpp
Normal file
1010
src/nn/llamafile/sgemm.cpp
Normal file
File diff suppressed because it is too large
Load Diff
9
src/nn/llamafile/sgemm.hpp
Normal file
9
src/nn/llamafile/sgemm.hpp
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
#ifndef LLAMAFILE_SGEMM_H
|
||||||
|
#define LLAMAFILE_SGEMM_H
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
||||||
|
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype);
|
||||||
|
|
||||||
|
#endif
|
||||||
137
src/nn/nn-config-builder.hpp
Normal file
137
src/nn/nn-config-builder.hpp
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
#ifndef NN_CONFIG_BUILDER_H
|
||||||
|
#define NN_CONFIG_BUILDER_H
|
||||||
|
|
||||||
|
#include "nn-core.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
|
||||||
|
static char *cloneString(const char *str) {
|
||||||
|
NnUint len = std::strlen(str);
|
||||||
|
char *copy = new char[len + 1];
|
||||||
|
std::memcpy(copy, str, len + 1);
|
||||||
|
return copy;
|
||||||
|
}
|
||||||
|
|
||||||
|
class NnNetConfigBuilder {
|
||||||
|
public:
|
||||||
|
NnUint nNodes;
|
||||||
|
NnUint nBatches;
|
||||||
|
std::list<NnPipeConfig> pipes;
|
||||||
|
std::list<NnPreSyncConfig> preSyncs;
|
||||||
|
|
||||||
|
NnNetConfigBuilder(NnUint nNodes, NnUint nBatches) {
|
||||||
|
this->nNodes = nNodes;
|
||||||
|
this->nBatches = nBatches;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint addPipe(const char *name, NnSize3D size) {
|
||||||
|
NnUint pipeIndex = pipes.size();
|
||||||
|
pipes.push_back({ cloneString(name), size });
|
||||||
|
return pipeIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
void addPreSync(NnUint pipeIndex) {
|
||||||
|
preSyncs.push_back({ pipeIndex });
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNetConfig build() {
|
||||||
|
NnNetConfig config;
|
||||||
|
config.nNodes = nNodes;
|
||||||
|
config.nBatches = nBatches;
|
||||||
|
config.nPipes = pipes.size();
|
||||||
|
config.pipes = new NnPipeConfig[config.nPipes];
|
||||||
|
std::copy(pipes.begin(), pipes.end(), config.pipes);
|
||||||
|
config.nPreSyncs = preSyncs.size();
|
||||||
|
if (config.nPreSyncs > 0) {
|
||||||
|
config.preSyncs = new NnPreSyncConfig[config.nPreSyncs];
|
||||||
|
std::copy(preSyncs.begin(), preSyncs.end(), config.preSyncs);
|
||||||
|
} else {
|
||||||
|
config.preSyncs = nullptr;
|
||||||
|
}
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnNodeConfigBuilder {
|
||||||
|
public:
|
||||||
|
NnUint nodeIndex;
|
||||||
|
std::list<NnBufferConfig> buffers;
|
||||||
|
std::list<NnSegmentConfig> segments;
|
||||||
|
|
||||||
|
NnNodeConfigBuilder(NnUint nodeIndex) {
|
||||||
|
this->nodeIndex = nodeIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint addBuffer(const char *name, NnSize3D size) {
|
||||||
|
NnUint bufferIndex = buffers.size();
|
||||||
|
buffers.push_back({ cloneString(name), size });
|
||||||
|
return bufferIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
void addSegment(NnSegmentConfig segment) {
|
||||||
|
segments.push_back(segment);
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNodeConfig build() {
|
||||||
|
NnNodeConfig config;
|
||||||
|
config.nodeIndex = nodeIndex;
|
||||||
|
config.nBuffers = buffers.size();
|
||||||
|
if (config.nBuffers > 0) {
|
||||||
|
config.buffers = new NnBufferConfig[config.nBuffers];
|
||||||
|
std::copy(buffers.begin(), buffers.end(), config.buffers);
|
||||||
|
} else {
|
||||||
|
config.buffers = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
config.nSegments = segments.size();
|
||||||
|
assert(config.nSegments > 0);
|
||||||
|
config.segments = new NnSegmentConfig[config.nSegments];
|
||||||
|
std::copy(segments.begin(), segments.end(), config.segments);
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnSegmentConfigBuilder {
|
||||||
|
private:
|
||||||
|
std::list<NnOpConfig> ops;
|
||||||
|
std::list<NnSyncConfig> syncs;
|
||||||
|
|
||||||
|
public:
|
||||||
|
template <typename T>
|
||||||
|
void addOp(NnOpCode code, const char *name, NnUint index, NnPointerConfig input, NnPointerConfig output, NnSize3D weightSize, T config) {
|
||||||
|
NnUint configSize = sizeof(T);
|
||||||
|
NnByte *configCopy = new NnByte[configSize];
|
||||||
|
std::memcpy(configCopy, &config, configSize);
|
||||||
|
ops.push_back({
|
||||||
|
code,
|
||||||
|
cloneString(name),
|
||||||
|
index,
|
||||||
|
input,
|
||||||
|
output,
|
||||||
|
weightSize,
|
||||||
|
configCopy,
|
||||||
|
configSize
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
void addSync(NnUint pipeIndex, NnSyncType syncType) {
|
||||||
|
syncs.push_back({ pipeIndex, syncType });
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSegmentConfig build() {
|
||||||
|
NnSegmentConfig segment;
|
||||||
|
segment.nOps = ops.size();
|
||||||
|
if (segment.nOps > 0) {
|
||||||
|
segment.ops = new NnOpConfig[segment.nOps];
|
||||||
|
std::copy(ops.begin(), ops.end(), segment.ops);
|
||||||
|
}
|
||||||
|
segment.nSyncs = syncs.size();
|
||||||
|
if (segment.nSyncs > 0) {
|
||||||
|
segment.syncs = new NnSyncConfig[segment.nSyncs];
|
||||||
|
std::copy(syncs.begin(), syncs.end(), segment.syncs);
|
||||||
|
}
|
||||||
|
return segment;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
381
src/nn/nn-core.cpp
Normal file
381
src/nn/nn-core.cpp
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
#ifdef _WIN32
|
||||||
|
#define _USE_MATH_DEFINES
|
||||||
|
#endif
|
||||||
|
#include "nn-core.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <cmath>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
// utility functions
|
||||||
|
|
||||||
|
NnSize getBytes(NnFloatType floatType, NnSize n) {
|
||||||
|
if (floatType == F_32)
|
||||||
|
return n * sizeof(float);
|
||||||
|
if (floatType == F_16)
|
||||||
|
return n * (sizeof(float) / 2);
|
||||||
|
if (floatType == F_Q40) {
|
||||||
|
assert(n % Q40_BLOCK_SIZE == 0);
|
||||||
|
return (n / Q40_BLOCK_SIZE) * sizeof(NnBlockQ40);
|
||||||
|
}
|
||||||
|
if (floatType == F_Q80) {
|
||||||
|
assert(n % Q80_BLOCK_SIZE == 0);
|
||||||
|
return (n / Q80_BLOCK_SIZE) * sizeof(NnBlockQ80);
|
||||||
|
}
|
||||||
|
throw std::invalid_argument("Unsupported float type: " + std::to_string(floatType));
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize getBlockSize(NnFloatType floatType) {
|
||||||
|
if (floatType == F_32)
|
||||||
|
return 1;
|
||||||
|
if (floatType == F_16)
|
||||||
|
return 1;
|
||||||
|
if (floatType == F_Q40)
|
||||||
|
return Q40_BLOCK_SIZE;
|
||||||
|
if (floatType == F_Q80)
|
||||||
|
return Q80_BLOCK_SIZE;
|
||||||
|
throw std::invalid_argument("Unsupported float type");
|
||||||
|
}
|
||||||
|
|
||||||
|
NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output) {
|
||||||
|
// If weight=F_UNK, then returned enum should be <input>_<input>_<output>
|
||||||
|
|
||||||
|
if (input == F_32 && output == F_32) {
|
||||||
|
if (weight == F_UNK || weight == F_32)
|
||||||
|
return F32_F32_F32;
|
||||||
|
if (weight == F_Q40)
|
||||||
|
return F32_Q40_F32;
|
||||||
|
}
|
||||||
|
if (input == F_32 && output == F_Q80) {
|
||||||
|
if (weight == F_UNK || weight == F_32)
|
||||||
|
return F32_F32_Q80;
|
||||||
|
if (weight == F_Q40)
|
||||||
|
return F32_Q40_Q80;
|
||||||
|
}
|
||||||
|
if (input == F_Q80 && output == F_32) {
|
||||||
|
if (weight == F_UNK || weight == F_Q80)
|
||||||
|
return Q80_Q80_F32;
|
||||||
|
if (weight == F_32)
|
||||||
|
return Q80_F32_F32;
|
||||||
|
if (weight == F_Q40)
|
||||||
|
return Q80_Q40_F32;
|
||||||
|
}
|
||||||
|
if (input == F_Q80 && output == F_Q80) {
|
||||||
|
if (weight == F_UNK || weight == F_Q80)
|
||||||
|
return Q80_Q80_Q80;
|
||||||
|
}
|
||||||
|
throw std::invalid_argument("Unsupported op quant: " +
|
||||||
|
std::string(floatTypeToString(input)) + "/" +
|
||||||
|
std::string(floatTypeToString(weight)) + "/" +
|
||||||
|
std::string(floatTypeToString(output)));
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *opCodeToString(NnOpCode code) {
|
||||||
|
if (code == OP_MERGE_ADD) return "MERGE_ADD";
|
||||||
|
if (code == OP_MERGE_SUM) return "MERGE_SUM";
|
||||||
|
if (code == OP_EMBEDDING) return "EMBEDDING";
|
||||||
|
if (code == OP_INV_RMS) return "INV_RMS";
|
||||||
|
if (code == OP_RMS_NORM) return "RMS_NORM";
|
||||||
|
if (code == OP_MATMUL) return "MATMUL";
|
||||||
|
if (code == OP_ROPE) return "ROPE";
|
||||||
|
if (code == OP_MULTIHEAD_ATT) return "MULTIHEAD_ATT";
|
||||||
|
if (code == OP_GELU) return "GELU";
|
||||||
|
if (code == OP_SILU) return "SILU";
|
||||||
|
if (code == OP_MUL) return "MUL";
|
||||||
|
if (code == OP_SCALE) return "SCALE";
|
||||||
|
if (code == OP_CAST) return "CAST";
|
||||||
|
if (code == OP_REPEAT_Z) return "REPEAT_Z";
|
||||||
|
if (code == OP_SHIFT) return "SHIFT";
|
||||||
|
if (code == OP_SOFTMAX) return "SOFTMAX";
|
||||||
|
if (code == OP_MOE_GATE) return "MOE_GATE";
|
||||||
|
throw std::invalid_argument("Unknown op code: " + std::to_string(code));
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *opQuantTypeToString(NnOpQuantType type) {
|
||||||
|
if (type == F32_F32_F32) return "F32_F32_F32";
|
||||||
|
if (type == F32_Q40_F32) return "F32_Q40_F32";
|
||||||
|
if (type == F32_Q40_Q80) return "F32_Q40_Q80";
|
||||||
|
if (type == F32_F32_Q80) return "F32_F32_Q80";
|
||||||
|
if (type == Q80_Q80_Q80) return "Q80_Q80_Q80";
|
||||||
|
if (type == Q80_Q80_F32) return "Q80_Q80_F32";
|
||||||
|
if (type == Q80_Q40_F32) return "Q80_Q40_F32";
|
||||||
|
if (type == Q80_F32_F32) return "Q80_F32_F32";
|
||||||
|
throw std::invalid_argument("Unknown op quant type");
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize3D size0() {
|
||||||
|
return { F_UNK, 0, 0, 0, 0, 0 };
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize3D size1D(NnFloatType floatType, NnUint x) {
|
||||||
|
return size3D(floatType, 1, 1, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize3D size2D(NnFloatType floatType, NnUint y, NnUint x) {
|
||||||
|
return size3D(floatType, 1, y, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize3D size3D(NnFloatType floatType, NnUint z, NnUint y, NnUint x) {
|
||||||
|
NnSize len = z * y * x;
|
||||||
|
NnSize lenXY = y * x;
|
||||||
|
return { floatType, z, y, x, len, getBytes(floatType, len), getBytes(floatType, lenXY) };
|
||||||
|
}
|
||||||
|
|
||||||
|
NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index) {
|
||||||
|
return { source, index, PNTR_BATCH };
|
||||||
|
}
|
||||||
|
|
||||||
|
NnPointerConfig pointerBatchedSliceConfig(NnPointerSource source, NnUint index) {
|
||||||
|
return { source, index, PNTR_BATCHED_SLICE };
|
||||||
|
}
|
||||||
|
|
||||||
|
NnPointerConfig pointerRawConfig(NnPointerSource source, NnUint index) {
|
||||||
|
return { source, index, PNTR_RAW };
|
||||||
|
}
|
||||||
|
|
||||||
|
bool hasPointerContinuousMemory(NnPointerConfig *config) {
|
||||||
|
if (config->type == PNTR_RAW)
|
||||||
|
return true;
|
||||||
|
if (config->type == PNTR_BATCH)
|
||||||
|
return true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void releaseNetConfig(NnNetConfig *netConfig) {
|
||||||
|
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) {
|
||||||
|
delete[] netConfig->pipes[pipeIndex].name;
|
||||||
|
}
|
||||||
|
delete[] netConfig->pipes;
|
||||||
|
}
|
||||||
|
|
||||||
|
void releaseNodeConfig(NnNodeConfig *nodeConfig) {
|
||||||
|
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
|
||||||
|
NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex];
|
||||||
|
if (segment->nOps > 0) {
|
||||||
|
for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) {
|
||||||
|
NnOpConfig *op = &segment->ops[opIndex];
|
||||||
|
delete[] op->name;
|
||||||
|
delete[] op->config;
|
||||||
|
}
|
||||||
|
delete[] segment->ops;
|
||||||
|
}
|
||||||
|
if (segment->nSyncs > 0)
|
||||||
|
delete[] segment->syncs;
|
||||||
|
}
|
||||||
|
if (nodeConfig->nBuffers > 0) {
|
||||||
|
for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++)
|
||||||
|
delete[] nodeConfig->buffers[bufferIndex].name;
|
||||||
|
delete[] nodeConfig->buffers;
|
||||||
|
}
|
||||||
|
delete[] nodeConfig->segments;
|
||||||
|
}
|
||||||
|
|
||||||
|
void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
|
||||||
|
unsigned long total = 0;
|
||||||
|
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++)
|
||||||
|
total += netConfig->pipes[pipeIndex].size.nBytes;
|
||||||
|
for (NnUint bufferIndex = 0; bufferIndex < nodeConfig->nBuffers; bufferIndex++)
|
||||||
|
total += nodeConfig->buffers[bufferIndex].size.nBytes;
|
||||||
|
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
|
||||||
|
NnSegmentConfig *segment = &nodeConfig->segments[segmentIndex];
|
||||||
|
for (NnUint opIndex = 0; opIndex < segment->nOps; opIndex++) {
|
||||||
|
total += segment->ops[opIndex].weightSize.nBytes;
|
||||||
|
total += segment->ops[opIndex].configSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printf("📀 RequiredMemory: %lu MB\n", total / (1024 * 1024));
|
||||||
|
}
|
||||||
|
|
||||||
|
Timer::Timer() {
|
||||||
|
reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Timer::reset() {
|
||||||
|
startTime = std::chrono::high_resolution_clock::now();
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint Timer::elapsedMiliseconds() {
|
||||||
|
auto endTime = std::chrono::high_resolution_clock::now();
|
||||||
|
return (NnUint)std::chrono::duration_cast<std::chrono::milliseconds>(endTime - startTime).count();
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint Timer::elapsedMicroseconds() {
|
||||||
|
auto endTime = std::chrono::high_resolution_clock::now();
|
||||||
|
return (NnUint)std::chrono::duration_cast<std::chrono::microseconds>(endTime - startTime).count();
|
||||||
|
}
|
||||||
|
|
||||||
|
// slicers
|
||||||
|
|
||||||
|
NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes) {
|
||||||
|
NnKvCacheSlice s;
|
||||||
|
assert(kvDim % nNodes == 0);
|
||||||
|
s.kvDim0 = kvDim / nNodes;
|
||||||
|
s.keySize = size2D(F_32, seqLen, s.kvDim0);
|
||||||
|
s.valueSize = size2D(F_32, seqLen, s.kvDim0);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) {
|
||||||
|
NnRowMatmulSlice s;
|
||||||
|
assert(d % nNodes == 0);
|
||||||
|
s.type = type;
|
||||||
|
s.nNodes = nNodes;
|
||||||
|
s.d0 = d / nNodes;
|
||||||
|
s.n = n;
|
||||||
|
s.size = size2D(type, s.n, d);
|
||||||
|
s.sliceSize = size2D(type, s.n, s.d0);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d) {
|
||||||
|
NnColMatmulSlice s;
|
||||||
|
assert(n % nNodes == 0);
|
||||||
|
s.type = type;
|
||||||
|
s.nNodes = nNodes;
|
||||||
|
s.n = n;
|
||||||
|
s.n0 = n / nNodes;
|
||||||
|
s.d = d;
|
||||||
|
s.size = size2D(type, n, d);
|
||||||
|
s.sliceSize = size2D(type, s.n0, d);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnRopeSlice sliceRope(NnRopeType type, NnUint qDim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headDim, float ropeTheta, NnUint nodeIndex) {
|
||||||
|
NnRopeSlice s;
|
||||||
|
assert(qDim >= kvDim);
|
||||||
|
assert(qDim % nNodes == 0);
|
||||||
|
assert(kvDim % nNodes == 0);
|
||||||
|
|
||||||
|
s.kvDim = kvDim;
|
||||||
|
s.nKvHeads = nKvHeads;
|
||||||
|
s.seqLen = seqLen;
|
||||||
|
s.headDim = headDim;
|
||||||
|
s.ropeTheta = ropeTheta;
|
||||||
|
|
||||||
|
s.qDim0 = qDim / nNodes;
|
||||||
|
s.kvDim0 = kvDim / nNodes;
|
||||||
|
assert(s.qDim0 % 2 == 0);
|
||||||
|
assert(s.kvDim0 % 2 == 0);
|
||||||
|
|
||||||
|
if (type == ROPE_LLAMA || type == ROPE_LLAMA3_1) {
|
||||||
|
s.kvDimStart = s.kvDim0 * nodeIndex;
|
||||||
|
s.qDimStart = s.qDim0 * nodeIndex;
|
||||||
|
s.qDimEnd = s.qDimStart + s.qDim0;
|
||||||
|
s.qShift = s.qDimStart - s.kvDimStart;
|
||||||
|
s.sliceDim = s.qDimEnd - s.kvDimStart;
|
||||||
|
assert(s.sliceDim % 2 == 0);
|
||||||
|
s.cacheSize = size2D(F_32, seqLen, s.sliceDim);
|
||||||
|
} else if (type == ROPE_FALCON) {
|
||||||
|
s.cacheSize = size2D(F_32, seqLen, headDim);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("Unsupported rope type");
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes, NnUint nBatches) {
|
||||||
|
NnMultiHeadAttSlice s;
|
||||||
|
assert(nHeads % nNodes == 0);
|
||||||
|
s.nHeads = nHeads;
|
||||||
|
s.nHeads0 = nHeads / nNodes;
|
||||||
|
s.attSize = size2D(F_32, nBatches, s.nHeads0 * seqLen);
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitters
|
||||||
|
|
||||||
|
NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) {
|
||||||
|
NnSize blockSize = getBlockSize(slice->type);
|
||||||
|
NnSize batchBytes = getBytes(slice->type, blockSize);
|
||||||
|
assert(slice->n % blockSize == 0);
|
||||||
|
|
||||||
|
NnSize n = slice->n / blockSize;
|
||||||
|
NnSize offset = slice->d0 * nodeIndex * n * batchBytes;
|
||||||
|
NnSize copiedBytes = 0;
|
||||||
|
for (NnUint d = 0; d < slice->d0; d++) {
|
||||||
|
for (NnUint j = 0; j < n; j++) {
|
||||||
|
NnSize o = (d * n + j) * batchBytes;
|
||||||
|
std::memcpy(weight0 + o, weight + offset + o, batchBytes);
|
||||||
|
copiedBytes += batchBytes;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return copiedBytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0) {
|
||||||
|
NnSize blockSize = getBlockSize(slice->type);
|
||||||
|
NnSize batchBytes = getBytes(slice->type, blockSize);
|
||||||
|
assert(slice->n0 % blockSize == 0);
|
||||||
|
|
||||||
|
NnSize n = slice->n / blockSize;
|
||||||
|
NnSize rowBytes = n * batchBytes;
|
||||||
|
NnSize row0Bytes = (slice->n0 / blockSize) * batchBytes;
|
||||||
|
NnSize rowOffsetBytes = nodeIndex * row0Bytes;
|
||||||
|
NnSize copiedBytes = 0;
|
||||||
|
for (NnUint d = 0; d < slice->d; d++) {
|
||||||
|
std::memcpy(&weight0[row0Bytes * d], &weight[rowBytes * d + rowOffsetBytes], row0Bytes);
|
||||||
|
copiedBytes += row0Bytes;
|
||||||
|
}
|
||||||
|
return copiedBytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// helper
|
||||||
|
|
||||||
|
static inline float scaleFrequencyLlama3(const float freq, const NnRopeOpConfig *config) {
|
||||||
|
// https://github.com/meta-llama/llama-models/blob/4269717b2ea587627903bacbb75ccce1427ad914/models/llama3/reference_impl/model.py#L55
|
||||||
|
const float waveLen = 2.0f * M_PI / freq;
|
||||||
|
const float highFreqWavelen = config->ropeScalingOrigMaxSeqLen / config->ropeScalingHighFreqFactor;
|
||||||
|
if (waveLen < highFreqWavelen) {
|
||||||
|
return freq;
|
||||||
|
}
|
||||||
|
const float lowFreqWavelen = config->ropeScalingOrigMaxSeqLen / config->ropeScalingLowFreqFactor;
|
||||||
|
if (waveLen > lowFreqWavelen) {
|
||||||
|
return freq / config->ropeScalingFactor;
|
||||||
|
}
|
||||||
|
const float smooth = (config->ropeScalingOrigMaxSeqLen / waveLen - config->ropeScalingLowFreqFactor) /
|
||||||
|
(config->ropeScalingHighFreqFactor - config->ropeScalingLowFreqFactor);
|
||||||
|
return (1 - smooth) * freq / config->ropeScalingFactor + smooth * freq;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void fullfillRopeLlamaCache(const NnRopeOpConfig *config, float *cache) {
|
||||||
|
assert((config->slice.qDimEnd - config->slice.kvDimStart) % 2 == 0);
|
||||||
|
|
||||||
|
const bool applyScaling = config->ropeScalingFactor != 1.0f;
|
||||||
|
for (NnUint pos = 0; pos < config->slice.seqLen; pos++) {
|
||||||
|
for (NnUint i = config->slice.kvDimStart; i < config->slice.qDimEnd; i += 2) {
|
||||||
|
const NnUint h = i % config->slice.headDim;
|
||||||
|
float freq = 1.0f / powf(config->slice.ropeTheta, h / (float)config->slice.headDim);
|
||||||
|
if (applyScaling)
|
||||||
|
freq = scaleFrequencyLlama3(freq, config);
|
||||||
|
const float val = pos * freq;
|
||||||
|
const float fcr = cosf(val);
|
||||||
|
const float fci = sinf(val);
|
||||||
|
cache[pos * config->slice.sliceDim + (i - config->slice.kvDimStart)] = fcr;
|
||||||
|
cache[pos * config->slice.sliceDim + (i - config->slice.kvDimStart) + 1] = fci;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void fullfillRopeFalconCache(const NnRopeOpConfig *config, float *cache) {
|
||||||
|
const float hs = (float)config->slice.headDim;
|
||||||
|
|
||||||
|
for (NnUint pos = 0; pos < config->slice.seqLen; pos++) {
|
||||||
|
for (NnUint j = 0; j < config->slice.headDim / 2; j++) {
|
||||||
|
const float freq = 1.0f / powf(config->slice.ropeTheta, 2.0f * (float)(j / hs));
|
||||||
|
const float val = pos * freq;
|
||||||
|
const float fcr = cosf(val);
|
||||||
|
const float fci = sinf(val);
|
||||||
|
cache[pos * config->slice.headDim + j] = fcr;
|
||||||
|
cache[pos * config->slice.headDim + j + config->slice.headDim / 2] = fci;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fullfillRopeCache(const NnRopeOpConfig *config, float *cache) {
|
||||||
|
if (config->type == ROPE_LLAMA || config->type == ROPE_LLAMA3_1)
|
||||||
|
fullfillRopeLlamaCache(config, cache);
|
||||||
|
else if (config->type == ROPE_FALCON)
|
||||||
|
fullfillRopeFalconCache(config, cache);
|
||||||
|
else
|
||||||
|
throw std::invalid_argument("Unsupported rope type");
|
||||||
|
}
|
||||||
333
src/nn/nn-core.hpp
Normal file
333
src/nn/nn-core.hpp
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
#ifndef NN_CORE_H
|
||||||
|
#define NN_CORE_H
|
||||||
|
|
||||||
|
#include <chrono>
|
||||||
|
#include <list>
|
||||||
|
#include <memory>
|
||||||
|
#include <cstdint>
|
||||||
|
#include "nn-quants.hpp"
|
||||||
|
|
||||||
|
// primitives
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnFloatType floatType;
|
||||||
|
NnUint z;
|
||||||
|
NnUint y;
|
||||||
|
NnUint x;
|
||||||
|
NnSize length;
|
||||||
|
NnSize nBytes;
|
||||||
|
NnSize nBytesXY;
|
||||||
|
} NnSize3D;
|
||||||
|
|
||||||
|
// slices
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint kvDim0;
|
||||||
|
NnSize3D keySize;
|
||||||
|
NnSize3D valueSize;
|
||||||
|
} NnKvCacheSlice;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnFloatType type;
|
||||||
|
NnUint nNodes;
|
||||||
|
NnUint d0;
|
||||||
|
NnUint n;
|
||||||
|
NnSize3D size;
|
||||||
|
NnSize3D sliceSize;
|
||||||
|
} NnRowMatmulSlice;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnFloatType type;
|
||||||
|
NnUint nNodes;
|
||||||
|
NnUint n;
|
||||||
|
NnUint n0;
|
||||||
|
NnUint d;
|
||||||
|
NnSize3D size;
|
||||||
|
NnSize3D sliceSize;
|
||||||
|
} NnColMatmulSlice;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint qDim0;
|
||||||
|
NnUint qDimStart;
|
||||||
|
NnUint qDimEnd;
|
||||||
|
NnUint qShift;
|
||||||
|
NnUint kvDim;
|
||||||
|
NnUint kvDim0;
|
||||||
|
NnUint kvDimStart;
|
||||||
|
NnUint sliceDim;
|
||||||
|
NnUint seqLen;
|
||||||
|
NnUint headDim;
|
||||||
|
NnUint nKvHeads;
|
||||||
|
float ropeTheta;
|
||||||
|
NnSize3D cacheSize;
|
||||||
|
} NnRopeSlice;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint nHeads;
|
||||||
|
NnUint nHeads0;
|
||||||
|
NnSize3D attSize;
|
||||||
|
} NnMultiHeadAttSlice;
|
||||||
|
|
||||||
|
// base enums
|
||||||
|
|
||||||
|
enum NnOpCode {
|
||||||
|
OP_MERGE_ADD,
|
||||||
|
OP_MERGE_SUM,
|
||||||
|
OP_EMBEDDING,
|
||||||
|
OP_INV_RMS,
|
||||||
|
OP_RMS_NORM,
|
||||||
|
OP_MATMUL,
|
||||||
|
OP_ROPE,
|
||||||
|
OP_MULTIHEAD_ATT,
|
||||||
|
OP_GELU,
|
||||||
|
OP_SILU,
|
||||||
|
OP_MUL,
|
||||||
|
OP_SCALE,
|
||||||
|
OP_CAST,
|
||||||
|
OP_REPEAT_Z,
|
||||||
|
OP_SHIFT,
|
||||||
|
OP_SOFTMAX,
|
||||||
|
OP_MOE_GATE,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum NnOpQuantType {
|
||||||
|
// <input>_<weight>_<output>
|
||||||
|
F32_F32_F32,
|
||||||
|
F32_Q40_F32,
|
||||||
|
F32_Q40_Q80,
|
||||||
|
F32_F32_Q80,
|
||||||
|
Q80_Q80_Q80,
|
||||||
|
Q80_Q80_F32,
|
||||||
|
Q80_Q40_F32,
|
||||||
|
Q80_F32_F32,
|
||||||
|
};
|
||||||
|
|
||||||
|
#define N_OP_CODES (OP_SHIFT + 1)
|
||||||
|
#define N_OP_QUANTS (Q80_F32_F32 + 1)
|
||||||
|
|
||||||
|
enum NnPointerSource {
|
||||||
|
SRC_PIPE,
|
||||||
|
SRC_BUFFER,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum NnPointerType {
|
||||||
|
PNTR_RAW,
|
||||||
|
PNTR_BATCH,
|
||||||
|
PNTR_BATCHED_SLICE
|
||||||
|
};
|
||||||
|
|
||||||
|
enum NnSyncType {
|
||||||
|
SYNC_WITH_ROOT, // whole pipe to all nodes
|
||||||
|
SYNC_NODE_SLICES, // my slice of pipe to all nodes
|
||||||
|
SYNC_NODE_SLICES_EXCEPT_ROOT, // only workers send slices to root, root does not send
|
||||||
|
};
|
||||||
|
|
||||||
|
enum NnRopeType {
|
||||||
|
ROPE_LLAMA = 0,
|
||||||
|
ROPE_FALCON = 1,
|
||||||
|
ROPE_LLAMA3_1 = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
// base configs
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
char *name;
|
||||||
|
NnSize3D size;
|
||||||
|
} NnPipeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
char *name;
|
||||||
|
NnSize3D size;
|
||||||
|
} NnBufferConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnPointerSource source;
|
||||||
|
NnUint pointerIndex;
|
||||||
|
NnPointerType type;
|
||||||
|
} NnPointerConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnOpCode code;
|
||||||
|
char *name;
|
||||||
|
NnUint index;
|
||||||
|
NnPointerConfig input;
|
||||||
|
NnPointerConfig output;
|
||||||
|
NnSize3D weightSize;
|
||||||
|
NnByte *config;
|
||||||
|
NnUint configSize;
|
||||||
|
} NnOpConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint pipeIndex;
|
||||||
|
} NnPreSyncConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint pipeIndex;
|
||||||
|
NnSyncType syncType;
|
||||||
|
} NnSyncConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint nOps;
|
||||||
|
NnOpConfig *ops;
|
||||||
|
NnUint nSyncs;
|
||||||
|
NnSyncConfig *syncs;
|
||||||
|
} NnSegmentConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint nBatches;
|
||||||
|
NnUint nNodes;
|
||||||
|
NnUint nPipes;
|
||||||
|
NnPipeConfig *pipes;
|
||||||
|
NnUint nPreSyncs;
|
||||||
|
NnPreSyncConfig *preSyncs;
|
||||||
|
} NnNetConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint nodeIndex;
|
||||||
|
NnUint nBuffers;
|
||||||
|
NnBufferConfig *buffers;
|
||||||
|
NnUint nSegments;
|
||||||
|
NnSegmentConfig *segments;
|
||||||
|
} NnNodeConfig;
|
||||||
|
|
||||||
|
// op configs
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
// empty
|
||||||
|
} NnEmbeddingOpConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
float epsilon;
|
||||||
|
NnUint nColumns;
|
||||||
|
} NnInvRmsOpConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint invRmsBufferIndex;
|
||||||
|
NnUint nColumns;
|
||||||
|
} NnRmsNormOpConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint nExperts;
|
||||||
|
NnUint nActiveExperts;
|
||||||
|
NnUint activeExpertIndexesBufferIndex;
|
||||||
|
} NnMatmulOpConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnRopeType type;
|
||||||
|
NnUint isQ; // Cannot use `bool` here due to GPU memory alignment
|
||||||
|
NnUint positionPipeIndex;
|
||||||
|
NnUint ropeCacheBufferIndex;
|
||||||
|
float ropeScalingFactor;
|
||||||
|
float ropeScalingLowFreqFactor;
|
||||||
|
float ropeScalingHighFreqFactor;
|
||||||
|
NnUint ropeScalingOrigMaxSeqLen;
|
||||||
|
NnRopeSlice slice;
|
||||||
|
} NnRopeOpConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint nHeads;
|
||||||
|
NnUint nHeads0;
|
||||||
|
NnUint nKvHeads;
|
||||||
|
NnUint headDim;
|
||||||
|
NnUint seqLen;
|
||||||
|
NnUint qSliceD0;
|
||||||
|
NnUint kvDim0;
|
||||||
|
NnUint positionPipeIndex;
|
||||||
|
NnUint queryBufferIndex;
|
||||||
|
NnUint keyCacheBufferIndex;
|
||||||
|
NnUint valueCacheBufferIndex;
|
||||||
|
NnUint attBufferIndex;
|
||||||
|
} NnMultiHeadAttOpConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
// empty
|
||||||
|
} NnMergeAddOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
// empty
|
||||||
|
} NnMergeSumOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
// empty
|
||||||
|
} NnSiluOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint multiplierBufferIndex;
|
||||||
|
} NnMulOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint scaleBufferIndex;
|
||||||
|
} NnScaleOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
// empty
|
||||||
|
} NnCastOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
// empty
|
||||||
|
} NnRepeatZOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint indexPipeIndex;
|
||||||
|
} NnShiftOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
// empty
|
||||||
|
} NnSoftmaxOpCodeConfig;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint k;
|
||||||
|
NnUint normTopk;
|
||||||
|
NnUint indexesBufferIndex;
|
||||||
|
} NnMoeGateOpCodeConfig;
|
||||||
|
|
||||||
|
// utility functions
|
||||||
|
|
||||||
|
const char *opCodeToString(NnOpCode code);
|
||||||
|
const char *opQuantTypeToString(NnOpQuantType type);
|
||||||
|
|
||||||
|
NnSize getBytes(NnFloatType floatType, NnSize n);
|
||||||
|
NnSize getBlockSize(NnFloatType floatType);
|
||||||
|
NnOpQuantType getOpQuantType(NnFloatType input, NnFloatType weight, NnFloatType output);
|
||||||
|
NnSize3D size0();
|
||||||
|
NnSize3D size1D(NnFloatType floatType, NnUint x);
|
||||||
|
NnSize3D size2D(NnFloatType floatType, NnUint y, NnUint x);
|
||||||
|
NnSize3D size3D(NnFloatType floatType, NnUint z, NnUint y, NnUint x);
|
||||||
|
NnPointerConfig pointerBatchConfig(NnPointerSource source, NnUint index);
|
||||||
|
NnPointerConfig pointerBatchedSliceConfig(NnPointerSource source, NnUint index);
|
||||||
|
NnPointerConfig pointerRawConfig(NnPointerSource source, NnUint index);
|
||||||
|
bool hasPointerContinuousMemory(NnPointerConfig *config);
|
||||||
|
|
||||||
|
void releaseNetConfig(NnNetConfig *netConfig);
|
||||||
|
void releaseNodeConfig(NnNodeConfig *nodeConfig);
|
||||||
|
|
||||||
|
void printNodeRequiredMemory(NnNetConfig *netConfig, NnNodeConfig *nodeConfig);
|
||||||
|
|
||||||
|
class Timer {
|
||||||
|
private:
|
||||||
|
std::chrono::time_point<std::chrono::high_resolution_clock> startTime;
|
||||||
|
public:
|
||||||
|
Timer();
|
||||||
|
void reset();
|
||||||
|
NnUint elapsedMiliseconds();
|
||||||
|
NnUint elapsedMicroseconds();
|
||||||
|
};
|
||||||
|
|
||||||
|
// slicers
|
||||||
|
|
||||||
|
NnKvCacheSlice sliceKvCache(NnUint kvDim, NnUint seqLen, NnUint nNodes);
|
||||||
|
NnRowMatmulSlice sliceRowMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d);
|
||||||
|
NnColMatmulSlice sliceColMatmul(NnFloatType type, NnUint nNodes, NnUint n, NnUint d);
|
||||||
|
NnRopeSlice sliceRope(NnRopeType type, NnUint qDim, NnUint kvDim, NnUint nKvHeads, NnUint nNodes, NnUint seqLen, NnUint headDim, float ropeTheta, NnUint nodeIndex);
|
||||||
|
NnMultiHeadAttSlice sliceMultiHeadAtt(NnUint nHeads, NnUint seqLen, NnUint nNodes, NnUint nBatches);
|
||||||
|
|
||||||
|
// splitters
|
||||||
|
|
||||||
|
NnUint splitRowMatmulWeight(NnRowMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0);
|
||||||
|
NnUint splitColMatmulWeight(NnColMatmulSlice *slice, NnUint nodeIndex, NnByte *weight, NnByte *weight0);
|
||||||
|
|
||||||
|
// rope
|
||||||
|
|
||||||
|
void fullfillRopeCache(const NnRopeOpConfig *config, float *cache);
|
||||||
|
|
||||||
|
#endif
|
||||||
374
src/nn/nn-cpu-ops-test.cpp
Normal file
374
src/nn/nn-cpu-ops-test.cpp
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
#include "nn-cpu-ops.cpp"
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
// framework
|
||||||
|
|
||||||
|
void printPassed(const char *name) {
|
||||||
|
printf("✅ %24s passed\n", name);
|
||||||
|
fflush(stdout);
|
||||||
|
}
|
||||||
|
|
||||||
|
void rand(float *o, const NnUint n, const NnUint seed) {
|
||||||
|
srand(seed + 123456);
|
||||||
|
for (NnUint i = 0; i < n; i++) {
|
||||||
|
float v = (float)(rand() / RAND_MAX);
|
||||||
|
o[i] = v * 2.0f - 1.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void compare_F32(const char *name, const float *a, const float *b, const NnUint n, const float epsilon) {
|
||||||
|
for (NnUint i = 0; i < n; i++) {
|
||||||
|
float error = fabs(a[i] - b[i]);
|
||||||
|
if (error > epsilon) {
|
||||||
|
printf("❌ %s failed\n", name);
|
||||||
|
for (NnUint j = i; j < i + 16 && j < n; j++)
|
||||||
|
printf(" [%3d] %f != %f\n", j, a[j], b[j]);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printPassed(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// tests
|
||||||
|
|
||||||
|
void testSplitThreads() {
|
||||||
|
// <0; 32> across 3 threads
|
||||||
|
{
|
||||||
|
SPLIT_THREADS(a0Start, a0End, 32, 3, 0); // thread 0
|
||||||
|
assert(a0Start == 0);
|
||||||
|
assert(a0End == 11);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
SPLIT_THREADS(a1Start, a1End, 32, 3, 1); // thread 1
|
||||||
|
assert(a1Start == 11);
|
||||||
|
assert(a1End == 22);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
SPLIT_THREADS(a2Start, a2End, 32, 3, 2); // thread 2
|
||||||
|
assert(a2Start == 22);
|
||||||
|
assert(a2End == 32);
|
||||||
|
}
|
||||||
|
|
||||||
|
// <0; 4> across 8 threads
|
||||||
|
{
|
||||||
|
SPLIT_THREADS(b0Start, b0End, 4, 8, 0); // thread 0
|
||||||
|
assert(b0Start == 0);
|
||||||
|
assert(b0End == 1);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
SPLIT_THREADS(b0Start, b0End, 4, 8, 3); // thread 3
|
||||||
|
assert(b0Start == 3);
|
||||||
|
assert(b0End == 4);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
SPLIT_THREADS(b0Start, b0End, 4, 8, 4); // thread 4
|
||||||
|
assert(b0Start == 4);
|
||||||
|
assert(b0End == 4);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
SPLIT_THREADS(b0Start, b0End, 4, 8, 7); // thread 7
|
||||||
|
assert(b0Start == 4);
|
||||||
|
assert(b0End == 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
printPassed("splitThreads");
|
||||||
|
}
|
||||||
|
|
||||||
|
void testConvertF32toF16() {
|
||||||
|
float x[] = {0.0f, 0.25f, 0.3456f, 1.0f};
|
||||||
|
for (NnUint i = 0; i < sizeof(x) / sizeof(float); i++) {
|
||||||
|
NnFp16 f16 = CONVERT_F32_TO_F16(x[i]);
|
||||||
|
float f32 = CONVERT_F16_TO_F32(f16);
|
||||||
|
compare_F32("convertF32toF16", &x[i], &f32, 1, 0.0005);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// quantization
|
||||||
|
void testQuantization(const NnUint m) {
|
||||||
|
std::vector<float> a(m * Q40_BLOCK_SIZE);
|
||||||
|
std::vector<float> aTemp(m * Q40_BLOCK_SIZE);
|
||||||
|
std::vector<NnBlockQ40> aQ40(m);
|
||||||
|
std::vector<NnBlockQ80> aQ80(m);
|
||||||
|
|
||||||
|
rand(a.data(), m * Q40_BLOCK_SIZE, m);
|
||||||
|
|
||||||
|
quantizeF32toQ40(a.data(), aQ40.data(), m * Q40_BLOCK_SIZE, 1, 0);
|
||||||
|
dequantizeQ40toF32(aQ40.data(), aTemp.data(), m * Q40_BLOCK_SIZE, 1, 0);
|
||||||
|
|
||||||
|
compare_F32("testQuantization_Q40", a.data(), aTemp.data(), m * Q40_BLOCK_SIZE, 0.13);
|
||||||
|
|
||||||
|
quantizeF32toQ80(a.data(), aQ80.data(), m * Q80_BLOCK_SIZE, 1, 0);
|
||||||
|
dequantizeQ80toF32(aQ80.data(), aTemp.data(), m * Q80_BLOCK_SIZE, 1, 0);
|
||||||
|
|
||||||
|
compare_F32("testQuantization_Q80", a.data(), aTemp.data(), m * Q80_BLOCK_SIZE, 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
// invRms
|
||||||
|
void testInvRms() {
|
||||||
|
const float epsilon = 0.00001;
|
||||||
|
|
||||||
|
std::vector<float> x(8);
|
||||||
|
x[0] = 0.1f;
|
||||||
|
x[1] = 0.3f;
|
||||||
|
x[2] = 0.2f;
|
||||||
|
x[3] = 0.4f;
|
||||||
|
x[4] = 0.6f;
|
||||||
|
x[5] = 0.5f;
|
||||||
|
x[6] = 0.0f;
|
||||||
|
x[7] = 0.8f;
|
||||||
|
|
||||||
|
const float y0 = invRms_F32(x.data(), 8, epsilon);
|
||||||
|
float ev0 = 1.0f / 0.4402f;
|
||||||
|
compare_F32("rms_F32", &y0, &ev0, 1, 0.001f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// rmsNorm
|
||||||
|
void testRmsNorm(const NnUint m) {
|
||||||
|
std::vector<float> x(m);
|
||||||
|
std::vector<NnBlockQ80> xQ80(m / Q80_BLOCK_SIZE);
|
||||||
|
std::vector<float> w(m);
|
||||||
|
std::vector<float> y(m);
|
||||||
|
std::vector<float> yTemp(m);
|
||||||
|
|
||||||
|
rand(x.data(), m, m);
|
||||||
|
rand(w.data(), m, m * m);
|
||||||
|
quantizeF32toQ80(x.data(), xQ80.data(), m, 1, 0);
|
||||||
|
const float rms = invRms_F32(x.data(), m, 1e-5f);
|
||||||
|
|
||||||
|
rmsNorm_F32(y.data(), x.data(), rms, w.data(), m, 1, 0);
|
||||||
|
rmsNorm_Q80_F32_F32(yTemp.data(), xQ80.data(), rms, w.data(), m, 1, 0);
|
||||||
|
|
||||||
|
compare_F32("rmsNorm_Q80_F32_F32", y.data(), yTemp.data(), m, 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
// a *= b
|
||||||
|
void testMul(const NnUint m) {
|
||||||
|
const NnUint n = Q80_BLOCK_SIZE * m;
|
||||||
|
|
||||||
|
std::vector<float> a0(n);
|
||||||
|
std::vector<float> b0(n);
|
||||||
|
|
||||||
|
std::vector<float> aQ(n);
|
||||||
|
std::vector<NnBlockQ80> b1(n / Q80_BLOCK_SIZE);
|
||||||
|
|
||||||
|
rand(a0.data(), n, m);
|
||||||
|
rand(aQ.data(), n, m);
|
||||||
|
rand(b0.data(), n, m);
|
||||||
|
quantizeF32toQ80(b0.data(), b1.data(), n, 1, 0);
|
||||||
|
|
||||||
|
mul_F32(a0.data(), a0.data(), b0.data(), n, 1, 0);
|
||||||
|
mul_Q80_F32(aQ.data(), aQ.data(), b1.data(), n, 1, 0);
|
||||||
|
|
||||||
|
compare_F32("mul_Q80_F32", a0.data(), aQ.data(), n, 0.005);
|
||||||
|
}
|
||||||
|
|
||||||
|
// y += x
|
||||||
|
void testAdd(const NnUint m) {
|
||||||
|
const NnUint n = Q80_BLOCK_SIZE * m;
|
||||||
|
|
||||||
|
std::vector<float> y(n);
|
||||||
|
std::vector<float> yTemp(n);
|
||||||
|
std::vector<float> x(n);
|
||||||
|
std::vector<NnBlockQ80> xQ80(n / Q80_BLOCK_SIZE);
|
||||||
|
|
||||||
|
rand(y.data(), n, m);
|
||||||
|
rand(yTemp.data(), n, m);
|
||||||
|
rand(x.data(), n, m);
|
||||||
|
quantizeF32toQ80(x.data(), xQ80.data(), n, 1, 0);
|
||||||
|
|
||||||
|
add_F32(y.data(), x.data(), n, 1, 0);
|
||||||
|
add_Q80_F32(yTemp.data(), xQ80.data(), n, 1, 0);
|
||||||
|
|
||||||
|
compare_F32("add_Q80_F32", y.data(), yTemp.data(), n, 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testMergeSum() {
|
||||||
|
float inp[] = {
|
||||||
|
/* [z0, y0] */ 0.1f, 0.2f,
|
||||||
|
/* [z0, y1] */ 0.3f, 0.4f,
|
||||||
|
/* [z1, y0] */ 0.5f, 0.6f,
|
||||||
|
/* [z1, y1] */ 0.7f, 0.8f,
|
||||||
|
};
|
||||||
|
float out[4];
|
||||||
|
|
||||||
|
float *i[4] = {
|
||||||
|
&inp[0],
|
||||||
|
&inp[2],
|
||||||
|
&inp[4],
|
||||||
|
&inp[6],
|
||||||
|
};
|
||||||
|
float *o[2] = {
|
||||||
|
&out[0],
|
||||||
|
&out[2]
|
||||||
|
};
|
||||||
|
|
||||||
|
mergeSum_F32(o, i, 2u, 2u, 2u, 2u, 1u, 0u);
|
||||||
|
|
||||||
|
const float expectedOutput[4] = {
|
||||||
|
0.6f,
|
||||||
|
0.8f,
|
||||||
|
1.0f,
|
||||||
|
1.2f,
|
||||||
|
};
|
||||||
|
compare_F32("mergeSum_F32", out, expectedOutput, 4u, 0.00000001f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testSoftmax() {
|
||||||
|
std::vector<float> y(8);
|
||||||
|
for (NnUint i = 0; i < 8; i++)
|
||||||
|
y[i] = i / 8.0f;
|
||||||
|
|
||||||
|
softmax_F32(y.data(), 8);
|
||||||
|
|
||||||
|
float expectedOutput[8] = {
|
||||||
|
0.077399f,
|
||||||
|
0.087780f,
|
||||||
|
0.099500f,
|
||||||
|
0.112761f,
|
||||||
|
0.127778f,
|
||||||
|
0.144793f,
|
||||||
|
0.164072f,
|
||||||
|
0.185917f
|
||||||
|
};
|
||||||
|
compare_F32("softmax_F32", y.data(), expectedOutput, 8, 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testSilu() {
|
||||||
|
std::vector<float> y(8);
|
||||||
|
for (NnUint i = 0; i < 8; i++)
|
||||||
|
y[i] = i / 8.0f;
|
||||||
|
|
||||||
|
silu_F32(y.data(), 8, 1, 0);
|
||||||
|
|
||||||
|
float expectedOutput[8] = {
|
||||||
|
0.000000f,
|
||||||
|
0.066401f,
|
||||||
|
0.140544f,
|
||||||
|
0.222250f,
|
||||||
|
0.311233f,
|
||||||
|
0.407116f,
|
||||||
|
0.509461f,
|
||||||
|
0.617802f
|
||||||
|
};
|
||||||
|
compare_F32("silu_F32", y.data(), expectedOutput, 8, 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
// matmul
|
||||||
|
void testMatmul_F32_Q40_F32(const NnUint m = 2) {
|
||||||
|
const NnUint n = Q80_BLOCK_SIZE * m;
|
||||||
|
const NnUint d = Q80_BLOCK_SIZE * m;
|
||||||
|
|
||||||
|
std::vector<float> x(n);
|
||||||
|
std::vector<float> w(n * d);
|
||||||
|
std::vector<float> o(d);
|
||||||
|
std::vector<float> oTemp(d);
|
||||||
|
std::vector<NnBlockQ80> xQ80(n / Q80_BLOCK_SIZE);
|
||||||
|
std::vector<NnBlockQ40> wQ40((n * d) / Q40_BLOCK_SIZE);
|
||||||
|
|
||||||
|
rand(x.data(), n, m);
|
||||||
|
rand(w.data(), n * d, m);
|
||||||
|
quantizeF32toQ40(w.data(), wQ40.data(), n * d, 1, 0);
|
||||||
|
quantizeF32toQ80(x.data(), xQ80.data(), n, 1, 0);
|
||||||
|
|
||||||
|
matmul_F32_F32_F32(o.data(), x.data(), w.data(), n, d, 1, 0);
|
||||||
|
|
||||||
|
matmul_Q80_Q40_F32(oTemp.data(), xQ80.data(), wQ40.data(), n, d, 1, 0);
|
||||||
|
compare_F32("matmul_Q80_Q40_F32", o.data(), oTemp.data(), d, 4.0f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testLlamafileSgemm() {
|
||||||
|
const NnUint batchSize = 8;
|
||||||
|
const NnUint n = 256;
|
||||||
|
const NnUint d = 128;
|
||||||
|
|
||||||
|
std::vector<float> x(n * batchSize);
|
||||||
|
std::vector<NnBlockQ80> xQ((n * batchSize) / Q80_BLOCK_SIZE);
|
||||||
|
std::vector<float> w(n * d);
|
||||||
|
std::vector<NnBlockQ40> wQ((n * d) / Q40_BLOCK_SIZE);
|
||||||
|
std::vector<float> o(d * batchSize);
|
||||||
|
std::vector<float> oTemp(d * batchSize);
|
||||||
|
|
||||||
|
rand(x.data(), n * batchSize, 12345);
|
||||||
|
rand(w.data(), n * d, 23456);
|
||||||
|
|
||||||
|
quantizeF32toQ80(x.data(), xQ.data(), n * batchSize, 1, 0);
|
||||||
|
quantizeF32toQ40(w.data(), wQ.data(), n * d, 1, 0);
|
||||||
|
|
||||||
|
// f32
|
||||||
|
|
||||||
|
for (NnUint i = 0; i < batchSize; i++) {
|
||||||
|
matmul_F32_F32_F32(o.data() + i * d, x.data() + i * n, w.data(), n, d, 1, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(llamafile_sgemm(
|
||||||
|
d, batchSize, n,
|
||||||
|
w.data(), n,
|
||||||
|
x.data(), n,
|
||||||
|
oTemp.data(), d,
|
||||||
|
0, 1, 0,
|
||||||
|
F_32, F_32, F_32
|
||||||
|
));
|
||||||
|
|
||||||
|
compare_F32("llamafileSgemm_F32", o.data(), oTemp.data(), d * batchSize, 0.01f);
|
||||||
|
|
||||||
|
#if __ARM_FEATURE_DOTPROD
|
||||||
|
// q40ᵀ * q80
|
||||||
|
|
||||||
|
assert(llamafile_sgemm(
|
||||||
|
d, batchSize, n / Q80_BLOCK_SIZE,
|
||||||
|
wQ.data(), n / Q80_BLOCK_SIZE,
|
||||||
|
xQ.data(), n / Q80_BLOCK_SIZE,
|
||||||
|
oTemp.data(), d,
|
||||||
|
0, 1, 0,
|
||||||
|
F_Q40, F_Q80, F_32
|
||||||
|
));
|
||||||
|
|
||||||
|
compare_F32("llamafileSgemm_Q80_Q40", o.data(), oTemp.data(), d * batchSize, 1.5f);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void testScale() {
|
||||||
|
float i[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||||
|
float o[4];
|
||||||
|
scale_F32(i, o, 0.5f, 4u, 1u, 0u);
|
||||||
|
float expectedOutput[] = {0.5f, 1.0f, 1.5f, 2.0f};
|
||||||
|
compare_F32("scale_F32", o, expectedOutput, 4u, 0.00001f);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testTopk() {
|
||||||
|
float x[] = {1.0f, 4.0f, 2.0f, 3.0f};
|
||||||
|
std::vector<NnUint> topk(2);
|
||||||
|
topk_F32(x, topk.data(), 4u, 2u);
|
||||||
|
assert(topk[0] == 1u);
|
||||||
|
assert(topk[1] == 3u);
|
||||||
|
printPassed("testTopk");
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
initQuants();
|
||||||
|
|
||||||
|
printCpuInstructionSet();
|
||||||
|
testSplitThreads();
|
||||||
|
testConvertF32toF16();
|
||||||
|
testQuantization(32);
|
||||||
|
testQuantization(2);
|
||||||
|
testQuantization(1);
|
||||||
|
testInvRms();
|
||||||
|
testRmsNorm(128);
|
||||||
|
testMul(32);
|
||||||
|
testMul(2);
|
||||||
|
testMul(1);
|
||||||
|
testAdd(32);
|
||||||
|
testAdd(2);
|
||||||
|
testAdd(1);
|
||||||
|
testMergeSum();
|
||||||
|
testSoftmax();
|
||||||
|
testSilu();
|
||||||
|
testMatmul_F32_Q40_F32(32);
|
||||||
|
testMatmul_F32_Q40_F32(2);
|
||||||
|
testMatmul_F32_Q40_F32(1);
|
||||||
|
testLlamafileSgemm();
|
||||||
|
testScale();
|
||||||
|
testTopk();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
1600
src/nn/nn-cpu-ops.cpp
Normal file
1600
src/nn/nn-cpu-ops.cpp
Normal file
File diff suppressed because it is too large
Load Diff
43
src/nn/nn-cpu-ops.hpp
Normal file
43
src/nn/nn-cpu-ops.hpp
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#ifndef NN_CPU_OPS_H
|
||||||
|
#define NN_CPU_OPS_H
|
||||||
|
|
||||||
|
#include "nn-core.hpp"
|
||||||
|
|
||||||
|
#define ASSERT_EQ(a, b) \
|
||||||
|
if (a != b) { \
|
||||||
|
printf("Assertion failed: %d != %d (%s:%d)\n", a, b, __FILE__, __LINE__); \
|
||||||
|
exit(-1); \
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
const char *name;
|
||||||
|
NnByte nBatches;
|
||||||
|
NnByte *bufferFlags;
|
||||||
|
NnByte **buffers;
|
||||||
|
NnBufferConfig *bufferConfigs;
|
||||||
|
NnByte **pipes;
|
||||||
|
NnPipeConfig *pipeConfigs;
|
||||||
|
void *opConfig;
|
||||||
|
|
||||||
|
NnByte **input;
|
||||||
|
NnSize3D inputSize;
|
||||||
|
bool hasInputContinuousMemory;
|
||||||
|
|
||||||
|
NnByte **output;
|
||||||
|
NnSize3D outputSize;
|
||||||
|
bool hasOutputContinuousMemory;
|
||||||
|
|
||||||
|
NnByte *weight;
|
||||||
|
NnSize3D weightSize;
|
||||||
|
} NnCpuOpContext;
|
||||||
|
|
||||||
|
typedef void (*NnCpuOpForwardInit)(NnCpuOpContext *context);
|
||||||
|
typedef void (*NnCpuOpForward)(NnUint nThreads, NnUint threadIndex, NnUint batchSize, NnCpuOpContext *context);
|
||||||
|
|
||||||
|
void printCpuInstructionSet();
|
||||||
|
NnCpuOpForwardInit getCpuOpForwardInit(NnOpCode code, NnOpQuantType quantType);
|
||||||
|
NnCpuOpForward getCpuOpForward(NnOpCode code, NnOpQuantType quantType);
|
||||||
|
|
||||||
|
void softmax_F32(float *x, const NnUint size);
|
||||||
|
|
||||||
|
#endif
|
||||||
85
src/nn/nn-cpu-test.cpp
Normal file
85
src/nn/nn-cpu-test.cpp
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
#include "nn-core.hpp"
|
||||||
|
#include "nn-config-builder.hpp"
|
||||||
|
#include "nn-cpu.hpp"
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#define DIM 32
|
||||||
|
#define N_BATCHES 2
|
||||||
|
|
||||||
|
void buildConfig(NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
|
||||||
|
NnUint nNodes = 1;
|
||||||
|
NnNetConfigBuilder netBuilder(nNodes, N_BATCHES);
|
||||||
|
NnUint xPipeIndex = netBuilder.addPipe("X", size2D(F_32, N_BATCHES, DIM));
|
||||||
|
|
||||||
|
NnNodeConfigBuilder nodeBuilder(0);
|
||||||
|
NnUint invRmsBufferIndex = nodeBuilder.addBuffer("inv_rms", size2D(F_32, N_BATCHES, 1));
|
||||||
|
NnSegmentConfigBuilder segmentBuilder;
|
||||||
|
segmentBuilder.addSync(xPipeIndex, SYNC_NODE_SLICES_EXCEPT_ROOT);
|
||||||
|
|
||||||
|
segmentBuilder.addOp(OP_INV_RMS, "inv_rms", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnInvRmsOpConfig{1e-5f, 1});
|
||||||
|
|
||||||
|
segmentBuilder.addOp(OP_RMS_NORM, "rms_norm", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size1D(F_32, DIM),
|
||||||
|
NnRmsNormOpConfig{invRmsBufferIndex, 1});
|
||||||
|
|
||||||
|
nodeBuilder.addSegment(segmentBuilder.build());
|
||||||
|
|
||||||
|
*netConfig = netBuilder.build();
|
||||||
|
*nodeConfig = nodeBuilder.build();
|
||||||
|
}
|
||||||
|
|
||||||
|
void print2D(const char *name, NnUint x, NnUint y, float *w) {
|
||||||
|
for (NnUint i = 0; i < y; i++) {
|
||||||
|
printf("%s[%d] = ", name, i);
|
||||||
|
for (NnUint j = 0; j < x; j++)
|
||||||
|
printf("%f ", w[i * x + j]);
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
initQuants();
|
||||||
|
|
||||||
|
NnUint nThreads = 2;
|
||||||
|
NnNetConfig netConfig;
|
||||||
|
NnNodeConfig nodeConfig;
|
||||||
|
buildConfig(&netConfig, &nodeConfig);
|
||||||
|
|
||||||
|
NnNetExecution execution(nThreads, &netConfig);
|
||||||
|
float *x = (float *)execution.pipes[0];
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint i = 0; i < DIM; i++)
|
||||||
|
x[b * DIM + i] = i / (float)DIM + (float)b;
|
||||||
|
}
|
||||||
|
|
||||||
|
print2D("x", DIM, N_BATCHES, x);
|
||||||
|
|
||||||
|
float rmsNormWeight[DIM];
|
||||||
|
for (NnUint i = 0; i < DIM; i++)
|
||||||
|
rmsNormWeight[i] = 0.5 + i / (float)DIM;
|
||||||
|
|
||||||
|
NnCpuDevice *device = new NnCpuDevice(&netConfig, &nodeConfig, &execution);
|
||||||
|
std::vector<NnExecutorDevice> devices;
|
||||||
|
devices.push_back(NnExecutorDevice(device, -1, -1));
|
||||||
|
|
||||||
|
NnFakeNodeSynchronizer synchronizer;
|
||||||
|
float *rms = (float *)device->buffers[0];
|
||||||
|
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
|
||||||
|
executor.loadWeight("rms_norm", 0u, 0u, sizeof(rmsNormWeight), (NnByte *)rmsNormWeight);
|
||||||
|
|
||||||
|
execution.setBatchSize(2);
|
||||||
|
executor.forward();
|
||||||
|
|
||||||
|
print2D("rms", N_BATCHES, 1, rms);
|
||||||
|
print2D("x", DIM, N_BATCHES, x);
|
||||||
|
|
||||||
|
releaseNetConfig(&netConfig);
|
||||||
|
releaseNodeConfig(&nodeConfig);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
232
src/nn/nn-cpu.cpp
Normal file
232
src/nn/nn-cpu.cpp
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
#include "nn-cpu.hpp"
|
||||||
|
#include "nn-cpu-ops.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <thread>
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <windows.h>
|
||||||
|
#else
|
||||||
|
#include <sys/mman.h>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define DEBUG_CPU_OP_QUANTS false
|
||||||
|
|
||||||
|
#define BUFFER_ALIGNMENT 64
|
||||||
|
|
||||||
|
static NnByte *allocAlignedBuffer(NnSize size) {
|
||||||
|
NnByte *buffer;
|
||||||
|
#ifdef _WIN32
|
||||||
|
buffer = (NnByte *)_aligned_malloc(size, BUFFER_ALIGNMENT);
|
||||||
|
if (buffer == NULL)
|
||||||
|
throw std::runtime_error("_aligned_malloc failed");
|
||||||
|
#else
|
||||||
|
if (posix_memalign((void **)&buffer, BUFFER_ALIGNMENT, size) != 0)
|
||||||
|
throw std::runtime_error("posix_memalign failed");
|
||||||
|
mlock(buffer, size);
|
||||||
|
#endif
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void releaseAlignedBuffer(NnByte *buffer) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
_aligned_free(buffer);
|
||||||
|
#else
|
||||||
|
free(buffer);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
NnCpuDevice::NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution) {
|
||||||
|
this->netConfig = netConfig;
|
||||||
|
this->nodeConfig = nodeConfig;
|
||||||
|
this->netExecution = netExecution;
|
||||||
|
|
||||||
|
printCpuInstructionSet();
|
||||||
|
|
||||||
|
nBuffers = nodeConfig->nBuffers;
|
||||||
|
buffers = new NnByte *[nBuffers];
|
||||||
|
for (NnUint bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++) {
|
||||||
|
NnBufferConfig *config = &nodeConfig->buffers[bufferIndex];
|
||||||
|
NnByte *buffer = allocAlignedBuffer(config->size.nBytes);
|
||||||
|
buffers[bufferIndex] = buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
bufferFlags = new NnByte[nBuffers];
|
||||||
|
std::memset(bufferFlags, 0, nBuffers * sizeof(NnByte));
|
||||||
|
}
|
||||||
|
|
||||||
|
NnCpuDevice::~NnCpuDevice() {
|
||||||
|
for (NnUint bufferIndex = 0; bufferIndex < nBuffers; bufferIndex++)
|
||||||
|
releaseAlignedBuffer(buffers[bufferIndex]);
|
||||||
|
delete[] buffers;
|
||||||
|
delete[] bufferFlags;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint NnCpuDevice::maxNThreads() {
|
||||||
|
return std::thread::hardware_concurrency();
|
||||||
|
}
|
||||||
|
|
||||||
|
NnDeviceSegment *NnCpuDevice::createSegment(NnUint segmentIndex) {
|
||||||
|
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
|
||||||
|
assert(segmentConfig->nOps > 0);
|
||||||
|
|
||||||
|
std::vector<NnOpQuantType> opQuants(segmentConfig->nOps);
|
||||||
|
std::vector<NnCpuOpForward> opForwardLocal(segmentConfig->nOps);
|
||||||
|
std::vector<NnSize3D> inputSizes(segmentConfig->nOps);
|
||||||
|
std::vector<NnSize3D> outputSizes(segmentConfig->nOps);
|
||||||
|
|
||||||
|
std::vector<std::vector<NnByte *>> inputsPtr(segmentConfig->nOps);
|
||||||
|
std::vector<std::vector<NnByte *>> outputsPtr(segmentConfig->nOps);
|
||||||
|
|
||||||
|
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
|
||||||
|
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
|
||||||
|
NnSize3D inputSize;
|
||||||
|
NnSize3D outputSize;
|
||||||
|
inputsPtr[opIndex] = resolvePointer(&inputSize, &opConfig->input);
|
||||||
|
outputsPtr[opIndex] = resolvePointer(&outputSize, &opConfig->output);
|
||||||
|
NnOpQuantType opQuant = getOpQuantType(
|
||||||
|
inputSize.floatType,
|
||||||
|
opConfig->weightSize.floatType,
|
||||||
|
outputSize.floatType);
|
||||||
|
#if DEBUG_CPU_OP_QUANTS
|
||||||
|
printf("%20s %2d: %s\n", opConfig->name, opConfig->index, opQuantTypeToString(opQuant));
|
||||||
|
#endif
|
||||||
|
NnCpuOpForward forward = getCpuOpForward(opConfig->code, opQuant);
|
||||||
|
if (forward == nullptr) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
std::string("Unsupported CPU op code: ") + opCodeToString(opConfig->code) +
|
||||||
|
", quant: " + opQuantTypeToString(opQuant) +
|
||||||
|
", op name: " + opConfig->name);
|
||||||
|
}
|
||||||
|
inputSizes[opIndex] = inputSize;
|
||||||
|
outputSizes[opIndex] = outputSize;
|
||||||
|
opQuants[opIndex] = opQuant;
|
||||||
|
opForwardLocal[opIndex] = forward;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnCpuOpForward *opForward = new NnCpuOpForward[segmentConfig->nOps];
|
||||||
|
NnCpuOpContext *opContexts = new NnCpuOpContext[segmentConfig->nOps];
|
||||||
|
|
||||||
|
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
|
||||||
|
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
|
||||||
|
NnCpuOpContext *opContext = &opContexts[opIndex];
|
||||||
|
NnCpuOpForwardInit opInit = getCpuOpForwardInit(opConfig->code, opQuants[opIndex]);
|
||||||
|
opContext->name = opConfig->name;
|
||||||
|
opContext->opConfig = opConfig->config;
|
||||||
|
opContext->weightSize = opConfig->weightSize;
|
||||||
|
opContext->nBatches = netConfig->nBatches;
|
||||||
|
opContext->pipes = netExecution->pipes;
|
||||||
|
opContext->pipeConfigs = netConfig->pipes;
|
||||||
|
opContext->buffers = buffers;
|
||||||
|
opContext->bufferConfigs = nodeConfig->buffers;
|
||||||
|
opContext->bufferFlags = bufferFlags;
|
||||||
|
|
||||||
|
opContext->input = new NnByte *[inputsPtr[opIndex].size()];
|
||||||
|
opContext->inputSize = inputSizes[opIndex];
|
||||||
|
opContext->hasInputContinuousMemory = hasPointerContinuousMemory(&opConfig->input);
|
||||||
|
std::memcpy(opContext->input, inputsPtr[opIndex].data(), inputsPtr[opIndex].size() * sizeof(NnByte *));
|
||||||
|
|
||||||
|
opContext->output = new NnByte *[outputsPtr[opIndex].size()];
|
||||||
|
opContext->outputSize = outputSizes[opIndex];
|
||||||
|
opContext->hasOutputContinuousMemory = hasPointerContinuousMemory(&opConfig->output);
|
||||||
|
std::memcpy(opContext->output, outputsPtr[opIndex].data(), outputsPtr[opIndex].size() * sizeof(NnByte *));
|
||||||
|
|
||||||
|
#if not(DEBUG_USE_MMAP_FOR_WEIGHTS)
|
||||||
|
if (opContext->weightSize.nBytes > 0)
|
||||||
|
opContext->weight = allocAlignedBuffer(opContext->weightSize.nBytes);
|
||||||
|
else
|
||||||
|
opContext->weight = nullptr;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (opInit != nullptr)
|
||||||
|
opInit(opContext);
|
||||||
|
opForward[opIndex] = opForwardLocal[opIndex];
|
||||||
|
}
|
||||||
|
return new NnCpuDeviceSegment(opForward, opContexts, segmentConfig->nOps);
|
||||||
|
}
|
||||||
|
|
||||||
|
NnCpuDeviceSegment::~NnCpuDeviceSegment() {
|
||||||
|
for (NnUint opIndex = 0; opIndex < nOps; opIndex++) {
|
||||||
|
NnCpuOpContext *context = &opContexts[opIndex];
|
||||||
|
delete[] context->input;
|
||||||
|
delete[] context->output;
|
||||||
|
#if not(DEBUG_USE_MMAP_FOR_WEIGHTS)
|
||||||
|
if (context->weightSize.nBytes > 0)
|
||||||
|
releaseAlignedBuffer(context->weight);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
delete[] opForward;
|
||||||
|
delete[] opContexts;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<NnByte *> NnCpuDevice::resolvePointer(NnSize3D *pntrSize, NnPointerConfig *pointerConfig) {
|
||||||
|
NnByte *source;
|
||||||
|
NnSize3D *sourceSize;
|
||||||
|
|
||||||
|
switch (pointerConfig->source) {
|
||||||
|
case SRC_BUFFER:
|
||||||
|
source = buffers[pointerConfig->pointerIndex];
|
||||||
|
sourceSize = &nodeConfig->buffers[pointerConfig->pointerIndex].size;
|
||||||
|
break;
|
||||||
|
case SRC_PIPE:
|
||||||
|
source = netExecution->pipes[pointerConfig->pointerIndex];
|
||||||
|
sourceSize = &netConfig->pipes[pointerConfig->pointerIndex].size;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Unsupported pointer type");
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (pointerConfig->type) {
|
||||||
|
case PNTR_RAW: {
|
||||||
|
*pntrSize = size1D(sourceSize->floatType, sourceSize->length);
|
||||||
|
return std::vector<NnByte *>{source};
|
||||||
|
}
|
||||||
|
case PNTR_BATCH:
|
||||||
|
case PNTR_BATCHED_SLICE: {
|
||||||
|
ASSERT_EQ(sourceSize->y, netConfig->nBatches);
|
||||||
|
std::vector<NnByte *> pntr(sourceSize->z * sourceSize->y);
|
||||||
|
|
||||||
|
NnSize batchBytes = getBytes(sourceSize->floatType, sourceSize->x);
|
||||||
|
for (NnUint z = 0u; z < sourceSize->z; z++) {
|
||||||
|
for (NnUint y = 0u; y < sourceSize->y; y++)
|
||||||
|
pntr[z * sourceSize->y + y] = &source[(z * sourceSize->y + y) * batchBytes];
|
||||||
|
}
|
||||||
|
*pntrSize = *sourceSize;
|
||||||
|
|
||||||
|
if (pointerConfig->type == PNTR_BATCHED_SLICE) {
|
||||||
|
assert(sourceSize->x % netConfig->nNodes == 0);
|
||||||
|
NnUint xSlice = sourceSize->x / netConfig->nNodes;
|
||||||
|
NnSize xSliceBytes = getBytes(sourceSize->floatType, xSlice);
|
||||||
|
for (NnUint z = 0; z < sourceSize->z; z++) {
|
||||||
|
for (NnUint y = 0; y < sourceSize->y; y++)
|
||||||
|
pntr[z * sourceSize->y + y] = &pntr[z * sourceSize->y + y][xSliceBytes * nodeConfig->nodeIndex];
|
||||||
|
}
|
||||||
|
*pntrSize = size3D(sourceSize->floatType, sourceSize->z, sourceSize->y, xSlice);
|
||||||
|
}
|
||||||
|
return pntr;
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument("Unsupported pointer config");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnCpuDeviceSegment::loadWeight(NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) {
|
||||||
|
assert(opIndex >= 0u);
|
||||||
|
assert(opIndex < nOps);
|
||||||
|
NnCpuOpContext *context = &opContexts[opIndex];
|
||||||
|
assert(offset + nBytes <= context->weightSize.nBytes);
|
||||||
|
#if DEBUG_USE_MMAP_FOR_WEIGHTS
|
||||||
|
assert(offset == 0u);
|
||||||
|
context->weight = weight;
|
||||||
|
#else
|
||||||
|
std::memcpy(&context->weight[offset], weight, nBytes);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnCpuDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) {
|
||||||
|
NnCpuOpContext *context = &opContexts[opIndex];
|
||||||
|
// printf("forward: %d %s (%d/%d)\n", opIndex, context->name, threadIndex + 1, nThreads); fflush(stdout);
|
||||||
|
opForward[opIndex](nThreads, threadIndex, batchSize, context);
|
||||||
|
}
|
||||||
39
src/nn/nn-cpu.hpp
Normal file
39
src/nn/nn-cpu.hpp
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#ifndef NN_CPU_H
|
||||||
|
#define NN_CPU_H
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include "nn-executor.hpp"
|
||||||
|
#include "nn-cpu-ops.hpp"
|
||||||
|
|
||||||
|
#define DEBUG_USE_MMAP_FOR_WEIGHTS false
|
||||||
|
|
||||||
|
class NnCpuDevice : public NnDevice {
|
||||||
|
public:
|
||||||
|
NnByte **buffers;
|
||||||
|
private:
|
||||||
|
NnNetConfig *netConfig;
|
||||||
|
NnNodeConfig *nodeConfig;
|
||||||
|
NnNetExecution *netExecution;
|
||||||
|
NnUint nBuffers;
|
||||||
|
NnByte *bufferFlags;
|
||||||
|
public:
|
||||||
|
NnCpuDevice(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution);
|
||||||
|
~NnCpuDevice() override;
|
||||||
|
NnUint maxNThreads() override;
|
||||||
|
NnDeviceSegment *createSegment(NnUint segmentIndex) override;
|
||||||
|
std::vector<NnByte *> resolvePointer(NnSize3D *pntrSize, NnPointerConfig *pointerConfig);
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnCpuDeviceSegment : public NnDeviceSegment {
|
||||||
|
public:
|
||||||
|
NnUint nOps;
|
||||||
|
NnCpuOpForward *opForward;
|
||||||
|
NnCpuOpContext *opContexts;
|
||||||
|
NnCpuDeviceSegment(NnCpuOpForward *opForward, NnCpuOpContext *opContexts, NnUint nOps)
|
||||||
|
: opForward(opForward), opContexts(opContexts), nOps(nOps) {}
|
||||||
|
~NnCpuDeviceSegment() override;
|
||||||
|
void loadWeight(NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) override;
|
||||||
|
void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
192
src/nn/nn-executor.cpp
Normal file
192
src/nn/nn-executor.cpp
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include "nn-executor.hpp"
|
||||||
|
|
||||||
|
void NnFakeNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) {
|
||||||
|
// Nothing
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNetExecution::NnNetExecution(NnUint nThreads, NnNetConfig *netConfig) {
|
||||||
|
this->nThreads = nThreads;
|
||||||
|
this->nBatches = netConfig->nBatches;
|
||||||
|
this->nPipes = netConfig->nPipes;
|
||||||
|
this->batchSize = 0; // This value must be overwritten before calling forward
|
||||||
|
|
||||||
|
pipes = new NnByte *[netConfig->nPipes];
|
||||||
|
for (NnUint pipeIndex = 0; pipeIndex < netConfig->nPipes; pipeIndex++) {
|
||||||
|
NnPipeConfig *pipeConfig = &netConfig->pipes[pipeIndex];
|
||||||
|
NnByte *pipe = new NnByte[pipeConfig->size.nBytes];
|
||||||
|
std::memset(pipe, 0, pipeConfig->size.nBytes);
|
||||||
|
pipes[pipeIndex] = pipe;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNetExecution::~NnNetExecution() {
|
||||||
|
for (NnUint pipeIndex = 0; pipeIndex < nPipes; pipeIndex++)
|
||||||
|
delete[] pipes[pipeIndex];
|
||||||
|
delete[] pipes;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetExecution::setBatchSize(NnUint batchSize) {
|
||||||
|
assert(batchSize <= nBatches);
|
||||||
|
this->batchSize = batchSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnExecutorDevice::NnExecutorDevice(NnDevice *device, int segmentFrom, int segmentTo) {
|
||||||
|
this->device = std::unique_ptr<NnDevice>(device);
|
||||||
|
this->segmentFrom = segmentFrom;
|
||||||
|
this->segmentTo = segmentTo;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnExecutor::NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector<NnExecutorDevice> *devices, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark)
|
||||||
|
: segments(nodeConfig->nSegments), steps()
|
||||||
|
{
|
||||||
|
NnUint maxNThreads = 0;
|
||||||
|
for (NnExecutorDevice &d : *devices) {
|
||||||
|
if (d.device->maxNThreads() > maxNThreads)
|
||||||
|
maxNThreads = d.device->maxNThreads();
|
||||||
|
}
|
||||||
|
if (netExecution->nThreads > maxNThreads)
|
||||||
|
throw std::invalid_argument("This configuration supports max " + std::to_string(maxNThreads) + " threads");
|
||||||
|
|
||||||
|
this->netExecution = netExecution;
|
||||||
|
this->nodeConfig = nodeConfig;
|
||||||
|
|
||||||
|
bool useSynchronizer = netConfig->nNodes > 1;
|
||||||
|
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
|
||||||
|
NnDevice *device = nullptr;
|
||||||
|
for (NnExecutorDevice &d : *devices) {
|
||||||
|
if (
|
||||||
|
(d.segmentFrom == -1 && d.segmentTo == -1) ||
|
||||||
|
(segmentIndex >= d.segmentFrom && segmentIndex <= d.segmentTo)
|
||||||
|
) {
|
||||||
|
device = d.device.get();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (device == nullptr)
|
||||||
|
throw std::invalid_argument("Cannot locate device for segment " + std::to_string(segmentIndex));
|
||||||
|
|
||||||
|
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
|
||||||
|
if (segmentConfig->nOps > 0) {
|
||||||
|
NnDeviceSegment *segment = device->createSegment(segmentIndex);
|
||||||
|
segments[segmentIndex] = std::unique_ptr<NnDeviceSegment>(segment);
|
||||||
|
|
||||||
|
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++)
|
||||||
|
steps.push_back(NnExecutorStep{ STEP_EXECUTE_OP, segment, opIndex, &segmentConfig->ops[opIndex] });
|
||||||
|
}
|
||||||
|
if (useSynchronizer && segmentConfig->nSyncs > 0)
|
||||||
|
steps.push_back(NnExecutorStep{ STEP_SYNC_NODES, nullptr, segmentIndex, nullptr });
|
||||||
|
}
|
||||||
|
|
||||||
|
steps.shrink_to_fit();
|
||||||
|
|
||||||
|
context.nThreads = netExecution->nThreads;
|
||||||
|
context.synchronizer = synchronizer;
|
||||||
|
context.nSteps = (NnUint)steps.size();
|
||||||
|
context.steps = steps.data();
|
||||||
|
if (benchmark)
|
||||||
|
context.timer = new Timer();
|
||||||
|
else
|
||||||
|
context.timer = nullptr;
|
||||||
|
|
||||||
|
threads = new NnExecutorThread[netExecution->nThreads];
|
||||||
|
for (NnUint threadIndex = 0; threadIndex < netExecution->nThreads; threadIndex++) {
|
||||||
|
NnExecutorThread *thread = &threads[threadIndex];
|
||||||
|
thread->threadIndex = threadIndex;
|
||||||
|
thread->context = &context;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NnExecutor::~NnExecutor() {
|
||||||
|
if (context.timer != nullptr)
|
||||||
|
delete context.timer;
|
||||||
|
delete[] threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnExecutor::loadWeight(const char *name, NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) {
|
||||||
|
for (NnUint segmentIndex = 0; segmentIndex < nodeConfig->nSegments; segmentIndex++) {
|
||||||
|
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
|
||||||
|
for (NnUint i = 0; i < segmentConfig->nOps; i++) {
|
||||||
|
NnOpConfig *opConfig = &segmentConfig->ops[i];
|
||||||
|
if (opConfig->index == opIndex && std::strcmp(opConfig->name, name) == 0) {
|
||||||
|
NnDeviceSegment *segment = segments[segmentIndex].get();
|
||||||
|
assert(segment != nullptr);
|
||||||
|
segment->loadWeight(i, offset, nBytes, weight);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
throw std::invalid_argument("Cannot locate op by name: " + std::string(name));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void executeStep(NnExecutorStep *step, NnUint nThreads, NnExecutorThread *thread, NnExecutorContext *context) {
|
||||||
|
if (step->type == STEP_EXECUTE_OP) {
|
||||||
|
step->segment->forward(step->arg0, nThreads, thread->threadIndex, context->batchSize);
|
||||||
|
} else if (step->type == STEP_SYNC_NODES) {
|
||||||
|
context->synchronizer->sync(step->arg0, nThreads, thread->threadIndex);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("Unsupported step type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void *executorThreadHandler(void *arg) {
|
||||||
|
NnExecutorThread *thread = (NnExecutorThread *)arg;
|
||||||
|
NnExecutorContext *context = thread->context;
|
||||||
|
NnUint nThreads = context->nThreads;
|
||||||
|
NnUint doneCount = nThreads - 1;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const unsigned int currentStepIndex = context->currentStepIndex.load();
|
||||||
|
if (currentStepIndex == context->nSteps)
|
||||||
|
break;
|
||||||
|
|
||||||
|
NnExecutorStep *step = &context->steps[currentStepIndex];
|
||||||
|
executeStep(step, nThreads, thread, context);
|
||||||
|
|
||||||
|
NnUint currentCount = context->doneThreadCount.fetch_add(1);
|
||||||
|
if (currentCount == doneCount) {
|
||||||
|
if (context->timer != nullptr) {
|
||||||
|
NnUint time = context->timer->elapsedMicroseconds();
|
||||||
|
context->totalTime[step->type] += time;
|
||||||
|
context->timer->reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
context->doneThreadCount.store(0);
|
||||||
|
context->currentStepIndex.fetch_add(1);
|
||||||
|
} else {
|
||||||
|
while (context->currentStepIndex.load() == currentStepIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnExecutor::forward() {
|
||||||
|
assert(netExecution->batchSize > 0);
|
||||||
|
|
||||||
|
NnUint nThreads = netExecution->nThreads;
|
||||||
|
context.currentStepIndex.exchange(0);
|
||||||
|
context.doneThreadCount.exchange(0);
|
||||||
|
context.batchSize = netExecution->batchSize;
|
||||||
|
|
||||||
|
if (context.timer != nullptr) {
|
||||||
|
std::memset(context.totalTime, 0, sizeof(context.totalTime));
|
||||||
|
context.timer->reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint threadIndex;
|
||||||
|
for (threadIndex = 1; threadIndex < nThreads; threadIndex++) {
|
||||||
|
int result = pthread_create(&threads[threadIndex].handler, NULL, (PthreadFunc)executorThreadHandler, (void *)&threads[threadIndex]);
|
||||||
|
if (result != 0)
|
||||||
|
throw std::runtime_error("Failed to create thread");
|
||||||
|
}
|
||||||
|
executorThreadHandler((void *)&threads[0]);
|
||||||
|
for (threadIndex = 1; threadIndex < nThreads; threadIndex++)
|
||||||
|
pthread_join(threads[threadIndex].handler, NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
NnUint NnExecutor::getTotalTime(NnExecutorStepType type) {
|
||||||
|
assert((NnUint)type < N_STEP_TYPES);
|
||||||
|
return context.totalTime[type];
|
||||||
|
}
|
||||||
103
src/nn/nn-executor.hpp
Normal file
103
src/nn/nn-executor.hpp
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
#ifndef NN_EXECUTOR_H
|
||||||
|
#define NN_EXECUTOR_H
|
||||||
|
|
||||||
|
#include "nn-core.hpp"
|
||||||
|
#include <atomic>
|
||||||
|
#include <vector>
|
||||||
|
#include "pthread.h"
|
||||||
|
|
||||||
|
class NnDeviceSegment {
|
||||||
|
public:
|
||||||
|
virtual ~NnDeviceSegment() {};
|
||||||
|
virtual void loadWeight(NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) = 0;
|
||||||
|
virtual void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnDevice {
|
||||||
|
public:
|
||||||
|
virtual NnUint maxNThreads() = 0;
|
||||||
|
virtual ~NnDevice() {}
|
||||||
|
virtual NnDeviceSegment *createSegment(NnUint segmentIndex) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnNodeSynchronizer {
|
||||||
|
public:
|
||||||
|
virtual ~NnNodeSynchronizer() {};
|
||||||
|
virtual void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnFakeNodeSynchronizer : public NnNodeSynchronizer {
|
||||||
|
public:
|
||||||
|
~NnFakeNodeSynchronizer() override {};
|
||||||
|
void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnNetExecution {
|
||||||
|
public:
|
||||||
|
NnUint nThreads;
|
||||||
|
NnUint nPipes;
|
||||||
|
NnByte **pipes;
|
||||||
|
NnUint batchSize;
|
||||||
|
NnUint nBatches;
|
||||||
|
NnNetExecution(NnUint nThreads, NnNetConfig *netConfig);
|
||||||
|
~NnNetExecution();
|
||||||
|
void setBatchSize(NnUint batchSize);
|
||||||
|
};
|
||||||
|
|
||||||
|
enum NnExecutorStepType {
|
||||||
|
STEP_EXECUTE_OP,
|
||||||
|
STEP_SYNC_NODES,
|
||||||
|
};
|
||||||
|
|
||||||
|
#define N_STEP_TYPES STEP_SYNC_NODES + 1
|
||||||
|
|
||||||
|
class NnExecutorDevice {
|
||||||
|
public:
|
||||||
|
std::unique_ptr<NnDevice> device;
|
||||||
|
int segmentFrom;
|
||||||
|
int segmentTo;
|
||||||
|
NnExecutorDevice(NnDevice *device, int segmentFrom, int segmentTo);
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnExecutorStepType type;
|
||||||
|
NnDeviceSegment *segment;
|
||||||
|
NnUint arg0;
|
||||||
|
NnOpConfig *opConfig;
|
||||||
|
} NnExecutorStep;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint nThreads;
|
||||||
|
NnUint nSteps;
|
||||||
|
NnExecutorStep *steps;
|
||||||
|
NnNodeSynchronizer *synchronizer;
|
||||||
|
std::atomic_uint currentStepIndex;
|
||||||
|
std::atomic_uint doneThreadCount;
|
||||||
|
NnUint batchSize;
|
||||||
|
Timer *timer;
|
||||||
|
NnUint totalTime[N_STEP_TYPES];
|
||||||
|
} NnExecutorContext;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint threadIndex;
|
||||||
|
NnExecutorContext *context;
|
||||||
|
PthreadHandler handler;
|
||||||
|
} NnExecutorThread;
|
||||||
|
|
||||||
|
class NnExecutor {
|
||||||
|
private:
|
||||||
|
NnNetExecution *netExecution;
|
||||||
|
NnNodeConfig *nodeConfig;
|
||||||
|
std::vector<std::unique_ptr<NnDeviceSegment>> segments;
|
||||||
|
std::vector<NnExecutorStep> steps;
|
||||||
|
NnExecutorThread *threads;
|
||||||
|
NnExecutorContext context;
|
||||||
|
public:
|
||||||
|
NnExecutor(NnNetConfig *netConfig, NnNodeConfig *nodeConfig, std::vector<NnExecutorDevice> *device, NnNetExecution *netExecution, NnNodeSynchronizer *synchronizer, bool benchmark);
|
||||||
|
~NnExecutor();
|
||||||
|
void loadWeight(const char *name, NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight);
|
||||||
|
void forward();
|
||||||
|
NnUint getTotalTime(NnExecutorStepType type);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
907
src/nn/nn-network.cpp
Normal file
907
src/nn/nn-network.cpp
Normal file
@@ -0,0 +1,907 @@
|
|||||||
|
#ifdef _WIN32
|
||||||
|
#include <winsock2.h>
|
||||||
|
#include <ws2tcpip.h> // For inet_addr and other functions
|
||||||
|
#include <windows.h> // For SSIZE_T
|
||||||
|
typedef SSIZE_T ssize_t;
|
||||||
|
#else
|
||||||
|
#include <sys/socket.h>
|
||||||
|
#include <netinet/tcp.h>
|
||||||
|
#include <arpa/inet.h>
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <netdb.h> // for getaddrinfo
|
||||||
|
#endif
|
||||||
|
#include "nn-network.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <vector>
|
||||||
|
#include <fcntl.h>
|
||||||
|
|
||||||
|
#define SOCKET_LAST_ERRCODE errno
|
||||||
|
#define SOCKET_LAST_ERROR strerror(errno)
|
||||||
|
|
||||||
|
#define ACK 23571114
|
||||||
|
#define MAX_CHUNK_SIZE 4096
|
||||||
|
|
||||||
|
static inline bool isEagainError() {
|
||||||
|
#ifdef _WIN32
|
||||||
|
return WSAGetLastError() == WSAEWOULDBLOCK;
|
||||||
|
#else
|
||||||
|
return SOCKET_LAST_ERRCODE == EAGAIN;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void setNonBlocking(int socket, bool enabled) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
u_long mode = enabled ? 1 : 0;
|
||||||
|
if (ioctlsocket(socket, FIONBIO, &mode) != 0) {
|
||||||
|
throw std::runtime_error("Error setting socket to non-blocking");
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
int flags = fcntl(socket, F_GETFL, 0);
|
||||||
|
if (enabled) {
|
||||||
|
flags |= O_NONBLOCK;
|
||||||
|
} else {
|
||||||
|
flags = flags & (~O_NONBLOCK);
|
||||||
|
}
|
||||||
|
if (fcntl(socket, F_SETFL, flags) < 0)
|
||||||
|
throw std::runtime_error("Error setting socket to non-blocking");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void setNoDelay(int socket) {
|
||||||
|
int flag = 1;
|
||||||
|
if (setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(int)) < 0)
|
||||||
|
throw std::runtime_error("Error setting socket to no-delay");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void setQuickAck(int socket) {
|
||||||
|
#ifndef _WIN32
|
||||||
|
#ifdef TCP_QUICKACK
|
||||||
|
int value = 1;
|
||||||
|
if (setsockopt(socket, IPPROTO_TCP, TCP_QUICKACK, (char*)&value, sizeof(int)) < 0)
|
||||||
|
throw std::runtime_error("Error setting quick ack");
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void setReuseAddr(int socket) {
|
||||||
|
int opt = 1;
|
||||||
|
#ifdef _WIN32
|
||||||
|
int iresult = setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, (char*)&opt, sizeof(opt));
|
||||||
|
if (iresult == SOCKET_ERROR) {
|
||||||
|
closesocket(socket);
|
||||||
|
throw std::runtime_error("setsockopt failed: " + std::to_string(WSAGetLastError()));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
if (setsockopt(socket, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
|
||||||
|
close(socket);
|
||||||
|
throw std::runtime_error("setsockopt failed: " + std::string(strerror(errno)));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void writeSocket(int socket, const void *data, NnSize size) {
|
||||||
|
while (size > 0) {
|
||||||
|
ssize_t s = send(socket, (const char*)data, size, 0);
|
||||||
|
if (s < 0) {
|
||||||
|
if (isEagainError()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw NnWriteNetworkException(0, "Error writing to socket");
|
||||||
|
} else if (s == 0) {
|
||||||
|
throw NnWriteNetworkException(0, "Socket closed");
|
||||||
|
}
|
||||||
|
size -= s;
|
||||||
|
data = (const char*)data + s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline bool tryReadSocket(int socket, void *data, NnSize size, unsigned long maxAttempts) {
|
||||||
|
// maxAttempts = 0 means infinite attempts
|
||||||
|
NnSize s = size;
|
||||||
|
while (s > 0) {
|
||||||
|
ssize_t r = recv(socket, (char*)data, s, 0);
|
||||||
|
if (r < 0) {
|
||||||
|
if (isEagainError()) {
|
||||||
|
if (s == size && maxAttempts > 0) {
|
||||||
|
maxAttempts--;
|
||||||
|
if (maxAttempts == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw NnReadNetworkException(0, "Error reading from socket");
|
||||||
|
} else if (r == 0) {
|
||||||
|
throw NnReadNetworkException(0, "Socket closed");
|
||||||
|
}
|
||||||
|
data = (char*)data + r;
|
||||||
|
s -= r;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void readSocket(int socket, void *data, NnSize size) {
|
||||||
|
if (!tryReadSocket(socket, data, size, 0)) {
|
||||||
|
throw std::runtime_error("Error reading from socket");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void readAckPacket(int socket) {
|
||||||
|
NnUint packet;
|
||||||
|
readSocket(socket, &packet, sizeof(packet));
|
||||||
|
if (packet != ACK)
|
||||||
|
throw std::runtime_error("Invalid ack packet");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void writeAckPacket(int socket) {
|
||||||
|
NnUint packet = ACK;
|
||||||
|
writeSocket(socket, &packet, sizeof(packet));
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline int connectSocket(char *host, int port) {
|
||||||
|
struct addrinfo hints;
|
||||||
|
struct addrinfo *addr = NULL;
|
||||||
|
std::memset(&hints, 0, sizeof(hints));
|
||||||
|
hints.ai_family = AF_INET;
|
||||||
|
hints.ai_socktype = SOCK_STREAM;
|
||||||
|
hints.ai_protocol = IPPROTO_TCP;
|
||||||
|
|
||||||
|
char portStr[11];
|
||||||
|
snprintf(portStr, sizeof(portStr), "%d", port);
|
||||||
|
|
||||||
|
int addrinfoError = getaddrinfo(host, portStr, &hints, &addr);
|
||||||
|
if (addrinfoError != 0 || addr == NULL) {
|
||||||
|
printf("Cannot resolve target %s (%s)\n", host, gai_strerror(addrinfoError));
|
||||||
|
throw std::runtime_error("Cannot resolve address");
|
||||||
|
}
|
||||||
|
|
||||||
|
int sock = ::socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
|
||||||
|
if (sock < 0)
|
||||||
|
throw std::runtime_error("Cannot create socket");
|
||||||
|
|
||||||
|
int connectResult = ::connect(sock, addr->ai_addr, addr->ai_addrlen);
|
||||||
|
if (connectResult != 0) {
|
||||||
|
printf("Cannot connect to %s:%d (%s)\n", host, port, SOCKET_LAST_ERROR);
|
||||||
|
throw std::runtime_error("Cannot connect");
|
||||||
|
}
|
||||||
|
|
||||||
|
setNoDelay(sock);
|
||||||
|
setQuickAck(sock);
|
||||||
|
return sock;
|
||||||
|
}
|
||||||
|
|
||||||
|
int createServerSocket(int port) {
|
||||||
|
const char *host = "0.0.0.0";
|
||||||
|
struct sockaddr_in serverAddr;
|
||||||
|
|
||||||
|
int serverSocket = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
|
||||||
|
if (serverSocket < 0)
|
||||||
|
throw std::runtime_error("Cannot create socket");
|
||||||
|
setReuseAddr(serverSocket);
|
||||||
|
|
||||||
|
memset(&serverAddr, 0, sizeof(serverAddr));
|
||||||
|
serverAddr.sin_family = AF_INET;
|
||||||
|
serverAddr.sin_port = htons(port);
|
||||||
|
serverAddr.sin_addr.s_addr = inet_addr(host);
|
||||||
|
|
||||||
|
int bindResult;
|
||||||
|
#ifdef _WIN32
|
||||||
|
bindResult = bind(serverSocket, (SOCKADDR*)&serverAddr, sizeof(serverAddr));
|
||||||
|
if (bindResult == SOCKET_ERROR) {
|
||||||
|
int error = WSAGetLastError();
|
||||||
|
closesocket(serverSocket);
|
||||||
|
throw std::runtime_error("Cannot bind port: " + std::to_string(error));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
bindResult = bind(serverSocket, (struct sockaddr*)&serverAddr, sizeof(serverAddr));
|
||||||
|
if (bindResult < 0) {
|
||||||
|
close(serverSocket);
|
||||||
|
throw std::runtime_error("Cannot bind port: " + std::string(strerror(errno)));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
int listenResult = listen(serverSocket, SOMAXCONN);
|
||||||
|
if (listenResult != 0) {
|
||||||
|
#ifdef _WIN32
|
||||||
|
closesocket(serverSocket);
|
||||||
|
throw std::runtime_error("Cannot listen on port: " + std::to_string(WSAGetLastError()));
|
||||||
|
#else
|
||||||
|
close(serverSocket);
|
||||||
|
throw std::runtime_error("Cannot listen on port: " + std::string(strerror(errno)));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("Listening on %s:%d...\n", host, port);
|
||||||
|
|
||||||
|
setNoDelay(serverSocket);
|
||||||
|
setQuickAck(serverSocket);
|
||||||
|
return serverSocket;
|
||||||
|
}
|
||||||
|
|
||||||
|
void destroySocket(int serverSocket) {
|
||||||
|
shutdown(serverSocket, 2);
|
||||||
|
#ifdef _WIN32
|
||||||
|
closesocket(serverSocket);
|
||||||
|
#else
|
||||||
|
close(serverSocket);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int acceptSocket(int serverSocket) {
|
||||||
|
struct sockaddr_in clientAddr;
|
||||||
|
socklen_t clientAddrSize = sizeof(clientAddr);
|
||||||
|
int clientSocket = ::accept(serverSocket, (struct sockaddr*)&clientAddr, &clientAddrSize);
|
||||||
|
if (clientSocket < 0)
|
||||||
|
throw std::runtime_error("Error accepting connection");
|
||||||
|
setNoDelay(clientSocket);
|
||||||
|
setQuickAck(clientSocket);
|
||||||
|
return clientSocket;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initSockets() {
|
||||||
|
#ifdef _WIN32
|
||||||
|
WSADATA wsaData;
|
||||||
|
if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
|
||||||
|
throw std::runtime_error("WSAStartup failed: " + std::to_string(WSAGetLastError()));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void cleanupSockets() {
|
||||||
|
#ifdef _WIN32
|
||||||
|
WSACleanup();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
NnReadNetworkException::NnReadNetworkException(int code, const char *message) {
|
||||||
|
this->code = code;
|
||||||
|
this->message = message;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnWriteNetworkException::NnWriteNetworkException(int code, const char *message) {
|
||||||
|
this->code = code;
|
||||||
|
this->message = message;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<NnNetwork> NnNetwork::serve(int port) {
|
||||||
|
int serverSocket = createServerSocket(port);
|
||||||
|
|
||||||
|
NnUint nSockets;
|
||||||
|
NnUint nodeIndex;
|
||||||
|
int rootSocket = acceptSocket(serverSocket);
|
||||||
|
printf("⭕ The root node has connected\n");
|
||||||
|
|
||||||
|
readSocket(rootSocket, &nSockets, sizeof(nSockets));
|
||||||
|
NnUint nNodes = nSockets - 1; // nSockets - 1 root node
|
||||||
|
printf("⭕ nNodes: %d\n", nNodes);
|
||||||
|
readSocket(rootSocket, &nodeIndex, sizeof(nodeIndex));
|
||||||
|
printf("⭕ NodeIndex: %d\n", nodeIndex);
|
||||||
|
|
||||||
|
int *sockets = new int[nSockets];
|
||||||
|
sockets[0] = rootSocket;
|
||||||
|
char* *hosts = new char*[nNodes];
|
||||||
|
int *ports = new int[nNodes];
|
||||||
|
printf("⭕ Socket[0]: accepted root node\n");
|
||||||
|
|
||||||
|
NnUint hostLen;
|
||||||
|
for (NnUint i = 0; i < nNodes; i++) {
|
||||||
|
readSocket(rootSocket, &hostLen, sizeof(hostLen));
|
||||||
|
hosts[i] = new char[hostLen];
|
||||||
|
readSocket(rootSocket, hosts[i], hostLen);
|
||||||
|
readSocket(rootSocket, &ports[i], sizeof(ports[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
writeAckPacket(rootSocket);
|
||||||
|
|
||||||
|
// We need to wait here until the root node will send a "root is ready" packet
|
||||||
|
readAckPacket(rootSocket);
|
||||||
|
|
||||||
|
for (NnUint i = 0; i < nNodes; i++) {
|
||||||
|
NnUint socketIndex = i + 1;
|
||||||
|
if (i >= nodeIndex) {
|
||||||
|
printf("⭕ Socket[%d]: connecting to %s:%d worker\n", socketIndex, hosts[i], ports[i]);
|
||||||
|
sockets[socketIndex] = connectSocket(hosts[i], ports[i]);
|
||||||
|
printf("⭕ Socket[%d]: connected\n", socketIndex);
|
||||||
|
} else {
|
||||||
|
printf("⭕ Socket[%d]: wait for %s:%d worker\n", socketIndex, hosts[i], ports[i]);
|
||||||
|
sockets[socketIndex] = acceptSocket(serverSocket);
|
||||||
|
printf("⭕ Socket[%d]: accepted\n", socketIndex);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (NnUint i = 0; i < nNodes; i++)
|
||||||
|
delete[] hosts[i];
|
||||||
|
delete[] hosts;
|
||||||
|
delete[] ports;
|
||||||
|
|
||||||
|
destroySocket(serverSocket);
|
||||||
|
printf("⭕ Network is initialized\n");
|
||||||
|
return std::unique_ptr<NnNetwork>(new NnNetwork(nSockets, sockets));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<NnNetwork> NnNetwork::connect(NnUint nSockets, char **hosts, NnUint *ports) {
|
||||||
|
assert(nSockets > 0);
|
||||||
|
|
||||||
|
int *sockets = new int[nSockets];
|
||||||
|
struct sockaddr_in addr;
|
||||||
|
for (NnUint i = 0; i < nSockets; i++) {
|
||||||
|
printf("⭕ Socket[%d]: connecting to %s:%d worker\n", i, hosts[i], ports[i]);
|
||||||
|
int socket = connectSocket(hosts[i], ports[i]);
|
||||||
|
sockets[i] = socket;
|
||||||
|
writeSocket(socket, &nSockets, sizeof(nSockets));
|
||||||
|
writeSocket(socket, &i, sizeof(i)); // send node index
|
||||||
|
for (NnUint j = 0; j < nSockets; j++) {
|
||||||
|
if (j == i)
|
||||||
|
continue;
|
||||||
|
NnUint hostLen = strlen(hosts[j]) + 1;
|
||||||
|
writeSocket(socket, &hostLen, sizeof(hostLen));
|
||||||
|
writeSocket(socket, hosts[j], hostLen);
|
||||||
|
writeSocket(socket, &ports[j], sizeof(ports[j]));
|
||||||
|
}
|
||||||
|
readAckPacket(socket);
|
||||||
|
printf("⭕ Socket[%d]: connected\n", i);
|
||||||
|
}
|
||||||
|
for (NnUint i = 0; i < nSockets; i++) {
|
||||||
|
writeAckPacket(sockets[i]);
|
||||||
|
}
|
||||||
|
printf("⭕ Network is initialized\n");
|
||||||
|
return std::unique_ptr<NnNetwork>(new NnNetwork(nSockets, sockets));
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNetwork::NnNetwork(NnUint nSockets, int *sockets) {
|
||||||
|
this->nSockets = nSockets;
|
||||||
|
this->sockets = sockets;
|
||||||
|
this->sentBytes = new NnSize[nSockets];
|
||||||
|
this->recvBytes = new NnSize[nSockets];
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNetwork::~NnNetwork() {
|
||||||
|
delete[] sentBytes;
|
||||||
|
delete[] recvBytes;
|
||||||
|
for (NnUint i = 0; i < nSockets; i++)
|
||||||
|
destroySocket(sockets[i]);
|
||||||
|
delete[] sockets;
|
||||||
|
printf("⭕ Network is closed\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::setTurbo(bool enabled) {
|
||||||
|
for (NnUint i = 0; i < nSockets; i++) {
|
||||||
|
::setNonBlocking(sockets[i], enabled);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::write(const NnUint socketIndex, const void *data, const NnSize size) {
|
||||||
|
assert(socketIndex < nSockets);
|
||||||
|
|
||||||
|
NnByte *current = (NnByte *)data;
|
||||||
|
int s = sockets[socketIndex];
|
||||||
|
for (NnSize chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) {
|
||||||
|
NnSize chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
|
||||||
|
writeSocket(s, current, chunkSize);
|
||||||
|
current += chunkSize;
|
||||||
|
}
|
||||||
|
sentBytes[socketIndex] += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::read(const NnUint socketIndex, void *data, const NnSize size) {
|
||||||
|
assert(socketIndex < nSockets);
|
||||||
|
|
||||||
|
NnByte *current = (NnByte *)data;
|
||||||
|
int s = sockets[socketIndex];
|
||||||
|
for (NnSize chunk = 0; chunk < size; chunk += MAX_CHUNK_SIZE) {
|
||||||
|
NnSize chunkSize = chunk + MAX_CHUNK_SIZE < size ? MAX_CHUNK_SIZE : size - chunk;
|
||||||
|
readSocket(s, current, chunkSize);
|
||||||
|
current += chunkSize;
|
||||||
|
}
|
||||||
|
recvBytes[socketIndex] += size;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::writeAck(const NnUint socketIndex) {
|
||||||
|
assert(socketIndex >= 0 && socketIndex < nSockets);
|
||||||
|
writeAckPacket(sockets[socketIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::readAck(const NnUint socketIndex) {
|
||||||
|
assert(socketIndex >= 0 && socketIndex < nSockets);
|
||||||
|
readAckPacket(sockets[socketIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool NnNetwork::tryReadWithMaxAttempts(NnUint socketIndex, void *data, NnSize size, unsigned long maxAttempts) {
|
||||||
|
assert(socketIndex >= 0 && socketIndex < nSockets);
|
||||||
|
if (tryReadSocket(sockets[socketIndex], data, size, maxAttempts)) {
|
||||||
|
recvBytes[socketIndex] += size;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::writeMany(NnUint n, NnSocketIo *ios) {
|
||||||
|
bool isWriting;
|
||||||
|
NnSize nBytes = 0;
|
||||||
|
for (NnUint i = 0; i < n; i++) {
|
||||||
|
NnSocketIo *io = &ios[i];
|
||||||
|
assert(io->socketIndex < nSockets);
|
||||||
|
sentBytes[io->socketIndex] += io->size;
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
isWriting = false;
|
||||||
|
for (NnUint i = 0; i < n; i++) {
|
||||||
|
NnSocketIo *io = &ios[i];
|
||||||
|
if (io->size > 0) {
|
||||||
|
isWriting = true;
|
||||||
|
int socket = sockets[io->socketIndex];
|
||||||
|
ssize_t chunkSize = io->size > MAX_CHUNK_SIZE ? MAX_CHUNK_SIZE : io->size;
|
||||||
|
ssize_t s = send(socket, (const char*)io->data, chunkSize, 0);
|
||||||
|
if (s < 0) {
|
||||||
|
if (isEagainError()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw NnWriteNetworkException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR);
|
||||||
|
} else if (s == 0) {
|
||||||
|
throw NnWriteNetworkException(0, "Socket closed");
|
||||||
|
}
|
||||||
|
io->size -= s;
|
||||||
|
io->data = (char*)io->data + s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (isWriting);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::writeAll(void *data, NnSize size) {
|
||||||
|
std::vector<NnSocketIo> ios(nSockets);
|
||||||
|
for (NnUint i = 0; i < nSockets; i++) {
|
||||||
|
NnSocketIo *io = &ios[i];
|
||||||
|
io->socketIndex = i;
|
||||||
|
io->data = data;
|
||||||
|
io->size = size;
|
||||||
|
}
|
||||||
|
writeMany(nSockets, &ios[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::readMany(NnUint n, NnSocketIo *ios) {
|
||||||
|
bool isReading;
|
||||||
|
NnSize nBytes = 0;
|
||||||
|
for (NnUint i = 0; i < n; i++) {
|
||||||
|
NnSocketIo *io = &ios[i];
|
||||||
|
assert(io->socketIndex < nSockets);
|
||||||
|
recvBytes[io->socketIndex] += io->size;
|
||||||
|
}
|
||||||
|
do {
|
||||||
|
isReading = false;
|
||||||
|
for (NnUint i = 0; i < n; i++) {
|
||||||
|
NnSocketIo *io = &ios[i];
|
||||||
|
if (io->size > 0) {
|
||||||
|
isReading = true;
|
||||||
|
int socket = sockets[io->socketIndex];
|
||||||
|
ssize_t r = recv(socket, (char*)io->data, io->size, 0);
|
||||||
|
if (r < 0) {
|
||||||
|
if (isEagainError()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw NnReadNetworkException(SOCKET_LAST_ERRCODE, SOCKET_LAST_ERROR);
|
||||||
|
} else if (r == 0) {
|
||||||
|
throw NnReadNetworkException(0, "Socket closed");
|
||||||
|
}
|
||||||
|
io->size -= r;
|
||||||
|
io->data = (char*)io->data + r;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (isReading);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::getStats(NnSize *sentBytes, NnSize *recvBytes) {
|
||||||
|
*sentBytes = 0;
|
||||||
|
*recvBytes = 0;
|
||||||
|
for (NnUint i = 0; i < nSockets; i++) {
|
||||||
|
*sentBytes += this->sentBytes[i];
|
||||||
|
*recvBytes += this->recvBytes[i];
|
||||||
|
}
|
||||||
|
resetStats();
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetwork::resetStats() {
|
||||||
|
for (NnUint i = 0; i < nSockets; i++) {
|
||||||
|
sentBytes[i] = 0;
|
||||||
|
recvBytes[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void syncWithRoot(NnNetwork *network, NnByte nodeIndex, NnByte *buffer, NnSize nBytes, NnUint nThreads, NnUint threadIndex) {
|
||||||
|
if (nodeIndex == 0) {
|
||||||
|
// root
|
||||||
|
|
||||||
|
NnUint nSocketsPerThread = network->nSockets / nThreads + (network->nSockets % nThreads > threadIndex ? 1 : 0);
|
||||||
|
if (nSocketsPerThread == 0) return;
|
||||||
|
|
||||||
|
std::vector<NnSocketIo> ios(nSocketsPerThread);
|
||||||
|
for (NnUint i = 0; i < nSocketsPerThread; i++) {
|
||||||
|
ios[i].socketIndex = threadIndex + i * nThreads;
|
||||||
|
ios[i].data = buffer;
|
||||||
|
ios[i].size = nBytes;
|
||||||
|
}
|
||||||
|
network->writeMany(nSocketsPerThread, &ios[0]);
|
||||||
|
} else {
|
||||||
|
// worker
|
||||||
|
|
||||||
|
if (threadIndex != 0) return;
|
||||||
|
|
||||||
|
NnSocketIo ios;
|
||||||
|
ios.data = buffer;
|
||||||
|
ios.size = nBytes;
|
||||||
|
ios.socketIndex = 0; // root
|
||||||
|
network->readMany(1, &ios);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void syncNodeSlices(bool onlyFromWorkerToRoot, NnNetwork *network, NnUint nodeIndex, NnUint nNodes, NnByte *buffer, NnSize nBytes, NnUint nThreads, NnUint threadIndex) {
|
||||||
|
bool isWorker = nodeIndex != 0;
|
||||||
|
NnUint nSockets = onlyFromWorkerToRoot && isWorker ? 1 : network->nSockets;
|
||||||
|
NnUint nSocketsPerThread = nSockets / nThreads + (nSockets % nThreads > threadIndex ? 1 : 0);
|
||||||
|
if (nSocketsPerThread == 0) return;
|
||||||
|
NnSize sliceBytes = nBytes / nNodes;
|
||||||
|
|
||||||
|
std::vector<NnSocketIo> ios(nSocketsPerThread);
|
||||||
|
|
||||||
|
if (!onlyFromWorkerToRoot || isWorker) {
|
||||||
|
NnByte *mySliceData = &buffer[sliceBytes * nodeIndex];
|
||||||
|
|
||||||
|
for (NnUint i = 0; i < nSocketsPerThread; i++) {
|
||||||
|
NnUint socketIndex = threadIndex + i * nThreads;
|
||||||
|
ios[i].socketIndex = socketIndex;
|
||||||
|
ios[i].data = mySliceData;
|
||||||
|
ios[i].size = sliceBytes;
|
||||||
|
}
|
||||||
|
network->writeMany(nSocketsPerThread, &ios[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!onlyFromWorkerToRoot || !isWorker) {
|
||||||
|
for (NnUint i = 0; i < nSocketsPerThread; i++) {
|
||||||
|
NnUint socketIndex = threadIndex + i * nThreads;
|
||||||
|
NnUint sliceIndex = socketIndex >= nodeIndex ? socketIndex + 1 : socketIndex;
|
||||||
|
NnByte *sliceData = &buffer[sliceBytes * sliceIndex];
|
||||||
|
ios[i].socketIndex = socketIndex;
|
||||||
|
ios[i].data = sliceData;
|
||||||
|
ios[i].size = sliceBytes;
|
||||||
|
}
|
||||||
|
network->readMany(nSocketsPerThread, &ios[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNetworkNodeSynchronizer::NnNetworkNodeSynchronizer(NnNetwork *network, NnNetExecution *execution, NnNetConfig *netConfig, NnNodeConfig *nodeConfig) {
|
||||||
|
this->network = network;
|
||||||
|
this->execution = execution;
|
||||||
|
this->netConfig = netConfig;
|
||||||
|
this->nodeConfig = nodeConfig;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnNetworkNodeSynchronizer::sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) {
|
||||||
|
NnSegmentConfig *segmentConfig = &nodeConfig->segments[segmentIndex];
|
||||||
|
|
||||||
|
for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) {
|
||||||
|
NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex];
|
||||||
|
NnByte *pipe = execution->pipes[syncConfig->pipeIndex];
|
||||||
|
NnPipeConfig *pipeConfig = &netConfig->pipes[syncConfig->pipeIndex];
|
||||||
|
NnSize batchBytes = getBytes(pipeConfig->size.floatType, pipeConfig->size.x);
|
||||||
|
|
||||||
|
for (NnUint batchIndex = 0; batchIndex < execution->batchSize; batchIndex++) {
|
||||||
|
NnByte *pipeBatch = &pipe[batchIndex * batchBytes];
|
||||||
|
|
||||||
|
if (syncConfig->syncType == SYNC_WITH_ROOT) {
|
||||||
|
syncWithRoot(network, nodeConfig->nodeIndex, pipeBatch, batchBytes, nThreads, threadIndex);
|
||||||
|
} else if (syncConfig->syncType == SYNC_NODE_SLICES) {
|
||||||
|
syncNodeSlices(false, network, nodeConfig->nodeIndex, netConfig->nNodes, pipeBatch, batchBytes, nThreads, threadIndex);
|
||||||
|
} else if (syncConfig->syncType == SYNC_NODE_SLICES_EXCEPT_ROOT) {
|
||||||
|
syncNodeSlices(true, network, nodeConfig->nodeIndex, netConfig->nNodes, pipeBatch, batchBytes, nThreads, threadIndex);
|
||||||
|
} else {
|
||||||
|
throw std::invalid_argument("Unknown sync type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void writeString(NnNetwork *network, NnUint socketIndex, char *str) {
|
||||||
|
NnUint bytes = std::strlen(str) + 1;
|
||||||
|
network->write(socketIndex, &bytes, sizeof(NnUint));
|
||||||
|
network->write(socketIndex, str, bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
static char *readString(NnNetwork *network, NnUint socketIndex) {
|
||||||
|
NnUint bytes;
|
||||||
|
network->read(socketIndex, &bytes, sizeof(NnUint));
|
||||||
|
char *str = new char[bytes];
|
||||||
|
network->read(socketIndex, str, bytes);
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnRootConfigWriter::NnRootConfigWriter(NnNetwork *network) {
|
||||||
|
this->network = network;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnRootConfigWriter::writeNet(NnUint socketIndex, NnNetConfig *config) {
|
||||||
|
network->writeAck(socketIndex);
|
||||||
|
network->write(socketIndex, &config->nBatches, sizeof(config->nBatches));
|
||||||
|
network->write(socketIndex, &config->nNodes, sizeof(config->nNodes));
|
||||||
|
network->write(socketIndex, &config->nPipes, sizeof(config->nPipes));
|
||||||
|
for (NnUint pipeIndex = 0; pipeIndex < config->nPipes; pipeIndex++) {
|
||||||
|
NnPipeConfig *pipeConfig = &config->pipes[pipeIndex];
|
||||||
|
network->write(socketIndex, &pipeConfig->size, sizeof(pipeConfig->size));
|
||||||
|
writeString(network, socketIndex, pipeConfig->name);
|
||||||
|
}
|
||||||
|
network->write(socketIndex, &config->nPreSyncs, sizeof(config->nPreSyncs));
|
||||||
|
for (NnUint preSyncIndex = 0; preSyncIndex < config->nPreSyncs; preSyncIndex++) {
|
||||||
|
NnPreSyncConfig *preSyncConfig = &config->preSyncs[preSyncIndex];
|
||||||
|
network->write(socketIndex, &preSyncConfig->pipeIndex, sizeof(preSyncConfig->pipeIndex));
|
||||||
|
}
|
||||||
|
network->readAck(socketIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnRootConfigWriter::writeNode(NnUint socketIndex, NnNodeConfig *config) {
|
||||||
|
network->writeAck(socketIndex);
|
||||||
|
network->write(socketIndex, &config->nodeIndex, sizeof(config->nodeIndex));
|
||||||
|
network->write(socketIndex, &config->nBuffers, sizeof(config->nBuffers));
|
||||||
|
network->write(socketIndex, &config->nSegments, sizeof(config->nSegments));
|
||||||
|
|
||||||
|
for (NnUint bufferIndex = 0; bufferIndex < config->nBuffers; bufferIndex++) {
|
||||||
|
NnBufferConfig *bufferConfig = &config->buffers[bufferIndex];
|
||||||
|
network->write(socketIndex, &bufferConfig->size, sizeof(bufferConfig->size));
|
||||||
|
writeString(network, socketIndex, bufferConfig->name);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (NnUint segmentIndex = 0; segmentIndex < config->nSegments; segmentIndex++) {
|
||||||
|
NnSegmentConfig *segmentConfig = &config->segments[segmentIndex];
|
||||||
|
network->write(socketIndex, &segmentConfig->nSyncs, sizeof(segmentConfig->nSyncs));
|
||||||
|
network->write(socketIndex, &segmentConfig->nOps, sizeof(segmentConfig->nOps));
|
||||||
|
|
||||||
|
for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) {
|
||||||
|
NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex];
|
||||||
|
network->write(socketIndex, &syncConfig->pipeIndex, sizeof(syncConfig->pipeIndex));
|
||||||
|
network->write(socketIndex, &syncConfig->syncType, sizeof(syncConfig->syncType));
|
||||||
|
}
|
||||||
|
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
|
||||||
|
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
|
||||||
|
network->write(socketIndex, &opConfig->code, sizeof(opConfig->code));
|
||||||
|
network->write(socketIndex, &opConfig->index, sizeof(opConfig->index));
|
||||||
|
network->write(socketIndex, &opConfig->weightSize, sizeof(opConfig->weightSize));
|
||||||
|
network->write(socketIndex, &opConfig->configSize, sizeof(opConfig->configSize));
|
||||||
|
writeString(network, socketIndex, opConfig->name);
|
||||||
|
network->write(socketIndex, &opConfig->input, sizeof(opConfig->input));
|
||||||
|
network->write(socketIndex, &opConfig->output, sizeof(opConfig->output));
|
||||||
|
if (opConfig->configSize > 0)
|
||||||
|
network->write(socketIndex, opConfig->config, opConfig->configSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
network->readAck(socketIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnRootConfigWriter::writeToWorkers(NnNetConfig *netConfig, NnNodeConfig *nodeConfigs) {
|
||||||
|
for (NnUint nodeIndex = 1; nodeIndex < netConfig->nNodes; nodeIndex++) {
|
||||||
|
NnUint socketIndex = nodeIndex - 1;
|
||||||
|
writeNet(socketIndex, netConfig);
|
||||||
|
writeNode(socketIndex, &nodeConfigs[nodeIndex]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NnWorkerConfigReader::NnWorkerConfigReader(NnNetwork *network) {
|
||||||
|
this->network = network;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNetConfig NnWorkerConfigReader::readNet() {
|
||||||
|
network->readAck(ROOT_SOCKET_INDEX);
|
||||||
|
NnNetConfig config;
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &config.nBatches, sizeof(config.nBatches));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &config.nNodes, sizeof(config.nNodes));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &config.nPipes, sizeof(config.nPipes));
|
||||||
|
config.pipes = new NnPipeConfig[config.nPipes];
|
||||||
|
for (NnUint pipeIndex = 0; pipeIndex < config.nPipes; pipeIndex++) {
|
||||||
|
NnPipeConfig *pipeConfig = &config.pipes[pipeIndex];
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &pipeConfig->size, sizeof(pipeConfig->size));
|
||||||
|
pipeConfig->name = readString(network, ROOT_SOCKET_INDEX);
|
||||||
|
}
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &config.nPreSyncs, sizeof(config.nPreSyncs));
|
||||||
|
config.preSyncs = new NnPreSyncConfig[config.nPreSyncs];
|
||||||
|
for (NnUint preSyncIndex = 0; preSyncIndex < config.nPreSyncs; preSyncIndex++) {
|
||||||
|
NnPreSyncConfig *preSyncConfig = &config.preSyncs[preSyncIndex];
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &preSyncConfig->pipeIndex, sizeof(preSyncConfig->pipeIndex));
|
||||||
|
}
|
||||||
|
network->writeAck(ROOT_SOCKET_INDEX);
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnNodeConfig NnWorkerConfigReader::readNode() {
|
||||||
|
network->readAck(ROOT_SOCKET_INDEX);
|
||||||
|
|
||||||
|
NnNodeConfig config;
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &config.nodeIndex, sizeof(config.nodeIndex));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &config.nBuffers, sizeof(config.nBuffers));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &config.nSegments, sizeof(config.nSegments));
|
||||||
|
|
||||||
|
config.buffers = new NnBufferConfig[config.nBuffers];
|
||||||
|
config.segments = new NnSegmentConfig[config.nSegments];
|
||||||
|
|
||||||
|
for (NnUint bufferIndex = 0; bufferIndex < config.nBuffers; bufferIndex++) {
|
||||||
|
NnBufferConfig *bufferConfig = &config.buffers[bufferIndex];
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &bufferConfig->size, sizeof(bufferConfig->size));
|
||||||
|
bufferConfig->name = readString(network, ROOT_SOCKET_INDEX);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (NnUint segmentIndex = 0; segmentIndex < config.nSegments; segmentIndex++) {
|
||||||
|
NnSegmentConfig *segmentConfig = &config.segments[segmentIndex];
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &segmentConfig->nSyncs, sizeof(segmentConfig->nSyncs));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &segmentConfig->nOps, sizeof(segmentConfig->nOps));
|
||||||
|
|
||||||
|
if (segmentConfig->nSyncs > 0) {
|
||||||
|
segmentConfig->syncs = new NnSyncConfig[segmentConfig->nSyncs];
|
||||||
|
|
||||||
|
for (NnUint syncIndex = 0; syncIndex < segmentConfig->nSyncs; syncIndex++) {
|
||||||
|
NnSyncConfig *syncConfig = &segmentConfig->syncs[syncIndex];
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &syncConfig->pipeIndex, sizeof(syncConfig->pipeIndex));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &syncConfig->syncType, sizeof(syncConfig->syncType));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (segmentConfig->nOps > 0) {
|
||||||
|
segmentConfig->ops = new NnOpConfig[segmentConfig->nOps];
|
||||||
|
|
||||||
|
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
|
||||||
|
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &opConfig->code, sizeof(opConfig->code));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &opConfig->index, sizeof(opConfig->index));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &opConfig->weightSize, sizeof(opConfig->weightSize));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &opConfig->configSize, sizeof(opConfig->configSize));
|
||||||
|
opConfig->name = readString(network, ROOT_SOCKET_INDEX);
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &opConfig->input, sizeof(opConfig->input));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &opConfig->output, sizeof(opConfig->output));
|
||||||
|
if (opConfig->configSize > 0) {
|
||||||
|
opConfig->config = new NnByte[opConfig->configSize];
|
||||||
|
network->read(ROOT_SOCKET_INDEX, opConfig->config, opConfig->configSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
network->writeAck(ROOT_SOCKET_INDEX);
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnRootWeightLoader::NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnUint nNodes) {
|
||||||
|
this->executor = executor;
|
||||||
|
this->network = network;
|
||||||
|
this->nNodes = nNodes;
|
||||||
|
this->tempSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnRootWeightLoader::~NnRootWeightLoader() {
|
||||||
|
if (tempSize > 0)
|
||||||
|
delete[] temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnRootWeightLoader::finish() {
|
||||||
|
NnUint zeroSize = 0;
|
||||||
|
for (NnUint socketIndex = 0; socketIndex < nNodes - 1; socketIndex++) {
|
||||||
|
network->write(socketIndex, &zeroSize, sizeof(zeroSize));
|
||||||
|
network->readAck(socketIndex);
|
||||||
|
}
|
||||||
|
if (tempSize > 0) {
|
||||||
|
delete[] temp;
|
||||||
|
tempSize = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnRootWeightLoader::allocate(NnSize size) {
|
||||||
|
if (tempSize < size) {
|
||||||
|
if (tempSize > 0)
|
||||||
|
delete[] temp;
|
||||||
|
tempSize = size;
|
||||||
|
temp = new NnByte[size];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnRootWeightLoader::writeWeight(NnUint nodeIndex, const char *opName, NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) {
|
||||||
|
NnUint nameSize = std::strlen(opName) + 1;
|
||||||
|
NnUint socketIndex = nodeIndex - 1;
|
||||||
|
network->write(socketIndex, &nameSize, sizeof(nameSize));
|
||||||
|
network->write(socketIndex, opName, nameSize);
|
||||||
|
network->write(socketIndex, &opIndex, sizeof(opIndex));
|
||||||
|
network->write(socketIndex, &offset, sizeof(offset));
|
||||||
|
network->write(socketIndex, &nBytes, sizeof(nBytes));
|
||||||
|
network->write(socketIndex, weight, nBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize NnRootWeightLoader::loadRoot(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight) {
|
||||||
|
executor->loadWeight(opName, opIndex, 0u, nBytes, weight);
|
||||||
|
return nBytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize NnRootWeightLoader::loadAll(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight) {
|
||||||
|
executor->loadWeight(opName, opIndex, 0u, nBytes, weight);
|
||||||
|
|
||||||
|
if (nNodes > 1u) {
|
||||||
|
for (NnUint nodeIndex = 1u; nodeIndex < nNodes; nodeIndex++)
|
||||||
|
writeWeight(nodeIndex, opName, opIndex, 0u, nBytes, weight);
|
||||||
|
}
|
||||||
|
return nBytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize NnRootWeightLoader::loadRowMatmulSlices(const char *opName, const NnUint opIndex, const NnUint expertIndex, NnRowMatmulSlice *slice, NnByte *weight) {
|
||||||
|
const NnUint offset = expertIndex * slice->sliceSize.nBytes;
|
||||||
|
if (nNodes == 1u) {
|
||||||
|
executor->loadWeight(opName, opIndex, offset, slice->sliceSize.nBytes, weight);
|
||||||
|
} else {
|
||||||
|
allocate(slice->sliceSize.nBytes);
|
||||||
|
for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) {
|
||||||
|
splitRowMatmulWeight(slice, nodeIndex, weight, temp);
|
||||||
|
if (nodeIndex == 0u)
|
||||||
|
executor->loadWeight(opName, opIndex, offset, slice->sliceSize.nBytes, temp);
|
||||||
|
else
|
||||||
|
writeWeight(nodeIndex, opName, opIndex, offset, slice->sliceSize.nBytes, temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return slice->size.nBytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnSize NnRootWeightLoader::loadColMatmulSlices(const char *opName, const NnUint opIndex, const NnUint expertIndex, NnColMatmulSlice *slice, NnByte *weight) {
|
||||||
|
const NnUint offset = expertIndex * slice->sliceSize.nBytes;
|
||||||
|
if (nNodes == 1) {
|
||||||
|
executor->loadWeight(opName, opIndex, offset, slice->sliceSize.nBytes, weight);
|
||||||
|
} else {
|
||||||
|
allocate(slice->sliceSize.nBytes);
|
||||||
|
for (NnUint nodeIndex = 0; nodeIndex < nNodes; nodeIndex++) {
|
||||||
|
splitColMatmulWeight(slice, nodeIndex, weight, temp);
|
||||||
|
if (nodeIndex == 0)
|
||||||
|
executor->loadWeight(opName, opIndex, offset, slice->sliceSize.nBytes, temp);
|
||||||
|
else
|
||||||
|
writeWeight(nodeIndex, opName, opIndex, offset, slice->sliceSize.nBytes, temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return slice->size.nBytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnWorkerWeightReader::NnWorkerWeightReader(NnExecutor *executor, NnNetwork *network) {
|
||||||
|
this->executor = executor;
|
||||||
|
this->network = network;
|
||||||
|
this->tempSize = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnWorkerWeightReader::~NnWorkerWeightReader() {
|
||||||
|
if (tempSize > 0)
|
||||||
|
delete[] temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnWorkerWeightReader::allocate(NnUint size) {
|
||||||
|
if (tempSize < size) {
|
||||||
|
if (tempSize > 0)
|
||||||
|
delete[] temp;
|
||||||
|
tempSize = size;
|
||||||
|
temp = new NnByte[size];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void NnWorkerWeightReader::read() {
|
||||||
|
NnUint nameSize;
|
||||||
|
NnUint opIndex;
|
||||||
|
NnSize offset;
|
||||||
|
NnSize nBytes;
|
||||||
|
while (true) {
|
||||||
|
network->read(0, &nameSize, sizeof(nameSize));
|
||||||
|
if (nameSize == 0) {
|
||||||
|
network->writeAck(ROOT_SOCKET_INDEX);
|
||||||
|
if (tempSize > 0) {
|
||||||
|
delete temp;
|
||||||
|
tempSize = 0;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::unique_ptr<char[]> opNamePtr(new char[nameSize]);
|
||||||
|
char *opName = opNamePtr.get();
|
||||||
|
network->read(ROOT_SOCKET_INDEX, opName, nameSize);
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &opIndex, sizeof(opIndex));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &offset, sizeof(offset));
|
||||||
|
network->read(ROOT_SOCKET_INDEX, &nBytes, sizeof(nBytes));
|
||||||
|
allocate(nBytes);
|
||||||
|
network->read(0, temp, nBytes);
|
||||||
|
executor->loadWeight(opName, opIndex, offset, nBytes, temp);
|
||||||
|
printf("💿 Loaded %22s %3d, %12zu kB\n", opName, opIndex, nBytes / 1024);
|
||||||
|
}
|
||||||
|
printf("💿 Weights loaded\n");
|
||||||
|
}
|
||||||
129
src/nn/nn-network.hpp
Normal file
129
src/nn/nn-network.hpp
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
#ifndef NN_NETWORK_H
|
||||||
|
#define NN_NETWORK_H
|
||||||
|
|
||||||
|
#include "nn-executor.hpp"
|
||||||
|
|
||||||
|
#define ROOT_SOCKET_INDEX 0
|
||||||
|
|
||||||
|
void initSockets();
|
||||||
|
void cleanupSockets();
|
||||||
|
int acceptSocket(int serverSocket);
|
||||||
|
void setReuseAddr(int socket);
|
||||||
|
void writeSocket(int socket, const void* data, NnSize size);
|
||||||
|
void readSocket(int socket, void* data, NnSize size);
|
||||||
|
int createServerSocket(int port);
|
||||||
|
void destroySocket(int serverSocket);
|
||||||
|
|
||||||
|
class NnReadNetworkException : public std::exception {
|
||||||
|
public:
|
||||||
|
int code;
|
||||||
|
const char *message;
|
||||||
|
NnReadNetworkException(int code, const char *message);
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnWriteNetworkException : public std::exception {
|
||||||
|
public:
|
||||||
|
int code;
|
||||||
|
const char *message;
|
||||||
|
NnWriteNetworkException(int code, const char *message);
|
||||||
|
};
|
||||||
|
|
||||||
|
struct NnSocketIo {
|
||||||
|
NnUint socketIndex;
|
||||||
|
const void *data;
|
||||||
|
NnSize size;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnNetwork {
|
||||||
|
private:
|
||||||
|
int *sockets;
|
||||||
|
NnSize *sentBytes;
|
||||||
|
NnSize *recvBytes;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static std::unique_ptr<NnNetwork> serve(int port);
|
||||||
|
static std::unique_ptr<NnNetwork> connect(NnUint nSockets, char **hosts, NnUint *ports);
|
||||||
|
|
||||||
|
NnUint nSockets;
|
||||||
|
|
||||||
|
NnNetwork(NnUint nSockets, int *sockets);
|
||||||
|
~NnNetwork();
|
||||||
|
|
||||||
|
void setTurbo(bool enabled);
|
||||||
|
void write(const NnUint socketIndex, const void *data, const NnSize size);
|
||||||
|
void read(const NnUint socketIndex, void *data, const NnSize size);
|
||||||
|
void writeAck(const NnUint socketIndex);
|
||||||
|
void readAck(const NnUint socketIndex);
|
||||||
|
bool tryReadWithMaxAttempts(NnUint socketIndex, void *data, NnSize size, unsigned long maxAttempts);
|
||||||
|
void writeMany(NnUint n, NnSocketIo *ios);
|
||||||
|
void writeAll(void *data, NnSize size);
|
||||||
|
void readMany(NnUint n, NnSocketIo *ios);
|
||||||
|
void getStats(NnSize *sentBytes, NnSize *recvBytes);
|
||||||
|
void resetStats();
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnNetworkNodeSynchronizer : public NnNodeSynchronizer {
|
||||||
|
private:
|
||||||
|
NnNetwork *network;
|
||||||
|
NnNetExecution *execution;
|
||||||
|
NnNetConfig *netConfig;
|
||||||
|
NnNodeConfig *nodeConfig;
|
||||||
|
public:
|
||||||
|
NnNetworkNodeSynchronizer(NnNetwork *network, NnNetExecution *execution, NnNetConfig *netConfig, NnNodeConfig *nodeConfig);
|
||||||
|
~NnNetworkNodeSynchronizer() override {};
|
||||||
|
void sync(NnUint segmentIndex, NnUint nThreads, NnUint threadIndex) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnRootConfigWriter {
|
||||||
|
private:
|
||||||
|
NnNetwork *network;
|
||||||
|
public:
|
||||||
|
NnRootConfigWriter(NnNetwork *network);
|
||||||
|
void writeNet(NnUint socketIndex, NnNetConfig *config);
|
||||||
|
void writeNode(NnUint socketIndex, NnNodeConfig *config);
|
||||||
|
void writeToWorkers(NnNetConfig *netConfig, NnNodeConfig *nodeConfigs);
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnWorkerConfigReader {
|
||||||
|
private:
|
||||||
|
NnNetwork *network;
|
||||||
|
public:
|
||||||
|
NnWorkerConfigReader(NnNetwork *network);
|
||||||
|
NnNetConfig readNet();
|
||||||
|
NnNodeConfig readNode();
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnRootWeightLoader {
|
||||||
|
private:
|
||||||
|
NnExecutor *executor;
|
||||||
|
NnNetwork *network;
|
||||||
|
NnUint nNodes;
|
||||||
|
NnByte *temp;
|
||||||
|
NnSize tempSize;
|
||||||
|
public:
|
||||||
|
NnRootWeightLoader(NnExecutor *executor, NnNetwork *network, NnUint nNodes);
|
||||||
|
~NnRootWeightLoader();
|
||||||
|
void writeWeight(NnUint nodeIndex, const char *opName, NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight);
|
||||||
|
NnSize loadRoot(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight);
|
||||||
|
NnSize loadAll(const char *opName, NnUint opIndex, NnSize nBytes, NnByte *weight);
|
||||||
|
NnSize loadRowMatmulSlices(const char *opName, const NnUint opIndex, const NnUint expertIndex, NnRowMatmulSlice *slice, NnByte *weight);
|
||||||
|
NnSize loadColMatmulSlices(const char *opName, const NnUint opIndex, const NnUint expertIndex, NnColMatmulSlice *slice, NnByte *weight);
|
||||||
|
void finish();
|
||||||
|
private:
|
||||||
|
void allocate(NnSize size);};
|
||||||
|
|
||||||
|
class NnWorkerWeightReader {
|
||||||
|
private:
|
||||||
|
NnExecutor *executor;
|
||||||
|
NnNetwork *network;
|
||||||
|
NnByte *temp;
|
||||||
|
NnUint tempSize;
|
||||||
|
public:
|
||||||
|
NnWorkerWeightReader(NnExecutor *executor, NnNetwork *network);
|
||||||
|
~NnWorkerWeightReader();
|
||||||
|
void read();
|
||||||
|
private:
|
||||||
|
void allocate(NnUint size);
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
255
src/nn/nn-quants.cpp
Normal file
255
src/nn/nn-quants.cpp
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
#include "nn-quants.hpp"
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include <cmath>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#if defined(CONVERT_F16_TO_F32_LOOKUP)
|
||||||
|
float f16ToF32Lookup[65536];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void initQuants() {
|
||||||
|
#if defined(CONVERT_F16_TO_F32_LOOKUP)
|
||||||
|
for (NnUint i = 0; i < 65536; i++)
|
||||||
|
f16ToF32Lookup[i] = convertF16toF32Impl((NnFp16)i);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
float convertF16toF32Impl(const NnFp16 value) {
|
||||||
|
union Fl32 {
|
||||||
|
uint32_t u;
|
||||||
|
float f;
|
||||||
|
};
|
||||||
|
const Fl32 magic = { (254U - 15U) << 23 };
|
||||||
|
const Fl32 infNan = { (127U + 16U) << 23 };
|
||||||
|
Fl32 result;
|
||||||
|
result.u = (value & 0x7FFFU) << 13;
|
||||||
|
result.f *= magic.f;
|
||||||
|
if (result.f >= infNan.f)
|
||||||
|
result.u |= 255U << 23;
|
||||||
|
result.u |= (value & 0x8000U) << 16;
|
||||||
|
return result.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
NnFp16 convertF32ToF16Impl(const float x) {
|
||||||
|
int i = *(int *)&x;
|
||||||
|
int s = (i >> 16) & 0x00008000;
|
||||||
|
int e = ((i >> 23) & 0x000000ff) - (127 - 15);
|
||||||
|
int m = i & 0x007fffff;
|
||||||
|
if (e <= 0) {
|
||||||
|
if (e < -10) {
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
m = m | 0x00800000;
|
||||||
|
int t = 14 - e;
|
||||||
|
int a = (1 << (t - 1)) - 1;
|
||||||
|
int b = (m >> t) & 1;
|
||||||
|
m = (m + a + b) >> t;
|
||||||
|
return s | m;
|
||||||
|
}
|
||||||
|
if (e == 0xff - (127 - 15)) {
|
||||||
|
if (m == 0) {
|
||||||
|
return s | 0x7c00;
|
||||||
|
}
|
||||||
|
m >>= 13;
|
||||||
|
return s | 0x7c00 | m | (m == 0);
|
||||||
|
}
|
||||||
|
m = m + 0x00000fff + ((m >> 13) & 1);
|
||||||
|
if (m & 0x00800000) {
|
||||||
|
m = 0;
|
||||||
|
e += 1;
|
||||||
|
}
|
||||||
|
assert(e <= 30);
|
||||||
|
return s | (e << 10) | (m >> 13);
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) {
|
||||||
|
assert(n % Q80_BLOCK_SIZE == 0);
|
||||||
|
const NnUint nBlocks = n / Q80_BLOCK_SIZE;
|
||||||
|
SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex);
|
||||||
|
|
||||||
|
#if defined(__ARM_NEON)
|
||||||
|
for (NnUint i = start; i < end; i++) {
|
||||||
|
const float *x = &input[i * Q80_BLOCK_SIZE];
|
||||||
|
NnBlockQ80 *y = &output[i];
|
||||||
|
|
||||||
|
float32x4_t amaxVec = vdupq_n_f32(0.0f);
|
||||||
|
for (NnUint j = 0; j < Q80_BLOCK_SIZE; j += 4) {
|
||||||
|
const float32x4_t vec = vld1q_f32(&x[j]);
|
||||||
|
const float32x4_t abs_vec = vabsq_f32(vec);
|
||||||
|
amaxVec = vmaxq_f32(amaxVec, abs_vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
float amax = vmaxvq_f32(amaxVec);
|
||||||
|
|
||||||
|
const float d = amax / 127.0f;
|
||||||
|
const float id = d != 0.0f ? 1.0f / d : 0.0f;
|
||||||
|
|
||||||
|
y->d = CONVERT_F32_TO_F16(d);
|
||||||
|
|
||||||
|
const float32x4_t vid_vec = vdupq_n_f32(id);
|
||||||
|
|
||||||
|
for (NnUint j = 0; j < Q80_BLOCK_SIZE; j += 4) {
|
||||||
|
float32x4_t vec = vld1q_f32(&x[j]);
|
||||||
|
vec = vmulq_f32(vec, vid_vec);
|
||||||
|
|
||||||
|
const uint32x4_t sign_mask = vcgeq_f32(vec, vdupq_n_f32(0.0f));
|
||||||
|
const float32x4_t half = vbslq_f32(sign_mask, vdupq_n_f32(0.5f), vdupq_n_f32(-0.5f));
|
||||||
|
vec = vaddq_f32(vec, half);
|
||||||
|
|
||||||
|
const int32x4_t vec_i32 = vcvtq_s32_f32(vec);
|
||||||
|
const int16x4_t vec_i16 = vqmovn_s32(vec_i32);
|
||||||
|
const int8x8_t vec_i8 = vqmovn_s16(vcombine_s16(vec_i16, vec_i16));
|
||||||
|
|
||||||
|
vst1_lane_s32((int32_t *)(y->qs + j), vreinterpret_s32_s8(vec_i8), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#elif defined(__AVX2__)
|
||||||
|
for (NnUint i = start; i < end; ++i) {
|
||||||
|
const float *x = input + i * Q80_BLOCK_SIZE;
|
||||||
|
NnBlockQ80 *y = output + i;
|
||||||
|
|
||||||
|
__m256 max_abs = _mm256_setzero_ps();
|
||||||
|
for (int j = 0; j < Q80_BLOCK_SIZE; j += 8) {
|
||||||
|
__m256 vec = _mm256_loadu_ps(x + j);
|
||||||
|
__m256 abs_vec = _mm256_and_ps(vec, _mm256_castsi256_ps(_mm256_set1_epi32(0x7FFFFFFF)));
|
||||||
|
max_abs = _mm256_max_ps(max_abs, abs_vec);
|
||||||
|
}
|
||||||
|
__m128 max_hi = _mm256_extractf128_ps(max_abs, 1);
|
||||||
|
__m128 max_lo = _mm256_castps256_ps128(max_abs);
|
||||||
|
__m128 max_128 = _mm_max_ps(max_hi, max_lo);
|
||||||
|
max_128 = _mm_max_ps(max_128, _mm_movehl_ps(max_128, max_128));
|
||||||
|
max_128 = _mm_max_ss(max_128, _mm_shuffle_ps(max_128, max_128, _MM_SHUFFLE(1, 1, 1, 1)));
|
||||||
|
float amax = _mm_cvtss_f32(max_128);
|
||||||
|
|
||||||
|
const float d = amax / 127.0f;
|
||||||
|
const float id = (d != 0.0f) ? 1.0f / d : 0.0f;
|
||||||
|
y->d = CONVERT_F32_TO_F16(d);
|
||||||
|
|
||||||
|
const __m256 id_vec = _mm256_set1_ps(id);
|
||||||
|
const __m128i shuffle_mask = _mm_set_epi8(
|
||||||
|
-1, -1, -1, -1, -1, -1, -1, -1,
|
||||||
|
-1, -1, -1, -1, 12, 8, 4, 0
|
||||||
|
);
|
||||||
|
|
||||||
|
for (int j = 0; j < Q80_BLOCK_SIZE; j += 8) {
|
||||||
|
__m256 vec = _mm256_loadu_ps(x + j);
|
||||||
|
__m256 scaled = _mm256_mul_ps(vec, id_vec);
|
||||||
|
__m256 rounded = _mm256_round_ps(scaled, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
|
||||||
|
__m256i integers = _mm256_cvtps_epi32(rounded);
|
||||||
|
|
||||||
|
__m128i low = _mm256_extracti128_si256(integers, 0);
|
||||||
|
__m128i high = _mm256_extracti128_si256(integers, 1);
|
||||||
|
|
||||||
|
__m128i low_bytes = _mm_shuffle_epi8(low, shuffle_mask);
|
||||||
|
__m128i high_bytes = _mm_shuffle_epi8(high, shuffle_mask);
|
||||||
|
|
||||||
|
uint32_t low_part = _mm_extract_epi32(low_bytes, 0);
|
||||||
|
uint32_t high_part = _mm_extract_epi32(high_bytes, 0);
|
||||||
|
uint64_t packed = (static_cast<uint64_t>(high_part) << 32) | low_part;
|
||||||
|
std::memcpy(y->qs + j, &packed, sizeof(packed));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
for (NnUint i = start; i < end; i++) {
|
||||||
|
const float *x = &input[i * Q80_BLOCK_SIZE];
|
||||||
|
NnBlockQ80 *y = &output[i];
|
||||||
|
|
||||||
|
float amax = 0.0f;
|
||||||
|
for (NnUint j = 0; j < Q80_BLOCK_SIZE; j++) {
|
||||||
|
const float v = fabsf(x[j]);
|
||||||
|
amax = amax > v ? amax : v;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = amax / ((1 << 7) - 1);
|
||||||
|
const float id = d ? 1.0f / d : 0.0f;
|
||||||
|
y->d = CONVERT_F32_TO_F16(d);
|
||||||
|
for (NnUint j = 0; j < Q80_BLOCK_SIZE; ++j) {
|
||||||
|
y->qs[j] = roundf(x[j] * id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnUint k, const NnUint nThreads, const NnUint threadIndex) {
|
||||||
|
assert(k % Q80_BLOCK_SIZE == 0);
|
||||||
|
const int nBlocks = k / Q80_BLOCK_SIZE;
|
||||||
|
const int blocksPerThread = nBlocks / nThreads;
|
||||||
|
const int sk = blocksPerThread * Q80_BLOCK_SIZE;
|
||||||
|
const int currentThreadBlocks = blocksPerThread + (threadIndex == nThreads - 1 ? nBlocks % nThreads : 0);
|
||||||
|
|
||||||
|
const NnBlockQ80 *x = &input[blocksPerThread * threadIndex];
|
||||||
|
float* y = &output[sk * threadIndex];
|
||||||
|
|
||||||
|
for (int i = 0; i < currentThreadBlocks; i++) {
|
||||||
|
const float d = CONVERT_F16_TO_F32(x[i].d);
|
||||||
|
for (int j = 0; j < Q80_BLOCK_SIZE; j++) {
|
||||||
|
y[i * Q80_BLOCK_SIZE + j] = x[i].qs[j] * d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) {
|
||||||
|
assert(n % Q40_BLOCK_SIZE == 0);
|
||||||
|
const NnUint nBlocks = n / Q40_BLOCK_SIZE;
|
||||||
|
const NnUint halfSize = Q40_BLOCK_SIZE / 2;
|
||||||
|
SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex);
|
||||||
|
|
||||||
|
for (NnUint i = start; i < end; i++) {
|
||||||
|
float amax = 0.0f;
|
||||||
|
float max = 0.0f;
|
||||||
|
for (NnUint j = 0; j < Q40_BLOCK_SIZE; j++) {
|
||||||
|
float v = x[i * Q40_BLOCK_SIZE + j];
|
||||||
|
if (amax < fabsf(v)) {
|
||||||
|
amax = fabsf(v);
|
||||||
|
max = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = max / -8.0f;
|
||||||
|
const float id = d ? 1.0f / d : 0.0f;
|
||||||
|
|
||||||
|
NnBlockQ40 *o = &output[i];
|
||||||
|
o->d = CONVERT_F32_TO_F16(d);
|
||||||
|
for (NnUint j = 0; j < halfSize; j++) {
|
||||||
|
const float x0 = x[i * Q40_BLOCK_SIZE + j] * id;
|
||||||
|
const float x1 = x[i * Q40_BLOCK_SIZE + halfSize + j] * id;
|
||||||
|
|
||||||
|
uint8_t xi0 = (int8_t)(x0 + 8.5f);
|
||||||
|
uint8_t xi1 = (int8_t)(x1 + 8.5f);
|
||||||
|
if (xi0 > 15) xi0 = 15;
|
||||||
|
if (xi1 > 15) xi1 = 15;
|
||||||
|
|
||||||
|
o->qs[j] = xi0 | (xi1 << 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex) {
|
||||||
|
assert(n % Q40_BLOCK_SIZE == 0);
|
||||||
|
const NnUint nBlocks = n / Q40_BLOCK_SIZE;
|
||||||
|
SPLIT_THREADS(start, end, nBlocks, nThreads, threadIndex);
|
||||||
|
|
||||||
|
for (NnUint i = start; i < end; i++) {
|
||||||
|
const NnBlockQ40 *b = &x[i];
|
||||||
|
const float d = CONVERT_F16_TO_F32(b->d);
|
||||||
|
|
||||||
|
for (int j = 0; j < Q40_BLOCK_SIZE / 2; ++j) {
|
||||||
|
const int x0 = (b->qs[j] & 0x0F) - 8;
|
||||||
|
const int x1 = (b->qs[j] >> 4) - 8;
|
||||||
|
|
||||||
|
output[i * Q40_BLOCK_SIZE + j] = x0 * d;
|
||||||
|
output[i * Q40_BLOCK_SIZE + j + Q40_BLOCK_SIZE / 2] = x1 * d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *floatTypeToString(NnFloatType type) {
|
||||||
|
if (type == F_UNK) return "F_UNK";
|
||||||
|
if (type == F_32) return "F_32";
|
||||||
|
if (type == F_16) return "F_16";
|
||||||
|
if (type == F_Q40) return "F_Q40";
|
||||||
|
if (type == F_Q80) return "F_Q80";
|
||||||
|
throw std::invalid_argument("Unknown float type");
|
||||||
|
}
|
||||||
88
src/nn/nn-quants.hpp
Normal file
88
src/nn/nn-quants.hpp
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
#ifndef NN_QUANTS_H
|
||||||
|
#define NN_QUANTS_H
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
|
#if defined(__ARM_NEON)
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#elif defined(__AVX2__) || defined(__F16C__)
|
||||||
|
#include <immintrin.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
typedef std::uint8_t NnByte;
|
||||||
|
typedef std::uint32_t NnUint;
|
||||||
|
typedef std::size_t NnSize;
|
||||||
|
typedef std::uint16_t NnFp16;
|
||||||
|
|
||||||
|
float convertF16toF32Impl(const NnFp16 value);
|
||||||
|
NnFp16 convertF32ToF16Impl(const float x);
|
||||||
|
|
||||||
|
#if defined(__ARM_NEON) && defined(__ARM_FP16_FORMAT_IEEE)
|
||||||
|
inline float convertF16ToF32Neon(const NnFp16 value) {
|
||||||
|
__fp16 fp;
|
||||||
|
std::memcpy(&fp, &value, sizeof(fp));
|
||||||
|
return (float)fp;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline NnFp16 convertF32ToF16Neon(const float x) {
|
||||||
|
__fp16 h = x;
|
||||||
|
return *(NnFp16 *)&h;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CONVERT_F16_TO_F32(value) convertF16ToF32Neon(value)
|
||||||
|
#define CONVERT_F32_TO_F16(value) convertF32ToF16Neon(value)
|
||||||
|
#elif defined(__F16C__)
|
||||||
|
#define CONVERT_F32_TO_F16(v) _cvtss_sh((v), _MM_FROUND_TO_NEAREST_INT)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if !defined(CONVERT_F16_TO_F32)
|
||||||
|
extern float f16ToF32Lookup[65536];
|
||||||
|
|
||||||
|
inline static float convertF16ToF32Lookup(const NnFp16 value) {
|
||||||
|
return f16ToF32Lookup[value];
|
||||||
|
}
|
||||||
|
|
||||||
|
#define CONVERT_F16_TO_F32_LOOKUP
|
||||||
|
#define CONVERT_F16_TO_F32(value) convertF16ToF32Lookup(value)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if !defined(CONVERT_F32_TO_F16)
|
||||||
|
#define CONVERT_F32_TO_F16(value) convertF32ToF16Impl(value)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define Q40_BLOCK_SIZE 32
|
||||||
|
#define Q80_BLOCK_SIZE 32
|
||||||
|
|
||||||
|
enum NnFloatType {
|
||||||
|
F_UNK = -1,
|
||||||
|
F_32 = 0,
|
||||||
|
F_16 = 1,
|
||||||
|
F_Q40 = 2,
|
||||||
|
F_Q80 = 3,
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
std::uint16_t d;
|
||||||
|
std::uint8_t qs[Q40_BLOCK_SIZE / 2];
|
||||||
|
} NnBlockQ40;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
std::uint16_t d;
|
||||||
|
std::int8_t qs[Q80_BLOCK_SIZE];
|
||||||
|
} NnBlockQ80;
|
||||||
|
|
||||||
|
void initQuants();
|
||||||
|
void quantizeF32toQ80(const float *input, NnBlockQ80 *output, const NnUint k, const NnUint nThreads, const NnUint threadIndex);
|
||||||
|
void dequantizeQ80toF32(const NnBlockQ80 *input, float* output, const NnUint k, const NnUint nThreads, const NnUint threadIndex);
|
||||||
|
void quantizeF32toQ40(const float *x, NnBlockQ40 *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex);
|
||||||
|
void dequantizeQ40toF32(const NnBlockQ40 *x, float *output, const NnUint n, const NnUint nThreads, const NnUint threadIndex);
|
||||||
|
|
||||||
|
const char *floatTypeToString(NnFloatType type);
|
||||||
|
|
||||||
|
#define SPLIT_THREADS(varStart, varEnd, rangeLen, nThreads, threadIndex) \
|
||||||
|
const NnUint rangeSlice = rangeLen / nThreads; \
|
||||||
|
const NnUint rangeRest = rangeLen % nThreads; \
|
||||||
|
const NnUint varStart = threadIndex * rangeSlice + (threadIndex < rangeRest ? threadIndex : rangeRest); \
|
||||||
|
const NnUint varEnd = varStart + rangeSlice + (threadIndex < rangeRest ? 1 : 0);
|
||||||
|
|
||||||
|
#endif
|
||||||
989
src/nn/nn-vulkan-test.cpp
Normal file
989
src/nn/nn-vulkan-test.cpp
Normal file
@@ -0,0 +1,989 @@
|
|||||||
|
#include <cstdio>
|
||||||
|
#include <cmath>
|
||||||
|
#include "nn-config-builder.hpp"
|
||||||
|
#include "nn-quants.hpp"
|
||||||
|
#include "nn-vulkan.hpp"
|
||||||
|
|
||||||
|
#define N_BATCHES 4
|
||||||
|
|
||||||
|
void printOk(const char *name) {
|
||||||
|
printf("✅ %24s passed\n", name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void assertFloat(NnUint position, const float value, const float expectedValue, const float tolerance) {
|
||||||
|
float diff = fabs(expectedValue - value);
|
||||||
|
if (diff > tolerance) {
|
||||||
|
printf("❌ [%d] failed: value=%f, expectedValue=%f, diff=%f\n", position, value, expectedValue, diff);
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void execute(
|
||||||
|
void (*build)(NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder),
|
||||||
|
void (*execute)(NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device)
|
||||||
|
) {
|
||||||
|
NnUint nNodes = 1;
|
||||||
|
NnNetConfigBuilder netBuilder(nNodes, N_BATCHES);
|
||||||
|
NnNodeConfigBuilder nodeBuilder(0);
|
||||||
|
NnSegmentConfigBuilder segmentBuilder;
|
||||||
|
build(&netBuilder, &nodeBuilder, &segmentBuilder);
|
||||||
|
nodeBuilder.addSegment(segmentBuilder.build());
|
||||||
|
|
||||||
|
NnNetConfig netConfig = netBuilder.build();
|
||||||
|
NnNodeConfig nodeConfig = nodeBuilder.build();
|
||||||
|
std::unique_ptr<NnNetConfig, void(*)(NnNetConfig *)> netConfigPtr(&netConfig, releaseNetConfig);
|
||||||
|
std::unique_ptr<NnNodeConfig, void(*)(NnNodeConfig *)> nodeConfigPtr(&nodeConfig, releaseNodeConfig);
|
||||||
|
|
||||||
|
NnNetExecution execution(1, &netConfig);
|
||||||
|
|
||||||
|
NnUint gpuIndex = 0;
|
||||||
|
std::vector<NnExecutorDevice> devices;
|
||||||
|
NnVulkanDevice *device = new NnVulkanDevice(gpuIndex, &netConfig, &nodeConfig, &execution);
|
||||||
|
devices.push_back(NnExecutorDevice(device, -1, -1));
|
||||||
|
NnFakeNodeSynchronizer synchronizer;
|
||||||
|
NnExecutor executor(&netConfig, &nodeConfig, &devices, &execution, &synchronizer, false);
|
||||||
|
|
||||||
|
execute(&executor, &execution, device);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint dim>
|
||||||
|
void testRmsNorm_F32_F32_F32() {
|
||||||
|
#define TEST_RMS_NORM_EPS 1e-5f
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, dim));
|
||||||
|
NnUint invRmsBufferIndex = nodeBuilder->addBuffer("inv_rms", size2D(F_32, N_BATCHES, 1));
|
||||||
|
segmentBuilder->addOp(OP_INV_RMS, "inv_rms", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_BUFFER, invRmsBufferIndex),
|
||||||
|
size0(),
|
||||||
|
NnInvRmsOpConfig{TEST_RMS_NORM_EPS, 1});
|
||||||
|
segmentBuilder->addOp(OP_RMS_NORM, "rms_norm", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size1D(F_32, dim),
|
||||||
|
NnRmsNormOpConfig{invRmsBufferIndex, 1});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
const NnUint batchSize = 2;
|
||||||
|
execution->setBatchSize(batchSize);
|
||||||
|
|
||||||
|
std::vector<float> normWeight(dim);
|
||||||
|
for (NnUint i = 0; i < dim; i++)
|
||||||
|
normWeight[i] = (0.25f + (float)i) / (float)dim;
|
||||||
|
executor->loadWeight("rms_norm", 0u, 0u, normWeight.size() * sizeof(float), (NnByte *)normWeight.data());
|
||||||
|
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
float expectedS[batchSize];
|
||||||
|
for (NnUint b = 0; b < batchSize; b++) {
|
||||||
|
float *xBatchPipe = &xPipe[b * dim];
|
||||||
|
float s = 0.0f;
|
||||||
|
for (NnUint i = 0; i < dim; i++) {
|
||||||
|
float u = (float)(dim - i + b) / (float)(dim / 2);
|
||||||
|
xBatchPipe[i] = u;
|
||||||
|
s += u * u;
|
||||||
|
}
|
||||||
|
s /= (float)dim;
|
||||||
|
expectedS[b] = 1.0f / sqrtf(s + TEST_RMS_NORM_EPS);
|
||||||
|
}
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
float invRmsBuffer[N_BATCHES];
|
||||||
|
device->data.buffers[0].get()->read((NnByte *)invRmsBuffer);
|
||||||
|
|
||||||
|
for (NnUint b = 0; b < batchSize; b++) {
|
||||||
|
float *xBatchPipe = &xPipe[b * dim];
|
||||||
|
|
||||||
|
const float t = 0.0000019f;
|
||||||
|
assertFloat(b, invRmsBuffer[b], expectedS[b], t);
|
||||||
|
const float s = invRmsBuffer[b];
|
||||||
|
for (NnUint i = 0; i < dim; i++) {
|
||||||
|
float u = (float)(dim - i + b) / (float)(dim / 2);
|
||||||
|
assertFloat(b * dim + i, xBatchPipe[i], (u * s) * normWeight[i], t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printOk("testRmsNorm_F32_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint dim>
|
||||||
|
void testSilu_F32_F32() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, dim));
|
||||||
|
segmentBuilder->addOp(OP_SILU, "silu", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnSiluOpCodeConfig{});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
|
||||||
|
float expectedOutput[dim * N_BATCHES];
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
const NnUint offset = b * dim;
|
||||||
|
for (NnUint i = 0; i < dim; i++) {
|
||||||
|
const float v = i / (float)dim + (float)(b + 1);
|
||||||
|
xPipe[offset + i] = v;
|
||||||
|
expectedOutput[offset + i] = v / (1.0 + expf(-v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
float t = 0.00001f;
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
const NnUint offset = b * dim;
|
||||||
|
for (NnUint i = 0; i < dim; i++)
|
||||||
|
assertFloat(offset + i, xPipe[offset + i], expectedOutput[offset + i], t);
|
||||||
|
}
|
||||||
|
|
||||||
|
printOk("testSilu_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint dim, NnUint nZ>
|
||||||
|
void testMul_F32_F32() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, nZ, N_BATCHES, dim));
|
||||||
|
NnUint sBufferIndex = nodeBuilder->addBuffer("s", size3D(F_32, nZ, N_BATCHES, dim));
|
||||||
|
segmentBuilder->addOp(OP_MUL, "mul", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnMulOpCodeConfig{sBufferIndex});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
float sBuffer[nZ * N_BATCHES * dim];
|
||||||
|
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++) {
|
||||||
|
xPipe[i] = (float)i;
|
||||||
|
sBuffer[i] = (i % 8) / 10.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
device->data.buffers[0].get()->write((NnByte *)sBuffer);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++)
|
||||||
|
assertFloat(i, xPipe[i], i * ((i % 8) / 10.0f), 0.000001f);
|
||||||
|
printOk("testMul_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void testMergeAdd_F32_F32() {
|
||||||
|
#define MERGE_ADD_F32_NODES 2
|
||||||
|
#define MERGE_ADD_F32_DIM 64
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint zPipeIndex = netBuilder->addPipe("Z", size2D(F_32, N_BATCHES, MERGE_ADD_F32_DIM * MERGE_ADD_F32_NODES));
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MERGE_ADD_F32_DIM));
|
||||||
|
segmentBuilder->addOp(OP_MERGE_ADD, "mergeAdd", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, zPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnMergeAddOpCodeConfig{});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
|
||||||
|
float *zPipe = (float *)execution->pipes[0];
|
||||||
|
float *xPipe = (float *)execution->pipes[1];
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint n = 0; n < MERGE_ADD_F32_NODES; n++) {
|
||||||
|
for (NnUint i = 0; i < MERGE_ADD_F32_DIM; i++)
|
||||||
|
zPipe[b * MERGE_ADD_F32_NODES * MERGE_ADD_F32_DIM + n * MERGE_ADD_F32_DIM + i] = (float)(b + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint i = 0; i < MERGE_ADD_F32_DIM; i++) {
|
||||||
|
NnUint j = b * MERGE_ADD_F32_DIM + i;
|
||||||
|
assertFloat(j, xPipe[j], (float)(2 * b + 2), 0.00001f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printOk("testMergeAdd_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void testMergeSum_F32_F32() {
|
||||||
|
#define MERGE_SUM_F32_N_Z 2
|
||||||
|
#define MERGE_SUM_F32_DIM 4
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, MERGE_SUM_F32_N_Z, N_BATCHES, MERGE_SUM_F32_DIM));
|
||||||
|
NnUint yPipeIndex = netBuilder->addPipe("Y", size3D(F_32, 1u, N_BATCHES, MERGE_SUM_F32_DIM));
|
||||||
|
segmentBuilder->addOp(OP_MERGE_SUM, "merge_sum", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, yPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnMergeSumOpCodeConfig{});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(2);
|
||||||
|
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
float *xPipeZ0 = &xPipe[0];
|
||||||
|
float *xPipeZ1 = &xPipe[N_BATCHES * MERGE_SUM_F32_DIM];
|
||||||
|
float *yPipe = (float *)execution->pipes[1];
|
||||||
|
xPipeZ0[0] = 1.0f;
|
||||||
|
xPipeZ0[1] = 1.1f;
|
||||||
|
xPipeZ0[2] = 1.2f;
|
||||||
|
xPipeZ0[3] = 1.3f;
|
||||||
|
|
||||||
|
xPipeZ0[4] = 2.0f;
|
||||||
|
xPipeZ0[5] = 2.1f;
|
||||||
|
xPipeZ0[6] = 2.2f;
|
||||||
|
xPipeZ0[7] = 2.3f;
|
||||||
|
|
||||||
|
xPipeZ1[0] = 0.5f;
|
||||||
|
xPipeZ1[1] = 0.1f;
|
||||||
|
xPipeZ1[2] = 0.2f;
|
||||||
|
xPipeZ1[3] = 0.3f;
|
||||||
|
|
||||||
|
xPipeZ1[4] = 0.4f;
|
||||||
|
xPipeZ1[5] = 0.3f;
|
||||||
|
xPipeZ1[6] = 0.2f;
|
||||||
|
xPipeZ1[7] = 0.1f;
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
const float t = 0.00001f;
|
||||||
|
assertFloat(0, yPipe[0], 1.5f, t);
|
||||||
|
assertFloat(1, yPipe[1], 1.2f, t);
|
||||||
|
assertFloat(2, yPipe[2], 1.4f, t);
|
||||||
|
assertFloat(3, yPipe[3], 1.6f, t);
|
||||||
|
|
||||||
|
assertFloat(4, yPipe[4], 2.4f, t);
|
||||||
|
assertFloat(5, yPipe[5], 2.4f, t);
|
||||||
|
assertFloat(6, yPipe[6], 2.4f, t);
|
||||||
|
assertFloat(7, yPipe[7], 2.4f, t);
|
||||||
|
|
||||||
|
printOk("testMergeSum_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint nNodes, NnUint dim>
|
||||||
|
static void testMergeAdd_Q80_F32() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
const NnUint zPipeIndex = netBuilder->addPipe("Z", size2D(F_Q80, N_BATCHES, dim * nNodes));
|
||||||
|
const NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, dim));
|
||||||
|
segmentBuilder->addOp(OP_MERGE_ADD, "mergeAdd", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, zPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnMergeAddOpCodeConfig{});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
|
||||||
|
float z[N_BATCHES * dim * nNodes];
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint n = 0; n < nNodes; n++) {
|
||||||
|
for (NnUint i = 0; i < dim; i++)
|
||||||
|
z[b * nNodes * dim + n * dim + i] = (float)(b + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NnBlockQ80 *zPipe = (NnBlockQ80 *)execution->pipes[0];
|
||||||
|
const float *xPipe = (float *)execution->pipes[1];
|
||||||
|
quantizeF32toQ80(z, zPipe, N_BATCHES * dim * nNodes, 1, 0);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint i = 0; i < dim; i++) {
|
||||||
|
float expectedValue = (float)((b + 1) * nNodes);
|
||||||
|
NnUint j = b * dim + i;
|
||||||
|
assertFloat(j, xPipe[j], expectedValue, 0.001f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printOk("testMergeAdd_Q80_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void testEmbedding_F32_F32() {
|
||||||
|
#define EMBEDDING_DIM 16
|
||||||
|
#define EMBEDDING_LEN 8
|
||||||
|
assert(EMBEDDING_LEN > N_BATCHES);
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, EMBEDDING_DIM));
|
||||||
|
segmentBuilder->addOp(OP_EMBEDDING, "embedding", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, posPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size2D(F_32, EMBEDDING_LEN, EMBEDDING_DIM),
|
||||||
|
NnEmbeddingOpConfig{});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
|
||||||
|
float embedding[EMBEDDING_DIM * EMBEDDING_LEN];
|
||||||
|
for (NnUint l = 0; l < EMBEDDING_LEN; l++) {
|
||||||
|
for (NnUint i = 0; i < EMBEDDING_DIM; i++)
|
||||||
|
embedding[l * EMBEDDING_DIM + i] = (float)(l + 4);
|
||||||
|
}
|
||||||
|
float *posPipe = (float *)execution->pipes[0];
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++)
|
||||||
|
posPipe[b] = (float)b;
|
||||||
|
|
||||||
|
executor->loadWeight("embedding", 0u, 0u, EMBEDDING_DIM * EMBEDDING_LEN * sizeof(float), (NnByte *)embedding);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
float *xPipe = (float *)execution->pipes[1];
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint i = 0; i < EMBEDDING_DIM; i++) {
|
||||||
|
NnUint j = b * EMBEDDING_DIM + i;
|
||||||
|
assertFloat(j, xPipe[j], (float)(b + 4), 0.00001f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printOk("testEmbedding_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint dim>
|
||||||
|
void testShift_F32_F32() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, dim));
|
||||||
|
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, 1, N_BATCHES * dim));
|
||||||
|
segmentBuilder->addOp(
|
||||||
|
OP_SHIFT, "shift", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerRawConfig(SRC_PIPE, yPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnShiftOpCodeConfig{posPipeIndex});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
float *xPipe = (float *)execution->pipes[1];
|
||||||
|
float *yPipe = (float *)execution->pipes[2];
|
||||||
|
|
||||||
|
float pos[N_BATCHES];
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
pos[b] = (float)b;
|
||||||
|
for (NnUint i = 0; i < dim; i++)
|
||||||
|
xPipe[b * dim + i] = (float)(b * 100 + i);
|
||||||
|
}
|
||||||
|
|
||||||
|
device->data.pipes[0].get()->write((NnByte *)pos);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint i = 0; i < dim; i++) {
|
||||||
|
NnUint j = b * dim + i;
|
||||||
|
assertFloat(j, yPipe[j], (float)(b * 100 + i), 0.00001f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printOk("testShift_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint dim, NnUint nZ>
|
||||||
|
void testCast_F32_F32() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, nZ, N_BATCHES, dim));
|
||||||
|
NnUint yPipeIndex = netBuilder->addPipe("Y", size3D(F_32, nZ, N_BATCHES, dim));
|
||||||
|
segmentBuilder->addOp(
|
||||||
|
OP_CAST, "cast", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, yPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
float *yPipe = (float *)execution->pipes[1];
|
||||||
|
|
||||||
|
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++)
|
||||||
|
xPipe[i] = (float)(i + 1);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++)
|
||||||
|
assertFloat(i, yPipe[i], (float)(i + 1), 0.00001f);
|
||||||
|
printOk("testCast_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint dim, NnUint nZ>
|
||||||
|
void testCast_F32_Q80() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, nZ, N_BATCHES, dim));
|
||||||
|
NnUint yPipeIndex = netBuilder->addPipe("Y", size3D(F_Q80, nZ, N_BATCHES, dim));
|
||||||
|
segmentBuilder->addOp(
|
||||||
|
OP_CAST, "cast", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, yPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnCastOpCodeConfig{});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
NnBlockQ80 *yPipe = (NnBlockQ80 *)execution->pipes[1];
|
||||||
|
|
||||||
|
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++)
|
||||||
|
xPipe[i] = (float)(i + 1);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
float yF32[nZ * N_BATCHES * dim];
|
||||||
|
dequantizeQ80toF32(yPipe, yF32, nZ * N_BATCHES * dim, 1, 0);
|
||||||
|
|
||||||
|
for (NnUint i = 0; i < nZ * N_BATCHES * dim; i++) {
|
||||||
|
const float expectedV = (float)(i + 1);
|
||||||
|
const float change = (yF32[i] - expectedV) / expectedV;
|
||||||
|
assertFloat(i, change, 0.0, 0.009f);
|
||||||
|
}
|
||||||
|
printOk("testCast_F32_Q80");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnRopeType ropeType, void (*assertOutput)(float *x0, float *x1)>
|
||||||
|
void testRope_F32_F32() {
|
||||||
|
#define ROPE_DIM 2048
|
||||||
|
#define ROPE_KV_DIM 512
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
const NnUint nHeads = 32;
|
||||||
|
const NnUint seqLen = 4096;
|
||||||
|
const NnRopeSlice slice = sliceRope(ropeType, ROPE_DIM, ROPE_KV_DIM, 8, 1, seqLen, ROPE_DIM / nHeads, 500000.0f, 0);
|
||||||
|
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, ROPE_DIM));
|
||||||
|
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
|
||||||
|
NnUint ropeCacheBufferIndex = nodeBuilder->addBuffer("ropeCache", slice.cacheSize);
|
||||||
|
NnUint isQ = 1;
|
||||||
|
|
||||||
|
segmentBuilder->addOp(
|
||||||
|
OP_ROPE, "rope_llama", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnRopeOpConfig{ropeType, isQ, posPipeIndex, ropeCacheBufferIndex, 32.0f, 1.0f, 4.0f, 8192, slice});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(2);
|
||||||
|
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
float pos[N_BATCHES];
|
||||||
|
pos[0] = (float)6;
|
||||||
|
pos[1] = (float)31;
|
||||||
|
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint i = 0; i < ROPE_DIM; i++)
|
||||||
|
xPipe[b * ROPE_DIM + i] = 1.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
device->data.pipes[1].get()->write((NnByte *)pos);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
float *x0 = &xPipe[0 * ROPE_DIM];
|
||||||
|
float *x1 = &xPipe[1 * ROPE_DIM];
|
||||||
|
assertOutput(x0, x1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void assertRopeLlama_F32_F32(float *x0, float *x1) {
|
||||||
|
const float t = 0.000001f;
|
||||||
|
|
||||||
|
assertFloat(0, x0[0], 1.239586f, t);
|
||||||
|
assertFloat(1, x0[1], 0.680755f, t);
|
||||||
|
assertFloat(2, x0[2], 0.077202f, t);
|
||||||
|
assertFloat(3, x0[3], -1.412105f, t);
|
||||||
|
assertFloat(1988, x0[1988], -1.356766f, t);
|
||||||
|
assertFloat(2022, x0[2022], 0.999923f, t);
|
||||||
|
assertFloat(2023, x0[2023], 1.000077f, t);
|
||||||
|
|
||||||
|
assertFloat(0, x1[0], 1.318780f, t);
|
||||||
|
assertFloat(1, x1[1], 0.510705f, t);
|
||||||
|
assertFloat(1078, x1[1078], 0.999985f, t);
|
||||||
|
assertFloat(1078, x1[1079], 1.000015f, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
void assertRopeFalcon_F32_F32(float *x0, float *x1) {
|
||||||
|
const float t = 0.000001f;
|
||||||
|
|
||||||
|
assertFloat(0, x0[0], 1.239586f, t);
|
||||||
|
assertFloat(1, x0[1], 0.077202f, t);
|
||||||
|
assertFloat(2, x0[2], -1.356766f, t);
|
||||||
|
assertFloat(3, x0[3], -1.164938f, t);
|
||||||
|
assertFloat(1988, x0[1988], -0.522115f, t);
|
||||||
|
assertFloat(1988, x0[1989], 0.018772f, t);
|
||||||
|
assertFloat(2022, x0[2022], 1.361834f, t);
|
||||||
|
assertFloat(2023, x0[2023], 1.276253f, t);
|
||||||
|
|
||||||
|
assertFloat(0, x1[0], 1.318780f, t);
|
||||||
|
assertFloat(1, x1[1], -1.139289f, t);
|
||||||
|
assertFloat(1, x1[2], -0.417384f, t);
|
||||||
|
assertFloat(1, x1[3], -1.291486f, t);
|
||||||
|
assertFloat(1078, x1[1078], 1.003737f, t);
|
||||||
|
assertFloat(1078, x1[1079], 1.002481f, t);
|
||||||
|
}
|
||||||
|
|
||||||
|
void testRopeLlama_F32_F32() {
|
||||||
|
testRope_F32_F32<NnRopeType::ROPE_LLAMA, assertRopeLlama_F32_F32>();
|
||||||
|
printOk("testRopeLlama_F32_F32");
|
||||||
|
}
|
||||||
|
|
||||||
|
void testRopeFalcon_F32_F32() {
|
||||||
|
testRope_F32_F32<NnRopeType::ROPE_FALCON, assertRopeFalcon_F32_F32>();
|
||||||
|
printOk("testRopeFalcon_F32_F32");
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint N, NnUint D>
|
||||||
|
void testMatmul_F32_F32_F32() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, N));
|
||||||
|
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, D));
|
||||||
|
NnUint nullBufferIndex = nodeBuilder->addBuffer("null", size1D(F_32, 1u));
|
||||||
|
segmentBuilder->addOp(
|
||||||
|
OP_MATMUL, "matmul", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, yPipeIndex),
|
||||||
|
size2D(F_32, N, D),
|
||||||
|
NnMatmulOpConfig{0u, 0u, nullBufferIndex});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
float *yPipe = (float *)execution->pipes[1];
|
||||||
|
|
||||||
|
float weight[N * D];
|
||||||
|
for (NnUint i = 0; i < N_BATCHES * N; i++)
|
||||||
|
xPipe[i] = i * 0.0001f;
|
||||||
|
for (NnUint i = 0; i < N * D; i++)
|
||||||
|
weight[i] = i * 0.000001f;
|
||||||
|
executor->loadWeight("matmul", 0u, 0u, N * D * sizeof(float), (NnByte *)weight);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint d = 0; d < D; d++) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (NnUint n = 0; n < N; n++)
|
||||||
|
sum += xPipe[b * N + n] * weight[d * N + n];
|
||||||
|
|
||||||
|
const NnUint p = b * D + d;
|
||||||
|
assertFloat(p, yPipe[p], sum, 0.0002f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printOk("testMatmul_F32_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void testMatmul_F32_F32_F32_expert() {
|
||||||
|
#define MATMUL_F32_N 4
|
||||||
|
#define MATMUL_F32_D 1
|
||||||
|
#define MATMUL_F32_E 4
|
||||||
|
#define MATMUL_F32_A 2
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, MATMUL_F32_A, N_BATCHES, MATMUL_F32_N));
|
||||||
|
NnUint yPipeIndex = netBuilder->addPipe("Y", size3D(F_32, MATMUL_F32_A, N_BATCHES, MATMUL_F32_D));
|
||||||
|
NnUint activeExpertIndexesIndex = nodeBuilder->addBuffer("indexes", size2D(F_32, N_BATCHES, MATMUL_F32_A));
|
||||||
|
segmentBuilder->addOp(
|
||||||
|
OP_MATMUL, "matmul", 0u,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, yPipeIndex),
|
||||||
|
size3D(F_32, MATMUL_F32_E, MATMUL_F32_N, MATMUL_F32_D),
|
||||||
|
NnMatmulOpConfig{MATMUL_F32_E, MATMUL_F32_A, activeExpertIndexesIndex});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
float *xPipe = (float *)execution->pipes[0]; // (A, N_BATCHES, N)
|
||||||
|
float *yPipe = (float *)execution->pipes[1]; // (A, N_BATCHES, D)
|
||||||
|
|
||||||
|
constexpr NnUint wSize = MATMUL_F32_N * MATMUL_F32_D;
|
||||||
|
constexpr NnUint wSizeBytes = wSize * sizeof(float);
|
||||||
|
float weight[wSize];
|
||||||
|
float indexes[N_BATCHES * MATMUL_F32_A];
|
||||||
|
|
||||||
|
for (NnUint e = 0u; e < MATMUL_F32_E; e++) {
|
||||||
|
for (NnUint i = 0u; i < wSize; i++)
|
||||||
|
weight[i] = 0.1f * (float)(e + 1);
|
||||||
|
executor->loadWeight("matmul", 0u, wSizeBytes * e, wSizeBytes, (NnByte *)weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (NnUint i = 0u; i < MATMUL_F32_A * N_BATCHES; i++)
|
||||||
|
indexes[i] = (float)(i % MATMUL_F32_E); // 0, 1, 2, 3, 0, 1, 2, 3, ...
|
||||||
|
for (NnUint i = 0; i < MATMUL_F32_A * N_BATCHES * MATMUL_F32_N; i++)
|
||||||
|
xPipe[i] = (float)(i / MATMUL_F32_N + 1); // 1.0, 1.0, ... 2.0, 2.0, ...
|
||||||
|
|
||||||
|
device->data.buffers[0].get()->write((NnByte *)indexes);
|
||||||
|
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
float t = 0.0001f;
|
||||||
|
assertFloat(0, yPipe[0], 0.1f /* index=0, e=0 */ * (4 * 1.0f), t);
|
||||||
|
assertFloat(1, yPipe[1], 0.3f /* index=2, e=2 */ * (4 * 2.0f), t);
|
||||||
|
assertFloat(2, yPipe[2], 0.1f /* index=4, e=0 */ * (4 * 3.0f), t);
|
||||||
|
assertFloat(3, yPipe[3], 0.3f /* index=6, e=2 */ * (4 * 4.0f), t);
|
||||||
|
|
||||||
|
assertFloat(4, yPipe[4], 0.2f /* index=1, e=1 */ * (4 * 5.0f), t);
|
||||||
|
assertFloat(5, yPipe[5], 0.4f /* index=3, e=3 */ * (4 * 6.0f), t);
|
||||||
|
assertFloat(6, yPipe[6], 0.2f /* index=5, e=1 */ * (4 * 7.0f), t);
|
||||||
|
assertFloat(7, yPipe[7], 0.4f /* index=7, e=3 */ * (4 * 8.0f), t);
|
||||||
|
|
||||||
|
printOk("testMatmul_F32_F32_F32_expert");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint N, NnUint D>
|
||||||
|
void testMatmul_Q80_Q40_F32() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_Q80, N_BATCHES, N));
|
||||||
|
NnUint yPipeIndex = netBuilder->addPipe("Y", size2D(F_32, N_BATCHES, D));
|
||||||
|
NnUint nullBufferIndex = nodeBuilder->addBuffer("null", size1D(F_32, 1u));
|
||||||
|
segmentBuilder->addOp(
|
||||||
|
OP_MATMUL, "matmul", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, yPipeIndex),
|
||||||
|
size2D(F_Q40, N, D),
|
||||||
|
NnMatmulOpConfig{0u, 0u, nullBufferIndex});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
NnBlockQ80 *xPipe = (NnBlockQ80 *)execution->pipes[0];
|
||||||
|
float *yPipe = (float *)execution->pipes[1];
|
||||||
|
|
||||||
|
constexpr NnUint xSize = N_BATCHES * N;
|
||||||
|
constexpr NnUint weightSize = N * D;
|
||||||
|
constexpr NnUint weightBlocks = weightSize / Q40_BLOCK_SIZE;
|
||||||
|
|
||||||
|
std::unique_ptr<float[]> x(new float[xSize]);
|
||||||
|
std::unique_ptr<float[]> weight(new float[weightSize]);
|
||||||
|
std::unique_ptr<NnBlockQ40[]> weightQ40(new NnBlockQ40[weightBlocks]);
|
||||||
|
|
||||||
|
for (NnUint i = 0; i < xSize; i++)
|
||||||
|
x[i] = 0.1f + (i / (float)N - 0.5f) * 0.0005f;
|
||||||
|
for (NnUint i = 0; i < weightSize; i++)
|
||||||
|
weight[i] = 0.1f + (i / (float)D - 0.5f) * 0.0005f;
|
||||||
|
|
||||||
|
quantizeF32toQ80(x.get(), xPipe, xSize, 1, 0);
|
||||||
|
quantizeF32toQ40(weight.get(), weightQ40.get(), weightSize, 1, 0);
|
||||||
|
|
||||||
|
executor->loadWeight("matmul", 0u, 0u, weightBlocks * sizeof(NnBlockQ40), (NnByte *)weightQ40.get());
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
for (NnUint d = 0; d < D; d++) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (NnUint n = 0; n < N; n++)
|
||||||
|
sum += x[b * N + n] * weight[d * N + n];
|
||||||
|
const NnUint p = b * D + d;
|
||||||
|
|
||||||
|
const float err = sum == 0.0 ? (yPipe[p] - sum) : (yPipe[p] - sum) / sum;
|
||||||
|
// printf("[%d] %f %f (%f)\n", b, yPipe[p], sum, err);
|
||||||
|
assertFloat(p, err, 0.0f, 0.009f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
printOk("testMatmul_Q80_Q40_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void testMultiheadAtt_F32_F32() {
|
||||||
|
#define MULTIHEAD_ATT_DIM 128
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
const NnUint nHeads = 32;
|
||||||
|
const NnUint nKvHeads = 8;
|
||||||
|
const NnUint headDim = MULTIHEAD_ATT_DIM / nHeads;
|
||||||
|
const NnUint seqLen = 4096;
|
||||||
|
const NnUint qSliceD0 = 2048;
|
||||||
|
const NnUint kvDim0 = 512;
|
||||||
|
const NnKvCacheSlice kvCacheSlice = sliceKvCache(kvDim0, seqLen, 1);
|
||||||
|
const NnMultiHeadAttSlice multiHeadAttSlice = sliceMultiHeadAtt(nHeads, seqLen, 1, N_BATCHES);
|
||||||
|
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MULTIHEAD_ATT_DIM));
|
||||||
|
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
|
||||||
|
NnUint qBufferIndex = nodeBuilder->addBuffer("POS", size2D(F_32, N_BATCHES, qSliceD0));
|
||||||
|
NnUint kCacheBufferIndex = nodeBuilder->addBuffer("kCache", kvCacheSlice.keySize);
|
||||||
|
NnUint vCacheBufferIndex = nodeBuilder->addBuffer("vCache", kvCacheSlice.valueSize);
|
||||||
|
NnUint attCacheBufferIndex = nodeBuilder->addBuffer("vCache", multiHeadAttSlice.attSize);
|
||||||
|
|
||||||
|
segmentBuilder->addOp(
|
||||||
|
OP_MULTIHEAD_ATT, "multihead_att", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnMultiHeadAttOpConfig{nHeads, nHeads, nKvHeads, headDim, seqLen, qSliceD0, kvDim0,
|
||||||
|
posPipeIndex, qBufferIndex, kCacheBufferIndex, vCacheBufferIndex, attCacheBufferIndex});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// TODO: for now this is a smoke test
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
executor->forward();
|
||||||
|
printOk("testMultiheadAtt_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
template <NnUint dim, NnUint nZ>
|
||||||
|
void testSoftmax_F32_F32() {
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, nZ, N_BATCHES, dim));
|
||||||
|
segmentBuilder->addOp(OP_SOFTMAX, "softmax", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnSoftmaxOpCodeConfig{});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
|
||||||
|
for (NnUint z = 0u; z < nZ; z++) {
|
||||||
|
float *xPipeZ = &xPipe[z * N_BATCHES * dim];
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
const NnUint offset = b * dim;
|
||||||
|
for (NnUint i = 0u; i < dim; i++)
|
||||||
|
xPipeZ[offset + i] = i / (float)dim + (float)b;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
float t = 0.00001f;
|
||||||
|
for (NnUint z = 0u; z < nZ; z++) {
|
||||||
|
float *xPipeZ = &xPipe[z * N_BATCHES * dim];
|
||||||
|
for (NnUint b = 0; b < N_BATCHES; b++) {
|
||||||
|
const NnUint offset = b * dim;
|
||||||
|
float max = ((dim - 1) / (float)dim) + (float)b;
|
||||||
|
for (NnUint i = 0u; i < dim; i++) {
|
||||||
|
const float v = i / (float)dim + (float)b;
|
||||||
|
const float expectedOutput = expf(v - max);
|
||||||
|
assertFloat(offset + i, xPipeZ[offset + i], expectedOutput, t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printOk("testSoftmax_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void testScale_F32_F32() {
|
||||||
|
#define SCALE_F32_N_Z 4
|
||||||
|
#define SCALE_F32_DIM 64
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size3D(F_32, SCALE_F32_N_Z, N_BATCHES, SCALE_F32_DIM));
|
||||||
|
NnUint scaleBufferIndex = nodeBuilder->addBuffer("scale", size3D(F_32, SCALE_F32_N_Z, SCALE_F32_DIM, 1u));
|
||||||
|
segmentBuilder->addOp(OP_SCALE, "scale", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnScaleOpCodeConfig{scaleBufferIndex});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
float scale[SCALE_F32_N_Z * SCALE_F32_DIM];
|
||||||
|
|
||||||
|
for (NnUint z = 0u; z < SCALE_F32_N_Z; z++) {
|
||||||
|
const NnUint zOffset = z * N_BATCHES * SCALE_F32_DIM;
|
||||||
|
for (NnUint y = 0u; y < N_BATCHES; y++) {
|
||||||
|
scale[z * N_BATCHES + y] = 0.5f * (float)(z * 100 + y);
|
||||||
|
for (NnUint i = 0; i < SCALE_F32_DIM; i++)
|
||||||
|
xPipe[zOffset + y * SCALE_F32_DIM + i] = (float)(z * 10000 + y * 100 + i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
device->data.buffers[0].get()->write((NnByte *)scale);
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
// assert
|
||||||
|
for (NnUint z = 0u; z < SCALE_F32_N_Z; z++) {
|
||||||
|
const NnUint zOffset = z * N_BATCHES * SCALE_F32_DIM;
|
||||||
|
for (NnUint y = 0u; y < N_BATCHES; y++) {
|
||||||
|
for (NnUint i = 0; i < SCALE_F32_DIM; i++) {
|
||||||
|
const NnUint p = zOffset + y * SCALE_F32_DIM + i;
|
||||||
|
const float v = xPipe[p];
|
||||||
|
const float expectedOutput = (float)(z * 10000 + y * 100 + i) * 0.5f * (float)(z * 100 + y);
|
||||||
|
assertFloat(p, v, expectedOutput, 0.00001f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printOk("testScale_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void testMoeGate_F32_F32() {
|
||||||
|
#define MOE_GATE_F32_DIM 8
|
||||||
|
#define MOE_GATE_F32_K 4
|
||||||
|
execute(
|
||||||
|
[](NnNetConfigBuilder *netBuilder, NnNodeConfigBuilder *nodeBuilder, NnSegmentConfigBuilder *segmentBuilder) {
|
||||||
|
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, MOE_GATE_F32_DIM));
|
||||||
|
NnUint gPipeIndex = netBuilder->addPipe("g", size3D(F_32, MOE_GATE_F32_K, N_BATCHES, 1u));
|
||||||
|
NnUint indexesBufferIndex = nodeBuilder->addBuffer("i", size2D(F_32, N_BATCHES, MOE_GATE_F32_K));
|
||||||
|
segmentBuilder->addOp(OP_MOE_GATE, "moe_gate", 0,
|
||||||
|
pointerBatchConfig(SRC_PIPE, xPipeIndex),
|
||||||
|
pointerBatchConfig(SRC_PIPE, gPipeIndex),
|
||||||
|
size0(),
|
||||||
|
NnMoeGateOpCodeConfig{MOE_GATE_F32_K, 0u, indexesBufferIndex});
|
||||||
|
},
|
||||||
|
[](NnExecutor *executor, NnNetExecution *execution, NnVulkanDevice *device) {
|
||||||
|
// arrange
|
||||||
|
execution->setBatchSize(N_BATCHES);
|
||||||
|
|
||||||
|
float *xPipe = (float *)execution->pipes[0];
|
||||||
|
float *gPipe = (float *)execution->pipes[1];
|
||||||
|
xPipe[0] = 3.0f;
|
||||||
|
xPipe[1] = 1.0f;
|
||||||
|
xPipe[2] = 6.0f;
|
||||||
|
xPipe[3] = 5.0f;
|
||||||
|
xPipe[4] = 8.0f;
|
||||||
|
xPipe[5] = 4.0f;
|
||||||
|
xPipe[6] = 2.0f;
|
||||||
|
xPipe[7] = 7.0f;
|
||||||
|
|
||||||
|
// act
|
||||||
|
executor->forward();
|
||||||
|
|
||||||
|
float pos[N_BATCHES * MOE_GATE_F32_K];
|
||||||
|
device->data.buffers[0].get()->read((NnByte *)pos);
|
||||||
|
|
||||||
|
// assert
|
||||||
|
const float t = 0.00001f;
|
||||||
|
assertFloat(0, gPipe[0 * N_BATCHES], 8.0f, t);
|
||||||
|
assertFloat(1, gPipe[1 * N_BATCHES], 7.0f, t);
|
||||||
|
assertFloat(2, gPipe[2 * N_BATCHES], 6.0f, t);
|
||||||
|
assertFloat(3, gPipe[3 * N_BATCHES], 5.0f, t);
|
||||||
|
|
||||||
|
assertFloat(100, pos[0], 4.0f, t);
|
||||||
|
assertFloat(101, pos[1], 7.0f, t);
|
||||||
|
assertFloat(102, pos[2], 2.0f, t);
|
||||||
|
assertFloat(103, pos[3], 3.0f, t);
|
||||||
|
|
||||||
|
printOk("testMoeGate_F32_F32");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
initQuants();
|
||||||
|
|
||||||
|
testRmsNorm_F32_F32_F32<4>();
|
||||||
|
testRmsNorm_F32_F32_F32<1024>();
|
||||||
|
testRmsNorm_F32_F32_F32<3196>();
|
||||||
|
|
||||||
|
testSilu_F32_F32<4>();
|
||||||
|
testSilu_F32_F32<32>();
|
||||||
|
testSilu_F32_F32<104>();
|
||||||
|
|
||||||
|
testMul_F32_F32<32, 1>();
|
||||||
|
testMul_F32_F32<48, 4>();
|
||||||
|
|
||||||
|
testMergeAdd_F32_F32();
|
||||||
|
|
||||||
|
testMergeSum_F32_F32();
|
||||||
|
|
||||||
|
testMergeAdd_Q80_F32<2, 64>();
|
||||||
|
testMergeAdd_Q80_F32<4, 128>();
|
||||||
|
testMergeAdd_Q80_F32<4, 160>();
|
||||||
|
|
||||||
|
testEmbedding_F32_F32();
|
||||||
|
|
||||||
|
testShift_F32_F32<32>();
|
||||||
|
testShift_F32_F32<9>();
|
||||||
|
|
||||||
|
testCast_F32_F32<128, 1>();
|
||||||
|
testCast_F32_F32<32, 2>();
|
||||||
|
testCast_F32_F32<8, 4>();
|
||||||
|
|
||||||
|
testCast_F32_Q80<256, 1>();
|
||||||
|
testCast_F32_Q80<64, 4>();
|
||||||
|
|
||||||
|
testRopeLlama_F32_F32();
|
||||||
|
testRopeFalcon_F32_F32();
|
||||||
|
|
||||||
|
testMatmul_F32_F32_F32<64, 96>();
|
||||||
|
testMatmul_F32_F32_F32<3191, 109>();
|
||||||
|
|
||||||
|
testMatmul_F32_F32_F32_expert();
|
||||||
|
|
||||||
|
testMatmul_Q80_Q40_F32<14336, 4096>();
|
||||||
|
testMatmul_Q80_Q40_F32<4096, 14336>();
|
||||||
|
testMatmul_Q80_Q40_F32<4096, 4096>();
|
||||||
|
testMatmul_Q80_Q40_F32<64, 48>();
|
||||||
|
testMatmul_Q80_Q40_F32<64, 64>();
|
||||||
|
testMatmul_Q80_Q40_F32<192, 16>();
|
||||||
|
|
||||||
|
testMultiheadAtt_F32_F32();
|
||||||
|
|
||||||
|
testSoftmax_F32_F32<256, 1>();
|
||||||
|
testSoftmax_F32_F32<512, 4>();
|
||||||
|
|
||||||
|
testScale_F32_F32();
|
||||||
|
|
||||||
|
testMoeGate_F32_F32();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
1163
src/nn/nn-vulkan.cpp
Normal file
1163
src/nn/nn-vulkan.cpp
Normal file
File diff suppressed because it is too large
Load Diff
173
src/nn/nn-vulkan.hpp
Normal file
173
src/nn/nn-vulkan.hpp
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
#ifndef NN_VULKAN_HPP
|
||||||
|
#define NN_VULKAN_HPP
|
||||||
|
|
||||||
|
#include <vulkan/vulkan.hpp>
|
||||||
|
#include <vector>
|
||||||
|
#include "nn-executor.hpp"
|
||||||
|
#include "nn-cpu-ops.hpp"
|
||||||
|
|
||||||
|
class NnVulkanContext {
|
||||||
|
public:
|
||||||
|
vk::Instance instance;
|
||||||
|
vk::PhysicalDevice physicalDevice;
|
||||||
|
vk::Device device;
|
||||||
|
uint32_t queueFamilyIndex;
|
||||||
|
vk::CommandPool commandPool;
|
||||||
|
vk::Queue queue;
|
||||||
|
NnSize nonCoherentAtomSize;
|
||||||
|
NnVulkanContext(const NnUint gpuIndex);
|
||||||
|
~NnVulkanContext();
|
||||||
|
std::pair<vk::Buffer, vk::DeviceMemory> createRawBuffer(const uint32_t memoryTypeIndex, const vk::DeviceSize bufferSize, const vk::BufferUsageFlags usageFlags);
|
||||||
|
};
|
||||||
|
|
||||||
|
enum NnVulkanStagingCopierDirection {
|
||||||
|
COPY_TO_DEVICE,
|
||||||
|
COPY_FROM_DEVICE
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnVulkanStagingCopier {
|
||||||
|
private:
|
||||||
|
NnVulkanContext *context;
|
||||||
|
uint32_t memoryTypeIndex;
|
||||||
|
|
||||||
|
vk::DeviceSize allocatedSize;
|
||||||
|
vk::Buffer hostBuffer;
|
||||||
|
vk::DeviceMemory hostMemory;
|
||||||
|
void *hostPointer;
|
||||||
|
public:
|
||||||
|
NnVulkanStagingCopier(NnVulkanContext *context);
|
||||||
|
~NnVulkanStagingCopier();
|
||||||
|
void allocate(const NnSize size);
|
||||||
|
void copy(NnByte *data, const NnSize size, const NnVulkanStagingCopierDirection direction);
|
||||||
|
void executeCopyCommand(vk::Buffer& target, const NnSize offset, const NnSize size, const NnVulkanStagingCopierDirection direction);
|
||||||
|
void addCopyCommand(vk::CommandBuffer& commandBuffer, vk::Buffer& target, const NnSize offset, const NnSize size, const NnVulkanStagingCopierDirection direction);
|
||||||
|
void tryRelease();
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnVulkanBuffer {
|
||||||
|
private:
|
||||||
|
bool isHostVisible;
|
||||||
|
NnVulkanContext *context;
|
||||||
|
NnVulkanStagingCopier *copier;
|
||||||
|
vk::DeviceMemory deviceMemory;
|
||||||
|
NnByte *hostPointer;
|
||||||
|
public:
|
||||||
|
const char *name;
|
||||||
|
NnSize bufferSize;
|
||||||
|
bool isSliceable;
|
||||||
|
vk::Buffer deviceBuffer;
|
||||||
|
vk::BufferUsageFlags usageFlags;
|
||||||
|
NnVulkanBuffer(NnVulkanContext *context, NnVulkanStagingCopier *copier, const char *name, const NnSize bufferSize, const bool isSliceable, vk::BufferUsageFlags usageFlags, bool fastAccess);
|
||||||
|
~NnVulkanBuffer();
|
||||||
|
void write(const NnByte *data);
|
||||||
|
void write(const NnByte *data, const NnSize offset, const NnSize size);
|
||||||
|
void read(NnByte *data);
|
||||||
|
void read(NnByte *data, const NnSize offset, const NnSize size);
|
||||||
|
NnSize calcSliceSize(const NnSize nominator, const NnSize denominator);
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnVulkanBufferFactory {
|
||||||
|
private:
|
||||||
|
NnVulkanContext *context;
|
||||||
|
NnVulkanStagingCopier *copier;
|
||||||
|
public:
|
||||||
|
NnVulkanBufferFactory(NnVulkanContext *context, NnVulkanStagingCopier *copier);
|
||||||
|
std::unique_ptr<NnVulkanBuffer> create(const char *name, const NnSize bufferSize, const bool isSliceable, vk::BufferUsageFlags usageFlags, bool fastAccess);
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnUint inputOffset;
|
||||||
|
NnUint inputSizeX;
|
||||||
|
NnUint outputOffset;
|
||||||
|
NnUint outputSizeX;
|
||||||
|
} NnVulkanBatchInfo;
|
||||||
|
|
||||||
|
class NnVulkanDeviceData {
|
||||||
|
private:
|
||||||
|
NnNetConfig *netConfig;
|
||||||
|
NnNodeConfig *nodeConfig;
|
||||||
|
public:
|
||||||
|
std::vector<std::unique_ptr<NnVulkanBuffer>> pipes;
|
||||||
|
std::vector<std::unique_ptr<NnVulkanBuffer>> buffers;
|
||||||
|
std::vector<std::unique_ptr<NnVulkanBuffer>> internalBuffers;
|
||||||
|
NnVulkanDeviceData(NnVulkanBufferFactory *bufferFactory, NnNetConfig *netConfig, NnNodeConfig *nodeConfig);
|
||||||
|
~NnVulkanDeviceData();
|
||||||
|
|
||||||
|
NnSize3D resolveBufferSize(NnPointerConfig *config);
|
||||||
|
NnVulkanBuffer *resolvePointerVulkanBuffer(NnPointerConfig *config);
|
||||||
|
NnUint resolveBufferBatchOffset(NnPointerConfig *config, NnUint batchIndex, NnUint zIndex);
|
||||||
|
NnUint resolveBufferBatchWidth(NnPointerConfig *config);
|
||||||
|
NnVulkanBuffer *resolvePipeByIndex(NnUint pipeIndex);
|
||||||
|
NnVulkanBuffer *resolveBufferByIndex(NnUint bufferIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnVulkanDevice : public NnDevice {
|
||||||
|
private:
|
||||||
|
NnVulkanContext context;
|
||||||
|
NnVulkanStagingCopier copier;
|
||||||
|
NnVulkanBufferFactory bufferFactory;
|
||||||
|
NnNetConfig *netConfig;
|
||||||
|
NnNodeConfig *nodeConfig;
|
||||||
|
NnNetExecution *netExecution;
|
||||||
|
public:
|
||||||
|
NnVulkanDeviceData data;
|
||||||
|
NnVulkanDevice(NnUint gpuIndex, NnNetConfig *netConfig, NnNodeConfig *nodeConfig, NnNetExecution *netExecution);
|
||||||
|
~NnVulkanDevice() override;
|
||||||
|
NnUint maxNThreads() override;
|
||||||
|
NnDeviceSegment *createSegment(NnUint segmentIndex) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class NnVulkanDeviceSegmentData {
|
||||||
|
private:
|
||||||
|
NnVulkanDeviceData *data;
|
||||||
|
std::vector<NnUint> batchInfoBufferIndex;
|
||||||
|
std::vector<NnUint> weightBufferIndex;
|
||||||
|
std::vector<NnUint> configBufferIndex;
|
||||||
|
public:
|
||||||
|
NnVulkanDeviceSegmentData(NnVulkanBufferFactory *bufferFactory, NnVulkanDeviceData *data, NnSegmentConfig *segmentConfig, NnUint nBatches);
|
||||||
|
NnVulkanBuffer *resolveOpBatchInfoVulkanBuffer(NnUint opIndex);
|
||||||
|
NnVulkanBuffer *resolveOpWeightVulkanBuffer(NnUint opIndex);
|
||||||
|
NnVulkanBuffer *resolveOpConfigVulkanBuffer(NnUint opIndex);
|
||||||
|
};
|
||||||
|
|
||||||
|
enum NnOpBufferAccessType {
|
||||||
|
ACCESS_IMMUTABLE,
|
||||||
|
ACCESS_READONLY,
|
||||||
|
ACCESS_READ_WRITE,
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
NnOpBufferAccessType type;
|
||||||
|
NnVulkanBuffer *buffer;
|
||||||
|
} NnOpBufferAccess;
|
||||||
|
|
||||||
|
class NnVulkanDeviceSegment : public NnDeviceSegment {
|
||||||
|
private:
|
||||||
|
NnVulkanContext *context;
|
||||||
|
NnVulkanStagingCopier *copier;
|
||||||
|
NnVulkanDeviceData *data;
|
||||||
|
NnNetConfig *netConfig;
|
||||||
|
NnUint segmentIndex;
|
||||||
|
NnSegmentConfig *segmentConfig;
|
||||||
|
NnNetExecution *netExecution;
|
||||||
|
std::unique_ptr<NnVulkanDeviceSegmentData> segmentData;
|
||||||
|
|
||||||
|
std::vector<vk::ShaderModule> shaderModules;
|
||||||
|
std::vector<vk::DescriptorSetLayout> descriptorSetLayouts;
|
||||||
|
vk::DescriptorPool descriptorPool;
|
||||||
|
std::vector<vk::DescriptorSet> descriptorSets;
|
||||||
|
vk::Fence fence;
|
||||||
|
std::vector<vk::PipelineLayout> pipelineLayouts;
|
||||||
|
std::vector<vk::Pipeline> pipelines;
|
||||||
|
vk::PipelineCache pipelineCache;
|
||||||
|
vk::CommandBuffer commandBuffer;
|
||||||
|
std::vector<std::vector<NnVulkanBuffer *>> buffersToSync;
|
||||||
|
NnUint lastBatchSize;
|
||||||
|
public:
|
||||||
|
NnVulkanDeviceSegment(NnVulkanContext *context, NnVulkanStagingCopier *copier, NnVulkanBufferFactory *bufferFactory, NnVulkanDeviceData *data, NnNetConfig *netConfig, NnUint segmentIndex, NnSegmentConfig *segmentConfig, NnNetExecution *netExecution);
|
||||||
|
~NnVulkanDeviceSegment() override;
|
||||||
|
void loadWeight(NnUint opIndex, NnSize offset, NnSize nBytes, NnByte *weight) override;
|
||||||
|
void forward(NnUint opIndex, NnUint nThreads, NnUint threadIndex, NnUint batchSize) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
39
src/nn/pthread.h
Normal file
39
src/nn/pthread.h
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#ifndef PTHREAD_WRAPPER
|
||||||
|
#define PTHREAD_WRAPPER
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#include <windows.h>
|
||||||
|
|
||||||
|
typedef HANDLE PthreadHandler;
|
||||||
|
typedef DWORD PthreadResult;
|
||||||
|
typedef DWORD (WINAPI *PthreadFunc)(void *);
|
||||||
|
|
||||||
|
static int pthread_create(PthreadHandler *out, void *unused, PthreadFunc func, void *arg) {
|
||||||
|
(void) unused;
|
||||||
|
PthreadHandler handle = CreateThread(NULL, 0, func, arg, 0, NULL);
|
||||||
|
if (handle == NULL) {
|
||||||
|
return EAGAIN;
|
||||||
|
}
|
||||||
|
*out = handle;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int pthread_join(PthreadHandler thread, void *unused) {
|
||||||
|
(void) unused;
|
||||||
|
DWORD ret = WaitForSingleObject(thread, INFINITE);
|
||||||
|
if (ret == WAIT_FAILED) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
CloseHandle(thread);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
#include <pthread.h>
|
||||||
|
|
||||||
|
typedef pthread_t PthreadHandler;
|
||||||
|
typedef void* PthreadResult;
|
||||||
|
typedef void* (*PthreadFunc)(void *);
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
37
src/nn/vulkan/cast-forward-f32-f32.comp
Normal file
37
src/nn/vulkan/cast-forward-f32-f32.comp
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define CHUNK_SIZE 4
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint zIndex = gl_WorkGroupID.z;
|
||||||
|
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||||
|
|
||||||
|
const uint chunkIndex = gl_WorkGroupID.x;
|
||||||
|
const BatchInfo info = infos[b];
|
||||||
|
const uint offset = chunkIndex * CHUNK_SIZE;
|
||||||
|
const uint xOffset = info.inputOffset + offset;
|
||||||
|
const uint yOffset = info.outputOffset + offset;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||||
|
y[yOffset + i] = x[xOffset + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
56
src/nn/vulkan/cast-forward-f32-q80.comp
Normal file
56
src/nn/vulkan/cast-forward-f32-q80.comp
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_8bit_storage : enable
|
||||||
|
#extension GL_EXT_shader_16bit_storage : enable
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types : enable
|
||||||
|
|
||||||
|
#define Q80_BLOCK_SIZE 32
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset; // number of Q80 blocks
|
||||||
|
uint outputSizeX; // number of Q80 blocks
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BlockQ80 {
|
||||||
|
float16_t d;
|
||||||
|
int8_t qs[Q80_BLOCK_SIZE];
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { BlockQ80 y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint zIndex = gl_WorkGroupID.z;
|
||||||
|
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||||
|
|
||||||
|
const uint d = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[b];
|
||||||
|
const uint xiOffset = info.inputOffset + d * Q80_BLOCK_SIZE;
|
||||||
|
const uint yiOffset = info.outputOffset + d;
|
||||||
|
|
||||||
|
float amax = 0.0;
|
||||||
|
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
|
||||||
|
amax = max(amax, abs(x[xiOffset + j]));
|
||||||
|
}
|
||||||
|
|
||||||
|
const float dd = amax / 127.0f;
|
||||||
|
const float id = dd != 0.0f ? 1.0f / dd : 0.0f;
|
||||||
|
|
||||||
|
y[yiOffset].d = float16_t(dd);
|
||||||
|
|
||||||
|
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
|
||||||
|
const float v = x[xiOffset + j];
|
||||||
|
y[yiOffset].qs[j] = int8_t(clamp(round(v * id), -127.0f, 127.0f));
|
||||||
|
}
|
||||||
|
}
|
||||||
38
src/nn/vulkan/embedding-forward-f32-f32.comp
Normal file
38
src/nn/vulkan/embedding-forward-f32-f32.comp
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define CHUNK_SIZE 4
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
|
||||||
|
|
||||||
|
shared uint sharedPosition;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint chunkIndex = gl_WorkGroupID.x;
|
||||||
|
const uint position = uint(x[batchIndex]);
|
||||||
|
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
const uint offset = chunkIndex * CHUNK_SIZE;
|
||||||
|
const uint yOffset = info.outputOffset + offset;
|
||||||
|
const uint wOffset = position * info.outputSizeX + offset;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||||
|
y[yOffset + i] = weight[wOffset + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
55
src/nn/vulkan/inv-rms-forward-f32-f32.comp
Normal file
55
src/nn/vulkan/inv-rms-forward-f32-f32.comp
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define N_THREADS 256
|
||||||
|
|
||||||
|
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly uniform opConfigBuffer {
|
||||||
|
float epsilon;
|
||||||
|
uint nColumns;
|
||||||
|
};
|
||||||
|
|
||||||
|
shared float sums[N_THREADS];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint threadIndex = gl_LocalInvocationID.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint colIndex = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
const uint dim = info.inputSizeX / nColumns;
|
||||||
|
const uint offset = info.inputOffset + dim * colIndex;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (uint i = threadIndex; i < dim; i += N_THREADS) {
|
||||||
|
float v = x[offset + i];
|
||||||
|
sum += v * v;
|
||||||
|
}
|
||||||
|
sums[threadIndex] = sum;
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
|
||||||
|
if (threadIndex < i)
|
||||||
|
sums[threadIndex] += sums[threadIndex + i];
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIndex == 0) {
|
||||||
|
y[batchIndex * nColumns + colIndex] = inversesqrt((sums[0] / float(dim)) + epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
60
src/nn/vulkan/matmul-forward-f32-f32-f32.comp
Normal file
60
src/nn/vulkan/matmul-forward-f32-f32-f32.comp
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#define N_THREADS 128
|
||||||
|
|
||||||
|
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
|
||||||
|
layout(binding = 4) readonly uniform opConfigBuffer {
|
||||||
|
uint nExperts;
|
||||||
|
uint nActiveExperts;
|
||||||
|
uint activeExpertIndexesBufferIndex;
|
||||||
|
};
|
||||||
|
layout(binding = 5) readonly buffer activeExpertIndexesBuffer { float activeExpertIndexes[]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint threadIndex = gl_LocalInvocationID.x;
|
||||||
|
const uint nWorkGroups = gl_NumWorkGroups.x;
|
||||||
|
const uint workGroupIndex = gl_WorkGroupID.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint zIndex = gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
BatchInfo info = infos[zIndex * N_BATCHES + batchIndex];
|
||||||
|
const uint slice = info.outputSizeX / nWorkGroups;
|
||||||
|
const uint rest = info.outputSizeX % nWorkGroups;
|
||||||
|
const uint dim = slice + (workGroupIndex < rest ? 1 : 0);
|
||||||
|
const uint offset = workGroupIndex * slice + min(workGroupIndex, rest);
|
||||||
|
|
||||||
|
const uint expertIndex = nExperts == 0
|
||||||
|
? 0
|
||||||
|
: uint(activeExpertIndexes[batchIndex * nActiveExperts + zIndex]);
|
||||||
|
const uint expertOffset = expertIndex * info.inputSizeX * info.outputSizeX;
|
||||||
|
|
||||||
|
const uint inputSizeX = info.inputSizeX;
|
||||||
|
const uint xOffset = info.inputOffset;
|
||||||
|
const uint yOffset = info.outputOffset;
|
||||||
|
|
||||||
|
for (uint i = threadIndex; i < dim; i += N_THREADS) {
|
||||||
|
const uint d = offset + i;
|
||||||
|
const uint wOffset = expertOffset + d * inputSizeX;
|
||||||
|
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (uint j = 0; j < inputSizeX; j++) {
|
||||||
|
sum += x[xOffset + j] * weight[wOffset + j];
|
||||||
|
}
|
||||||
|
y[yOffset + d] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
125
src/nn/vulkan/matmul-forward-q80-q40-f32.comp
Normal file
125
src/nn/vulkan/matmul-forward-q80-q40-f32.comp
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_8bit_storage : enable
|
||||||
|
#extension GL_EXT_shader_16bit_storage : enable
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types : enable
|
||||||
|
|
||||||
|
#define TILE_SIZE_X 2
|
||||||
|
#define TILE_SIZE_D 8
|
||||||
|
|
||||||
|
#define Q80_Q40_BLOCK_SIZE 32
|
||||||
|
|
||||||
|
layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
layout(constant_id = 2) const uint N_THREADS = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BlockQ80 {
|
||||||
|
float16_t d;
|
||||||
|
int8_t qs[Q80_Q40_BLOCK_SIZE];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BlockQ40 {
|
||||||
|
float16_t d;
|
||||||
|
uint8_t qs[Q80_Q40_BLOCK_SIZE / 2];
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly buffer weightBuffer { BlockQ40 weight[]; };
|
||||||
|
layout(binding = 4) readonly uniform opConfigBuffer {
|
||||||
|
uint nExperts;
|
||||||
|
uint nActiveExperts;
|
||||||
|
uint activeExpertIndexesBufferIndex;
|
||||||
|
};
|
||||||
|
layout(binding = 5) readonly buffer activeExpertIndexesBuffer { float activeExpertIndexes[]; };
|
||||||
|
|
||||||
|
shared float sums[N_THREADS * TILE_SIZE_D];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint threadIndex = gl_LocalInvocationID.x;
|
||||||
|
const uint workGroupIndex = gl_WorkGroupID.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint zIndex = gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||||
|
const BatchInfo info = infos[b];
|
||||||
|
|
||||||
|
const uint expertIndex = nExperts == 0
|
||||||
|
? 0
|
||||||
|
: uint(activeExpertIndexes[batchIndex * nActiveExperts + zIndex]);
|
||||||
|
const uint expertOffset = expertIndex * info.inputSizeX * info.outputSizeX;
|
||||||
|
|
||||||
|
const uint inputOffset = info.inputOffset;
|
||||||
|
const uint inputSizeX = info.inputSizeX;
|
||||||
|
const uint d = TILE_SIZE_D * workGroupIndex;
|
||||||
|
|
||||||
|
vec4 xTemp[Q80_Q40_BLOCK_SIZE / 4];
|
||||||
|
|
||||||
|
for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
|
||||||
|
sums[threadIndex * TILE_SIZE_D + dt] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint it = 0; it < TILE_SIZE_X; it++) {
|
||||||
|
const uint xi = inputOffset + threadIndex + it * N_THREADS;
|
||||||
|
const float xScale = float(x[xi].d);
|
||||||
|
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
|
||||||
|
xTemp[j] = vec4(
|
||||||
|
x[xi].qs[j * 2],
|
||||||
|
x[xi].qs[j * 2 + Q80_Q40_BLOCK_SIZE / 2],
|
||||||
|
x[xi].qs[j * 2 + 1],
|
||||||
|
x[xi].qs[j * 2 + 1 + Q80_Q40_BLOCK_SIZE / 2]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
|
||||||
|
const uint wi = expertOffset + (d + dt) * inputSizeX + threadIndex + it * N_THREADS;
|
||||||
|
const BlockQ40 wBlock = weight[wi];
|
||||||
|
|
||||||
|
float s = 0.0f;
|
||||||
|
[[unroll]] for (uint j = 0; j < Q80_Q40_BLOCK_SIZE / 4; j++) {
|
||||||
|
uint w0 = wBlock.qs[j * 2];
|
||||||
|
uint w1 = wBlock.qs[j * 2 + 1];
|
||||||
|
s += dot(xTemp[j], vec4(
|
||||||
|
int(w0 & 0xFu) - 8,
|
||||||
|
int(w0 >> 4) - 8,
|
||||||
|
int(w1 & 0xFu) - 8,
|
||||||
|
int(w1 >> 4) - 8
|
||||||
|
));
|
||||||
|
}
|
||||||
|
sums[threadIndex * TILE_SIZE_D + dt] += s * xScale * wBlock.d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
const uint outputOffset = infos[b].outputOffset; // Hoisting fix for Raspberry PI
|
||||||
|
|
||||||
|
uint i = N_THREADS;
|
||||||
|
while (i % 2 == 0) {
|
||||||
|
i >>= 1;
|
||||||
|
for (uint dt = 0; dt < TILE_SIZE_D; dt++) {
|
||||||
|
if (threadIndex < i) {
|
||||||
|
sums[threadIndex * TILE_SIZE_D + dt] += sums[(threadIndex + i) * TILE_SIZE_D + dt];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
for (uint dt = threadIndex; dt < TILE_SIZE_D; dt += N_THREADS) {
|
||||||
|
float s = 0.0;
|
||||||
|
for (uint j = 1; j <= i; j++) {
|
||||||
|
s += sums[(j - 1) * TILE_SIZE_D + dt];
|
||||||
|
}
|
||||||
|
y[outputOffset + d + dt] = float(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
46
src/nn/vulkan/merge-add-forward-f32-f32.comp
Normal file
46
src/nn/vulkan/merge-add-forward-f32-f32.comp
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#define N_THREADS 256
|
||||||
|
|
||||||
|
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint threadIndex = gl_LocalInvocationID.x;
|
||||||
|
const uint nWorkGroups = gl_NumWorkGroups.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint workGroupIndex = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
const uint slice = info.outputSizeX / nWorkGroups;
|
||||||
|
const uint rest = info.outputSizeX % nWorkGroups;
|
||||||
|
const uint offset = workGroupIndex * slice + min(rest, workGroupIndex);
|
||||||
|
|
||||||
|
const uint dim = slice + (workGroupIndex < rest ? 1 : 0);
|
||||||
|
const uint outputSizeX = info.outputSizeX;
|
||||||
|
const uint parts = info.inputSizeX / info.outputSizeX;
|
||||||
|
const uint xOffset = info.inputOffset + offset;
|
||||||
|
const uint yOffset = info.outputOffset + offset;
|
||||||
|
|
||||||
|
for (uint i = threadIndex; i < dim; i += N_THREADS) {
|
||||||
|
float sum = 0.0;
|
||||||
|
const uint iOffset = xOffset + i;
|
||||||
|
const uint oOffset = yOffset + i;
|
||||||
|
for (uint n = 0; n < parts; n++) {
|
||||||
|
sum += x[n * outputSizeX + iOffset];
|
||||||
|
}
|
||||||
|
y[oOffset] += sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
57
src/nn/vulkan/merge-add-forward-q80-f32.comp
Normal file
57
src/nn/vulkan/merge-add-forward-q80-f32.comp
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_8bit_storage : enable
|
||||||
|
#extension GL_EXT_shader_16bit_storage : enable
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types : enable
|
||||||
|
|
||||||
|
#define Q80_BLOCK_SIZE 32
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset; // number of Q80 blocks
|
||||||
|
uint inputSizeX; // number of Q80 blocks
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BlockQ80 {
|
||||||
|
float16_t d;
|
||||||
|
int8_t qs[Q80_BLOCK_SIZE];
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { BlockQ80 x[]; };
|
||||||
|
layout(binding = 1) buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint i = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
const uint nBlocks = info.outputSizeX / Q80_BLOCK_SIZE; // 128
|
||||||
|
const uint nSlices = info.inputSizeX / nBlocks;
|
||||||
|
float16_t sums[Q80_BLOCK_SIZE];
|
||||||
|
|
||||||
|
const uint xiOffset = info.inputOffset + i;
|
||||||
|
const uint yiOffset = info.outputOffset + i * Q80_BLOCK_SIZE;
|
||||||
|
|
||||||
|
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
|
||||||
|
sums[k] = float16_t(0.0f);
|
||||||
|
}
|
||||||
|
for (uint n = 0; n < nSlices; n++) {
|
||||||
|
const BlockQ80 b = x[xiOffset + n * nBlocks];
|
||||||
|
const float16_t d = b.d;
|
||||||
|
|
||||||
|
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
|
||||||
|
sums[k] += float16_t(b.qs[k]) * d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint k = 0; k < Q80_BLOCK_SIZE; k++) {
|
||||||
|
y[yiOffset + k] += float(sums[k]);
|
||||||
|
}
|
||||||
|
}
|
||||||
41
src/nn/vulkan/merge-sum-forward-f32-f32.comp
Normal file
41
src/nn/vulkan/merge-sum-forward-f32-f32.comp
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define CHUNK_SIZE 4
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint workGroupIndex = gl_WorkGroupID.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
|
||||||
|
const uint chunkOffset = workGroupIndex * CHUNK_SIZE;
|
||||||
|
|
||||||
|
const BatchInfo outputInfo = infos[batchIndex];
|
||||||
|
const uint yOffset = outputInfo.outputOffset + chunkOffset;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (uint z = 0; z < N_Z; z++) {
|
||||||
|
const BatchInfo inputInfo = infos[batchIndex + z * N_BATCHES];
|
||||||
|
const uint xOffset = inputInfo.inputOffset + chunkOffset;
|
||||||
|
sum += x[xOffset + i];
|
||||||
|
}
|
||||||
|
y[yOffset + i] = sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
69
src/nn/vulkan/moe-gate-forward-f32-f32.comp
Normal file
69
src/nn/vulkan/moe-gate-forward-f32-f32.comp
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N = 16;
|
||||||
|
layout(constant_id = 2) const uint K = 2;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[K * N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly uniform opConfigBuffer {
|
||||||
|
uint k;
|
||||||
|
uint normTopk;
|
||||||
|
uint indexesBufferIndex;
|
||||||
|
};
|
||||||
|
layout(binding = 4) buffer indexesBuffer { float indexes[]; };
|
||||||
|
|
||||||
|
shared float topVals[K];
|
||||||
|
shared uint topIdx[K];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
// TODO: this impl is not optimal
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
BatchInfo info = infos[batchIndex];
|
||||||
|
|
||||||
|
for (uint i = 0; i < K; i++) {
|
||||||
|
topVals[i] = -1e10f;
|
||||||
|
topIdx[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint i = 0; i < info.inputSizeX; i++) {
|
||||||
|
float v = x[info.inputOffset + i];
|
||||||
|
|
||||||
|
for (uint k = 0; k < K; k++) {
|
||||||
|
if (v > topVals[k]) {
|
||||||
|
for (uint s = K - 1; s > k; s--) {
|
||||||
|
topVals[s] = topVals[s - 1];
|
||||||
|
topIdx[s] = topIdx[s - 1];
|
||||||
|
}
|
||||||
|
topVals[k] = v;
|
||||||
|
topIdx[k] = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float sum = 1.0f;
|
||||||
|
if (normTopk == 1) {
|
||||||
|
sum = 0.0f;
|
||||||
|
for (uint k = 0; k < K; k++) {
|
||||||
|
sum += topVals[k];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint k = 0; k < K; k++) {
|
||||||
|
indexes[batchIndex * K + k] = float(topIdx[k]);
|
||||||
|
y[infos[k * N_BATCHES + batchIndex].outputOffset] = topVals[k] / sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
41
src/nn/vulkan/mul-forward-f32-f32.comp
Normal file
41
src/nn/vulkan/mul-forward-f32-f32.comp
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define CHUNK_SIZE 4
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly uniform configBuffer {
|
||||||
|
uint multiplierBufferIndex;
|
||||||
|
};
|
||||||
|
layout(binding = 4) readonly buffer multiplierBuffer { float m[]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint chunkIndex = gl_WorkGroupID.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint zIndex = gl_WorkGroupID.z;
|
||||||
|
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[b];
|
||||||
|
const uint chunkOffset = chunkIndex * CHUNK_SIZE;
|
||||||
|
const uint xyOffset = info.inputOffset + chunkOffset;
|
||||||
|
const uint mOffset = info.inputSizeX * b + chunkOffset;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||||
|
y[xyOffset + i] = x[xyOffset + i] * m[mOffset + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
127
src/nn/vulkan/multi-head-att-forward-f32-f32.comp
Normal file
127
src/nn/vulkan/multi-head-att-forward-f32-f32.comp
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define N_THREADS 256
|
||||||
|
|
||||||
|
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly uniform configBuffer {
|
||||||
|
uint nHeads;
|
||||||
|
uint nHeads0;
|
||||||
|
uint nKvHeads;
|
||||||
|
uint headDim;
|
||||||
|
uint seqLen;
|
||||||
|
uint qSliceD0;
|
||||||
|
uint kvDim0;
|
||||||
|
uint positionPipeIndex;
|
||||||
|
uint queryBufferIndex;
|
||||||
|
uint keyCacheBufferIndex;
|
||||||
|
uint valueCacheBufferIndex;
|
||||||
|
uint attBufferIndex;
|
||||||
|
};
|
||||||
|
layout(binding = 4) readonly buffer positionsBuffer { float positions[]; };
|
||||||
|
layout(binding = 5) readonly buffer queryBuffer { float query[]; };
|
||||||
|
layout(binding = 6) readonly buffer keyCacheBuffer { float keyCache[]; };
|
||||||
|
layout(binding = 7) readonly buffer valueCacheBuffer { float valueCache[]; };
|
||||||
|
layout(binding = 8) buffer attBufferBuffer { float att[]; };
|
||||||
|
|
||||||
|
shared uint sharedPosition;
|
||||||
|
shared float sharedMaxScore;
|
||||||
|
shared float temp[N_THREADS];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint threadIndex = gl_LocalInvocationID.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint h = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
if (threadIndex == 0) {
|
||||||
|
sharedPosition = uint(positions[batchIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
const uint kvMul = nHeads / nKvHeads;
|
||||||
|
const uint headIndex = h / kvMul;
|
||||||
|
const float invHeadDimRoot = 1.0f / sqrt(float(headDim));
|
||||||
|
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
const uint position = sharedPosition;
|
||||||
|
|
||||||
|
const uint attOffset = batchIndex * nHeads0 * seqLen + h * seqLen;
|
||||||
|
const uint qOffset = batchIndex * qSliceD0 + h * headDim;
|
||||||
|
const uint kvOffset = headIndex * headDim;
|
||||||
|
const uint yOffset = info.outputOffset + h * headDim;
|
||||||
|
|
||||||
|
float ms = -1e10f;
|
||||||
|
for (uint p = threadIndex; p <= position; p += N_THREADS) {
|
||||||
|
const uint kOffset = kvOffset + p * kvDim0;
|
||||||
|
|
||||||
|
float score = 0.0f;
|
||||||
|
for (uint i = 0; i < headDim; i++) {
|
||||||
|
score += query[qOffset + i] * keyCache[kOffset + i];
|
||||||
|
}
|
||||||
|
score *= invHeadDimRoot;
|
||||||
|
ms = max(ms, score);
|
||||||
|
att[attOffset + p] = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
temp[threadIndex] = ms;
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
|
||||||
|
if (threadIndex < i)
|
||||||
|
temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (threadIndex == 0) {
|
||||||
|
sharedMaxScore = temp[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
const float maxScore = sharedMaxScore;
|
||||||
|
|
||||||
|
float s = 0.0f;
|
||||||
|
for (uint p = threadIndex; p <= position; p += N_THREADS) {
|
||||||
|
float v = exp(att[attOffset + p] - maxScore);
|
||||||
|
att[attOffset + p] = v;
|
||||||
|
s += v;
|
||||||
|
}
|
||||||
|
|
||||||
|
temp[threadIndex] = s;
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
|
||||||
|
if (threadIndex < i)
|
||||||
|
temp[threadIndex] += temp[threadIndex + i];
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
const float yScale = 1.0f / temp[0];
|
||||||
|
|
||||||
|
for (uint i = threadIndex; i < headDim; i += N_THREADS) {
|
||||||
|
float sum = 0.0f;
|
||||||
|
const uint vOffset = kvOffset + i;
|
||||||
|
for (uint p = 0; p <= position; p += 1) {
|
||||||
|
const float a = att[attOffset + p];
|
||||||
|
const float v = valueCache[vOffset + p * kvDim0];
|
||||||
|
sum += v * a;
|
||||||
|
}
|
||||||
|
y[yOffset + i] = sum * yScale;
|
||||||
|
}
|
||||||
|
}
|
||||||
60
src/nn/vulkan/repeatz-forward-f32-q80.comp
Normal file
60
src/nn/vulkan/repeatz-forward-f32-q80.comp
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
#extension GL_EXT_shader_8bit_storage : enable
|
||||||
|
#extension GL_EXT_shader_16bit_storage : enable
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types : enable
|
||||||
|
|
||||||
|
#define Q80_BLOCK_SIZE 32
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct BlockQ80 {
|
||||||
|
float16_t d;
|
||||||
|
int8_t qs[Q80_BLOCK_SIZE];
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { BlockQ80 y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint d = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
const uint xiOffset = info.inputOffset + d * Q80_BLOCK_SIZE;
|
||||||
|
|
||||||
|
float amax = 0.0;
|
||||||
|
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
|
||||||
|
amax = max(amax, abs(x[xiOffset + j]));
|
||||||
|
}
|
||||||
|
|
||||||
|
const float dd = amax / 127.0f;
|
||||||
|
const float id = dd != 0.0f ? 1.0f / dd : 0.0f;
|
||||||
|
|
||||||
|
const float16_t dd16 = float16_t(dd);
|
||||||
|
for (uint z = 0; z < N_Z; z++) {
|
||||||
|
const uint yiOffset = infos[batchIndex + z * N_BATCHES].outputOffset + d;
|
||||||
|
y[yiOffset].d = dd16;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint j = 0; j < Q80_BLOCK_SIZE; ++j) {
|
||||||
|
const float v = x[xiOffset + j];
|
||||||
|
const int8_t v8 = int8_t(clamp(round(v * id), -127.0f, 127.0f));
|
||||||
|
for (uint z = 0; z < N_Z; z++) {
|
||||||
|
const uint yiOffset = infos[batchIndex + z * N_BATCHES].outputOffset + d;
|
||||||
|
y[yiOffset].qs[j] = v8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
44
src/nn/vulkan/rms-norm-forward-f32-f32-f32.comp
Normal file
44
src/nn/vulkan/rms-norm-forward-f32-f32-f32.comp
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define CHUNK_SIZE 4
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly buffer weightBuffer { float weight[]; };
|
||||||
|
layout(binding = 4) readonly uniform configBuffer {
|
||||||
|
uint invRmsBufferIndex; // not used
|
||||||
|
uint nColumns;
|
||||||
|
};
|
||||||
|
layout(binding = 5) readonly buffer invRmsBuffer { float invRms[]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint chunkIndex = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
const uint dim = info.inputSizeX / nColumns;
|
||||||
|
const uint offset = chunkIndex * CHUNK_SIZE;
|
||||||
|
const uint colIndex = offset / dim;
|
||||||
|
const float s = invRms[batchIndex * nColumns + colIndex];
|
||||||
|
|
||||||
|
const uint xOffset = info.inputOffset + offset;
|
||||||
|
const uint yOffset = info.outputOffset + offset;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||||
|
y[yOffset + i] = (x[xOffset + i] * s) * weight[(offset + i) % dim];
|
||||||
|
}
|
||||||
|
}
|
||||||
115
src/nn/vulkan/rope-forward-f32-f32.comp
Normal file
115
src/nn/vulkan/rope-forward-f32-f32.comp
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#define N_THREADS 256
|
||||||
|
|
||||||
|
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct RopeSlice {
|
||||||
|
uint qDim0;
|
||||||
|
uint qDimStart;
|
||||||
|
uint qDimEnd;
|
||||||
|
uint qShift;
|
||||||
|
uint kvDim;
|
||||||
|
uint kvDim0;
|
||||||
|
uint kvDimStart;
|
||||||
|
uint sliceDim;
|
||||||
|
uint seqLen;
|
||||||
|
uint headDim;
|
||||||
|
uint nKvHeads;
|
||||||
|
float ropeTheta;
|
||||||
|
// NnSize2D cacheSize;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly uniform configBuffer {
|
||||||
|
uint ropeType;
|
||||||
|
uint isQ;
|
||||||
|
uint positionPipeIndex;
|
||||||
|
uint ropeCacheBufferIndex;
|
||||||
|
float ropeScalingFactor;
|
||||||
|
float ropeScalingLowFreqFactor;
|
||||||
|
float ropeScalingHighFreqFactor;
|
||||||
|
uint ropeScalingOrigMaxSeqLen;
|
||||||
|
RopeSlice slice;
|
||||||
|
};
|
||||||
|
layout(binding = 4) readonly buffer positionsBuffer { float positions[]; };
|
||||||
|
layout(binding = 5) readonly buffer ropeCacheBuffer { float ropeCache[]; };
|
||||||
|
|
||||||
|
shared uint sharedPosition;
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint threadIndex = gl_LocalInvocationID.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
|
||||||
|
if (threadIndex == 0) {
|
||||||
|
sharedPosition = uint(positions[batchIndex]);
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
const uint position = sharedPosition;
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
|
||||||
|
const uint xOffset = info.inputOffset;
|
||||||
|
const uint yOffset = info.outputOffset;
|
||||||
|
const uint dim0 = isQ == 1 ? slice.qDim0 : slice.kvDim0;
|
||||||
|
|
||||||
|
if (ropeType == 0 || ropeType == 2 /* Llama */) {
|
||||||
|
uint posOffset = position * slice.sliceDim;
|
||||||
|
if (isQ == 1) {
|
||||||
|
posOffset += slice.qShift;
|
||||||
|
}
|
||||||
|
const uint dim0Half = dim0 / 2;
|
||||||
|
|
||||||
|
for (uint i = threadIndex; i < dim0Half; i += N_THREADS) {
|
||||||
|
const uint j = i * 2;
|
||||||
|
const uint c = posOffset + j;
|
||||||
|
|
||||||
|
float fcr = ropeCache[c];
|
||||||
|
float fci = ropeCache[c + 1];
|
||||||
|
float v0 = x[xOffset + j];
|
||||||
|
float v1 = x[xOffset + j + 1];
|
||||||
|
|
||||||
|
const float x0 = fma(-v1, fci, v0 * fcr);
|
||||||
|
const float x1 = fma( v0, fci, v1 * fcr);
|
||||||
|
|
||||||
|
y[yOffset + j] = x0;
|
||||||
|
y[yOffset + j + 1] = x1;
|
||||||
|
}
|
||||||
|
} else if (ropeType == 1 /* Falcon */) {
|
||||||
|
const uint posOffset = position * slice.headDim;
|
||||||
|
const uint headDim = slice.headDim;
|
||||||
|
const uint headDimHalf = headDim / 2;
|
||||||
|
const uint nHeads0 = dim0 / headDim;
|
||||||
|
|
||||||
|
for (uint h = 0; h < nHeads0; h++) {
|
||||||
|
const uint o = h * headDim;
|
||||||
|
|
||||||
|
for (uint i = threadIndex; i < headDimHalf; i += N_THREADS) {
|
||||||
|
const uint c = posOffset + i;
|
||||||
|
float fcr = ropeCache[c];
|
||||||
|
float fci = ropeCache[c + headDimHalf];
|
||||||
|
|
||||||
|
float v0 = x[xOffset + o + i];
|
||||||
|
float v1 = x[xOffset + o + i + headDimHalf];
|
||||||
|
|
||||||
|
float x0 = v0 * fcr - v1 * fci;
|
||||||
|
float x1 = v0 * fci + v1 * fcr;
|
||||||
|
|
||||||
|
y[yOffset + o + i] = x0;
|
||||||
|
y[yOffset + o + i + headDimHalf] = x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
43
src/nn/vulkan/scale-forward-f32-f32.comp
Normal file
43
src/nn/vulkan/scale-forward-f32-f32.comp
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define CHUNK_SIZE 4
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly buffer configBuffer {
|
||||||
|
uint scaleBufferIndex;
|
||||||
|
};
|
||||||
|
layout(binding = 4) readonly buffer scaleBuffer { float scale[]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint workGroupIndex = gl_WorkGroupID.x;
|
||||||
|
const uint zIndex = gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||||
|
|
||||||
|
const float s = scale[b];
|
||||||
|
const BatchInfo info = infos[b];
|
||||||
|
const uint offset = workGroupIndex * CHUNK_SIZE;
|
||||||
|
const uint xOffset = info.inputOffset + offset;
|
||||||
|
const uint yOffset = info.outputOffset + offset;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||||
|
y[yOffset + i] = x[xOffset + i] * s;
|
||||||
|
}
|
||||||
|
}
|
||||||
40
src/nn/vulkan/shift-forward-f32-f32.comp
Normal file
40
src/nn/vulkan/shift-forward-f32-f32.comp
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define CHUNK_SIZE 4
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_BATCHES]; };
|
||||||
|
layout(binding = 3) readonly uniform configBuffer {
|
||||||
|
uint indexPipeIndex;
|
||||||
|
};
|
||||||
|
layout(binding = 4) readonly buffer indexBuffer { float indexes[]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint chunkIndex = gl_WorkGroupID.x;
|
||||||
|
|
||||||
|
const uint index = uint(indexes[batchIndex]);
|
||||||
|
|
||||||
|
const BatchInfo info = infos[batchIndex];
|
||||||
|
const uint offset = chunkIndex * CHUNK_SIZE;
|
||||||
|
const uint xOffset = info.inputOffset + offset;;
|
||||||
|
const uint yOffset = index * info.inputSizeX + offset;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||||
|
y[yOffset + i] = x[xOffset + i];
|
||||||
|
}
|
||||||
|
}
|
||||||
38
src/nn/vulkan/silu-forward-f32-f32.comp
Normal file
38
src/nn/vulkan/silu-forward-f32-f32.comp
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define CHUNK_SIZE 4
|
||||||
|
|
||||||
|
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint chunkIndex = gl_WorkGroupID.x;
|
||||||
|
const uint zIndex = gl_WorkGroupID.z;
|
||||||
|
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[b];
|
||||||
|
const uint offset = chunkIndex * CHUNK_SIZE;
|
||||||
|
const uint xOffset = info.inputOffset + offset;
|
||||||
|
const uint yOffset = info.outputOffset + offset;
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = 0; i < CHUNK_SIZE; i++) {
|
||||||
|
float v = x[xOffset + i];
|
||||||
|
y[yOffset + i] = v / (1.0f + exp(-v));
|
||||||
|
}
|
||||||
|
}
|
||||||
56
src/nn/vulkan/softmax-forward-f32-f32.comp
Normal file
56
src/nn/vulkan/softmax-forward-f32-f32.comp
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
|
#define N_THREADS 256
|
||||||
|
|
||||||
|
layout(local_size_x = N_THREADS, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout(constant_id = 0) const uint N_BATCHES = 32;
|
||||||
|
layout(constant_id = 1) const uint N_Z = 1;
|
||||||
|
|
||||||
|
struct BatchInfo {
|
||||||
|
uint inputOffset;
|
||||||
|
uint inputSizeX;
|
||||||
|
uint outputOffset;
|
||||||
|
uint outputSizeX;
|
||||||
|
};
|
||||||
|
|
||||||
|
layout(binding = 0) readonly buffer inputBuffer { float x[]; };
|
||||||
|
layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
|
||||||
|
layout(binding = 2) readonly uniform batchInfosBuffer { BatchInfo infos[N_Z * N_BATCHES]; };
|
||||||
|
|
||||||
|
shared float temp[N_THREADS];
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint threadIndex = gl_LocalInvocationID.x;
|
||||||
|
const uint batchIndex = gl_WorkGroupID.y;
|
||||||
|
const uint zIndex = gl_WorkGroupID.z;
|
||||||
|
const uint b = zIndex * N_BATCHES + batchIndex;
|
||||||
|
|
||||||
|
const BatchInfo info = infos[b];
|
||||||
|
const uint size = info.inputSizeX;
|
||||||
|
const uint xOffset = info.inputOffset;
|
||||||
|
const uint yOffset = info.outputOffset;
|
||||||
|
|
||||||
|
float m = -1e10f;
|
||||||
|
for (uint i = threadIndex; i < size; i += N_THREADS) {
|
||||||
|
m = max(m, x[xOffset + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
temp[threadIndex] = m;
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
[[unroll]] for (uint i = N_THREADS / 2; i > 0; i >>= 1) {
|
||||||
|
if (threadIndex < i)
|
||||||
|
temp[threadIndex] = max(temp[threadIndex], temp[threadIndex + i]);
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
|
||||||
|
barrier();
|
||||||
|
|
||||||
|
const float maxVal = temp[0];
|
||||||
|
for (uint i = threadIndex; i < size; i += N_THREADS) {
|
||||||
|
y[yOffset + i] = exp(x[xOffset + i] - maxVal);
|
||||||
|
}
|
||||||
|
}
|
||||||
319
src/tokenizer-test.cpp
Normal file
319
src/tokenizer-test.cpp
Normal file
@@ -0,0 +1,319 @@
|
|||||||
|
#include <cassert>
|
||||||
|
#include <cstring>
|
||||||
|
#include "tokenizer.hpp"
|
||||||
|
|
||||||
|
#define DEV_TESTS false
|
||||||
|
|
||||||
|
#define ASSERT_EQ(a, b) \
|
||||||
|
if (a != b) { \
|
||||||
|
printf("Assertion failed: %d != %d (%s:%d)\n", a, b, __FILE__, __LINE__); \
|
||||||
|
exit(-1); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define TEST_EOS_ID 10000
|
||||||
|
|
||||||
|
void printOk(const char *name) {
|
||||||
|
printf("✅ %24s passed\n", name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void compare(const char *name, const int *a, const int *b, const unsigned int aN, const int bN) {
|
||||||
|
bool passed = true;
|
||||||
|
if (aN != bN) {
|
||||||
|
passed = false;
|
||||||
|
} else {
|
||||||
|
for (unsigned int i = 0; i < bN; i++) {
|
||||||
|
if (a[i] != b[i]) {
|
||||||
|
passed = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!passed) {
|
||||||
|
printf("❌ %24s failed\na: ", name);
|
||||||
|
for (unsigned int j = 0; j < aN; j++)
|
||||||
|
printf("%5d ", a[j]);
|
||||||
|
printf("\nb: ");
|
||||||
|
for (unsigned int j = 0; j < bN; j++)
|
||||||
|
printf("%5d ", b[j]);
|
||||||
|
printf("\n");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
printOk(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void dev_testEncode(Tokenizer *tokenizer) {
|
||||||
|
int tokens[1024];
|
||||||
|
int nTokens;
|
||||||
|
|
||||||
|
{
|
||||||
|
const char *text = "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||||
|
const int expectedTokens[] = {128000, 128006, 882, 128007, 271, 15339, 128009, 128006, 78191, 128007, 271};
|
||||||
|
|
||||||
|
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
|
||||||
|
compare("case0", expectedTokens, tokens, 11, nTokens);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const char *text = "!!&&@(*x)^^!";
|
||||||
|
const int expectedTokens[] = {128000, 3001, 7827, 31, 4163, 87, 8, 22634, 0};
|
||||||
|
|
||||||
|
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
|
||||||
|
compare("case1", expectedTokens, tokens, 9, nTokens);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
const char *text = "😃!😇x";
|
||||||
|
const int expectedTokens[] = {128000, 76460, 225, 0, 76460, 229, 87};
|
||||||
|
|
||||||
|
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
|
||||||
|
compare("case2", expectedTokens, tokens, 7, nTokens);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void dev_testDecoderEmojiStreamRecover(Tokenizer *tokenizer) {
|
||||||
|
char *x0 = tokenizer->decode(128000);
|
||||||
|
assert(x0 == nullptr);
|
||||||
|
|
||||||
|
char *x1 = tokenizer->decode(76460);
|
||||||
|
assert(x1 == nullptr);
|
||||||
|
|
||||||
|
char *x2 = tokenizer->decode(76460);
|
||||||
|
assert(x2 == nullptr);
|
||||||
|
|
||||||
|
char *x3 = tokenizer->decode(225);
|
||||||
|
assert(strcmp(x3, "<EFBFBD>😃") == 0);
|
||||||
|
|
||||||
|
printOk("testDecoderEmojiStreamRecover");
|
||||||
|
}
|
||||||
|
|
||||||
|
void dev_testDecoderEmoji(Tokenizer *tokenizer) {
|
||||||
|
char *x0 = tokenizer->decode(128000);
|
||||||
|
assert(x0 == nullptr);
|
||||||
|
|
||||||
|
char *x1 = tokenizer->decode(76460);
|
||||||
|
assert(x1 == nullptr);
|
||||||
|
|
||||||
|
char *x2 = tokenizer->decode(225);
|
||||||
|
assert(strcmp(x2, "😃") == 0);
|
||||||
|
|
||||||
|
char *x3 = tokenizer->decode(0);
|
||||||
|
assert(strcmp(x3, "!") == 0);
|
||||||
|
|
||||||
|
char *x4 = tokenizer->decode(56);
|
||||||
|
assert(strcmp(x4, "Y") == 0);
|
||||||
|
|
||||||
|
printOk("testDecoderEmoji");
|
||||||
|
}
|
||||||
|
|
||||||
|
void dev_testDecoderEmojiWithEos(Tokenizer *tokenizer) {
|
||||||
|
char *x0 = tokenizer->decode(128000);
|
||||||
|
assert(x0 == nullptr);
|
||||||
|
|
||||||
|
char *x1 = tokenizer->decode(76460);
|
||||||
|
assert(x1 == nullptr);
|
||||||
|
|
||||||
|
char *x2 = tokenizer->decode(225);
|
||||||
|
assert(strcmp(x2, "😃") == 0);
|
||||||
|
|
||||||
|
char *x3 = tokenizer->decode(128001);
|
||||||
|
assert(x3 == nullptr); // piece should not contain <|end_of_text|>
|
||||||
|
|
||||||
|
printOk("decoderEmojiWithEos");
|
||||||
|
}
|
||||||
|
|
||||||
|
void testChatTemplateDetection() {
|
||||||
|
ChatTemplateGenerator t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "<eos>");
|
||||||
|
assert(t0.type == TEMPLATE_LLAMA3);
|
||||||
|
|
||||||
|
printOk("chatTemplateDetection");
|
||||||
|
}
|
||||||
|
|
||||||
|
void testEosDetectorWithPadding() {
|
||||||
|
const int tokens[2] = {TEST_EOS_ID, TEST_EOS_ID + 1};
|
||||||
|
const char *pieces[2] = { "<eos>", "<stop>" };
|
||||||
|
EosDetector detector(2, tokens, pieces, 1, 1);
|
||||||
|
|
||||||
|
// "<eos>"
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(2, "eo"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(3, "s>"), EOS);
|
||||||
|
assert(detector.getDelta() == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// "<stop> "
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(2, "stop"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(3, "> "), EOS);
|
||||||
|
assert(detector.getDelta() == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// " "
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, " "), NOT_EOS);
|
||||||
|
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, " ") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// "!<eos> "
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "!<"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(2, "eos"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(3, "> "), EOS);
|
||||||
|
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, "!") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// "!<eos> "
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "<eo"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(2, "s>XY"), NOT_EOS);
|
||||||
|
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, "<eos>XY") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// "<eo" + EOS
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "<eo"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
|
||||||
|
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, "<eo") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// EOS
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
|
||||||
|
assert(detector.getDelta() == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// after reset it's expected to return nullptr delta if to the append() passed nullptr piece
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "x"), NOT_EOS);
|
||||||
|
char *delta0 = detector.getDelta();
|
||||||
|
assert(std::strcmp(delta0, "x") == 0);
|
||||||
|
|
||||||
|
detector.reset();
|
||||||
|
|
||||||
|
ASSERT_EQ(detector.append(2, nullptr), NOT_EOS);
|
||||||
|
char *delta1 = detector.getDelta();
|
||||||
|
assert(delta1 == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
printOk("eosDetectorWithPadding");
|
||||||
|
}
|
||||||
|
|
||||||
|
void testEosDetectorWithLongPadding() {
|
||||||
|
const int tokens[1] = {TEST_EOS_ID};
|
||||||
|
const char *pieces[1] = { "|end|" };
|
||||||
|
EosDetector detector(1, tokens, pieces, 5, 5);
|
||||||
|
|
||||||
|
// "lipsum"
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "lipsum"), NOT_EOS);
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, "lipsum") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// "lorem"
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "lorem"), NOT_EOS);
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, "lorem") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// "lorem|enQ"
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "lorem|"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(2, "enQ"), NOT_EOS);
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, "lorem|enQ") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
printOk("eosDetectorWithLongPadding");
|
||||||
|
}
|
||||||
|
|
||||||
|
void testEosDetectorWithoutPadding() {
|
||||||
|
const int tokens[1] = {TEST_EOS_ID};
|
||||||
|
const char *pieces[1] = { "<eos>" };
|
||||||
|
EosDetector detector(1, tokens, pieces, 0, 0);
|
||||||
|
|
||||||
|
// "<eos>"
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(2, "eo"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(3, "s>"), EOS);
|
||||||
|
assert(detector.getDelta() == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// " <"
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, " <"), NOT_EOS);
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, " <") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// "<eos> "
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(1, "<eos"), MAYBE_EOS);
|
||||||
|
ASSERT_EQ(detector.append(2, "> "), NOT_EOS);
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, "<eos> ") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// EOS
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
|
||||||
|
assert(detector.getDelta() == nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// emoji
|
||||||
|
detector.reset();
|
||||||
|
{
|
||||||
|
ASSERT_EQ(detector.append(TEST_EOS_ID, "😃"), EOS);
|
||||||
|
char *delta = detector.getDelta();
|
||||||
|
assert(delta != nullptr);
|
||||||
|
assert(std::strcmp(delta, "😃") == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
printOk("eosDetectorWithLongPadding");
|
||||||
|
}
|
||||||
|
|
||||||
|
int main() {
|
||||||
|
#if DEV_TESTS
|
||||||
|
Tokenizer tokenizer("models/llama3_2_1b_instruct_q40/dllama_tokenizer_llama3_2_1b_instruct_q40.t");
|
||||||
|
dev_testEncode(&tokenizer);
|
||||||
|
dev_testDecoderEmoji(&tokenizer);
|
||||||
|
dev_testDecoderEmojiWithEos(&tokenizer);
|
||||||
|
dev_testDecoderEmojiStreamRecover(&tokenizer);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
testChatTemplateDetection();
|
||||||
|
testEosDetectorWithPadding();
|
||||||
|
testEosDetectorWithLongPadding();
|
||||||
|
testEosDetectorWithoutPadding();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
722
src/tokenizer.cpp
Normal file
722
src/tokenizer.cpp
Normal file
@@ -0,0 +1,722 @@
|
|||||||
|
#include <cstdio>
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdlib>
|
||||||
|
#include <cstring>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <ctype.h>
|
||||||
|
#include <ctime>
|
||||||
|
#include <cassert>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <sstream>
|
||||||
|
#include <vector>
|
||||||
|
#include "nn/nn-core.hpp"
|
||||||
|
#include "nn/nn-cpu-ops.hpp"
|
||||||
|
#include "tokenizer.hpp"
|
||||||
|
#if defined(__ARM_NEON)
|
||||||
|
#include <arm_neon.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define DEBUG_TOKENIZER_ENCODER false
|
||||||
|
#define DEBUG_TOKENIZER_BENCHMARK false
|
||||||
|
#define DEBUG_TEMPLATE_GENERATOR false
|
||||||
|
#define DEBUG_SAMPLER_BENCHMARK false
|
||||||
|
|
||||||
|
unsigned int randomU32(unsigned long long *state) {
|
||||||
|
// xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
|
||||||
|
*state ^= *state >> 12;
|
||||||
|
*state ^= *state << 25;
|
||||||
|
*state ^= *state >> 27;
|
||||||
|
return (*state * 0x2545F4914F6CDD1Dull) >> 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
float randomF32(unsigned long long *state) {
|
||||||
|
// random float32 in <0,1)
|
||||||
|
return (randomU32(state) >> 8) / 16777216.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
int compareTokens(const void *a, const void *b) {
|
||||||
|
return strcmp(((TokenIndex*)a)->str, ((TokenIndex*)b)->str);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tokenizer::Tokenizer(const char* tokenizerPath)
|
||||||
|
: eosTokenIds() {
|
||||||
|
bosId = -1;
|
||||||
|
chatTemplate = nullptr;
|
||||||
|
maxTokenLength = 0;
|
||||||
|
|
||||||
|
// read in the file
|
||||||
|
FILE *file = fopen(tokenizerPath, "rb");
|
||||||
|
if (!file)
|
||||||
|
throw std::runtime_error("Failed to open tokenizer file");
|
||||||
|
int magic;
|
||||||
|
if (fread(&magic, sizeof(int), 1, file) != 1)
|
||||||
|
throw std::runtime_error("Cannot read tokenizer magic number");
|
||||||
|
|
||||||
|
|
||||||
|
if (magic == 0x567123) {
|
||||||
|
TokenizerOldHeader header;
|
||||||
|
if (fread(&header, sizeof(TokenizerOldHeader), 1, file) != 1)
|
||||||
|
throw std::runtime_error("Cannot read tokenizer header");
|
||||||
|
maxTokenLength = header.maxTokenLength;
|
||||||
|
vocabSize = header.vocabSize;
|
||||||
|
bosId = header.bosId;
|
||||||
|
eosTokenIds.push_back(header.eosId);
|
||||||
|
} else if (magic == 0x567124) {
|
||||||
|
int headerSize;
|
||||||
|
if (fread(&headerSize, sizeof(int), 1, file) != 1)
|
||||||
|
throw std::runtime_error("Cannot read tokenizer header size");
|
||||||
|
int nKv = (headerSize - 2 * sizeof(int)) / sizeof(int);
|
||||||
|
std::vector<int> buffer(nKv);
|
||||||
|
if (fread(&buffer[0], nKv * sizeof(int), 1, file) != 1) {
|
||||||
|
throw std::runtime_error("Cannot read header values");
|
||||||
|
}
|
||||||
|
int version = -1;
|
||||||
|
int chatTemplateLength = -1;
|
||||||
|
int nEosTokens = 0;
|
||||||
|
for (int i = 0; i < nKv; i += 2) {
|
||||||
|
int key = buffer[i];
|
||||||
|
int value = buffer[i + 1];
|
||||||
|
if (key == TOK_VERSION) version = value;
|
||||||
|
else if (key == TOK_VOCAB_SIZE) vocabSize = value;
|
||||||
|
else if (key == MAX_TOKEN_LENGTH) maxTokenLength = (unsigned int)value;
|
||||||
|
else if (key == BOS_ID) bosId = value;
|
||||||
|
else if (key == EOS_ID) eosTokenIds.push_back(value); // Backward compatibility
|
||||||
|
else if (key == CHAT_EOS_ID) eosTokenIds.push_back(value); // Backward compatibility
|
||||||
|
else if (key == CHAT_TEMPLATE) chatTemplateLength = value;
|
||||||
|
else if (key == CHAT_STOP) fseek(file, value, SEEK_CUR); // Ignored
|
||||||
|
else if (key == PAD_ID) {} // Ignored
|
||||||
|
else if (key == N_EOS_TOKENS) nEosTokens = value;
|
||||||
|
else if (key == ADD_BOS) addBos = value == 1;
|
||||||
|
else {
|
||||||
|
throw std::runtime_error("Invalid tokenizer header key:" + std::to_string(key));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (version != 1)
|
||||||
|
throw std::runtime_error("Old tokenizer version, please regenerate your tokenizer");
|
||||||
|
|
||||||
|
if (chatTemplateLength > 0) {
|
||||||
|
chatTemplate = new char[chatTemplateLength + 1];
|
||||||
|
if (fread(chatTemplate, chatTemplateLength, 1, file) != 1)
|
||||||
|
throw std::runtime_error("Cannot read chat template from tokenizer file");
|
||||||
|
chatTemplate[chatTemplateLength] = '\0';
|
||||||
|
}
|
||||||
|
if (nEosTokens > 0) {
|
||||||
|
int eosTokenId;
|
||||||
|
for (int i = 0; i < nEosTokens; i++) {
|
||||||
|
if (fread(&eosTokenId, sizeof(int), 1, file) != 1)
|
||||||
|
throw std::runtime_error("Cannot read eos token id from tokenizer file");
|
||||||
|
eosTokenIds.push_back(eosTokenId);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Invalid tokenizer file");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (maxTokenLength < 1)
|
||||||
|
throw std::runtime_error("Invalid tokenizer max token length");
|
||||||
|
|
||||||
|
// malloc space to hold the scores and the strings
|
||||||
|
vocab = new char*[vocabSize];
|
||||||
|
vocabLength = new unsigned int[vocabSize];
|
||||||
|
vocabScores = new float[vocabSize];
|
||||||
|
|
||||||
|
int length;
|
||||||
|
for (int i = 0; i < vocabSize; i++) {
|
||||||
|
if (fread(vocabScores + i, sizeof(float), 1, file) != 1)
|
||||||
|
throw std::runtime_error("Cannot read size from tokenizer file");
|
||||||
|
if (fread(&length, sizeof(int), 1, file) != 1)
|
||||||
|
throw std::runtime_error("Cannot read length from tokenizer file");
|
||||||
|
vocab[i] = new char[length + 1];
|
||||||
|
if (fread(vocab[i], length, 1, file) != 1)
|
||||||
|
throw std::runtime_error("Cannot read word from tokenizer file");
|
||||||
|
vocab[i][length] = '\0'; // add the string terminating token
|
||||||
|
vocabLength[i] = length;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: this is very unstable assumption that bosId splits regular and special vocab
|
||||||
|
regularVocabSize = bosId;
|
||||||
|
specialVocabSize = vocabSize - regularVocabSize;
|
||||||
|
|
||||||
|
regularVocab = new TokenIndex[regularVocabSize];
|
||||||
|
for (int i = 0; i < regularVocabSize; i++) {
|
||||||
|
regularVocab[i].str = vocab[i];
|
||||||
|
regularVocab[i].id = i;
|
||||||
|
}
|
||||||
|
qsort(regularVocab, regularVocabSize, sizeof(TokenIndex), compareTokens);
|
||||||
|
|
||||||
|
specialVocab = new TokenIndex[specialVocabSize];
|
||||||
|
for (int i = 0; i < specialVocabSize; i++) {
|
||||||
|
specialVocab[i].str = vocab[i + regularVocabSize];
|
||||||
|
specialVocab[i].id = i + regularVocabSize;
|
||||||
|
}
|
||||||
|
|
||||||
|
strBufferSize = maxTokenLength * 2;
|
||||||
|
if (strBufferSize < (4 * 2)) { // ensure place for 2 utf-8 multi-byte sequence
|
||||||
|
strBufferSize = 4 * 2;
|
||||||
|
}
|
||||||
|
strBufferSize += 1 + 2;
|
||||||
|
strBuffer = new char[strBufferSize];
|
||||||
|
utf8Buffer = new char[strBufferSize];
|
||||||
|
|
||||||
|
if (bosId >= 0) {
|
||||||
|
printf("📄 AddBos: %d\n", addBos ? 1 : 0);
|
||||||
|
printf("📄 BosId: %d (%s)\n", bosId, vocab[bosId]);
|
||||||
|
}
|
||||||
|
if (eosTokenIds.size() > 0) {
|
||||||
|
printf("📄 EosId: ");
|
||||||
|
for (unsigned int i = 0; i < eosTokenIds.size(); i++) {
|
||||||
|
printf("%d (%s) ", eosTokenIds[i], vocab[eosTokenIds[i]]);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
}
|
||||||
|
printf("📄 RegularVocabSize: %d\n", regularVocabSize);
|
||||||
|
printf("📄 SpecialVocabSize: %d\n", specialVocabSize);
|
||||||
|
|
||||||
|
fclose(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tokenizer::~Tokenizer() {
|
||||||
|
if (chatTemplate != NULL) delete[] chatTemplate;
|
||||||
|
|
||||||
|
for (int i = 0; i < vocabSize; i++)
|
||||||
|
delete[] vocab[i];
|
||||||
|
delete[] vocab;
|
||||||
|
delete[] vocabLength;
|
||||||
|
delete[] vocabScores;
|
||||||
|
delete[] regularVocab;
|
||||||
|
delete[] specialVocab;
|
||||||
|
delete[] strBuffer;
|
||||||
|
delete[] utf8Buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tokenizer::findSpecialTokenStartWith(char *piece) {
|
||||||
|
for (unsigned int i = 0; i < specialVocabSize; i++) {
|
||||||
|
unsigned int tokenId = specialVocab[i].id;
|
||||||
|
unsigned int length = vocabLength[tokenId];
|
||||||
|
if (std::strncmp(vocab[tokenId], piece, length) == 0)
|
||||||
|
return tokenId;
|
||||||
|
}
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tokenizer::findRegularToken(char *piece) {
|
||||||
|
TokenIndex tok = { .str = piece };
|
||||||
|
TokenIndex *res = (TokenIndex*)bsearch(&tok, regularVocab, regularVocabSize, sizeof(TokenIndex), compareTokens);
|
||||||
|
return res != NULL ? res->id : -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Tokenizer::isEos(int token) {
|
||||||
|
for (unsigned int i = 0; i < eosTokenIds.size(); i++) {
|
||||||
|
if (token == eosTokenIds[i])
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tokenizer::resetDecoder() {
|
||||||
|
strBufferPos = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
char *Tokenizer::detokUtf8() {
|
||||||
|
char* src = strBuffer;
|
||||||
|
char* dst = utf8Buffer;
|
||||||
|
char* checkpoint_src = src;
|
||||||
|
char* checkpoint = dst;
|
||||||
|
unsigned expect_continuation = 0;
|
||||||
|
|
||||||
|
while (unsigned char c = *src) {
|
||||||
|
bool need_recovery = false;
|
||||||
|
if (expect_continuation) {
|
||||||
|
if ((c & 0xc0) == 0x80) {
|
||||||
|
*dst++ = *src++;
|
||||||
|
expect_continuation--;
|
||||||
|
} else {
|
||||||
|
need_recovery = true;
|
||||||
|
}
|
||||||
|
} else if (c <= 0x7f) {
|
||||||
|
*dst++ = *src++;
|
||||||
|
} else if (c >= 0xc0 && c <= 0xdf) {
|
||||||
|
*dst++ = *src++;
|
||||||
|
expect_continuation = 1;
|
||||||
|
} else if (c >= 0xe0 && c <= 0xef) {
|
||||||
|
*dst++ = *src++;
|
||||||
|
expect_continuation = 2;
|
||||||
|
} else if (c >= 0xf0 && c <= 0xf7) {
|
||||||
|
*dst++ = *src++;
|
||||||
|
expect_continuation = 3;
|
||||||
|
} else {
|
||||||
|
need_recovery = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!need_recovery) {
|
||||||
|
if (!expect_continuation) {
|
||||||
|
checkpoint = dst;
|
||||||
|
checkpoint_src = src;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// perform stream recover
|
||||||
|
if (expect_continuation) {
|
||||||
|
expect_continuation = 0;
|
||||||
|
} else {
|
||||||
|
++src;
|
||||||
|
}
|
||||||
|
dst = checkpoint;
|
||||||
|
// emit 0xfffd
|
||||||
|
*dst++ = 0xef;
|
||||||
|
*dst++ = 0xbf;
|
||||||
|
*dst++ = 0xbd;
|
||||||
|
|
||||||
|
fprintf(stderr, "Tokenizer: decoded invalid utf8 -- attempting stream recover\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src > checkpoint_src) {
|
||||||
|
memmove(strBuffer, checkpoint_src, src - checkpoint_src + 1);
|
||||||
|
strBufferPos = src - checkpoint_src;
|
||||||
|
} else {
|
||||||
|
strBufferPos = 0;
|
||||||
|
}
|
||||||
|
*checkpoint = '\0';
|
||||||
|
if (checkpoint > utf8Buffer) {
|
||||||
|
return utf8Buffer;
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
char *Tokenizer::decode(int token) {
|
||||||
|
if (token == bosId)
|
||||||
|
return nullptr;
|
||||||
|
if (isEos(token)) {
|
||||||
|
if (strBufferPos > 0)
|
||||||
|
return strBuffer;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
char *piece = vocab[token];
|
||||||
|
int pieceLen = vocabLength[token];
|
||||||
|
|
||||||
|
assert(strBufferPos + pieceLen + 1 < strBufferSize);
|
||||||
|
std::memcpy(&strBuffer[strBufferPos], piece, pieceLen * sizeof(char));
|
||||||
|
strBufferPos += pieceLen;
|
||||||
|
strBuffer[strBufferPos] = '\0';
|
||||||
|
|
||||||
|
return detokUtf8();
|
||||||
|
}
|
||||||
|
|
||||||
|
void Tokenizer::encode(char *text, int *tokens, int *nTokens, bool isStart, bool addSpecialTokens) {
|
||||||
|
#if DEBUG_TOKENIZER_BENCHMARK
|
||||||
|
Timer startTime;
|
||||||
|
#endif
|
||||||
|
if (text == nullptr)
|
||||||
|
throw std::runtime_error("Input text is null");
|
||||||
|
|
||||||
|
size_t strLen = 0;
|
||||||
|
|
||||||
|
*nTokens = 0;
|
||||||
|
|
||||||
|
if (isStart && addBos && bosId >= 0)
|
||||||
|
tokens[(*nTokens)++] = bosId;
|
||||||
|
|
||||||
|
for (char *c = text; *c != '\0'; c++) {
|
||||||
|
if (addSpecialTokens) {
|
||||||
|
int specialTokenId = findSpecialTokenStartWith(c);
|
||||||
|
if (specialTokenId >= 0) {
|
||||||
|
tokens[(*nTokens)++] = specialTokenId;
|
||||||
|
c += vocabLength[specialTokenId] - 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
strBuffer[strLen] = *c;
|
||||||
|
strLen++;
|
||||||
|
assert(strLen < strBufferSize);
|
||||||
|
strBuffer[strLen] = '\0';
|
||||||
|
|
||||||
|
int id = findRegularToken(strBuffer);
|
||||||
|
if (id != -1) {
|
||||||
|
tokens[(*nTokens)++] = id;
|
||||||
|
strLen = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(strLen == 0);
|
||||||
|
|
||||||
|
// merge the best consecutive pair each iteration, according the scores in vocab_scores
|
||||||
|
while (1) {
|
||||||
|
float best_score = -1e10;
|
||||||
|
int best_id = -1;
|
||||||
|
int best_idx = -1;
|
||||||
|
|
||||||
|
for (int i=0; i < (*nTokens-1); i++) {
|
||||||
|
// check if we can merge the pair (tokens[i], tokens[i+1])
|
||||||
|
snprintf(strBuffer, strBufferSize, "%s%s", vocab[tokens[i]], vocab[tokens[i+1]]);
|
||||||
|
int id = findRegularToken(strBuffer);
|
||||||
|
if (id != -1 && vocabScores[id] > best_score) {
|
||||||
|
// this merge pair exists in vocab! record its score and position
|
||||||
|
best_score = vocabScores[id];
|
||||||
|
best_id = id;
|
||||||
|
best_idx = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (best_idx == -1) {
|
||||||
|
break; // we couldn't find any more pairs to merge, so we're done
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
|
||||||
|
tokens[best_idx] = best_id;
|
||||||
|
// delete token at position best_idx+1, shift the entire sequence back 1
|
||||||
|
for (int i = best_idx+1; i < (*nTokens-1); i++) {
|
||||||
|
tokens[i] = tokens[i+1];
|
||||||
|
}
|
||||||
|
(*nTokens)--; // token length decreased
|
||||||
|
}
|
||||||
|
|
||||||
|
#if DEBUG_TOKENIZER_BENCHMARK
|
||||||
|
NnUint duration = startTime.elapsedMicroseconds();
|
||||||
|
printf("🕒 [%22s] %u μs\n", "ENCODER", duration);
|
||||||
|
#endif
|
||||||
|
#if DEBUG_TOKENIZER_ENCODER
|
||||||
|
printf("\033[1;33m[");
|
||||||
|
for (unsigned int i = 0; i < *nTokens; i++)
|
||||||
|
printf("%d,", tokens[i]);
|
||||||
|
printf("]\033[0m");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int sample_argmax(float* probabilities, int n) {
|
||||||
|
// return the index that has the highest probability
|
||||||
|
int max_i = 0;
|
||||||
|
float max_p = probabilities[0];
|
||||||
|
for (int i = 1; i < n; i++) {
|
||||||
|
if (probabilities[i] > max_p) {
|
||||||
|
max_i = i;
|
||||||
|
max_p = probabilities[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return max_i;
|
||||||
|
}
|
||||||
|
|
||||||
|
int sample_mult(float* probabilities, int n, float coin) {
|
||||||
|
// sample index from probabilities (they must sum to 1!)
|
||||||
|
// coin is a random number in [0, 1), usually from random_f32()
|
||||||
|
float cdf = 0.0f;
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
cdf += probabilities[i];
|
||||||
|
if (coin < cdf) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n - 1; // in case of rounding errors
|
||||||
|
}
|
||||||
|
|
||||||
|
int compare(const void* a, const void* b) {
|
||||||
|
ProbIndex* a_ = (ProbIndex*) a;
|
||||||
|
ProbIndex* b_ = (ProbIndex*) b;
|
||||||
|
if (a_->prob > b_->prob) return -1;
|
||||||
|
if (a_->prob < b_->prob) return 1;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int sample_topp(float* probabilities, int n, float topp, ProbIndex* probindex, float coin) {
|
||||||
|
// top-p sampling (or "nucleus sampling") samples from the smallest set of
|
||||||
|
// tokens that exceed probability topp. This way we never sample tokens that
|
||||||
|
// have very low probabilities and are less likely to go "off the rails".
|
||||||
|
// coin is a random number in [0, 1), usually from random_f32()
|
||||||
|
|
||||||
|
int n0 = 0;
|
||||||
|
// quicksort indices in descending order of probabilities
|
||||||
|
// values smaller than (1 - topp) / (n - 1) cannot be part of the result
|
||||||
|
// so for efficiency we crop these out as candidates before sorting
|
||||||
|
const float cutoff = (1.0f - topp) / (n - 1);
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
if (probabilities[i] >= cutoff) {
|
||||||
|
probindex[n0].index = i;
|
||||||
|
probindex[n0].prob = probabilities[i];
|
||||||
|
n0++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
qsort(probindex, n0, sizeof(ProbIndex), compare);
|
||||||
|
|
||||||
|
// truncate the list where cumulative probability exceeds topp
|
||||||
|
float cumulative_prob = 0.0f;
|
||||||
|
int last_idx = n0 - 1; // in case of rounding errors consider all elements
|
||||||
|
for (int i = 0; i < n0; i++) {
|
||||||
|
cumulative_prob += probindex[i].prob;
|
||||||
|
if (cumulative_prob > topp) {
|
||||||
|
last_idx = i;
|
||||||
|
break; // we've exceeded topp by including last_idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sample from the truncated list
|
||||||
|
float r = coin * cumulative_prob;
|
||||||
|
float cdf = 0.0f;
|
||||||
|
for (int i = 0; i <= last_idx; i++) {
|
||||||
|
cdf += probindex[i].prob;
|
||||||
|
if (r < cdf) {
|
||||||
|
return probindex[i].index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return probindex[last_idx].index; // in case of rounding errors
|
||||||
|
}
|
||||||
|
|
||||||
|
Sampler::Sampler(int vocab_size, float temperature, float topp, unsigned long long rngSeed) {
|
||||||
|
this->vocab_size = vocab_size;
|
||||||
|
this->temperature = temperature;
|
||||||
|
this->topp = topp;
|
||||||
|
this->rngState = rngSeed;
|
||||||
|
// buffer only used with nucleus sampling; may not need but it's ~small
|
||||||
|
probindex = new ProbIndex[vocab_size];
|
||||||
|
}
|
||||||
|
|
||||||
|
Sampler::~Sampler() {
|
||||||
|
delete[] probindex;
|
||||||
|
}
|
||||||
|
|
||||||
|
int Sampler::sample(float* logits) {
|
||||||
|
#if DEBUG_SAMPLER_BENCHMARK
|
||||||
|
Timer startTime;
|
||||||
|
#endif
|
||||||
|
// sample the token given the logits and some hyperparameters
|
||||||
|
int next;
|
||||||
|
if (temperature == 0.0f) {
|
||||||
|
// greedy argmax sampling: take the token with the highest probability
|
||||||
|
next = sample_argmax(logits, vocab_size);
|
||||||
|
} else {
|
||||||
|
// apply the temperature to the logits
|
||||||
|
for (int q=0; q < vocab_size; q++) { logits[q] /= temperature; }
|
||||||
|
// apply softmax to the logits to get the probabilities for next token
|
||||||
|
softmax_F32(logits, vocab_size);
|
||||||
|
// flip a (float) coin (this is our source of entropy for sampling)
|
||||||
|
float coin = randomF32(&rngState);
|
||||||
|
// we sample from this distribution to get the next token
|
||||||
|
if (topp <= 0 || topp >= 1) {
|
||||||
|
// simply sample from the predicted probability distribution
|
||||||
|
next = sample_mult(logits, vocab_size, coin);
|
||||||
|
} else {
|
||||||
|
// top-p (nucleus) sampling, clamping the least likely tokens to zero
|
||||||
|
next = sample_topp(logits, vocab_size, topp, probindex, coin);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#if DEBUG_SAMPLER_BENCHMARK
|
||||||
|
NnUint duration = startTime.elapsedMicroseconds();
|
||||||
|
printf("🕒 [%22s] %u μs\n", "SAMPLER", duration);
|
||||||
|
#endif
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::setTemp(float temp) {
|
||||||
|
this->temperature = temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Sampler::setSeed(unsigned long long seed) {
|
||||||
|
this->rngState = seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
TokenizerChatStops::TokenizerChatStops(Tokenizer* tokenizer) {
|
||||||
|
nStops = tokenizer->eosTokenIds.size();
|
||||||
|
char** s = new char*[nStops];
|
||||||
|
for (unsigned int i = 0; i < nStops; i++) {
|
||||||
|
s[i] = tokenizer->vocab[tokenizer->eosTokenIds[i]];
|
||||||
|
}
|
||||||
|
|
||||||
|
maxStopLength = 0;
|
||||||
|
for (size_t i = 0; i < nStops; i++) {
|
||||||
|
size_t len = strlen(s[i]);
|
||||||
|
if (len > maxStopLength) maxStopLength = len;
|
||||||
|
}
|
||||||
|
stops = (const char**)s;
|
||||||
|
}
|
||||||
|
|
||||||
|
TokenizerChatStops::~TokenizerChatStops() {
|
||||||
|
delete[] stops;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char *chatTemplateTypeToString(const ChatTemplateType type) {
|
||||||
|
if (type == TEMPLATE_LLAMA2) return "llama2";
|
||||||
|
if (type == TEMPLATE_LLAMA3) return "llama3";
|
||||||
|
if (type == TEMPLATE_DEEP_SEEK3) return "deepSeek3";
|
||||||
|
if (type == TEMPLATE_CHATML) return "chatml";
|
||||||
|
return "unknown";
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatTemplateGenerator::ChatTemplateGenerator(const ChatTemplateType type, const char* chatTemplate, const char* eos)
|
||||||
|
: buffer()
|
||||||
|
{
|
||||||
|
if (type == TEMPLATE_UNKNOWN) {
|
||||||
|
if (chatTemplate == NULL)
|
||||||
|
throw std::runtime_error("The tokenizer does not include chat template");
|
||||||
|
if (strstr(chatTemplate, "[INST]") != NULL) {
|
||||||
|
this->type = TEMPLATE_LLAMA2;
|
||||||
|
} else if (strstr(chatTemplate, "<|start_header_id|>") != NULL) {
|
||||||
|
this->type = TEMPLATE_LLAMA3;
|
||||||
|
} else if (strstr(chatTemplate, "<|Assistant|>") != NULL) {
|
||||||
|
this->type = TEMPLATE_DEEP_SEEK3;
|
||||||
|
} else if (strstr(chatTemplate, "<|im_start|>") != NULL) {
|
||||||
|
this->type = TEMPLATE_CHATML;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Not supported chat template");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this->type = type;
|
||||||
|
}
|
||||||
|
this->eos = eos;
|
||||||
|
|
||||||
|
printf("⭐ Chat template: %s\n", chatTemplateTypeToString(this->type));
|
||||||
|
}
|
||||||
|
|
||||||
|
GeneratedChat ChatTemplateGenerator::generate(unsigned int nItems, ChatItem* items, bool appendGenerationPrompt) {
|
||||||
|
buffer.clear();
|
||||||
|
|
||||||
|
size_t publicPromptSize = 0;
|
||||||
|
|
||||||
|
if (type == TEMPLATE_LLAMA2) {
|
||||||
|
unsigned int i = 0;
|
||||||
|
if (nItems >= 2 && items[0].role == "system" && items[1].role == "user") {
|
||||||
|
buffer += "[INST] <<SYS>>\n" + items[0].message + "\n<</SYS>>\n\n" + items[1].message + " [/INST]" + eos;
|
||||||
|
i += 2;
|
||||||
|
}
|
||||||
|
for (; i < nItems; i++) {
|
||||||
|
if (items[i].role == "assistant") {
|
||||||
|
buffer += items[i].message + eos;
|
||||||
|
} else if (items[i].role == "user") {
|
||||||
|
buffer += "[INST] " + items[i].message + " [/INST]" + eos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (type == TEMPLATE_LLAMA3) {
|
||||||
|
for (unsigned int i = 0; i < nItems; i++)
|
||||||
|
buffer += "<|start_header_id|>" + items[i].role + "<|end_header_id|>\n\n" + items[i].message + eos;
|
||||||
|
if (appendGenerationPrompt)
|
||||||
|
buffer += "<|start_header_id|>assistant<|end_header_id|>\n\n";
|
||||||
|
} else if (type == TEMPLATE_DEEP_SEEK3) {
|
||||||
|
unsigned int i = 0;
|
||||||
|
if (nItems > 0 && items[0].role == "system") {
|
||||||
|
buffer += items[0].message;
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
for (; i < nItems; i++) {
|
||||||
|
if (items[i].role == "user") {
|
||||||
|
buffer += "<|User|>" + items[i].message;
|
||||||
|
} else if (items[i].role == "assistant") {
|
||||||
|
buffer += "<|Assistant|>" + items[i].message;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (appendGenerationPrompt) {
|
||||||
|
buffer += "<|Assistant|><think>\n";
|
||||||
|
publicPromptSize = 8;
|
||||||
|
}
|
||||||
|
} else if (type == TEMPLATE_CHATML) {
|
||||||
|
for (unsigned int i = 0; i < nItems; i++) {
|
||||||
|
if (items[i].role == "system") {
|
||||||
|
buffer += "<|im_start|>system\n" + items[i].message + "<|im_end|>\n";
|
||||||
|
} else if (items[i].role == "user") {
|
||||||
|
buffer += "<|im_start|>user\n" + items[i].message + "<|im_end|>\n";
|
||||||
|
} else if (items[i].role == "assistant") {
|
||||||
|
buffer += "<|im_start|>assistant\n" + items[i].message + "<|im_end|>\n";
|
||||||
|
}
|
||||||
|
if (appendGenerationPrompt)
|
||||||
|
buffer += "<|im_start|>assistant\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *content = buffer.c_str();
|
||||||
|
size_t length = buffer.size();
|
||||||
|
const char *publicPrompt = publicPromptSize > 0
|
||||||
|
? &content[length - publicPromptSize]
|
||||||
|
: nullptr;
|
||||||
|
#if DEBUG_TEMPLATE_GENERATOR
|
||||||
|
printf("\033[1;31m[%s]\033[0m", content);
|
||||||
|
#endif
|
||||||
|
return {content, length, publicPrompt};
|
||||||
|
}
|
||||||
|
|
||||||
|
EosDetector::EosDetector(size_t nTokens, const int *tokens, const char** pieces, int paddingLeft, int paddingRight) {
|
||||||
|
this->nTokens = nTokens;
|
||||||
|
this->tokens = tokens;
|
||||||
|
this->pieces = pieces;
|
||||||
|
this->pieceSizes = new size_t[nTokens];
|
||||||
|
for (size_t s = 0; s < nTokens; s++) {
|
||||||
|
pieceSizes[s] = strlen(pieces[s]);
|
||||||
|
printf("🛑 Stop: %s\n", pieces[s]);
|
||||||
|
}
|
||||||
|
this->bufferPos = 0;
|
||||||
|
this->bufferSize = 0;
|
||||||
|
this->paddingLeft = paddingLeft;
|
||||||
|
this->paddingRight = paddingRight;
|
||||||
|
}
|
||||||
|
|
||||||
|
EosDetector::~EosDetector() {
|
||||||
|
if (bufferSize > 0)
|
||||||
|
delete[] buffer;
|
||||||
|
delete[] pieceSizes;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool EosDetector::isEos(int tokenId) {
|
||||||
|
for (size_t i = 0; i < nTokens; i++) {
|
||||||
|
if (tokenId == tokens[i])
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
EosDetectorType EosDetector::append(int tokenId, const char *piece) {
|
||||||
|
if (piece != nullptr) {
|
||||||
|
int pieceLength = std::strlen(piece);
|
||||||
|
int newSize = bufferPos + pieceLength + 1;
|
||||||
|
if (newSize > bufferSize) {
|
||||||
|
char* newBuffer = new char[newSize];
|
||||||
|
if (bufferPos > 0)
|
||||||
|
std::memcpy(newBuffer, buffer, bufferPos);
|
||||||
|
if (bufferSize > 0)
|
||||||
|
delete[] buffer;
|
||||||
|
buffer = newBuffer;
|
||||||
|
bufferSize = newSize;
|
||||||
|
}
|
||||||
|
std::memcpy(&buffer[bufferPos], piece, pieceLength);
|
||||||
|
bufferPos += pieceLength;
|
||||||
|
buffer[bufferPos] = '\0';
|
||||||
|
}
|
||||||
|
|
||||||
|
// detection
|
||||||
|
|
||||||
|
if (isEos(tokenId)) {
|
||||||
|
eosPos = bufferPos;
|
||||||
|
return EOS;
|
||||||
|
}
|
||||||
|
eosPos = -1;
|
||||||
|
|
||||||
|
for (size_t s = 0; s < nTokens; s++) {
|
||||||
|
size_t pieceSize = pieceSizes[s];
|
||||||
|
if (bufferPos > pieceSize + paddingLeft + paddingRight) continue;
|
||||||
|
|
||||||
|
for (int lo = 0; lo <= paddingLeft; lo++) {
|
||||||
|
int n = bufferPos - lo;
|
||||||
|
if (n == 0 || n > pieceSize + paddingRight) continue;
|
||||||
|
if (n > pieceSize) n = pieceSize;
|
||||||
|
if (strncmp(buffer + lo, pieces[s], n) == 0) {
|
||||||
|
if (n == pieceSize) {
|
||||||
|
eosPos = lo;
|
||||||
|
buffer[eosPos] = '\0';
|
||||||
|
return EOS;
|
||||||
|
}
|
||||||
|
return MAYBE_EOS;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NOT_EOS;
|
||||||
|
}
|
||||||
|
|
||||||
|
char* EosDetector::getDelta() {
|
||||||
|
if (bufferPos == 0) return nullptr;
|
||||||
|
if (eosPos == -1) return buffer;
|
||||||
|
if (eosPos == 0) return nullptr;
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
void EosDetector::reset() {
|
||||||
|
bufferPos = 0;
|
||||||
|
}
|
||||||
157
src/tokenizer.hpp
Normal file
157
src/tokenizer.hpp
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
#ifndef TOKENIZER_HPP
|
||||||
|
#define TOKENIZER_HPP
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
char *str;
|
||||||
|
unsigned int id;
|
||||||
|
} TokenIndex;
|
||||||
|
|
||||||
|
struct TokenizerOldHeader {
|
||||||
|
unsigned int vocabSize;
|
||||||
|
unsigned int maxTokenLength;
|
||||||
|
int bosId;
|
||||||
|
int eosId;
|
||||||
|
int padId;
|
||||||
|
};
|
||||||
|
|
||||||
|
enum TokenizerHeaderKey {
|
||||||
|
TOK_VERSION = 0,
|
||||||
|
TOK_VOCAB_SIZE = 1,
|
||||||
|
MAX_TOKEN_LENGTH = 2,
|
||||||
|
BOS_ID = 3,
|
||||||
|
EOS_ID = 4, // Backward compatibility
|
||||||
|
PAD_ID = 5, // Ignored
|
||||||
|
CHAT_EOS_ID = 6, // Backward compatibility
|
||||||
|
CHAT_TEMPLATE = 7,
|
||||||
|
CHAT_STOP = 8, // Ignored
|
||||||
|
N_EOS_TOKENS = 9,
|
||||||
|
ADD_BOS = 10,
|
||||||
|
};
|
||||||
|
|
||||||
|
class Tokenizer {
|
||||||
|
private:
|
||||||
|
unsigned int maxTokenLength;
|
||||||
|
unsigned int regularVocabSize;
|
||||||
|
unsigned int specialVocabSize;
|
||||||
|
float *vocabScores;
|
||||||
|
unsigned int *vocabLength;
|
||||||
|
TokenIndex *regularVocab;
|
||||||
|
TokenIndex *specialVocab;
|
||||||
|
size_t strBufferSize;
|
||||||
|
char *strBuffer;
|
||||||
|
char *utf8Buffer;
|
||||||
|
size_t strBufferPos;
|
||||||
|
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::vector<int> eosTokenIds;
|
||||||
|
unsigned int vocabSize;
|
||||||
|
char **vocab;
|
||||||
|
int bosId;
|
||||||
|
bool addBos;
|
||||||
|
char *chatTemplate;
|
||||||
|
|
||||||
|
Tokenizer(const char *tokenizer_path);
|
||||||
|
~Tokenizer();
|
||||||
|
int findSpecialTokenStartWith(char *piece);
|
||||||
|
int findRegularToken(char *piece);
|
||||||
|
void encode(char *text, int *tokens, int *nTokens, bool isStart, bool addSpecialTokens);
|
||||||
|
bool isEos(int token);
|
||||||
|
char *decode(int token);
|
||||||
|
void resetDecoder();
|
||||||
|
|
||||||
|
private:
|
||||||
|
char *detokUtf8();
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
float prob;
|
||||||
|
int index;
|
||||||
|
} ProbIndex;
|
||||||
|
|
||||||
|
class Sampler {
|
||||||
|
private:
|
||||||
|
int vocab_size;
|
||||||
|
ProbIndex *probindex;
|
||||||
|
float temperature;
|
||||||
|
float topp;
|
||||||
|
unsigned long long rngState;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Sampler(int vocab_size, float temperature, float topp, unsigned long long rngSeed);
|
||||||
|
~Sampler();
|
||||||
|
int sample(float *logits);
|
||||||
|
void setTemp(float temp);
|
||||||
|
void setSeed(unsigned long long rngSeed);
|
||||||
|
};
|
||||||
|
|
||||||
|
class TokenizerChatStops {
|
||||||
|
public:
|
||||||
|
const char **stops;
|
||||||
|
size_t nStops;
|
||||||
|
size_t maxStopLength;
|
||||||
|
TokenizerChatStops(Tokenizer *tokenizer);
|
||||||
|
~TokenizerChatStops();
|
||||||
|
};
|
||||||
|
|
||||||
|
enum ChatTemplateType {
|
||||||
|
TEMPLATE_UNKNOWN = 0,
|
||||||
|
TEMPLATE_LLAMA2 = 1,
|
||||||
|
TEMPLATE_LLAMA3 = 2,
|
||||||
|
TEMPLATE_DEEP_SEEK3 = 3,
|
||||||
|
TEMPLATE_CHATML = 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ChatItem {
|
||||||
|
std::string role;
|
||||||
|
std::string message;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct GeneratedChat {
|
||||||
|
const char *content;
|
||||||
|
size_t length;
|
||||||
|
const char *publicPrompt;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ChatTemplateGenerator {
|
||||||
|
public:
|
||||||
|
const char *eos;
|
||||||
|
ChatTemplateType type;
|
||||||
|
std::string buffer;
|
||||||
|
ChatTemplateGenerator(const ChatTemplateType type, const char *chatTemplate, const char *eos);
|
||||||
|
GeneratedChat generate(unsigned int nItems, ChatItem *items, bool appendGenerationPrompt);
|
||||||
|
};
|
||||||
|
|
||||||
|
enum EosDetectorType {
|
||||||
|
MAYBE_EOS = 0,
|
||||||
|
EOS = 1,
|
||||||
|
NOT_EOS = 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
class EosDetector {
|
||||||
|
private:
|
||||||
|
size_t nTokens;
|
||||||
|
const int *tokens;
|
||||||
|
const char **pieces;
|
||||||
|
size_t *pieceSizes;
|
||||||
|
size_t bufferPos;
|
||||||
|
size_t bufferSize;
|
||||||
|
int eosPos;
|
||||||
|
int paddingLeft;
|
||||||
|
int paddingRight;
|
||||||
|
public:
|
||||||
|
char *buffer;
|
||||||
|
EosDetector(size_t nTokens, const int *tokens, const char* *pieces, int paddingLeft, int paddingRight);
|
||||||
|
~EosDetector();
|
||||||
|
|
||||||
|
EosDetectorType append(int tokenId, const char *piece);
|
||||||
|
bool isEos(int tokenId);
|
||||||
|
char *getDelta();
|
||||||
|
void reset();
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
||||||
Reference in New Issue
Block a user