Skip to content

Commit

Permalink
Support Llama3 8bit quantized inference
Browse files Browse the repository at this point in the history
runq - add llama3 support
  • Loading branch information
trholding committed Jul 12, 2024
1 parent 63e69a3 commit e893f18
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 36 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ Learn more about the Llama2 models & architecture at Meta: [Llama 2 @ Meta](http
Llama3 models work now.

* Non quantized (fp32) is supported. run supports both llama2 and llama3 with -l 3 option.
* Quantized inference will be supported soon. Right now runq supports only llama2.
* Quantized inference with runq supported now.
* Known issues - chat mode doesn't work yet, fix coming soonish

First you'll need to obtain approval from Meta to download llama3 models on hugging face.

Expand Down
143 changes: 108 additions & 35 deletions runq.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@

int buffertokens = 1; // output token buffer size
int stats = 1; // extended status info
int llamaver = 2; // llama version (default is 2, valid 2 & 3)
float rope_sf = 10000.0; // Rope scaling factor, 10000.0 => llama2, 500000.0 > llama3
int BOS = 1; // Beginning of Sentence token value, llama2 = 1 , llama3 = 128000
int EOS = 2; // End of Sentence token value, llama2 = 2 , llama3 = 128009 (end of text)
char system_template[1024]="";
char user_template[1024]="";

// ----------------------------------------------------------------------------
// L2E Humanoid : Linux Kernel Support Directives
Expand Down Expand Up @@ -350,8 +356,8 @@ void memory_map_weights(TransformerWeights *w, Config* p, void* ptr, uint8_t sha

// L2E Addition
#if defined (INC_BIN) || defined(STRLIT)
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
int* fd, float** data, ssize_t* file_size) {
void read_checkpoint(char* checkpoint, Config* config, TransformerWeights* weights,
int* fd, float** data, ssize_t* file_size) {
// Calculate the file size from the raw data
*file_size = strlen(checkpoint);

Expand Down Expand Up @@ -568,7 +574,9 @@ float* forward(Transformer* transformer, int token, int pos) {
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
int head_dim = i % head_size;
float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
// L2E Addition
float freq = 1.0f / powf(rope_sf, head_dim / (float)head_size);
// END L2E Addition
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
Expand Down Expand Up @@ -767,8 +775,10 @@ void free_tokenizer(Tokenizer* t) {

char* decode(Tokenizer* t, int prev_token, int token) {
char *piece = t->vocab[token];
// following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
if (prev_token == 1 && piece[0] == ' ') { piece++; }
// L2E Addition
// following BOS (1) or (2) token, sentencepiece decoder strips any leading whitespace (see PR #89)
if (prev_token == BOS && piece[0] == ' ') { piece++; }
// END L2E Addition
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val;
Expand Down Expand Up @@ -801,7 +811,7 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {

void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
// bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
// bos != 0 means prepend the BOS token, eos != 0 means append the EOS token
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }

if (t->sorted_vocab == NULL) {
Expand All @@ -822,17 +832,25 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
// start at 0 tokens
*n_tokens = 0;

// add optional BOS (=1) token, if desired
if (bos) tokens[(*n_tokens)++] = 1;
// L2E Addition
// add optional BOS token, if desired
if (bos) tokens[(*n_tokens)++] = BOS;
// END L2E Addition


// add_dummy_prefix is true by default
// so prepend a dummy prefix token to the input string, but only if text != ""
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
if (text[0] != '\0') {
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
tokens[(*n_tokens)++] = dummy_prefix;

// L2E Addition
if (llamaver == 2) {
if (text[0] != '\0') {
int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
tokens[(*n_tokens)++] = dummy_prefix;
}
}
// END L2E Addition

// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
// Code point ↔ UTF-8 conversion
Expand Down Expand Up @@ -883,13 +901,16 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
}

// merge the best consecutive pair each iteration, according the scores in vocab_scores
// L2E Addition
// merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores
while (1) {
float best_score = -1e10;
int best_id = -1;
int best_idx = -1;
int best_merge = 0; // length of the best merge sequence (2 for pair, 3 for triple)

for (int i=0; i < (*n_tokens-1); i++) {
// try to find the best pair or triple to merge
for (int i = 0; i < (*n_tokens - 1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
Expand All @@ -898,27 +919,43 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
best_merge = 2;
}

// check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2])
if (i < (*n_tokens - 2)) {
sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]);
id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
if (id != -1 && t->vocab_scores[id] > best_score) {
// this merge triple exists in vocab! record its score and position
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
best_merge = 3;
}
}
}

if (best_idx == -1) {
break; // we couldn't find any more pairs to merge, so we're done
break; // we couldn't find any more pairs or triples to merge, so we're done
}

// merge the consecutive pair (best_idx, best_idx+1) into new token best_id
// merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id
tokens[best_idx] = best_id;
// delete token at position best_idx+1, shift the entire sequence back 1
for (int i = best_idx+1; i < (*n_tokens-1); i++) {
tokens[i] = tokens[i+1];
// delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back
for (int i = best_idx + 1; i < (*n_tokens - best_merge + 1); i++) {
tokens[i] = tokens[i + best_merge - 1];
}
(*n_tokens)--; // token length decreased
(*n_tokens) -= (best_merge - 1); // token length decreased by the number of merged tokens minus one
}

// add optional EOS (=2) token, if desired
if (eos) tokens[(*n_tokens)++] = 2;
// add optional EOS token, if desired
if (eos) tokens[(*n_tokens)++] = EOS;

free(str_buffer);

}
// END L2E Addition

// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
Expand Down Expand Up @@ -1118,9 +1155,11 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
next = sample(sampler, logits);
}
pos++;

// data-dependent terminating condition: the BOS (=1) token delimits sequences
if (next == 1) { break; }

// L2E Addition
// data-dependent terminating condition: the BOS token delimits sequences
if (next == BOS) { break; }
// END L2E Addition

// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, token, next);
Expand Down Expand Up @@ -1170,18 +1209,46 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,

// buffers for reading the system prompt and user prompt from stdin
// you'll notice they are soomewhat haphazardly and unsafely set atm
// L2E Addition
char system_prompt[512];
char user_prompt[512];
char rendered_prompt[1152];
char rendered_prompt[2048];
int num_prompt_tokens = 0;
int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
int* prompt_tokens = (int*)malloc(2048 * sizeof(int));
// END L2E Addition
int user_idx;

// start the main loop
int8_t user_turn = 1; // user starts
int next; // will store the next token in the sequence
int token; // stores the current token to feed into the transformer
int prev_token;
// L2E Addition
/* System and user prompt templates for llama 2 and llama 3
Llama 2:
System:
[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]
User:
[INST] %s [/INST]
Llama 3:
System:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|>
User:
<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>\n
Assistant: (Starts Generating)
<|start_header_id|>assistant<|end_header_id|>\n\n
*/
if (llamaver == 3) {
BOS = 128000; // 128000 = <|begin_of_text|>
EOS = 128009; // 128009 = <|eot_id|> , 128001 = <|end_of_text|>
strcpy(system_template, "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|>");
strcpy(user_template, "<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n");
} else {
int prev_token;
strcpy(system_template,"[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]");
strcpy(user_template, "[INST] %s [/INST]");
}
// END L2E Addition
int pos = 0; // position in the sequence
while (pos < steps) {

Expand All @@ -1206,14 +1273,14 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
// otherwise get user prompt from stdin
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}
// L2E Addition
// render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') {
char system_template[] = "[INST] <<SYS>>\n%s\n<</SYS>>\n\n%s [/INST]";
sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
} else {
char user_template[] = "[INST] %s [/INST]";
sprintf(rendered_prompt, user_template, user_prompt);
}
// END L2E Addition
// encode the rendered prompt into tokens
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
user_idx = 0; // reset the user index
Expand All @@ -1229,22 +1296,25 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
// otherwise use the next token sampled from previous turn
token = next;
}
// EOS (=2) token ends the Assistant turn
if (token == 2) { user_turn = 1; }
// L2E Addition
// EOS token ends the Assistant turn
if (token == EOS) { user_turn = 1; }
// End L2E Addition

// forward the transformer to get logits for the next token
float* logits = forward(transformer, token, pos);
next = sample(sampler, logits);
pos++;

if (user_idx >= num_prompt_tokens && next != 2) {
// L2E Addition
if (user_idx >= num_prompt_tokens && next != EOS) {
// the Assistant is responding, so print its output
char* piece = decode(tokenizer, token, next);
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout);
}
if (next == 2) { printf("\n"); }
if (next == EOS) { printf("\n"); }
}
// End L2E Addition
printf("\n");
free(prompt_tokens);
}
Expand Down Expand Up @@ -1283,7 +1353,8 @@ void error_usage() {
fprintf(stderr, " -y <string> (optional) system prompt in chat mode\n");
// L2E Addition
fprintf(stderr, " -b <int> number of tokens to buffer, default 1. 0 = max_seq_len\n");
fprintf(stderr, " -x <int> extended info / stats, default 1 = on. 0 = off\n");
fprintf(stderr, " -x <int> extended info / stats, default 1 = on. 0 = off\n");
fprintf(stderr, " -l <int> llama version / default 2 = llama2. 3 = llama3\n");
// END L2E Addition
exit(EXIT_FAILURE);
}
Expand Down Expand Up @@ -1344,9 +1415,11 @@ int main(int argc, char *argv[]) {
// L2E Addition
else if (argv[i][1] == 'b') { buffertokens = atoi(argv[i + 1]); }
else if (argv[i][1] == 'x') { stats = atoi(argv[i + 1]); }
else if (argv[i][1] == 'l') { llamaver = atoi(argv[i + 1]); }
// END L2E Addition
else { error_usage(); }
}
if (llamaver == 3){ rope_sf = 500000.0; }
// L2E Addition
#endif
// END L2E Addition
Expand Down

0 comments on commit e893f18

Please sign in to comment.