🚨 New Course Alert! 🚨 Complete Python With DSA Bootcamp + LEETCODE Exercises is now live! 🎉 Enroll Now

Fine-Tune Gemma Models With Custom Data in Keras using LoRA

Hello all, my name is Krish Naik. Welcome to my YouTube channel! In this specific video, we are going to fine-tune Gemma models in Keras using the LoRA technique. If you have seen my fine-tuning playlist, I've created numerous videos on topics such as fine-tuning, LoRA, CORA, Contagion, and many more. In this particular video, we will delve into fine-tuning the Gemma model using our own custom data.

Setting Up the Environment

First, you need to complete the setup instruction at Gemma setup. If you click on the provided link, you can see the entire documentation. The first requirement is an API key. Here's how you can get it:

  1. Go to https://aistudio.google.com and click on Get API Key.
  2. Create the API key by selecting a project and giving it a name. Copy this key for future use.

My API key is already created, so I'll use that. It's crucial to also get access to Google Geni 1.5 Pro, available at the same URL.

Setting Up Kaggle and Google Colab

Next, visit https://kaggle.com and request access to the Gemma setup. After logging in, consent to the license for Gemma and you will be able to access the models. The models support JAX, TensorFlow, and PyTorch.

For this setup, I am using a paid Google Colab Pro account to ensure sufficient RAM for fine-tuning. Complete the following steps:

  1. Create your API key and set it as kaggle_key in the code.
  2. Optionally, create a Hugging Face token if necessary.
  3. Setup your Kaggle key and username in the Colab environment by generating new tokens and setting up the JSON keys.

With the environment set up, we can now install the required libraries and configure the backend:


import os 

# Ensure you have your Kaggle key and username set up
!pip install keras-nlp keras == 2.5.1 

# Select the backend
import keras_nlp
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".95"
import tensorflow as tf

Importing the Data

To fine-tune the Gemma model, we need a dataset in the format of a JSONL file. Below is an example dataset:


{
  "instruction": "What should I do on a trip to Europe?",
  "context": "First, get a passport and visa. Then plan your trip."
}

Use the above structure where instruction is the question and context is the answer. You can create your dataset or use an open-source one like the Dolly 15K dataset.

Loading and Fine-Tuning the Model

We use the following code to load the model from Kaggle and fine-tune it:


import keras_nlp.model

# Load Gemma model
model_name = "gemma-2B"
model = keras_nlp.model.load_model(model_name)

# Load data
import json

with open("dolly_15k.jsonl", "r") as file:
    data = [json.loads(line) for line in file]

# Prepare training data
X_train = [item["instruction"] for item in data][:1000]
y_train = [item["context"] for item in data][:1000]

Enable LoRA with rank set to four and setup training parameters:


# Define LoRA
model.enable_lora(rank=4)

# Compile model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.005),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train model
model.fit(X_train, y_train, epochs=1, batch_size=1)

Testing the Fine-Tuned Model

After fine-tuning, we can test the model to ensure it delivers accurate responses:


# Testing the fine-tuned model
response = model.predict("What should I do on a trip to Europe?")
print(response)

# Output: "First, get a passport and visa. Then plan your trip."

Similarly, another example:


response = model.predict("Explain the process of photosynthesis in a way a child could understand.")
print(response)

# Outputs a simplified explanation of photosynthesis.

Conclusion

In this tutorial, we have successfully fine-tuned a Gemma model using Keras and the LoRA technique. You can experiment with different datasets and hyperparameters to improve the model's performance further. Fine-tuning can significantly enhance your model’s capabilities and yield more accurate results.

Call-to-action:

Watch the complete tutorial on YouTube for detailed steps and a deeper understanding.