What Is Retrieval Augmented Generation?
Retrieval Augmented Generation (RAG) is an AI framework for improving the quality of LLM-generated responses by grounding the model on external sources of knowledge to supplement the LLM’s internal representation of information. Implementing RAG in an LLM-based question-answering system has two main benefits: it ensures that the model has access to the most current, reliable facts and that users have access to the model’s sources, ensuring that its claims can be checked for accuracy and ultimately trusted.
Here we will implement a RAG application that will interact with a codebase of our choice and generate sample code using the code repository as the base. Interaction with a codebase using a RAG application offers numerous advantages for modern software development. It significantly enhances efficiency by combining the power of retrieval and generation, which saves considerable time for developers who would otherwise spend considerable time searching for relevant code snippets or crafting solutions from scratch.
Moreover, the use of a codebase as a reference point ensures the accuracy of the generated code, which aligns closely with project standards and conventions while reducing the risk of errors. RAG applications facilitate seamless knowledge transfer within development teams by providing instant access to pertinent code examples and best practices, which can accelerate onboarding for new members and promote consistency across projects. The scalability of RAG applications ensures adaptability to handle larger codebases and more intricate development tasks, which allow developers to sustainably leverage their benefits as projects evolve.
A few assumptions were made while implementing the RAG application:
We have considered files with extensions such as .py, .md, and .txt.
Converted all the above corresponding files into text format.
Implementation Steps
Install the dependencies.
Clone the GitHub repo you want to talk to into the notebook.
Convert the repo into .txt files.
Loop over these .txt files and use Langchain to split them into “chunks”, vectorize them, and write them to Qdrant DB.
Use CodeLlama-7b Instruct model as the LLM.
Add a prompt template (optional).
Talk to the repository using LLM and matching context from the vector store based on the query asked.
Technology Stack Used
Qdrant: Vector store
CodeLlama-7b-hf: LLM
LangChain: Application Framework
Gradio: User Interface
Code Implementation Steps
Install required dependencies
!pip install langchain transformers accelerate sentence-transformers
!pip install qdrant-client
!pip install langchainhub
!pip install gradio
Import required dependencies
from langchain import HuggingFaceHub
from langchain_community.llms import LlamaCpp
from langchain_community.vectorstores import Qdrant
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from langchain_community.vectorstores import Qdrant
from langchain_community.document_loaders import DirectoryLoader,TextLoader
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
Clone to repo you want to talk to
!git clone https://github.com/alejandro-ao/chat-with-websites codebase
Convert repo to text, prepare it to be vectorized
def convert_files_to_txt(src_dir, dst_dir):
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
for root, dirs, files in os.walk(src_dir):
for file in files:
if not file.endswith('.jpg'):
file_path = os.path.join(root, file)
rel_path = os.path.relpath(file_path, src_dir)
new_root = os.path.join(dst_dir, os.path.dirname(rel_path))
os.makedirs(new_root, exist_ok=True)
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = f.read()
except UnicodeDecodeError:
try:
with open(file_path, 'r', encoding='latin-1') as f:
data = f.read()
except UnicodeDecodeError:
print(f"Failed to decode the file: {file_path}")
continue
new_file_path = os.path.join(new_root, file + '.txt')
with open(new_file_path, 'w', encoding='utf-8') as f:
f.write(data)
import os
convert_files_to_txt('/content/codebase', '/content/converted_codebase')
Perform chunking prior to vector store loading
src_dir = "/content/converted_codebase"
loader = DirectoryLoader(src_dir, show_progress=True, loader_cls=TextLoader)
repo_files = loader.load()
print(f"Number of files loaded: {len(repo_files)}")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
documents = text_splitter.split_documents(documents=repo_files)
print(f"Number of documents : {len(documents)}")
Set up the metadata
for doc in documents:
old_path_with_txt_extension = doc.metadata["source"]
new_path_without_txt_extension = old_path_with_txt_extension.replace(".txt", "")
doc.metadata.update({"source": new_path_without_txt_extension})
Instantiate the Embedding Model
model_name = "BAAI/bge-small-en-v1.5"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings":True}
embeddings = HuggingFaceBgeEmbeddings(model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
)
Instantiate the LLM
gpu_llm = HuggingFacePipeline.from_model_id(
model_id="codellama/CodeLlama-7b-hf",
task="text-generation",
device_map="auto",
pipeline_kwargs={"max_new_tokens": 100},
)
Instantiate Qdrant Vector Store and Load the documents
qdrant = Qdrant.from_documents(
documents,
embeddings,
path="/content/local_qdrant",
collection_name="my_documents",
)
Helper function to display documents
def pretty_print_docs(documents):
for doc in documents:
print(doc.metadata)
print(" - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - ")
print(doc.page_content)
Test and Check if the vector stores have been loaded successfully
query = "what is the syntax to import text_splitter using langchain"
found_docs = qdrant.similarity_search(query)
pretty_print_docs(found_docs)
Instantiate the Query Engine
Method 1: Using LCEL
from langchain import hub
prompt = hub.pull("rlm/rag-prompt")
Prompt: ChatPromptTemplate(input_variables=['context', 'question'], messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['context', 'question'], template="You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\nQuestion: {question} \nContext: {context} \nAnswer:"))])
chain = prompt | gpu_llm | StrOutputParser()
response = chain.invoke({"question":query,"context":found_docs})
print(response)
Response:
You can use the following code to import text_splitter using langchain
```python
from langchain.text_splitter import RecursiveCharacterTextSplitter
```
Context: [Document(page_content='# pip install streamlit langchain lanchain-openai beautifulsoup4 python-dotenv chromadb\n\nimport streamlit as st
Method 2: Using RetrievalQA Chain
from langchain.chains import RetrievalQA
qa = RetrievalQA.from_chain_type(llm=gpu_llm,
chain_type="stuff",
retriever=qdrant.as_retriever(search_kwargs={"k":2}),
return_source_documents=True)
response = qa.invoke("""complete the below code:\n#load web documents using lanchain \n#\n from lanchain_community.document_loaders""")
print(response)
print(response['result'])
{'query': 'complete the below code:\n#load web documents using lanchain \n#\n from lanchain_community.document_loaders', 'result': '\n\nfrom lanchain_community.document_loaders import WebBaseLoader\n\n#load web documents using lanchain \n#\n loader = WebBaseLoader(url)\n\n#load web documents using lanchain \n#\n document = loader.load()\n\n#split the document into chunks\n#\n text_splitter = RecursiveCharacterTextSplitter()\n\n#split the document into chunks\n#\n document_chunks = text_split', 'source_documents': [Document(page_content='# pip install streamlit langchain lanchain-openai beautifulsoup4 python-dotenv chromadb\n\nimport streamlit as st\nfrom langchain_core.messages import AIMessage, HumanMessage\nfrom langchain_community.document_loaders import WebBaseLoader\nfrom langchain.text_splitter import RecursiveCharacterTextSplitter\nfrom langchain_community.vectorstores import Chroma\nfrom langchain_openai import OpenAIEmbeddings, ChatOpenAI\nfrom dotenv import load_dotenv\nfrom langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder\nfrom langchain.chains import create_history_aware_retriever, create_retrieval_chain\nfrom langchain.chains.combine_documents import create_stuff_documents_chain\n\n\nload_dotenv()\n\ndef get_vectorstore_from_url(url):\n # get the text in document form\n loader = WebBaseLoader(url)\n document = loader.load()\n \n # split the document into chunks\n text_splitter = RecursiveCharacterTextSplitter()\n document_chunks = text_splitter.split_documents(document)\n \n # create a vectorstore from the chunks\n vector_store = Chroma.from_documents(document_chunks, OpenAIEmbeddings())\n\n return vector_store', metadata={'source': '/content/converted_codebase/src/app.py', '_id': 'e03ae45d29424e988ba8ccdee5f16859', '_collection_name': 'my_documents'}), Document(page_content='langchain==0.1.4\nlangchain_community==0.0.16\nlangchain_core==0.1.17\nlangchain_openai==0.0.5\npython-dotenv==1.0.1\nstreamlit==1.30.0\nchromadb==0.3.29\nbs4==0.0.2', metadata={'source': '/content/converted_codebase/requirements', '_id': 'ea9fc86bf8ed4e279005dd9ff1a25b28', '_collection_name': 'my_documents'})]}
Gradio Implementation
Helper function to load documents into vector store
def load_documents(folder_path):
src_dir = folder_path
loader = DirectoryLoader(src_dir, show_progress=True, loader_cls=TextLoader)
repo_files = loader.load()
print(f"Number of files loaded: {len(repo_files)}")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
documents = text_splitter.split_documents(documents=repo_files)
print(f"Number of documents : {len(documents)}")
qdrant = Qdrant.from_documents(
documents,
embeddings,
path="/content/local_qdrant_gradio",
collection_name="my_documents",)
return "documents loaded sucessfully"
Helper function to retrieve response based on the query
def retrive_response(query):
qa = RetrievalQA.from_chain_type(llm=gpu_llm,
chain_type="stuff",
retriever=qdrant.as_retriever(search_kwargs={"k":2}),
return_source_documents=True)
response = qa.invoke(query)
return response['result']
Gradio Application
import gradio as gr
app1 = gr.Interface(fn = load_documents, inputs= gr.Textbox(label="Enter the code base folder path"), outputs="text")
app2 = gr.Interface(fn=retrive_response, inputs= gr.Textbox(label="Enter your question here."), outputs="textbox")
demo = gr.TabbedInterface([app1, app2], ["Load Documents", "Ask Question"])
demo.launch()
if __name__ == "__main__":
demo.launch()
Load the documents into Vector Store
Ask Query once the document has been loaded successfully
Conclusion
In conclusion, the incorporation of RAG on the codebase reduced LLM hallucinations by offering a base reference to generate sample code.
References
connect with me