Tonic commited on
Commit
3adc147
1 Parent(s): 190f21f

add automatic metadata

Browse files
Files changed (1) hide show
  1. app.py +34 -10
app.py CHANGED
@@ -1,10 +1,5 @@
1
  # main.py
2
  import spaces
3
- import os
4
- import uuid
5
- import gradio as gr
6
- import torch
7
- import torch.nn.functional as F
8
  from torch.nn import DataParallel
9
  from torch import Tensor
10
  from transformers import AutoTokenizer, AutoModel
@@ -15,9 +10,16 @@ from langchain_chroma import Chroma
15
  from chromadb import Documents, EmbeddingFunction, Embeddings
16
  from chromadb.config import Settings
17
  from chromadb import HttpClient
18
- from utils import load_env_variables, parse_and_route
19
- from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name
 
 
 
 
20
  from dotenv import load_dotenv
 
 
 
21
 
22
  load_dotenv()
23
 
@@ -34,7 +36,7 @@ def clear_cuda_cache():
34
  torch.cuda.empty_cache()
35
 
36
  client = OpenAI(api_key=yi_token, base_url=API_BASE)
37
-
38
  class EmbeddingGenerator:
39
  def __init__(self, model_name: str, token: str, intention_client):
40
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -44,7 +46,7 @@ class EmbeddingGenerator:
44
 
45
  def clear_cuda_cache(self):
46
  torch.cuda.empty_cache()
47
-
48
  @spaces.GPU
49
  def compute_embeddings(self, input_text: str):
50
  # Get the intention
@@ -71,6 +73,17 @@ class EmbeddingGenerator:
71
  query_prefix = f"Instruct: {task_description}\nQuery: "
72
  queries = [input_text]
73
 
 
 
 
 
 
 
 
 
 
 
 
74
  # Get the embeddings
75
  with torch.no_grad():
76
  inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
@@ -80,8 +93,19 @@ class EmbeddingGenerator:
80
  # Normalize embeddings
81
  query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
82
  embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
 
 
 
 
83
  self.clear_cuda_cache()
84
- return embeddings_list
 
 
 
 
 
 
 
85
 
86
  class MyEmbeddingFunction(EmbeddingFunction):
87
  def __init__(self, embedding_generator: EmbeddingGenerator):
 
1
  # main.py
2
  import spaces
 
 
 
 
 
3
  from torch.nn import DataParallel
4
  from torch import Tensor
5
  from transformers import AutoTokenizer, AutoModel
 
10
  from chromadb import Documents, EmbeddingFunction, Embeddings
11
  from chromadb.config import Settings
12
  from chromadb import HttpClient
13
+ import os
14
+ import re
15
+ import uuid
16
+ import gradio as gr
17
+ import torch
18
+ import torch.nn.functional as F
19
  from dotenv import load_dotenv
20
+ from utils import load_env_variables, parse_and_route
21
+ from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name , metadata_prompt
22
+
23
 
24
  load_dotenv()
25
 
 
36
  torch.cuda.empty_cache()
37
 
38
  client = OpenAI(api_key=yi_token, base_url=API_BASE)
39
+
40
  class EmbeddingGenerator:
41
  def __init__(self, model_name: str, token: str, intention_client):
42
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
46
 
47
  def clear_cuda_cache(self):
48
  torch.cuda.empty_cache()
49
+
50
  @spaces.GPU
51
  def compute_embeddings(self, input_text: str):
52
  # Get the intention
 
73
  query_prefix = f"Instruct: {task_description}\nQuery: "
74
  queries = [input_text]
75
 
76
+ # Get the metadata
77
+ metadata_completion = self.intention_client.chat.completions.create(
78
+ model="yi-large",
79
+ messages=[
80
+ {"role": "system", "content": metadata_prompt},
81
+ {"role": "user", "content": input_text}
82
+ ]
83
+ )
84
+ metadata_output = metadata_completion.choices[0].message['content']
85
+ metadata = self.extract_metadata(metadata_output)
86
+
87
  # Get the embeddings
88
  with torch.no_grad():
89
  inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device)
 
93
  # Normalize embeddings
94
  query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
95
  embeddings_list = query_embeddings.detach().cpu().numpy().tolist()
96
+
97
+ # Include metadata in the embeddings
98
+ embeddings_with_metadata = [{"embedding": emb, "metadata": metadata} for emb in embeddings_list]
99
+
100
  self.clear_cuda_cache()
101
+ return embeddings_with_metadata
102
+
103
+ def extract_metadata(self, metadata_output: str):
104
+ # Regex pattern to extract key-value pairs
105
+ pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"')
106
+ matches = pattern.findall(metadata_output)
107
+ metadata = {key: value for key, value in matches}
108
+ return metadata
109
 
110
  class MyEmbeddingFunction(EmbeddingFunction):
111
  def __init__(self, embedding_generator: EmbeddingGenerator):