AFischer1985 commited on
Commit
41ec323
1 Parent(s): 0408088

Initial commit

Browse files
Files changed (3) hide show
  1. README.md +8 -8
  2. requirements.txt +3 -0
  3. run.py +117 -0
README.md CHANGED
@@ -1,12 +1,12 @@
 
1
  ---
2
- title: RAG Interface To Hub
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.12.0
8
- app_file: app.py
9
  pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+
2
  ---
3
+ title: RAG-Interface-to-Hub
4
+ emoji: 🔥
5
+ colorFrom: indigo
6
+ colorTo: indigo
7
  sdk: gradio
8
+ sdk_version: 3.47.1
9
+ app_file: run.py
10
  pinned: false
11
+ hf_oauth: false
12
  ---
 
 
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ llama-cpp-python[server]
2
+ chromadb
3
+ sentence_transformers
run.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #########################################################################################
2
+ # Title: Gradio Interface to LLM-chatbot with RAG-funcionality and ChromaDB on HF-Hub
3
+ # Author: Andreas Fischer
4
+ # Date: December 29th, 2023
5
+ # Last update: December 29th, 2023
6
+ ##########################################################################################
7
+
8
+
9
+ # Chroma-DB
10
+ #-----------
11
+ import os
12
+ import chromadb
13
+ dbPath="/home/af/Schreibtisch/gradio/Chroma/db"
14
+ if(os.path.exists(dbPath)==False):
15
+ dbPath="/home/user/app/db"
16
+ print(dbPath)
17
+ #client = chromadb.Client()
18
+ path=dbPath
19
+ client = chromadb.PersistentClient(path=path)
20
+ print(client.heartbeat())
21
+ print(client.get_version())
22
+ print(client.list_collections())
23
+ from chromadb.utils import embedding_functions
24
+ default_ef = embedding_functions.DefaultEmbeddingFunction()
25
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="T-Systems-onsite/cross-en-de-roberta-sentence-transformer")
26
+ #instructor_ef = embedding_functions.InstructorEmbeddingFunction(model_name="hkunlp/instructor-large", device="cuda")
27
+ print(str(client.list_collections()))
28
+
29
+ global collection
30
+ if("name=ChromaDB1" in str(client.list_collections())):
31
+ print("ChromaDB1 found!")
32
+ collection = client.get_collection(name="ChromaDB1", embedding_function=sentence_transformer_ef)
33
+ else:
34
+ print("ChromaDB1 created!")
35
+ collection = client.create_collection(
36
+ "ChromaDB1",
37
+ embedding_function=sentence_transformer_ef,
38
+ metadata={"hnsw:space": "cosine"})
39
+
40
+ collection.add(
41
+ documents=["The meaning of life is to love.", "This is a sentence", "This is a sentence too"],
42
+ metadatas=[{"source": "notion"}, {"source": "google-docs"}, {"source": "google-docs"}],
43
+ ids=["doc1", "doc2", "doc3"],
44
+ )
45
+
46
+ print("Database ready!")
47
+ print(collection.count())
48
+
49
+
50
+ # Model
51
+ #-------
52
+
53
+ from huggingface_hub import InferenceClient
54
+ import gradio as gr
55
+
56
+ client = InferenceClient(
57
+ "mistralai/Mixtral-8x7B-Instruct-v0.1"
58
+ #"mistralai/Mistral-7B-Instruct-v0.1"
59
+ )
60
+
61
+
62
+ # Gradio-GUI
63
+ #------------
64
+
65
+ import gradio as gr
66
+ import json
67
+
68
+ def format_prompt(message, history):
69
+ prompt = "<s>"
70
+ for user_prompt, bot_response in history:
71
+ prompt += f"[INST] {user_prompt} [/INST]"
72
+ prompt += f" {bot_response}</s> "
73
+ prompt += f"[INST] {message} [/INST]"
74
+ return prompt
75
+
76
+ def response(
77
+ prompt, history, temperature=0.9, max_new_tokens=500, top_p=0.95, repetition_penalty=1.0,
78
+ ):
79
+ temperature = float(temperature)
80
+ if temperature < 1e-2: temperature = 1e-2
81
+ top_p = float(top_p)
82
+ generate_kwargs = dict(
83
+ temperature=temperature,
84
+ max_new_tokens=max_new_tokens,
85
+ top_p=top_p,
86
+ repetition_penalty=repetition_penalty,
87
+ do_sample=True,
88
+ seed=42,
89
+ )
90
+ addon=""
91
+ results=collection.query(
92
+ query_texts=[prompt],
93
+ n_results=2,
94
+ #where={"source": "google-docs"}
95
+ #where_document={"$contains":"search_string"}
96
+ )
97
+ dists=["<small>(relevance: "+str(round((1-d)*100/100))+";" for d in results['distances'][0]]
98
+ sources=["source: "+s["source"]+")</small>" for s in results['metadatas'][0]]
99
+ results=results['documents'][0]
100
+ combination = zip(results,dists,sources)
101
+ combination = [' '.join(triplets) for triplets in combination]
102
+ print(combination)
103
+ if(len(results)>1):
104
+ addon=" Bitte berücksichtige bei deiner Antwort ggf. folgende Auszüge aus unserer Datenbank, sofern sie für die Antwort erforderlich sind. Beantworte die Frage knapp und präzise. Ignoriere unpassende Datenbank-Auszüge OHNE sie zu kommentieren, zu erwähnen oder aufzulisten:\n"+"\n".join(results)
105
+ system="Du bist ein KI-basiertes Assistenzsystem."+addon+"\n\nUser-Anliegen:"
106
+ #body={"prompt":system+"### Instruktion:\n"+message+"\n\n### Antwort:","max_tokens":500, "echo":"False","stream":"True"} #e.g. SauerkrautLM
107
+ formatted_prompt = format_prompt(system+"\n"+prompt, history)
108
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
109
+ output = ""
110
+ for response in stream:
111
+ output += response.token.text
112
+ yield output
113
+ output=output+"\n\n<br><details open><summary><strong>Sources</strong></summary><br><ul>"+ "".join(["<li>" + s + "</li>" for s in combination])+"</ul></details>"
114
+ yield output
115
+
116
+ gr.ChatInterface(response, chatbot=gr.Chatbot(render_markdown=True),title="RAG-Interface").queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
117
+ print("Interface up and running!")