Files
dllama/src/tokenizer-test.cpp
Chris 42172cbb6f
Some checks failed
main / Linux (amd64, ubuntu-22.04) (push) Successful in 49s
main / Linux (arm64, ubuntu-24.04-arm) (push) Has been cancelled
main / Windows (push) Has been cancelled
init
2025-10-24 11:42:14 +02:00

320 lines
9.0 KiB
C++
Raw Blame History

#include <cassert>
#include <cstring>
#include "tokenizer.hpp"
#define DEV_TESTS false
#define ASSERT_EQ(a, b) \
if (a != b) { \
printf("Assertion failed: %d != %d (%s:%d)\n", a, b, __FILE__, __LINE__); \
exit(-1); \
}
#define TEST_EOS_ID 10000
void printOk(const char *name) {
printf("✅ %24s passed\n", name);
}
void compare(const char *name, const int *a, const int *b, const unsigned int aN, const int bN) {
bool passed = true;
if (aN != bN) {
passed = false;
} else {
for (unsigned int i = 0; i < bN; i++) {
if (a[i] != b[i]) {
passed = false;
break;
}
}
}
if (!passed) {
printf("❌ %24s failed\na: ", name);
for (unsigned int j = 0; j < aN; j++)
printf("%5d ", a[j]);
printf("\nb: ");
for (unsigned int j = 0; j < bN; j++)
printf("%5d ", b[j]);
printf("\n");
exit(1);
}
printOk(name);
}
void dev_testEncode(Tokenizer *tokenizer) {
int tokens[1024];
int nTokens;
{
const char *text = "<|start_header_id|>user<|end_header_id|>\n\nhello<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n";
const int expectedTokens[] = {128000, 128006, 882, 128007, 271, 15339, 128009, 128006, 78191, 128007, 271};
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
compare("case0", expectedTokens, tokens, 11, nTokens);
}
{
const char *text = "!!&&@(*x)^^!";
const int expectedTokens[] = {128000, 3001, 7827, 31, 4163, 87, 8, 22634, 0};
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
compare("case1", expectedTokens, tokens, 9, nTokens);
}
{
const char *text = "😃!😇x";
const int expectedTokens[] = {128000, 76460, 225, 0, 76460, 229, 87};
tokenizer->encode((char *)text, tokens, &nTokens, true, true);
compare("case2", expectedTokens, tokens, 7, nTokens);
}
}
void dev_testDecoderEmojiStreamRecover(Tokenizer *tokenizer) {
char *x0 = tokenizer->decode(128000);
assert(x0 == nullptr);
char *x1 = tokenizer->decode(76460);
assert(x1 == nullptr);
char *x2 = tokenizer->decode(76460);
assert(x2 == nullptr);
char *x3 = tokenizer->decode(225);
assert(strcmp(x3, "<EFBFBD>😃") == 0);
printOk("testDecoderEmojiStreamRecover");
}
void dev_testDecoderEmoji(Tokenizer *tokenizer) {
char *x0 = tokenizer->decode(128000);
assert(x0 == nullptr);
char *x1 = tokenizer->decode(76460);
assert(x1 == nullptr);
char *x2 = tokenizer->decode(225);
assert(strcmp(x2, "😃") == 0);
char *x3 = tokenizer->decode(0);
assert(strcmp(x3, "!") == 0);
char *x4 = tokenizer->decode(56);
assert(strcmp(x4, "Y") == 0);
printOk("testDecoderEmoji");
}
void dev_testDecoderEmojiWithEos(Tokenizer *tokenizer) {
char *x0 = tokenizer->decode(128000);
assert(x0 == nullptr);
char *x1 = tokenizer->decode(76460);
assert(x1 == nullptr);
char *x2 = tokenizer->decode(225);
assert(strcmp(x2, "😃") == 0);
char *x3 = tokenizer->decode(128001);
assert(x3 == nullptr); // piece should not contain <|end_of_text|>
printOk("decoderEmojiWithEos");
}
void testChatTemplateDetection() {
ChatTemplateGenerator t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "<eos>");
assert(t0.type == TEMPLATE_LLAMA3);
printOk("chatTemplateDetection");
}
void testEosDetectorWithPadding() {
const int tokens[2] = {TEST_EOS_ID, TEST_EOS_ID + 1};
const char *pieces[2] = { "<eos>", "<stop>" };
EosDetector detector(2, tokens, pieces, 1, 1);
// "<eos>"
{
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "eo"), MAYBE_EOS);
ASSERT_EQ(detector.append(3, "s>"), EOS);
assert(detector.getDelta() == nullptr);
}
// "<stop> "
detector.reset();
{
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "stop"), MAYBE_EOS);
ASSERT_EQ(detector.append(3, "> "), EOS);
assert(detector.getDelta() == nullptr);
}
// " "
detector.reset();
{
ASSERT_EQ(detector.append(1, " "), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, " ") == 0);
}
// "!<eos> "
detector.reset();
{
ASSERT_EQ(detector.append(1, "!<"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "eos"), MAYBE_EOS);
ASSERT_EQ(detector.append(3, "> "), EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "!") == 0);
}
// "!<eos> "
detector.reset();
{
ASSERT_EQ(detector.append(1, "<eo"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "s>XY"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "<eos>XY") == 0);
}
// "<eo" + EOS
detector.reset();
{
ASSERT_EQ(detector.append(1, "<eo"), MAYBE_EOS);
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "<eo") == 0);
}
// EOS
detector.reset();
{
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
assert(detector.getDelta() == nullptr);
}
// after reset it's expected to return nullptr delta if to the append() passed nullptr piece
detector.reset();
{
ASSERT_EQ(detector.append(1, "x"), NOT_EOS);
char *delta0 = detector.getDelta();
assert(std::strcmp(delta0, "x") == 0);
detector.reset();
ASSERT_EQ(detector.append(2, nullptr), NOT_EOS);
char *delta1 = detector.getDelta();
assert(delta1 == nullptr);
}
printOk("eosDetectorWithPadding");
}
void testEosDetectorWithLongPadding() {
const int tokens[1] = {TEST_EOS_ID};
const char *pieces[1] = { "|end|" };
EosDetector detector(1, tokens, pieces, 5, 5);
// "lipsum"
{
ASSERT_EQ(detector.append(1, "lipsum"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "lipsum") == 0);
}
// "lorem"
detector.reset();
{
ASSERT_EQ(detector.append(1, "lorem"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "lorem") == 0);
}
// "lorem|enQ"
detector.reset();
{
ASSERT_EQ(detector.append(1, "lorem|"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "enQ"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "lorem|enQ") == 0);
}
printOk("eosDetectorWithLongPadding");
}
void testEosDetectorWithoutPadding() {
const int tokens[1] = {TEST_EOS_ID};
const char *pieces[1] = { "<eos>" };
EosDetector detector(1, tokens, pieces, 0, 0);
// "<eos>"
{
ASSERT_EQ(detector.append(1, "<"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "eo"), MAYBE_EOS);
ASSERT_EQ(detector.append(3, "s>"), EOS);
assert(detector.getDelta() == nullptr);
}
// " <"
detector.reset();
{
ASSERT_EQ(detector.append(1, " <"), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, " <") == 0);
}
// "<eos> "
detector.reset();
{
ASSERT_EQ(detector.append(1, "<eos"), MAYBE_EOS);
ASSERT_EQ(detector.append(2, "> "), NOT_EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "<eos> ") == 0);
}
// EOS
detector.reset();
{
ASSERT_EQ(detector.append(TEST_EOS_ID, nullptr), EOS);
assert(detector.getDelta() == nullptr);
}
// emoji
detector.reset();
{
ASSERT_EQ(detector.append(TEST_EOS_ID, "😃"), EOS);
char *delta = detector.getDelta();
assert(delta != nullptr);
assert(std::strcmp(delta, "😃") == 0);
}
printOk("eosDetectorWithLongPadding");
}
int main() {
#if DEV_TESTS
Tokenizer tokenizer("models/llama3_2_1b_instruct_q40/dllama_tokenizer_llama3_2_1b_instruct_q40.t");
dev_testEncode(&tokenizer);
dev_testDecoderEmoji(&tokenizer);
dev_testDecoderEmojiWithEos(&tokenizer);
dev_testDecoderEmojiStreamRecover(&tokenizer);
#endif
testChatTemplateDetection();
testEosDetectorWithPadding();
testEosDetectorWithLongPadding();
testEosDetectorWithoutPadding();
return 0;
}