EmbeddingAlign RAG: Boosting QA Systems
The Retrieval-Augmented Generation (RAG) method improves question-answering by retrieving relevant information before generating answers. In this paper, we introduce EmbeddingAlign RAG, a new way to improve document retrieval in RAG systems. We demonstrate that applying a linear transformation to both the query embedding and the chunks embeddings produces a significant improvement in retrieval accuracy.
Our approach uses small datasets, such as a single PDF with 500 chunks, an out-of-the-box model embedder (can be a black-box embedding model such as the OpenAI embedding API), and leverages already existing embeddings collections. The model is trained on CPU with standard hardware and increases retrieval inference time by less than 10ms. We show that EmbeddingAlign RAG improves retrieval accuracy, boosting the hit rate from 0.89 to 0.95. The simplicity of implementing this method makes it ideal for direct applications in real-world Q&A systems.
Introduction to RAG and Embedding Optimization
Challenges of existing RAG systems
A RAG system combines a retriever and a generator (like GPT-4). In this system, the retriever fetches relevant documents or document chunks based on user queries (questions). This system usually relies on document embeddings and query embeddings to retrieve relevant chunks. Periodically, new documents are added to the knowledge base to address expanding use cases.
The mathematical representation of the RAG system is as follows:
Let be the set of user queries, be the set of document chunks, and be the embedding function. For a query and a document chunk , the retrieval score is typically calculated as: where similarity is often cosine similarity:
However, as more documents are introduced, the retriever may pull irrelevant chunks, leading to poor performance in retrieval. A reason for this decrease in performance is because existing embeddings (either for the queries or chunks) are not well-aligned with the new data.
Furthermore, trying or training a new embedding model is expensive, as large document embeddings collection are costly and difficult to update from an engineering perspective. This leaves many RAG systems in a dead-end, where retrieval needs to be improved to push performances, but it’s too costly to do.
Our Approach: Embedding Alignment via Linear Transformation
We propose training a single linear transformation that will be applied on embedding vectors of both query and text chunks embedding vectors. The goal is to learn a transformation that brings the embeddings of user queries closer to the most relevant document chunks, improving retrieval accuracy.
Let be the linear transformation matrix we want to learn. The new retrieval score becomes:
where:
In our case, (the dimension of the embeddings from text-embedding-3-small by OpenAI). To find a matrix that improve the retrieval quality of our RAG pipeline, we need the following data:
- Query embeddings (user questions)
- Good chunk embeddings (relevant chunks from documents)
- Distractor chunk embeddings (irrelevant or misleading chunks)
In an already deployed RAG system, these data can be easily collected through feedback mechanisms or annotated datasets. For example, users or annotators’ feedback on end results (thumbs up/down on results) can be used to define good chunks and bad chunks.
From an inference perspective, the linear layer needs to be applied to the user query embedding during the Retrieval step, but also to all of the embeddings stored in the vectorstore. For latency and compute efficiency considerations, we recommend applying the linear transformation to all of the vectors of the vectorstore once, before starting to perform retrievals.
Synthetic Dataset Generation
In our case, we don’t have access to proprietary production data. To overcome this limitation, we simulate a realistic scenario by generating a synthetic dataset from existing documents (representing the knowledge base). This approach is also useful in cases where production data is absent or insufficient, as one only needs the text chunks of the knowledge base to improve the RAG.
We base our synthetic dataset on publicly available documents and models:
- Documents: SEC Form 10K reports from Uber and Lyft, which contain rich text on similar business activities (mobility services).
- Models: GPT-4o is used to generate queries and LlamaIndex to generate document chunks. For each chunk, a query (question) is generated to form (query, chunk) pairs.
Below is an example of a (query, chunk) pair from our dataset:
Query (generated question) | Chunk (original document) |
---|---|
How is the commercial agreement for the utiliza- tion of Lyft rideshare and fleet data by Woven Planet accounted for, and what financial impact did it have on the company’s deferred revenue? | the consolidated statement of operations for the quarter ended September 30, 2021. The commer- cial agreements for the utilization of Lyft rideshare and fleet data by Woven Planet is accounted for under ASC 606 and the Company recorded a de- ferred revenue liability of $42.5 million related to the performance obligations under these commer- cial agreements as part of the transaction at closing. The Company also derecognized $3.4 million in assets held for sale.97 |
Then, both the queries and chunks are embedded using the OpenAI text-embedding-3-small model. In the case of an already deployed RAG system, the embeddings of the chunks would already exist and be stored in a vectorstore. With the Uber document, we generate 686 embedding pairs for training and validation dataset. With the Lyft document, we generate 818 embedding pairs, used as the test dataset. The cost of generating this dataset is modest.
Data preparation & augmentation
To train the Linear transformation model that aligns user queries with the corresponding text chunks, we use triplet loss. This loss brings the correct pairs (query and chunk) closer together and pushes the incorrect chunks further apart. In addition, generating triplets significantly increases the size of the training dataset virtually for free.
A triplet consists of three elements:
- Query (question embedding)
- Chunk (correct document chunk embedding)
- Distractor (an incorrect document chunk embedding)
Each pair of (query, chunk) is matched with several distractors, which are incorrect chunks. We call augmentation factor the ratio of incorrect chunks associated with each pair. With an augmentation factor of 0.3, we create 141 100 triplets for the train-val dataset. By controlling triplets, we can penalize specific associations that one may want to avoid, further refining retrieval performance. Such distractors could be for instance collected through a user feedback mechanism (thumbs up or down buttons).
Training process
To improve the retrieval performance, we train a linear layer over the embeddings to minimize the triplet loss. The triplet loss encourages the model to minimize the distance between the query and the correct chunk while maximizing the distance between the query and the distractor.
The triplet loss formula is:
Where:
- is the query embedding,
- is the correct chunk embedding,
- is the distractor chunk embedding,
- represents the distance between two embeddings. In our case, we use the normalized cosine distance.
- is an hyperparameter. It’s a positive value that ensures the model creates sufficient separation between correct and incorrect pairs.
For training, we only rely on consumer-grade CPU and do not use GPU, as the linear layer and cosine distance require little computing resources.
Results
We compare our adapted user query embedding (that use the trained Linear transformation) to a simple one for the retrieval of document embeddings using two metrics:
MRR (Mean Reciprocal Rank): Measures how highly the correct answer is ranked, with higher MRR indicating better ranking performance.
where is the position of the first relevant item for query in the top-k results.
Hit Rate: The percentage of queries where the correct answer appears in the top-k results, reflecting overall retrieval quality.
where is the indicator function defined as:
In other words, the hit rate calculates the proportion of queries for which at least one correct answer is retrieved within the top-k results. This metric provides a measure of the retrieval system’s ability to surface relevant information for a given query within a specified number of top results. The value of k is typically chosen based on the specific requirements of the RAG system, such as the number of results shown to the user or the number of chunks processed by the generation step. Common values for k include 1, 5, or 10, depending on the application. Please note that the absolute value of these scores depends on the chosen value for k. We mostly use the Hit Rate for the top 4 results (k = 4) and MRR as they best highlight improvements in retrieval accuracy.
Metric | Simple embedding (reference) | Adapted embedding (ours) |
---|---|---|
Hit Rate | 0.89 | 0.95 |
MRR | 0.69 | 0.83 |
The adapted embedding shows a significant improvement on both metrics.
Impact on inference time
On our standard hardware, the model increases the retrieval time by 8.6% (less than 10ms). In practice, this delay is barely noticeable, as the later generation step in RAG may take several seconds.
Conclusion
In this paper, we introduce EmbeddingAlign RAG, a new way to improve document retrieval in Retrieval-Augmented Generation (RAG) systems. By adapting the embeddings to include domain-specific context, we significantly improve retrieval accuracy at a low computational cost by leveraging existing embeddings. Our method is efficient, scalable, and easy to implement. This makes EmbeddingAlign RAG ideal for practical question-answering systems, especially with smaller datasets, as it improves performance while increasing retrieval by less than 10 ms.
If you have any questions about this work, feel free to reach out to the phospho team at [email protected] or check out our work at phospho.ai 🧪