How to Build a Local RAG Knowledge Base with Google Gemma 2 2B

Sebastian Petrus
5 min readSep 4, 2024

--

Google’s recent release of Gemma 2 2B has sparked excitement in the AI community, offering a powerful language model that can be run entirely locally on various devices, including Mac. This article will guide you through the process of setting up and running Gemma 2 2B on your local machine, exploring different methods and potential applications.

Before we get started, if you want to manage all the AI models in one place, I strongly suggest you to take a look at Anakin AI, where you can use virtually any AI Model without the pain of managing 10+ subscriptions.

Google Gemma 2 2B: Perfect for Local Device RAG?

Google Gemma 2 2B Benchmarks

Let’s think about it: Google Gemma 2 2B is even more powerful than Mixtral 8x7B, and GPT-3.5-turbo, for a 2B model!

Google’s Gemma 2 2B offers an excellent balance of performance and efficiency, making it ideal for local deployment in RAG systems. Its compact size (2 billion parameters) allows for quick inference while maintaining high-quality outputs.

Setting Up Your RAG System with Gemma 2 2B

Hey, if you are working with AI APIs, Apidog is here to make your life easier. It’s an all-in-one API development tool that streamlines the entire process — from design and documentation to testing and debugging.

Step 1: Local Deployment of Gemma 2 2B

Before diving into RAG, it’s crucial to have Gemma 2 2B running locally. While there are multiple methods to achieve this, Ollama provides a straightforward approach:

  1. Install Ollama from https://ollama.com/
  2. Pull the Gemma 2 2B model:
ollama pull gemma2:2b

3. Run the model:

ollama run gemma2:2b

This setup allows you to interact with Gemma 2 2B directly on your local machine, ensuring privacy and reducing latency.

Step 2: Implementing the RAG Architecture

A. Choose a Vector Database

For our RAG system, we’ll use Marqo as our vector database. Marqo excels in creating and managing vector embeddings, crucial for efficient information retrieval.

  1. Install Marqo following the official documentation.
  2. Set up a Marqo instance to host your vector embeddings.

B. Prepare Your Knowledge Base

  1. Collect relevant documents, articles, or data sources that will form your knowledge base.
  2. Preprocess the data, ensuring it’s in a format suitable for indexing (e.g., plain text, JSON).

C. Index Your Data with Marqo

  1. Use Marqo’s API to create embeddings of your knowledge base:
import marqo

mq = marqo.Client()
# Create an index
mq.create_index("my_knowledge_base")
# Add documents to the index
documents = [
{"id": "1", "content": "Your document text here"},
{"id": "2", "content": "Another document text"}
]
mq.index("my_knowledge_base").add_documents(documents)

This process converts your text data into vector representations, enabling efficient similarity searches.

D. Develop the RAG Pipeline

Create a Python script that integrates Gemma 2 2B with Marqo:

import marqo
from ollama import Client

mq = marqo.Client()
ollama_client = Client()
def rag_query(user_query, top_k=3):
# Retrieve relevant documents
results = mq.index("my_knowledge_base").search(user_query, limit=top_k)

# Construct context from retrieved documents
context = " ".join([result['content'] for result in results['hits']])

# Prepare prompt for Gemma 2 2B
prompt = f"Context: {context}\n\nQuestion: {user_query}\n\nAnswer:"

# Generate response using Gemma 2 2B
response = ollama_client.generate(model="gemma2:2b", prompt=prompt)

return response['response']
# Example usage
query = "What is the capital of France?"
answer = rag_query(query)
print(answer)

This pipeline:

  1. Takes a user query
  2. Retrieves relevant information from the Marqo index
  3. Constructs a prompt combining the retrieved information and the query
  4. Feeds this to Gemma 2 2B
  5. Returns the generated response

Enhancing Your RAG System

Fine-tuning Gemma 2 2B

Fine-tuning Gemma 2 2B for your specific domain involves several technical steps:

1. Prepare the dataset:

  • Create a dataset of question-answer pairs relevant to your domain.
  • Format the data as JSON lines, with each line containing a “prompt” and “completion” field.
  • Example format:
{"prompt": "What is the capital of France?", "completion": "The capital of France is Paris."}

2. Set up the fine-tuning environment:

  • Install the necessary libraries: transformers, peft, bitsandbytes, accelerate, and datasets.
  • Load the pre-trained Gemma 2 2B model:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

3. Configure LoRA (Low-Rank Adaptation):

  • Use Parameter-Efficient Fine-Tuning (PEFT) to reduce memory requirements:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)

4. Set up the training arguments:

from transformers import TrainingArguments

training_args = TrainingArguments(
output_dir="./gemma-2b-finetuned",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=True,
save_total_limit=3,
)

5. Train the model:

from transformers import Trainer

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
trainer.train()

6. Save the fine-tuned model:

model.save_pretrained("./gemma-2b-finetuned")
tokenizer.save_pretrained("./gemma-2b-finetuned")

Optimizing Retrieval

To improve retrieval performance:

  1. Adjust the number of retrieved documents:
  • Experiment with different top_k values in your vector search:
results = vector_db.search(query_embedding, top_k=5)  # Adjust this value

2. Implement re-ranking:

  • Use a cross-encoder model for re-ranking:
from sentence_transformers import CrossEncoder

reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
def rerank(query, documents):
pairs = [[query, doc] for doc in documents]
scores = reranker.predict(pairs)
reranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
return [doc for doc, _ in reranked]
reranked_results = rerank(query, initial_results)

3. Hybrid retrieval:

  • Combine BM25 (keyword-based) and semantic search:
from rank_bm25 import BM25Okapi

def hybrid_search(query, corpus, embeddings, alpha=0.5):
bm25 = BM25Okapi(corpus)
bm25_scores = bm25.get_scores(query.split())
semantic_scores = cosine_similarity([query_embedding], embeddings)[0]
hybrid_scores = alpha * normalize(bm25_scores) + (1 - alpha) * normalize(semantic_scores)
return sorted(range(len(hybrid_scores)), key=lambda i: hybrid_scores[i], reverse=True)
results = hybrid_search(query, corpus, embeddings)

Implementing Caching

To set up a caching layer:

Step 1. Install and set up Redis:

import redis

redis_client = redis.Redis(host='localhost', port=6379, db=0)

Step 2. Implement caching in your RAG pipeline:

import json

def get_rag_response(query):
cache_key = f"rag_response:{query}"
cached_response = redis_client.get(cache_key)

if cached_response:
return json.loads(cached_response)

# Your existing RAG pipeline here
response = run_rag_pipeline(query)

# Cache the response
redis_client.setex(cache_key, 3600, json.dumps(response)) # Cache for 1 hour

return response

Building Multi-modal RAG

To extend your system for multi-modal queries:

  1. Use a vision-language model:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer

model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
def generate_image_caption(image):
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
output_ids = model.generate(pixel_values)
return tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]

2. Index image embeddings:

from PIL import Image

def index_image(image_path):
image = Image.open(image_path)
caption = generate_image_caption(image)
embedding = get_embedding(caption) # Your existing embedding function
marqo_client.index("images").add_documents([{
"id": image_path,
"caption": caption,
"embedding": embedding
}])

3. Modify the pipeline for multi-modal queries:

def multi_modal_rag(query, image=None):
if image:
image_caption = generate_image_caption(image)
query = f"{query} Image description: {image_caption}"

# Proceed with your existing RAG pipeline using the enhanced query
return run_rag_pipeline(query)

Working with Conversational RAG

To implement a chat interface with conversation history:

  1. Maintain conversation history:
conversation_history = []

def add_to_history(role, content):
conversation_history.append({"role": role, "content": content})

2. Use history for context:

def get_conversational_context():
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history[-5:]])

3. Implement sliding window context:

def get_sliding_window_context(window_size=5):
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history[-window_size:]])

def conversational_rag(query):
context = get_sliding_window_context()
enhanced_query = f"{context}\nHuman: {query}"
response = run_rag_pipeline(enhanced_query)
add_to_history("Human", query)
add_to_history("AI", response)
return response

Conclusion

Building a RAG Knowledge Base with Google Gemma 2 2B offers a powerful way to create intelligent, context-aware systems that run entirely on your local machine. By combining the strengths of efficient language models like Gemma 2 2B with vector databases like Marqo, you can create sophisticated AI applications that provide accurate, contextually relevant responses while maintaining data privacy and reducing latency.

--

--

Sebastian Petrus
Sebastian Petrus

Written by Sebastian Petrus

Asist Prof @U of Waterloo, AI/ML, e/acc

No responses yet