llama2.ipynb

This is a direct python translation of Andrej Karpathy's llama2.c: https://github.com/karpathy/llama2.c/tree/master

The purpose is for learning, and to show that "AI" is nothing to fear. This couldn't possibly harm you, nor could it take over the world or cause human extinction or anything like that. It's just a computer program. People claiming this is something to fear are mistaken.

This is (eventually) meant to show a non-specialist audience what "AI" really looks like.

This is the same model (only smaller) that was released by Meta (facebook) recently, and uses a similar architecture to all the modern language models like GPT. It uses weights that were trained on a dataset of children's stories, so it only writes short children's stories. You can go to the link above to learn more about it.

The program being in a jupyter notebook makes it easy to play with, but means the actual loop where the program runs is at the bottom. The upper part takes care of setting up the model and reading in the weights. Scroll to the bottom to see example output.

I translated this from the C-program (that used all loops for linear algebra) and made only minimal changes so far to make sure I could get it working. I may clean it up a bit.

In [1]:
# Run the commands below to download the model (60 MB) and vocabulary if you don't have them

#!wget https://github.com/karpathy/llama2.c/raw/master/tokenizer.bin
#!wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin
In [2]:
# dependencies
# np makes matrix multiplication and manipulation easier, struct is for reading in the weights file
# sys is for printing. There are no dependencies of any machine learning frameworks

import numpy as np
import struct
import sys
In [3]:
# helper function to read in weights

def get_weights(buffer, size, offset, b=4):
    w_size = np.prod(size)
    w = np.array(struct.unpack(f"{w_size}f", buffer[offset:(offset+w_size*b)]))
    w = w.reshape(size)
    return w, offset+w_size*b
    
In [4]:
# the example model is a small llm trained on a dataset of stories by Andrej Karpathy
# since the model is already trained, this code just reads it in
# the weights in the model file are used as multiplying factors in the model

model = "stories15M.bin"

with open(model, mode='rb') as file: # b is important -> binary
    fileContent = file.read()

# the model file begings with 7 integers of configuration data
# dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len
config = struct.unpack("iiiiiii", fileContent[:28])
dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len = config

# the sizes of the different types of layer, to be used in reading them in
emb_size = (vocab_size,dim)
rmsnorm_size = (n_layers, dim)
layer_size = (n_layers,dim,dim)
ffn_13_size = (n_layers,hidden_dim,dim) 
ffn_2_size = (n_layers, dim, hidden_dim) 
rmsfinal_size = (dim,)
head_size = dim // n_heads
cis_size = (seq_len, head_size//2)

# read in weights and use them to create numpy arrays for use in multiplication
weight_dict = {}
token_embedding_table, offset = get_weights(fileContent,emb_size,28) 
weight_dict['rms_att_weight'], offset = get_weights(fileContent,rmsnorm_size,offset) 
weight_dict['wq'], offset = get_weights(fileContent,layer_size,offset) 
weight_dict['wk'], offset = get_weights(fileContent,layer_size,offset)
weight_dict['wv'], offset = get_weights(fileContent,layer_size,offset)
weight_dict['wo'], offset = get_weights(fileContent,layer_size,offset)
weight_dict['rms_ffn_weight'], offset = get_weights(fileContent,rmsnorm_size,offset) 
weight_dict['w1'], offset = get_weights(fileContent,ffn_13_size,offset)
weight_dict['w2'], offset = get_weights(fileContent,ffn_2_size,offset)
weight_dict['w3'], offset = get_weights(fileContent,ffn_13_size,offset)
weight_dict['rms_final_weight'], offset = get_weights(fileContent,rmsfinal_size,offset)
weight_dict['freq_cis_real'], offset = get_weights(fileContent,cis_size,offset)
weight_dict['freq_cis_imag'], offset = get_weights(fileContent,cis_size,offset)
In [5]:
# helper functions 

# normalize a vector and multiply by weights
def rmsnorm(x,w):
    xn = np.sqrt(np.dot(x,x)/x.shape[0]+1e-5)
    return x*w/(xn)

# convert scores to probabilities
def softmax(x,size):
    
    mv = max(x[:size])
    xi = np.exp(x[:size]-mv)
    
    return np.concatenate([xi/sum(xi),0*x[size:]])

Transformer

In the cell below is the "transformer" that everybody is talking about, this is what's created a new hype wave in AI. It's essentially a lookup of the next word based on past words, with some standard tricks from deep learning thrown in and parameters that can be optimized to make the lookup match well with real sentences.

It's not necessary to understand it all, just that it's a short computer program that usefully processes input text to determine a likely next word.

I recommend this blog post for a better description of how it works: https://jaykmody.com/blog/gpt-from-scratch/

Bottom line, "attention" is what performs the lookup $$A = \mbox{softmax}(\frac{QK^T}{\sqrt{(d_k)}})V$$ where in an LLM $Q$, $K$ and $V$ are projections of the input text. The rest is just some programming around this equation to make it happen in a way that works robustly in a big deep learning model.

In [6]:
# token is the current word
# pos is the position within the text
# s keeps the past words
# w is the weights
# the token embedding table converts words to vectors for processing in the transformer
def transformer(token, pos, config, s, w, token_embedding_table=token_embedding_table):
    
    #just unpack the config
    dim, hidden_dim, n_layers, n_heads, n_kv_heads, vocab_size, seq_len = config
    head_size = dim // n_heads
    
    # convert the word to a vector
    x = token_embedding_table[token]
    
    # for encoding the position
    freq_cis_real_row = weight_dict["freq_cis_real"][pos]
    freq_cis_imag_row = weight_dict["freq_cis_imag"][pos]
    
    for l in range(n_layers): # probably a better way...
        
        # a bunch of matric multiplications to mix in the weights
        xb = rmsnorm(x, weight_dict["rms_att_weight"][l])
        
        # these are the vectors for the key-value lookup ("attention")
        q = np.matmul(xb,weight_dict["wq"][l].T)
        k = np.matmul(xb,weight_dict["wk"][l].T)
        v = np.matmul(xb,weight_dict["wv"][l].T)
        

        # position encodeing (RoPE)
        for h in range(n_heads):
            
            
            
            for i in range(0,head_size,2): # n_heads is 6, head_size is 48, total 288
                 
                q0 = q[h*head_size+i]  # just copy the c code, we can vectorize and do all at once after
                q1 = q[h*head_size+i+1]
                k0 = k[h*head_size+i]
                k1 = k[h*head_size+i+1]
                fcr = freq_cis_real_row[i//2]
                fci = freq_cis_imag_row[i//2]
                q[h*head_size+i] = q0 * fcr - q1 * fci
                q[h*head_size+i+1] = q0 * fci + q1 * fcr
                k[h*head_size+i] = k0 * fcr - k1 * fci
                k[h*head_size+i+1] = k0 * fci + k1 * fcr
 
    # saves it because it uses it for past timesteps
    
        s["key_cache"][l][pos] = k
        s["value_cache"][l][pos] = v
    
        # attention
        xb = np.zeros(dim)
        for h in range(n_heads):
    
            q_t = q[h*head_size:(h+1)*head_size] #q[h*headsize] is the start
        
            for t in range(pos+1):
        
                k_t = s["key_cache"][l][t][h*head_size:(h+1)*head_size]
                score = np.dot(q_t,k_t)/np.sqrt(head_size)
                s["att"][h][t] = score
            
                
            s["att"][h] = softmax(s["att"][h],pos+1)
            
            xbh = np.zeros(head_size)
            
            for t in range(pos+1):
                
                v_t = s["value_cache"][l][t][h*head_size:(h+1)*head_size]
              
                a = s["att"][h][t]
                xbh += a*v_t
               
        
            xb[h*head_size:(h+1)*head_size] = xbh
            

        # more weights + residual
        xb2 = np.matmul(xb, weight_dict["wo"][l].T)
        x = x + xb2
        
        xb = rmsnorm(x,weight_dict["rms_ffn_weight"][l])
        #print(xb.shape)
        #print(weight_dict["w1"][l].shape)
        
        hb = np.matmul(xb, weight_dict["w1"][l].T)
        hb2 = np.matmul(xb, weight_dict["w3"][l].T)
        
        hb = hb*(1/(1+np.exp(-hb)))
        
        hb = hb*hb2
        
        xb = np.matmul(hb, weight_dict["w2"][l].T)
        
        x += xb


    x = rmsnorm(x, weight_dict["rms_final_weight"])
    logits = np.matmul(x, token_embedding_table.T) # shared weights
    
    
    return x, logits
   
        
        
In [7]:
# sample from a vector of probabilities in proportion to each 
# there's a numpy function I belive
def sample(p):
    r = np.random.random()
    cdf = 0
    for i, v in enumerate(p):
        cdf+=v
        if r<cdf: return i
        
    return len(p)-1
In [8]:
# function to read in the vocabulary data that turns each word into a number for manipulation
def read_token(vc, p):
    # read a float score
    score = struct.unpack("f",vc[p:(p+4)])
    # read length
    tok_len = struct.unpack("i",vc[(p+4):(p+8)])[0]
    #print(tok_len)
    # read the string
    token = struct.unpack(f"{tok_len}s",vc[(p+8):(p+8+tok_len)])
    return score[0], token[0].decode(), p+8+tok_len

# file that was downloaded
vocab_file = "tokenizer.bin"

with open(vocab_file, mode='rb') as file: # b is important -> binary
    vc = file.read()

# int for max_token_length
max_token_length = struct.unpack("i",vc[:4])


vocab = []
vocab_scores = []
p = 4
for i in range(vocab_size):
    s, t, p = read_token(vc,p)
    vocab.append(t)
    vocab_scores.append(s)
In [9]:
# lookup a word's position in the vocabulary file
def lookup(s,vocab=vocab):
    for i, word in enumerate(vocab):
        if s==word: return i
    return -1

# encode a text string into tokens (words or parts of words)
def bpe_encode(text, vocab=vocab, vocab_scores=vocab_scores):
    
    tokens = [lookup(c) for c in text]
    
    while(True):
        
        running_merge = [vocab[tokens[i]]+vocab[tokens[i+1]] for i in range(len(tokens)-1)]
               
        best_id = -1
        best_score = -1e10
        
        for i, m in enumerate(running_merge):
            ind = lookup(m)
            if ind>0:
                score = vocab_scores[ind]
                if score > best_score:
                    best_score = score
                    best_id = i
        
        #print("Best id", best_id)
        #print("token there:", running_merge[best_id])
        #print("position in vocab:", lookup(running_merge[best_id]))
        
        if best_id == -1: break
        
        tokens[best_id] = lookup(running_merge[best_id])
        
        tokens.pop(best_id+1)
        #print 
    
    return tokens
In [10]:
# initialize the cache
state_dict = {}
state_dict["key_cache"] = np.zeros((n_layers,seq_len,dim))
state_dict["value_cache"] = np.zeros((n_layers,seq_len,dim))
state_dict["att"] = np.zeros((n_heads,seq_len)) # could be dropped and moved inside the transformer
In [12]:
# this is where the model runs. 

# inputs
temperature = 1 # zero makes it deterministic, bigger numbers increase the variability
steps = 256 # size of the text generated
prompt = "In an old house"

# start with token 1 which is the "beginning of sentence" token and at position zero
token = 1
pos = 0

prompt_tokens = bpe_encode(prompt)

while(pos < steps):
    
    # at each step, the next token is calculated from the previous ones
    x, logits = transformer(token, pos, config, state_dict, weight_dict)
    
    if pos < len(prompt_tokens):
        
        next_t = prompt_tokens[pos]
        
    else:
    
        if temperature == 0:
        
            next_t = np.argmax(logits)
        
        else:
            
            logits = softmax(logits/temperature,len(logits))
            next_t = sample(logits)
        
    token_str = vocab[next_t]
    print(token_str,end="")
    sys.stdout.flush()
    
    token = next_t
    pos+=1
    
    
    
In an old house, a little girl named Amy was playing with her toys. She loved to play all day long. Her mom called her from the house, "Amy, come and help me repair your broken fork."
Amy went to her mom and said, "Mom, I can't fix your fork. It's broken." Her mom looked at the fork and said, "Amy, I can fix it for you. You need to improve your work by being a good helper."
Amy practiced being a good girl all day. She helped her mom sort the dishes, cleaned the floor, and even helped her dad in the garden. As she worked, she remembered to be good at fixing things. She kept adding to her fix skills until it was time for dinner. When dinner was ready, Amy's family was so happy. They all sat down and ate together. The fixers were so proud of Amy that they gave her a big hug. Amy felt happy that she could help and improve her family's home. The end.
<s>
 Once upon a time, there was a little girl named Lily. One night, she was scared because she heard a monster in her dream

Remarks

You can see example output above. This model has 15 Million paremeters (weights) so it is relatively simply but it still writes a coherent story by predicting next words. Models like GPT4 have hundreds of billions or trillions of parameters, which is why they appear so coherent and real. Architectually they are the same as this program, just scaled up to have a better model of the language.

In [ ]: