|
from langchain_mongodb import MongoDBAtlasVectorSearch |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from pymongo import MongoClient |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain_community.llms import HuggingFaceEndpoint |
|
|
|
import os |
|
|
|
config= { |
|
'MONGODB_CONN_STRING': os.getenv('MONGODB_CONN_STRING'), |
|
'HUGGINGFACEHUB_API_TOKEN': os.getenv('HUGGINGFACEHUB_API_TOKEN'), |
|
'DB_NAME':os.getenv('DB_NAME'), |
|
'VECTOR_SEARCH_INDEX':os.getenv('VECTOR_SEARCH_INDEX'), |
|
'PASSWORD_DB': os.getenv('PASSWORD_DB') |
|
|
|
} |
|
client = MongoClient(config['MONGODB_CONN_STRING']) |
|
embeddings = HuggingFaceEmbeddings(model_name= "intfloat/e5-large-v2") |
|
|
|
llm_model = HuggingFaceEndpoint(repo_id='mistralai/Mistral-7B-Instruct-v0.2', |
|
huggingfacehub_api_token=config['HUGGINGFACEHUB_API_TOKEN'], |
|
temperature=0.3) |
|
|
|
template = """ |
|
<s>[INST] Instruction:Your are a helpful chatbot who can answer all data science ,anime and manga questions. |
|
You have to follow these rules strictly while answering the question based on context: |
|
1. Do not use the word context or based on context which is provided in answers. |
|
2. If there is no context you have to answer in 128 words not more than that. |
|
3. context are in series format so make your own best pattern based on that give answer. |
|
[/INST] |
|
context: |
|
{context}</s> |
|
### QUESTION: |
|
{question} [/INST] |
|
""" |
|
prompt = ChatPromptTemplate.from_template(template=template) |
|
parser = StrOutputParser() |
|
|
|
|
|
def get_all_collections(): |
|
database = client[config['DB_NAME']] |
|
names = database.list_collection_names() |
|
coll_dict = {} |
|
for name in names: |
|
coll_dict[name] = ' '.join(str(name).capitalize().split('_')) |
|
return coll_dict |
|
class VECTORDB_STORE: |
|
|
|
def __init__(self, coll_name): |
|
collection_name = self.get_collection_name(coll_name) |
|
collection = client[config['DB_NAME']][collection_name] |
|
self.vectordb_store = MongoDBAtlasVectorSearch(collection =collection, |
|
embedding= embeddings, |
|
index_name= config['VECTOR_SEARCH_INDEX']) |
|
@staticmethod |
|
def get_collection_name(coll_name): |
|
for key, value in get_all_collections().items(): |
|
if coll_name == value: |
|
return key |
|
return None |
|
|
|
def chain(self): |
|
retriever = self.vectordb_store.as_retriever(search_kwargs={"k": 10}) |
|
chain = {'context': retriever, 'question': RunnablePassthrough()} | prompt | llm_model | parser |
|
return chain |
|
|
|
|