How to Build a Local RAG Knowledge Base with Google Gemma 2 2B
Google’s recent release of Gemma 2 2B has sparked excitement in the AI community, offering a powerful language model that can be run entirely locally on various devices, including Mac. This article will guide you through the process of setting up and running Gemma 2 2B on your local machine, exploring different methods and potential applications.
Before we get started, if you want to manage all the AI models in one place, I strongly suggest you to take a look at Anakin AI, where you can use virtually any AI Model without the pain of managing 10+ subscriptions.
Google Gemma 2 2B: Perfect for Local Device RAG?
Let’s think about it: Google Gemma 2 2B is even more powerful than Mixtral 8x7B, and GPT-3.5-turbo, for a 2B model!
Google’s Gemma 2 2B offers an excellent balance of performance and efficiency, making it ideal for local deployment in RAG systems. Its compact size (2 billion parameters) allows for quick inference while maintaining high-quality outputs.
Setting Up Your RAG System with Gemma 2 2B
Hey, if you are working with AI APIs, Apidog is here to make your life easier. It’s an all-in-one API development tool that streamlines the entire process — from design and documentation to testing and debugging.
Step 1: Local Deployment of Gemma 2 2B
Before diving into RAG, it’s crucial to have Gemma 2 2B running locally. While there are multiple methods to achieve this, Ollama provides a straightforward approach:
- Install Ollama from https://ollama.com/
- Pull the Gemma 2 2B model:
ollama pull gemma2:2b
3. Run the model:
ollama run gemma2:2b
This setup allows you to interact with Gemma 2 2B directly on your local machine, ensuring privacy and reducing latency.
Step 2: Implementing the RAG Architecture
A. Choose a Vector Database
For our RAG system, we’ll use Marqo as our vector database. Marqo excels in creating and managing vector embeddings, crucial for efficient information retrieval.
- Install Marqo following the official documentation.
- Set up a Marqo instance to host your vector embeddings.
B. Prepare Your Knowledge Base
- Collect relevant documents, articles, or data sources that will form your knowledge base.
- Preprocess the data, ensuring it’s in a format suitable for indexing (e.g., plain text, JSON).
C. Index Your Data with Marqo
- Use Marqo’s API to create embeddings of your knowledge base:
import marqo
mq = marqo.Client()
# Create an index
mq.create_index("my_knowledge_base")
# Add documents to the index
documents = [
{"id": "1", "content": "Your document text here"},
{"id": "2", "content": "Another document text"}
]
mq.index("my_knowledge_base").add_documents(documents)
This process converts your text data into vector representations, enabling efficient similarity searches.
D. Develop the RAG Pipeline
Create a Python script that integrates Gemma 2 2B with Marqo:
import marqo
from ollama import Client
mq = marqo.Client()
ollama_client = Client()
def rag_query(user_query, top_k=3):
# Retrieve relevant documents
results = mq.index("my_knowledge_base").search(user_query, limit=top_k)
# Construct context from retrieved documents
context = " ".join([result['content'] for result in results['hits']])
# Prepare prompt for Gemma 2 2B
prompt = f"Context: {context}\n\nQuestion: {user_query}\n\nAnswer:"
# Generate response using Gemma 2 2B
response = ollama_client.generate(model="gemma2:2b", prompt=prompt)
return response['response']
# Example usage
query = "What is the capital of France?"
answer = rag_query(query)
print(answer)
This pipeline:
- Takes a user query
- Retrieves relevant information from the Marqo index
- Constructs a prompt combining the retrieved information and the query
- Feeds this to Gemma 2 2B
- Returns the generated response
Enhancing Your RAG System
Fine-tuning Gemma 2 2B
Fine-tuning Gemma 2 2B for your specific domain involves several technical steps:
1. Prepare the dataset:
- Create a dataset of question-answer pairs relevant to your domain.
- Format the data as JSON lines, with each line containing a “prompt” and “completion” field.
- Example format:
{"prompt": "What is the capital of France?", "completion": "The capital of France is Paris."}
2. Set up the fine-tuning environment:
- Install the necessary libraries:
transformers
,peft
,bitsandbytes
,accelerate
, anddatasets
. - Load the pre-trained Gemma 2 2B model:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
3. Configure LoRA (Low-Rank Adaptation):
- Use Parameter-Efficient Fine-Tuning (PEFT) to reduce memory requirements:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
4. Set up the training arguments:
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./gemma-2b-finetuned",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
fp16=True,
save_total_limit=3,
)
5. Train the model:
from transformers import Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
trainer.train()
6. Save the fine-tuned model:
model.save_pretrained("./gemma-2b-finetuned")
tokenizer.save_pretrained("./gemma-2b-finetuned")
Optimizing Retrieval
To improve retrieval performance:
- Adjust the number of retrieved documents:
- Experiment with different
top_k
values in your vector search:
results = vector_db.search(query_embedding, top_k=5) # Adjust this value
2. Implement re-ranking:
- Use a cross-encoder model for re-ranking:
from sentence_transformers import CrossEncoder
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
def rerank(query, documents):
pairs = [[query, doc] for doc in documents]
scores = reranker.predict(pairs)
reranked = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
return [doc for doc, _ in reranked]
reranked_results = rerank(query, initial_results)
3. Hybrid retrieval:
- Combine BM25 (keyword-based) and semantic search:
from rank_bm25 import BM25Okapi
def hybrid_search(query, corpus, embeddings, alpha=0.5):
bm25 = BM25Okapi(corpus)
bm25_scores = bm25.get_scores(query.split())
semantic_scores = cosine_similarity([query_embedding], embeddings)[0]
hybrid_scores = alpha * normalize(bm25_scores) + (1 - alpha) * normalize(semantic_scores)
return sorted(range(len(hybrid_scores)), key=lambda i: hybrid_scores[i], reverse=True)
results = hybrid_search(query, corpus, embeddings)
Implementing Caching
To set up a caching layer:
Step 1. Install and set up Redis:
import redis
redis_client = redis.Redis(host='localhost', port=6379, db=0)
Step 2. Implement caching in your RAG pipeline:
import json
def get_rag_response(query):
cache_key = f"rag_response:{query}"
cached_response = redis_client.get(cache_key)
if cached_response:
return json.loads(cached_response)
# Your existing RAG pipeline here
response = run_rag_pipeline(query)
# Cache the response
redis_client.setex(cache_key, 3600, json.dumps(response)) # Cache for 1 hour
return response
Building Multi-modal RAG
To extend your system for multi-modal queries:
- Use a vision-language model:
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
def generate_image_caption(image):
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
output_ids = model.generate(pixel_values)
return tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
2. Index image embeddings:
from PIL import Image
def index_image(image_path):
image = Image.open(image_path)
caption = generate_image_caption(image)
embedding = get_embedding(caption) # Your existing embedding function
marqo_client.index("images").add_documents([{
"id": image_path,
"caption": caption,
"embedding": embedding
}])
3. Modify the pipeline for multi-modal queries:
def multi_modal_rag(query, image=None):
if image:
image_caption = generate_image_caption(image)
query = f"{query} Image description: {image_caption}"
# Proceed with your existing RAG pipeline using the enhanced query
return run_rag_pipeline(query)
Working with Conversational RAG
To implement a chat interface with conversation history:
- Maintain conversation history:
conversation_history = []
def add_to_history(role, content):
conversation_history.append({"role": role, "content": content})
2. Use history for context:
def get_conversational_context():
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history[-5:]])
3. Implement sliding window context:
def get_sliding_window_context(window_size=5):
return "\n".join([f"{msg['role']}: {msg['content']}" for msg in conversation_history[-window_size:]])
def conversational_rag(query):
context = get_sliding_window_context()
enhanced_query = f"{context}\nHuman: {query}"
response = run_rag_pipeline(enhanced_query)
add_to_history("Human", query)
add_to_history("AI", response)
return response
Conclusion
Building a RAG Knowledge Base with Google Gemma 2 2B offers a powerful way to create intelligent, context-aware systems that run entirely on your local machine. By combining the strengths of efficient language models like Gemma 2 2B with vector databases like Marqo, you can create sophisticated AI applications that provide accurate, contextually relevant responses while maintaining data privacy and reducing latency.