Fine tune Llama 2 with the RTX 4090 and serve it with Nx

Published October 15, 2023 by Toran Billups

You can fine tune with Bumblebee but large models like Llama 2 require more than 100GB of vRAM to fine tune with full precision. In order to fine tune this efficiently on a single RTX 4090 with only 24GB of vRAM I reached for a python project called lit-gpt. This allowed me to fine tune on local hardware, offering several advantages, most notably the ability to keep proprietary data from third-party cloud providers like openai.


The setup for this is fairly straightforward but I'll detail out the steps for those who want to try this out.

    $ git clone lit
    $ cd lit
    $ git checkout bf60124fa72a56436c7d4fecc093c7fc48e84433
    $ pip install -r requirements.txt
    $ python3 scripts/ --repo_id meta-llama/Llama-2-7b-chat-hf
    $ python3 scripts/ --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf

Data engineering

Next we need to create a custom dataset to fine tune the model with. For the Q&A use case it will be similar to the stanford alpaca example where questions are labeled with `instruction` and answers are labeled with `output`.

        "input": "",
        "instruction": "Who is the president of the United States?",
        "output": "Joe Biden is the president of the United States."

Now that we have instruction data JSON we need to copy that file into the lit directory and run a script to prepare the dataset for fine tuning.

    $ mkdir -p data/alpaca
    $ cd data/alpaca
    $ cp ~/somefolder/demo.json .
    $ cd ../../
    $ python3 scripts/ --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf --data_file_name demo.json

Fine tune llama 2

With the instruction data split into test and training sets we can run the script to fine tune llama 2. It's worth mentioning that we are not fine tuning with full precision. The tradeoff is that we can fine tune on a single RTX 4090 in about 3 hours.

    $ python3 finetune/ --data_dir data/alpaca --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf --precision bf16-true --quantize bnb.nf4

After the fine tuning process we need to merge the weights.

    $ mkdir -p out/lora_merged/Llama-2-7b-chat-hf
    $ python3 scripts/ --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf --lora_path out/lora/alpaca/lit_model_lora_finetuned.pth --out_dir out/lora_merged/Llama-2-7b-chat-hf

To run this model we first need to copy over a few files from the original model.

    $ cd out/lora_merged/Llama-2-7b-chat-hf
    $ cp ~/lit/checkpoints/meta-llama/Llama-2-7b-chat-hf/tokenizer.model .
    $ cp ~/lit/checkpoints/meta-llama/Llama-2-7b-chat-hf/*.json .


Before we serve the model with Nx it's important to evaluate it first. One intuitive way to test if the model has learned anything is to run the prompt and ask it a question.

    $ pip install sentencepiece
    $ python3 chat/ --checkpoint_dir out/lora_merged/Llama-2-7b-chat-hf

Serving with Nx

Now that we have a working model we need to pull over 2 files and copy the config file so Bumblebee can find it.

    $ cd out/lora_merged/Llama-2-7b-chat-hf
    $ cp ~/lit/checkpoints/meta-llama/Llama-2-7b-chat-hf/pytorch_model-00001-of-00002.bin .
    $ cp ~/lit/checkpoints/meta-llama/Llama-2-7b-chat-hf/pytorch_model-00002-of-00002.bin .
    $ cp lit_config.json config.json

To test this end to end we point Nx at the file system instead of pulling llama 2 from hugging face.

    def serving() do
      llama = {:local, "/home/toranb/lit/out/lora_merged/Llama-2-7b-chat-hf"}
      {:ok, spec} = Bumblebee.load_spec(llama, module: Bumblebee.Text.Llama, architecture: :for_causal_language_modeling)
      {:ok, model_info} = Bumblebee.load_model(llama, spec: spec, backend: {EXLA.Backend, client: :host})
      {:ok, tokenizer} = Bumblebee.load_tokenizer(llama, module: Bumblebee.Text.LlamaTokenizer)
      {:ok, generation_config} = Bumblebee.load_generation_config(llama, spec_module: Bumblebee.Text.Llama)
      generation_config = Bumblebee.configure(generation_config, max_new_tokens: 500)
      Bumblebee.Text.generation(model_info, tokenizer, generation_config, defn_options: [compiler: EXLA])

Next you can wire this up in your application.ex

    def start(_type, _args) do
      children = [
        {Nx.Serving, serving: serving(), name: ChatServing}

And finally, you can prompt the model from elixir code with Nx.Serving.

    Nx.Serving.batched_run(ChatServing, prompt)

Buy Me a Coffee

Twitter / Github / Email