Home » Tutorials » PyTorch » Retrieval Augmented Generation (RAG) Application

Retrieval Augmented Generation (RAG) Application

Learn how to use Retrieval Augmented Generation on Intel Gaudi with Hugging Face

Retrieval Augmented Generation (RAG) on Intel® Gaudi® 2

A scalable Retrieval Augmented Generation (RAG) application using Hugging Face tools to deploy optimized applications utilizing the Intel Gaudi 2 AI acclerator.

This tutorial will show how to build a RAG application using Intel Gaudi 2. The Application will be built from easily accessible hugging face tools such as: text-generation-inference (TGI) and text-embeddings-inference (TEI). To make the code easier to understand, Langchain will be used. The User interface at the end of the tutorial will use Gradio to submit your queries. This application will be in docker environment, but can be easily deployed to a Kubernetes cluster.

Retrieval-augmented generation (RAG) is a method that enhances the precision and dependability of generative AI models by incorporating facts from external sources. This technique addresses the limitations of large language models (LLMs), which, despite their ability to generate responses to general prompts rapidly, may not provide in-depth or specific information. By enabling access to external knowledge sources, RAG improves factual consistency, increases the reliability of generated responses, and helps to mitigate the issue of “hallucination” in more complex and knowledge-intensive tasks.

This Tutorial will show the steps of building the full RAG pipeline on Intel Gaudi 2.  First, we will build the text generation, text embedding and vector store index and database generation tools.   Then the external dataset will be prepared by extracting the information from the external document, creating “chunks” of the document and creating numerical embeddings of the chunks of data.  These embeddings are then loaded into the vector database.  Then to start to run a query, it will run the embedding model again on the query, attempt to match it with the contents of the database and send the overall prompt and query response to the Llama 2 Large Language model to generate a full formed response.

RAG model
Figure 1. RAG model details

Initial Setup

There are the initial steps to ensure that your build environment is set correctly:

  1. Set the appropriate ports for access when you ssh into the Intel Gaudi 2 node. you need to ensure that the following ports are open:

    Port 8888 (for running this jupyter notebook)
    Port 7680 (for run the gradio server)

  2. Do to this, you need to add the following in your overall ssh command when connecting to the Intel Gaudi Node:

    ssh -L 8888:localhost:8888 -L 7860:localhost:7860 ....

  3. Before you load this Notebook, you will run the standard Intel Gaudi docker image but you need to include the /var/run/docker.sock file. Use these Run and exec commands below to start your docker.
docker run -itd --name RAG --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host -v /var/run/docker.sock:/var/run/docker.sock vault.habana.ai/gaudi-docker/1.15.0/ubuntu22.04/habanalabs/pytorch-installer-2.2.0:latest

docker exec -it RAG bash

cd ~ && git clone https://github.com/HabanaAI/Gaudi-tutorials

python3 -m pip install jupyterlab && python3 -m jupyterlab_server --IdentityProvider.token='' --ServerApp.password='' --allow-root --port 8888 --ServerApp.root_dir=$HOME &

Setup the docker environment in this notebook:

At this point you have cloned the Gaudi-tutorials notebook inside your docker image and have opened this notebook. Now start to follow the steps. Note that you will need to install docker again inside the Intel Gaudi container to manage the execution of the RAG tools.

cd /root/Gaudi-tutorials/PyTorch/RAG_Application
apt-get update
apt-get install docker.io curl -y

Loading the Tools for RAG

There are three steps in creating the RAG environment, text generation, text embedding and vectorization

Text Generation Interface (TGI)

First building block of application will be text-generation-inference, it’s purpose will be serving the LLM model that will answer question based on context. To run it, we need to build a docker image:

Please note: The Hugging Face Text Generation Interface depends on software that is subject to non-open source licenses. If you use or redistribute this software, it is your sole responsibility to ensure compliance with such licenses.

cd /root/Gaudi-tutorials/PyTorch/RAG_Application
git clone -b v1.2.1 https://github.com/huggingface/tgi-gaudi.git
cd tgi-gaudi
docker build -t tgi-gaudi .
cd ../

After building image you will run it:

How to access and Use the Llama 2 model

To use the Llama 2 model, you will need a HuggingFace account, agree to the terms of use of the model in its model card on the HF Hub, and create a read token. You then copy that token to the HUGGING_FACE_HUB_TOKEN variable below.

Use of the pretrained model is subject to compliance with third party licenses, including the “Llama 2 Community License Agreement” (LLAMAV2). For guidance on the intended use of the LLAMA2 model, what will be considered misuse and out-of-scope uses, who are the intended users and additional terms please review and read the instructions in this link https://ai.meta.com/llama/license/. Users bear sole liability and responsibility to follow and comply with any third party licenses, and Habana Labs disclaims and will bear no liability with respect to users’ use or compliance with third party licenses.

docker run -d -p 9001:80 \
    --runtime=habana \
    --name gaudi-tgi \
    -e HABANA_VISIBLE_DEVICES=0 \
    -e OMPI_MCA_btl_vader_single_copy_mechanism=none \
    -e HUGGING_FACE_HUB_TOKEN="<your_token_here>" \
    --cap-add=sys_nice \
    --ipc=host \
    tgi-gaudi \
    --model-id meta-llama/Llama-2-7b-chat-hf

After running the docker server, it will take some time to download the model and load it into the device. To check the status run: docker logs gaudi-tgi and you should see:

2024-02-23T16:24:35.125179Z  INFO shard-manager: text_generation_launcher: Waiting for shard to be ready... rank=0
2024-02-23T16:24:40.729388Z  INFO shard-manager: text_generation_launcher: Shard ready in 65.710470677s rank=0
2024-02-23T16:24:40.796775Z  INFO text_generation_launcher: Starting Webserver
2024-02-23T16:24:42.589516Z  WARN text_generation_router: router/src/main.rs:355: `--revision` is not set
2024-02-23T16:24:42.589551Z  WARN text_generation_router: router/src/main.rs:356: We strongly advise to set it to a known supported commit.
2024-02-23T16:24:42.842098Z  INFO text_generation_router: router/src/main.rs:377: Serving revision e852bc2e78a3fe509ec28c6d76512df3012acba7 of model Intel/neural-chat-7b-v3-1
2024-02-23T16:24:42.845898Z  INFO text_generation_router: router/src/main.rs:219: Warming up model
2024-02-23T16:24:42.846613Z  WARN text_generation_router: router/src/main.rs:230: Model does not support automatic max batch total tokens
2024-02-23T16:24:42.846620Z  INFO text_generation_router: router/src/main.rs:252: Setting max batch total tokens to 16000
2024-02-23T16:24:42.846623Z  INFO text_generation_router: router/src/main.rs:253: Connected
2024-02-23T16:24:42.846626Z  WARN text_generation_router: router/src/main.rs:258: Invalid hostname, defaulting to 0.0.0.0
Code language: JavaScript (javascript)

Once the setup is complete, you can verify that that the text generation is working by sending a request to it (note that first request could be slow due to graph compilation):

curl 127.0.0.1:9001/generate \
    -X POST \
    -d '{"inputs":"why is the earth round?","parameters":{"max_new_tokens":200}}' \
    -H 'Content-Type: application/json'

Text Embedding Interface (TEI)

Next building block will be text-embeddings-inference, it’s purpose will be serving embeddings model that will produce embedings for vector database. To run it, we need to build docker image:

Please note: The Hugging Face Text Embedding Interface depends on software that is subject to non-open source licenses. If you use or redistribute this software, it is your sole responsibility to ensure compliance with such licenses.

git clone https://github.com/huggingface/tei-gaudi
cd tei-gaudi
!docker build --quiet -t tei-gaudi .
cd ../../

After building the image we can run it:

docker run -d -p 9002:80 \
    --runtime=habana \
    --name gaudi-tei \
    -e HABANA_VISIBLE_DEVICES=4 \
    -e OMPI_MCA_btl_vader_single_copy_mechanism=none \
    --cap-add=sys_nice \
    --ipc=host \
    tei-gaudi \
    --model-id BAAI/bge-large-en-v1.5

PGVector Database

Third building block is a vector database, in this tutorial the choice was PGVector. Set up the docker this way:

docker pull pgvector/pgvector:pg16
docker run \
    -d \
    -e POSTGRES_PASSWORD=postgres \
    -p 9003:5432 \
    pgvector/pgvector:pg16

Application Front End

The last building block will be a frontend that will serve as a http server. Frontend is implemented in python using the Gradio interface. To setup environment we need to run:

cd /root/Gaudi-tutorials/PyTorch/RAG_Application
pip install -q -r requirements.txt

Data preparation

To have a good quality RAG application, we need to prepare data. Data processing for vector database is following, extract text information from documents (for example PDFs, CSVs) than split it into chunks not exceeding max length, with additional metadata (for example filename or file creation date). Than upload preprocessed data to vector database.

In the process of data preprocessing, text splitting plays a crucial role. It involves breaking down the text into smaller, semantically meaningful chunks for further processing and analysis. Here are some common methods of text splitting:

  • By Character: This method involves splitting the text into individual characters. It’s a straightforward approach, but it may not always be the most effective, as it doesn’t take into account the semantic meaning of words or phrases.
  • Recursive: Recursive splitting involves breaking down the text into smaller parts repeatedly until a certain condition is met. This method is particularly useful when dealing with complex structures in the text, as it allows for a more granular level of splitting.
  • HTML Specific: When dealing with HTML content, text splitting can be done based on specific HTML tags or elements. This method is useful for extracting meaningful information from web pages or other HTML documents.
  • Code Specific: In the context of programming code, text can be split based on specific code syntax or structures. This method is particularly useful for code analysis or for building tools that work with code.
  • By Tokens: Tokenization is a common method of text splitting in Natural Language Processing (NLP). It involves breaking down the text into individual words or tokens. This method is effective for understanding the semantic meaning of the text, as it allows for the analysis of individual words and their context.

In conclusion, the choice of text splitting method depends largely on the nature of the text and the specific requirements of the task at hand. It’s important to choose a method that effectively captures the semantic meaning of the text and facilitates further processing and analysis.

In this tutorial we will use Recursive method. For better understanding the topic you can check https://langchain-text-splitter.streamlit.app/ app.

Database Population

Database population is a step where we load documents, embed them and than load into database.

Data Loading

For ease of use, we’ll use helper funcitions from langchain. Note that langchain_community is also required.

from pathlib import Path

from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import TextLoader
from langchain.vectorstores.pgvector import PGVector
from langchain_community.embeddings import HuggingFaceHubEmbeddings

Loading Documents with embeddings

Here we need to create huggingface TEI client and PGVector client. For PGVector, collection name corresponds to table name, within connection string there is connection protocol: postgresql+psycopg2, next is user, password, host, port and database name. For ease of use, pre_delete_collection is set to true to prevent duplicates in database.

embeddings = HuggingFaceHubEmbeddings(model="http://localhost:9002", huggingfacehub_api_token="EMPTY")
store = PGVector(
    collection_name="documents",
    connection_string="postgresql+psycopg2://postgres:postgres@localhost:9003/postgres",
    embedding_function=embeddings,
    pre_delete_collection=True
)

Data Loading and Splitting

Data is loaded from text files from data/, than documents are splitted into chunks of size: 512 characters and finally loaded into database. Note that documents can have metadata, that can be also stored in vector database.

You can load new text file in the data/ folder to run the RAG pipeline on new content by running the following cell again with new data. This cell will create a new Database to run your query.

def load_file_to_db(path: str, store: PGVector):
    loader = TextLoader(path)
    document = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=0)
    for chunk in text_splitter.split_documents(document):
        store.add_documents([chunk])

for doc in Path("data/").glob("*.txt"):
    print(f"Loading {doc}...")
    load_file_to_db(str(doc), store)

print("Finished.")

Running the Application

To start the application run the following commands below to setup the Gradio Interface.
Load a text file in the data folder and the run the cell above and the application will ingest and start the chat application to ask question to the document.
You will see that it’s directly accessing the TGI and TEI libraries to ingest, create the embeddings and vector database, the run the query thruough the database and then use the LLM to generate an answer to your query.

load_ext gradio

from langchain.vectorstores.pgvector import PGVector
from langchain.embeddings import HuggingFaceHubEmbeddings
from text_generation import Client

rag_prompt_intel_raw = """### System: You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. 

### User: Question: {question}

Context: {context}

### Assistant: """

def get_sources(question):
    embeddings = HuggingFaceHubEmbeddings(model="http://localhost:9002", huggingfacehub_api_token="EMPTY")
    store = PGVector(
        collection_name="documents",
        connection_string="postgresql+psycopg2://postgres:postgres@localhost:9003/postgres",
        embedding_function=embeddings,
    )
    return store.similarity_search(f"Represent this sentence for searching relevant passages: {question}", k=2)

def sources_to_str(sources):
    return "\n".join(f"{i+1}. {s.page_content}" for i, s in enumerate(sources))

def get_answer(question, sources):
    client = Client("http://localhost:9001") #change this to 9009 for the new model
    context = "\n".join(s.page_content for s in sources)
    prompt = rag_prompt_intel_raw.format(question=question, context=context)
    # return client.generate_stream(prompt, max_new_tokens=1024, stop_sequences=["### User:", "</s>"])
    return client.generate(prompt, max_new_tokens=1024, stop_sequences=["### User:", "</s>"]).generated_text

default_question = "What is this the summary of this document?"

def rag_answer(question):
    sources = get_sources(question)
    answer = get_answer(question, sources)
    #return f"Sources:\n{sources_to_str(sources)}\nAnswer:\n{answer}"
    return f"{answer}"

Finally, you will run the gradio application and see the output

%%blocks

import gradio as gr

with gr.Blocks() as demo:
    gr.Markdown(f"# Intel Gaudi 2 RAG app")
    question = gr.Textbox(default_question, label="Question")
    answer = gr.Textbox(label="Answer")
    send_btn = gr.Button("Run")
    send_btn.click(fn=rag_answer, inputs=question, outputs=answer)
Intel Gaudi RAG app

Next Steps

You can add other .txt documents into the ./data folder and then re-run the steps in Data Preparation to update the vector database with the new document.

You can also try the LocalGPT tutorial: https://developer.habana.ai/tutorials/pytorch/using-localgpt-with-llama2/ which also uses RAG with the LocalGPT script to generate responses using the chroma database.

Using other models, for a complete list of models optimized for Gaudi, visit: https://huggingface.co/docs/optimum/habana/index

Copyright© 2024 Habana Labs, Ltd. an Intel Company.

Licensed under the Apache License, Version 2.0 (the “License”);

You may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

Share this article:

Stay Informed: Register for the latest Intel Gaudi AI Accelerator developer news, events, training, and updates.