Tonic commited on
Commit
5362be0
1 Parent(s): ebe18c1

use in memory chroma client

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. app.py +14 -42
.gitignore CHANGED
@@ -2,4 +2,5 @@
2
  chroma_data/
3
  __pycache__/
4
  chroma.log
5
- .venv/
 
 
2
  chroma_data/
3
  __pycache__/
4
  chroma.log
5
+ .venv/
6
+ pad.py
app.py CHANGED
@@ -38,7 +38,9 @@ hf_token, yi_token = load_env_variables()
38
  def clear_cuda_cache():
39
  torch.cuda.empty_cache()
40
 
41
- client = OpenAI(api_key=yi_token, base_url=API_BASE)
 
 
42
 
43
  class EmbeddingGenerator:
44
  def __init__(self, model_name: str, token: str, intention_client):
@@ -125,59 +127,29 @@ def load_documents(file_path: str, mode: str = "elements"):
125
  docs = loader.load()
126
  return [doc.page_content for doc in docs]
127
 
128
- def wait_for_chroma_server(client, retries=10, delay=0.5):
129
- for _ in range(retries):
130
- try:
131
- client.heartbeat()
132
- print("Chroma server is up and running!")
133
- return True
134
- except Exception as e:
135
- print(f"Attempt to connect to Chroma server failed: {e}")
136
- time.sleep(delay)
137
- print("Failed to connect to Chroma server after multiple attempts.")
138
- return False
139
-
140
  def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
141
- host = 'localhost'
142
- port = 8000
143
-
144
- client = HttpClient(host=host, port=port, settings=Settings(allow_reset=True, anonymized_telemetry=False))
145
-
146
- if not wait_for_chroma_server(client):
147
- raise ConnectionError("Could not connect to Chroma server. Ensure it is running.")
148
-
149
- client.reset() # Empties and completely resets the database
150
- collection = client.create_collection(collection_name)
151
- return client, collection
152
-
153
- def add_documents_to_chroma(client, collection, documents: list, embedding_function: MyEmbeddingFunction):
154
  for doc in documents:
155
  embeddings, metadata = embedding_function.embedding_generator.compute_embeddings(doc)
156
  for embedding, meta in zip(embeddings, metadata):
157
- collection.add(
158
  ids=[str(uuid.uuid1())],
159
  documents=[doc],
160
  embeddings=[embedding],
161
  metadatas=[meta]
162
  )
163
-
164
- def query_chroma(client, collection_name: str, query_text: str, embedding_function: MyEmbeddingFunction):
165
- # Compute query embeddings and metadata
166
- query_embeddings, query_metadata = embedding_function.embedding_generator.compute_embeddings(query_text)
167
 
168
- # Initialize Chroma with the collection
169
- db = Chroma(client=client, collection_name=collection_name, embedding_function=embedding_function)
170
-
171
- # Perform similarity search using the query embeddings and metadata
172
- result_docs = db.similarity_search(
173
- query_embeddings=query_embeddings,
174
- query_metadata=query_metadata
175
  )
176
-
177
  return result_docs
178
 
179
-
180
-
181
  # Initialize clients
182
  intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
183
  embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
@@ -246,5 +218,5 @@ with gr.Blocks() as demo:
246
  query_button.click(query_documents, inputs=query_input, outputs=query_output)
247
 
248
  if __name__ == "__main__":
249
- os.system("chroma run --host localhost --port 8000 &")
250
  demo.launch()
 
38
  def clear_cuda_cache():
39
  torch.cuda.empty_cache()
40
 
41
+ client = OpenAI(api_key=yi_token, base_url=API_BASE)
42
+ chroma_client = HttpClient(host="localhost", port=8000)
43
+ chroma_collection = chroma_client.create_collection("all-my-documents")
44
 
45
  class EmbeddingGenerator:
46
  def __init__(self, model_name: str, token: str, intention_client):
 
127
  docs = loader.load()
128
  return [doc.page_content for doc in docs]
129
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction):
131
+ db = Chroma(client=chroma_client, collection_name=collection_name, embedding_function=embedding_function)
132
+ return db
133
+
134
+ def add_documents_to_chroma(documents: list, embedding_function: MyEmbeddingFunction):
 
 
 
 
 
 
 
 
 
135
  for doc in documents:
136
  embeddings, metadata = embedding_function.embedding_generator.compute_embeddings(doc)
137
  for embedding, meta in zip(embeddings, metadata):
138
+ chroma_collection.add(
139
  ids=[str(uuid.uuid1())],
140
  documents=[doc],
141
  embeddings=[embedding],
142
  metadatas=[meta]
143
  )
 
 
 
 
144
 
145
+ def query_chroma(query_text: str, embedding_function: MyEmbeddingFunction):
146
+ query_embeddings, query_metadata = embedding_function.embedding_generator.compute_embeddings(query_text)
147
+ result_docs = chroma_collection.query(
148
+ query_texts=[query_text],
149
+ n_results=2
 
 
150
  )
 
151
  return result_docs
152
 
 
 
153
  # Initialize clients
154
  intention_client = OpenAI(api_key=yi_token, base_url=API_BASE)
155
  embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client)
 
218
  query_button.click(query_documents, inputs=query_input, outputs=query_output)
219
 
220
  if __name__ == "__main__":
221
+ # os.system("chroma run --host localhost --port 8000 &")
222
  demo.launch()