thomasht86 commited on
Commit
2034346
β€’
1 Parent(s): 5d22e58

Upload folder using huggingface_hub

Browse files
Files changed (12) hide show
  1. README.md +2 -143
  2. backend/colpali.py +236 -270
  3. backend/stopwords.py +2 -1
  4. backend/vespa_app.py +3 -2
  5. frontend/app.py +71 -24
  6. frontend/layout.py +2 -1
  7. globals.css +65 -51
  8. icons.py +1 -1
  9. main.py +63 -73
  10. output.css +145 -61
  11. requirements.txt +1 -1
  12. static/.DS_Store +0 -0
README.md CHANGED
@@ -9,152 +9,11 @@ sdk_version: 4.44.0
9
  app_file: main.py
10
  pinned: false
11
  license: apache-2.0
 
12
  models:
13
  - vidore/colpaligemma-3b-pt-448-base
14
  - vidore/colpali-v1.2
15
  preload_from_hub:
16
  - vidore/colpaligemma-3b-pt-448-base config.json,model-00001-of-00002.safetensors,model-00002-of-00002.safetensors,model.safetensors.index.json,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa
17
  - vidore/colpali-v1.2 adapter_config.json,adapter_model.safetensors,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 9912ce6f8a462d8cf2269f5606eabbd2784e764f
18
- ---
19
-
20
- <!-- Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
21
-
22
- <picture>
23
- <source media="(prefers-color-scheme: dark)" srcset="https://assets.vespa.ai/logos/Vespa-logo-green-RGB.svg">
24
- <source media="(prefers-color-scheme: light)" srcset="https://assets.vespa.ai/logos/Vespa-logo-dark-RGB.svg">
25
- <img alt="#Vespa" width="200" src="https://assets.vespa.ai/logos/Vespa-logo-dark-RGB.svg" style="margin-bottom: 25px;">
26
- </picture>
27
-
28
- # Visual Retrieval ColPali
29
-
30
- # Prepare data and Vespa application
31
-
32
- First, install `uv`:
33
-
34
- ```bash
35
- curl -LsSf https://astral.sh/uv/install.sh | sh
36
- ```
37
-
38
- Then, run:
39
-
40
- ```bash
41
- uv sync --extra dev --extra feed
42
- ```
43
-
44
- Convert the `prepare_feed_deploy.py` to notebook to:
45
-
46
- ```bash
47
- jupytext --to notebook prepare_feed_deploy.py
48
- ```
49
-
50
- And launch a Jupyter instance, see https://docs.astral.sh/uv/guides/integration/jupyter/ for recommended approach.
51
-
52
- Open and follow the `prepare_feed_deploy.ipynb` notebook to prepare the data and deploy the Vespa application.
53
-
54
- # Developing on the web app
55
-
56
-
57
- Then, in this directory, run:
58
-
59
- ```bash
60
- uv sync --extra dev
61
- ```
62
-
63
- This will generate a virtual environment with the required dependencies at `.venv`.
64
-
65
- To activate the virtual environment, run:
66
-
67
- ```bash
68
- source .venv/bin/activate
69
- ```
70
-
71
- And run development server:
72
-
73
- ```bash
74
- python hello.py
75
- ```
76
-
77
- ## Preparation
78
-
79
- First, set up your `.env` file by renaming `.env.example` to `.env` and filling in the required values.
80
- (Token can be shared with 1password, `HF_TOKEN` is personal and must be created at huggingface)
81
-
82
- ### Deploying the Vespa app
83
-
84
- To deploy the Vespa app, run:
85
-
86
- ```bash
87
- python deploy_vespa_app.py --tenant_name mytenant --vespa_application_name myapp --token_id_write mytokenid_write --token_id_read mytokenid_read
88
- ```
89
-
90
- You should get an output like:
91
-
92
- ```bash
93
- Found token endpoint: https://abcde.z.vespa-app.cloud
94
- ````
95
-
96
- ### Feeding the data
97
-
98
- #### Dependencies
99
-
100
- In addition to the python dependencies, you also need `poppler`
101
- On Mac:
102
-
103
- ```bash
104
- brew install poppler
105
- ```
106
-
107
- First, you need to create a huggingface token, after you have accepted the term to use the model
108
- at https://huggingface.co/google/paligemma-3b-mix-448.
109
- Add the token to your environment variables as `HF_TOKEN`:
110
-
111
- ```bash
112
- export HF_TOKEN=yourtoken
113
- ```
114
-
115
- To feed the data, run:
116
-
117
- ```bash
118
- python feed_vespa.py --vespa_app_url https://myapp.z.vespa-app.cloud --vespa_cloud_secret_token mysecrettoken
119
- ```
120
-
121
- ### Starting the front-end
122
-
123
- ```bash
124
- python main.py
125
- ```
126
-
127
- ## Deploy to huggingface πŸ€—
128
-
129
- ### Compiling dependencies
130
-
131
- Before a deploy, make sure to run this to compile the `uv` lock file to `requirements.txt` if you have made changes to the dependencies:
132
-
133
- ```bash
134
- uv pip compile pyproject.toml -o requirements.txt
135
- ```
136
-
137
- ### Deploying to huggingface
138
-
139
- To deploy, run
140
-
141
- ```bash
142
- huggingface-cli upload vespa-engine/colpali-vespa-visual-retrieval . . --repo-type=space
143
- ```
144
-
145
- Note that you need to set `HF_TOKEN` environment variable first.
146
- This is personal, and must be created at [huggingface](https://huggingface.co/settings/tokens).
147
- Make sure the token has `write` access.
148
- Be aware that this will not delete existing files, only modify or add,
149
- see [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/upload#upload-from-the-cli) for more
150
- information.
151
-
152
- ### Making changes to CSS
153
-
154
- To make changes to output.css apply, run
155
-
156
- ```bash
157
- shad4fast watch # watches all files passed through the tailwind.config.js content section
158
-
159
- shad4fast build # minifies the current output.css file to reduce bundle size in production.
160
- ```
 
9
  app_file: main.py
10
  pinned: false
11
  license: apache-2.0
12
+ suggested_hardware: t4-small
13
  models:
14
  - vidore/colpaligemma-3b-pt-448-base
15
  - vidore/colpali-v1.2
16
  preload_from_hub:
17
  - vidore/colpaligemma-3b-pt-448-base config.json,model-00001-of-00002.safetensors,model-00002-of-00002.safetensors,model.safetensors.index.json,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 12c59eb7e23bc4c26876f7be7c17760d5d3a1ffa
18
  - vidore/colpali-v1.2 adapter_config.json,adapter_model.safetensors,preprocessor_config.json,special_tokens_map.json,tokenizer.json,tokenizer_config.json 9912ce6f8a462d8cf2269f5606eabbd2784e764f
19
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/colpali.py CHANGED
@@ -1,308 +1,274 @@
1
- #!/usr/bin/env python3
2
-
3
  import torch
4
  from PIL import Image
5
  import numpy as np
6
- from typing import cast, Generator
7
  from pathlib import Path
8
  import base64
9
  from io import BytesIO
10
- from typing import Union, Tuple, List
11
- import matplotlib
12
- import matplotlib.cm as cm
13
  import re
14
  import io
15
-
16
- import time
17
- import backend.testquery as testquery
18
 
19
  from colpali_engine.models import ColPali, ColPaliProcessor
20
  from colpali_engine.utils.torch_utils import get_torch_device
21
- from einops import rearrange
22
  from vidore_benchmark.interpretability.torch_utils import (
23
  normalize_similarity_map_per_query_token,
24
  )
25
- from vidore_benchmark.interpretability.vit_configs import VIT_CONFIG
26
-
27
- matplotlib.use("Agg")
28
- # Prepare the colormap once to avoid recomputation
29
- colormap = cm.get_cmap("viridis")
30
-
31
- COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base"
32
 
33
 
34
- def load_model() -> Tuple[ColPali, ColPaliProcessor]:
35
- model_name = "vidore/colpali-v1.2"
36
-
37
- device = get_torch_device("auto")
38
- print(f"Using device: {device}")
39
 
40
- # Load the model
41
- model = cast(
42
- ColPali,
43
- ColPali.from_pretrained(
44
- model_name,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
46
- device_map=device,
47
- ),
48
- ).eval()
49
-
50
- # Load the processor
51
- processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
52
- return model, processor, device
53
-
54
-
55
- def load_vit_config(model):
56
- # Load the ViT config
57
- print(f"VIT config: {VIT_CONFIG}")
58
- vit_config = VIT_CONFIG[COLPALI_GEMMA_MODEL_NAME]
59
- return vit_config
60
-
61
-
62
- def gen_similarity_maps(
63
- model: ColPali,
64
- processor: ColPaliProcessor,
65
- device,
66
- query: str,
67
- query_embs: torch.Tensor,
68
- token_idx_map: dict,
69
- images: List[Union[Path, str]],
70
- vespa_sim_maps: List[str],
71
- ) -> Generator[Tuple[int, str, str], None, None]:
72
- """
73
- Generate similarity maps for the given images and query, and return base64-encoded blended images.
74
-
75
- Args:
76
- model (ColPali): The model used for generating embeddings.
77
- processor (ColPaliProcessor): Processor for images and text.
78
- device: Device to run the computations on.
79
- vit_config: Configuration for the Vision Transformer.
80
- query (str): The query string.
81
- query_embs (torch.Tensor): Query embeddings.
82
- token_idx_map (dict): Mapping from indices to tokens.
83
- images (List[Union[Path, str]]): List of image paths or base64-encoded strings.
84
- vespa_sim_maps (List[str]): List of Vespa similarity maps.
85
-
86
- Yields:
87
- Tuple[int, str, str]: A tuple containing the image index, the selected token, and the base64-encoded image.
88
 
89
- """
90
- vit_config = load_vit_config(model)
91
- # Process images and store original images and sizes
92
- processed_images = []
93
- original_images = []
94
- original_sizes = []
95
- for img in images:
96
- if isinstance(img, Path):
97
- try:
98
- img_pil = Image.open(img).convert("RGB")
99
- except Exception as e:
100
- raise ValueError(f"Failed to open image from path: {e}")
101
- elif isinstance(img, str):
102
- try:
103
- img_pil = Image.open(BytesIO(base64.b64decode(img))).convert("RGB")
104
- except Exception as e:
105
- raise ValueError(f"Failed to open image from base64 string: {e}")
106
- else:
107
- raise ValueError(f"Unsupported image type: {type(img)}")
108
- original_images.append(img_pil.copy())
109
- original_sizes.append(img_pil.size) # (width, height)
110
- processed_images.append(img_pil)
111
-
112
- # If similarity maps are provided, use them instead of computing them
113
- if vespa_sim_maps:
114
- print("Using provided similarity maps")
115
- # A sim map looks like this:
116
- # "quantized": [
117
- # {
118
- # "address": {
119
- # "patch": "0",
120
- # "querytoken": "0"
121
- # },
122
- # "value": 12, # score in range [-128, 127]
123
- # },
124
- # ... and so on.
125
- # Now turn these into a tensor of same shape as previous similarity map
 
 
 
 
 
126
  vespa_sim_map_tensor = torch.zeros(
127
- (
128
- len(vespa_sim_maps),
129
- query_embs.size(dim=1),
130
- vit_config.n_patch_per_dim,
131
- vit_config.n_patch_per_dim,
132
- )
133
  )
134
  for idx, vespa_sim_map in enumerate(vespa_sim_maps):
135
  for cell in vespa_sim_map["quantized"]["cells"]:
136
  patch = int(cell["address"]["patch"])
137
- # if dummy model then just use 1024 as the image_seq_length
138
-
139
- if hasattr(processor, "image_seq_length"):
140
- image_seq_length = processor.image_seq_length
141
  else:
142
  image_seq_length = 1024
143
 
144
  if patch >= image_seq_length:
145
  continue
146
- query_token = int(cell["address"]["querytoken"])
147
- value = cell["value"]
148
  vespa_sim_map_tensor[
149
  idx,
150
- int(query_token),
151
- int(patch) // vit_config.n_patch_per_dim,
152
- int(patch) % vit_config.n_patch_per_dim,
153
  ] = value
154
-
155
- # Normalize the similarity map per query token
156
- similarity_map_normalized = normalize_similarity_map_per_query_token(
157
- vespa_sim_map_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  )
159
- else:
160
- # Preprocess inputs
161
- print("Computing similarity maps")
162
- start2 = time.perf_counter()
163
- input_image_processed = processor.process_images(processed_images).to(device)
164
-
165
- # Forward passes
166
- with torch.no_grad():
167
- output_image = model.forward(**input_image_processed)
168
-
169
- # Remove the special tokens from the output
170
- output_image = output_image[:, : processor.image_seq_length, :]
171
 
172
- # Rearrange the output image tensor to represent the 2D grid of patches
173
- output_image = rearrange(
174
- output_image,
175
- "b (h w) c -> b h w c",
176
- h=vit_config.n_patch_per_dim,
177
- w=vit_config.n_patch_per_dim,
178
  )
179
-
180
- # Ensure query_embs has batch dimension
181
- if query_embs.dim() == 2:
182
- query_embs = query_embs.unsqueeze(0).to(device)
183
- else:
184
- query_embs = query_embs.to(device)
185
-
186
- # Compute the similarity map
187
- similarity_map = torch.einsum(
188
- "bnk,bhwk->bnhw", query_embs, output_image
189
- ) # Shape: (batch_size, query_tokens, h, w)
190
-
191
- end2 = time.perf_counter()
192
- print(f"Similarity map computation took: {end2 - start2} s")
193
-
194
- # Normalize the similarity map per query token
195
- similarity_map_normalized = normalize_similarity_map_per_query_token(
196
- similarity_map
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- # Collect the blended images
200
- start3 = time.perf_counter()
201
- for idx, img in enumerate(original_images):
202
- SCALING_FACTOR = 8
203
- sim_map_resolution = (
204
- max(32, int(original_sizes[idx][0] / SCALING_FACTOR)),
205
- max(32, int(original_sizes[idx][1] / SCALING_FACTOR)),
206
  )
207
-
208
- result_per_image = {}
209
- for token_idx, token in token_idx_map.items():
210
- if should_filter_token(token):
211
- continue
212
-
213
- # Get the similarity map for this image and the selected token
214
- sim_map = similarity_map_normalized[idx, token_idx, :, :] # Shape: (h, w)
215
-
216
- # Move the similarity map to CPU, convert to float (as BFloat16 not supported by Numpy) and convert to NumPy array
217
- sim_map_np = sim_map.cpu().float().numpy()
218
-
219
- # Resize the similarity map to the original image size
220
- sim_map_img = Image.fromarray(sim_map_np)
221
- sim_map_resized = sim_map_img.resize(
222
- sim_map_resolution, resample=Image.BICUBIC
223
- )
224
-
225
- # Convert the resized similarity map to a NumPy array
226
- sim_map_resized_np = np.array(sim_map_resized, dtype=np.float32)
227
-
228
- # Normalize the similarity map to range [0, 1]
229
- sim_map_min = sim_map_resized_np.min()
230
- sim_map_max = sim_map_resized_np.max()
231
- if sim_map_max - sim_map_min > 1e-6:
232
- sim_map_normalized = (sim_map_resized_np - sim_map_min) / (
233
- sim_map_max - sim_map_min
234
- )
235
- else:
236
- sim_map_normalized = np.zeros_like(sim_map_resized_np)
237
-
238
- # Apply a colormap to the normalized similarity map
239
- heatmap = colormap(sim_map_normalized) # Returns an RGBA array
240
-
241
- # Convert the heatmap to a PIL Image
242
- heatmap_uint8 = (heatmap * 255).astype(np.uint8)
243
- heatmap_img = Image.fromarray(heatmap_uint8)
244
- heatmap_img_rgba = heatmap_img.convert("RGBA")
245
-
246
- # Save the image to a BytesIO buffer
247
- buffer = io.BytesIO()
248
- heatmap_img_rgba.save(buffer, format="PNG")
249
- buffer.seek(0)
250
-
251
- # Encode the image to base64
252
- blended_img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
253
-
254
- # Store the base64-encoded image
255
- result_per_image[token] = blended_img_base64
256
- yield idx, token, token_idx, blended_img_base64
257
- end3 = time.perf_counter()
258
- print(f"Blending images took: {end3 - start3} s")
259
-
260
-
261
- def get_query_embeddings_and_token_map(
262
- processor, model, query
263
- ) -> Tuple[torch.Tensor, dict]:
264
- if model is None: # use static test query data (saves time when testing)
265
- return testquery.q_embs, testquery.idx_to_token
266
-
267
- start_time = time.perf_counter()
268
- inputs = processor.process_queries([query]).to(model.device)
269
- with torch.no_grad():
270
- embeddings_query = model(**inputs)
271
- q_emb = embeddings_query.to("cpu")[0] # Extract the single embedding
272
- # Use this cell output to choose a token using its index
273
- query_tokens = processor.tokenizer.tokenize(processor.decode(inputs.input_ids[0]))
274
- # reverse key, values in dictionary
275
- print(query_tokens)
276
- idx_to_token = {idx: val for idx, val in enumerate(query_tokens)}
277
- end_time = time.perf_counter()
278
- print(f"Query inference took: {end_time - start_time} s")
279
- return q_emb, idx_to_token
280
-
281
-
282
- def should_filter_token(token: str) -> bool:
283
- # Pattern to match tokens that start with '<', numbers, whitespace, special characters (except ▁), or the string 'Question'
284
- # Will exclude these tokens from the similarity map generation
285
- # Does NOT match:
286
- # 2
287
- # 0
288
- # 2
289
- # 3
290
- # ▁2
291
- # ▁hi
292
- #
293
- # Do match:
294
- # <bos>
295
- # Question
296
- # :
297
- # _Percentage
298
- # <pad>
299
- # \n
300
- # ▁
301
- # ?
302
- # )
303
- # %
304
- # /)
305
- pattern = re.compile(r"^<.*$|^\s+$|^(?!.*\d)(?!▁)\S+$|^Question$|^▁$")
306
- if pattern.match(token):
307
- return True
308
- return False
 
 
 
1
  import torch
2
  from PIL import Image
3
  import numpy as np
4
+ from typing import Generator, Tuple, List, Union, Dict
5
  from pathlib import Path
6
  import base64
7
  from io import BytesIO
 
 
 
8
  import re
9
  import io
10
+ import matplotlib.cm as cm
 
 
11
 
12
  from colpali_engine.models import ColPali, ColPaliProcessor
13
  from colpali_engine.utils.torch_utils import get_torch_device
 
14
  from vidore_benchmark.interpretability.torch_utils import (
15
  normalize_similarity_map_per_query_token,
16
  )
 
 
 
 
 
 
 
17
 
18
 
19
+ class SimMapGenerator:
20
+ """
21
+ Generates similarity maps based on query embeddings and image patches using the ColPali model.
22
+ """
 
23
 
24
+ COLPALI_GEMMA_MODEL_NAME = "vidore/colpaligemma-3b-pt-448-base"
25
+ colormap = cm.get_cmap("viridis") # Preload colormap for efficiency
26
+
27
+ def __init__(self, model_name: str = "vidore/colpali-v1.2", n_patch: int = 32):
28
+ """
29
+ Initializes the SimMapGenerator class with a specified model and patch dimension.
30
+
31
+ Args:
32
+ model_name (str): The model name for loading the ColPali model.
33
+ n_patch (int): The number of patches per dimension.
34
+ """
35
+ self.model_name = model_name
36
+ self.n_patch = n_patch
37
+ self.device = get_torch_device("auto")
38
+ print(f"Using device: {self.device}")
39
+ self.model, self.processor = self.load_model()
40
+
41
+ def load_model(self) -> Tuple[ColPali, ColPaliProcessor]:
42
+ """
43
+ Loads the ColPali model and processor.
44
+
45
+ Returns:
46
+ Tuple[ColPali, ColPaliProcessor]: Loaded model and processor.
47
+ """
48
+ model = ColPali.from_pretrained(
49
+ self.model_name,
50
  torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
51
+ device_map=self.device,
52
+ ).eval()
53
+
54
+ processor = ColPaliProcessor.from_pretrained(self.model_name)
55
+ return model, processor
56
+
57
+ def gen_similarity_maps(
58
+ self,
59
+ query: str,
60
+ query_embs: torch.Tensor,
61
+ token_idx_map: Dict[int, str],
62
+ images: List[Union[Path, str]],
63
+ vespa_sim_maps: List[Dict],
64
+ ) -> Generator[Tuple[int, str, str], None, None]:
65
+ """
66
+ Generates similarity maps for the provided images and query, and returns base64-encoded blended images.
67
+
68
+ Args:
69
+ query (str): The query string.
70
+ query_embs (torch.Tensor): Query embeddings tensor.
71
+ token_idx_map (dict): Mapping from indices to tokens.
72
+ images (List[Union[Path, str]]): List of image paths or base64-encoded strings.
73
+ vespa_sim_maps (List[Dict]): List of Vespa similarity maps.
74
+
75
+ Yields:
76
+ Tuple[int, str, str]: A tuple containing the image index, selected token, and base64-encoded image.
77
+ """
78
+ processed_images, original_images, original_sizes = [], [], []
79
+ for img in images:
80
+ img_pil = self._load_image(img)
81
+ original_images.append(img_pil.copy())
82
+ original_sizes.append(img_pil.size)
83
+ processed_images.append(img_pil)
84
+
85
+ vespa_sim_map_tensor = self._prepare_similarity_map_tensor(
86
+ query_embs, vespa_sim_maps
87
+ )
88
+ similarity_map_normalized = normalize_similarity_map_per_query_token(
89
+ vespa_sim_map_tensor
90
+ )
 
 
91
 
92
+ for idx, img in enumerate(original_images):
93
+ for token_idx, token in token_idx_map.items():
94
+ if self.should_filter_token(token):
95
+ continue
96
+
97
+ sim_map = similarity_map_normalized[idx, token_idx, :, :]
98
+ blended_img_base64 = self._blend_image(
99
+ img, sim_map, original_sizes[idx]
100
+ )
101
+ yield idx, token, token_idx, blended_img_base64
102
+
103
+ def _load_image(self, img: Union[Path, str]) -> Image:
104
+ """
105
+ Loads an image from a file path or a base64-encoded string.
106
+
107
+ Args:
108
+ img (Union[Path, str]): The image to load.
109
+
110
+ Returns:
111
+ Image: The loaded PIL image.
112
+ """
113
+ try:
114
+ if isinstance(img, Path):
115
+ return Image.open(img).convert("RGB")
116
+ elif isinstance(img, str):
117
+ return Image.open(BytesIO(base64.b64decode(img))).convert("RGB")
118
+ except Exception as e:
119
+ raise ValueError(f"Failed to load image: {e}")
120
+
121
+ def _prepare_similarity_map_tensor(
122
+ self, query_embs: torch.Tensor, vespa_sim_maps: List[Dict]
123
+ ) -> torch.Tensor:
124
+ """
125
+ Prepares a similarity map tensor from Vespa similarity maps.
126
+
127
+ Args:
128
+ query_embs (torch.Tensor): Query embeddings tensor.
129
+ vespa_sim_maps (List[Dict]): List of Vespa similarity maps.
130
+
131
+ Returns:
132
+ torch.Tensor: The prepared similarity map tensor.
133
+ """
134
  vespa_sim_map_tensor = torch.zeros(
135
+ (len(vespa_sim_maps), query_embs.size(1), self.n_patch, self.n_patch)
 
 
 
 
 
136
  )
137
  for idx, vespa_sim_map in enumerate(vespa_sim_maps):
138
  for cell in vespa_sim_map["quantized"]["cells"]:
139
  patch = int(cell["address"]["patch"])
140
+ query_token = int(cell["address"]["querytoken"])
141
+ value = cell["value"]
142
+ if hasattr(self.processor, "image_seq_length"):
143
+ image_seq_length = self.processor.image_seq_length
144
  else:
145
  image_seq_length = 1024
146
 
147
  if patch >= image_seq_length:
148
  continue
 
 
149
  vespa_sim_map_tensor[
150
  idx,
151
+ query_token,
152
+ patch // self.n_patch,
153
+ patch % self.n_patch,
154
  ] = value
155
+ return vespa_sim_map_tensor
156
+
157
+ def _blend_image(
158
+ self, img: Image, sim_map: torch.Tensor, original_size: Tuple[int, int]
159
+ ) -> str:
160
+ """
161
+ Blends an image with a similarity map and encodes it to base64.
162
+
163
+ Args:
164
+ img (Image): The original image.
165
+ sim_map (torch.Tensor): The similarity map tensor.
166
+ original_size (Tuple[int, int]): The original size of the image.
167
+
168
+ Returns:
169
+ str: The base64-encoded blended image.
170
+ """
171
+ SCALING_FACTOR = 8
172
+ sim_map_resolution = (
173
+ max(32, int(original_size[0] / SCALING_FACTOR)),
174
+ max(32, int(original_size[1] / SCALING_FACTOR)),
175
  )
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
+ sim_map_np = sim_map.cpu().float().numpy()
178
+ sim_map_img = Image.fromarray(sim_map_np).resize(
179
+ sim_map_resolution, resample=Image.BICUBIC
 
 
 
180
  )
181
+ sim_map_resized_np = np.array(sim_map_img, dtype=np.float32)
182
+ sim_map_normalized = self._normalize_sim_map(sim_map_resized_np)
183
+
184
+ heatmap = self.colormap(sim_map_normalized)
185
+ heatmap_img = Image.fromarray((heatmap * 255).astype(np.uint8)).convert("RGBA")
186
+
187
+ buffer = io.BytesIO()
188
+ heatmap_img.save(buffer, format="PNG")
189
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
190
+
191
+ @staticmethod
192
+ def _normalize_sim_map(sim_map: np.ndarray) -> np.ndarray:
193
+ """
194
+ Normalizes a similarity map to range [0, 1].
195
+
196
+ Args:
197
+ sim_map (np.ndarray): The similarity map.
198
+
199
+ Returns:
200
+ np.ndarray: The normalized similarity map.
201
+ """
202
+ sim_map_min, sim_map_max = sim_map.min(), sim_map.max()
203
+ if sim_map_max - sim_map_min > 1e-6:
204
+ return (sim_map - sim_map_min) / (sim_map_max - sim_map_min)
205
+ return np.zeros_like(sim_map)
206
+
207
+ @staticmethod
208
+ def should_filter_token(token: str) -> bool:
209
+ """
210
+ Determines if a token should be filtered out based on predefined patterns.
211
+
212
+ The function filters out tokens that:
213
+
214
+ - Start with '<' (e.g., '<bos>')
215
+ - Consist entirely of whitespace
216
+ - Are purely punctuation (excluding tokens that contain digits or start with '▁')
217
+ - Start with an underscore '_'
218
+ - Exactly match the word 'Question'
219
+ - Are exactly the single character '▁'
220
+
221
+ Output of test:
222
+ Token: '2' | False
223
+ Token: '0' | False
224
+ Token: '2' | False
225
+ Token: '3' | False
226
+ Token: '▁2' | False
227
+ Token: '▁hi' | False
228
+ Token: 'norwegian' | False
229
+ Token: 'unlisted' | False
230
+ Token: '<bos>' | True
231
+ Token: 'Question' | True
232
+ Token: ':' | True
233
+ Token: '<pad>' | True
234
+ Token: '\n' | True
235
+ Token: '▁' | True
236
+ Token: '?' | True
237
+ Token: ')' | True
238
+ Token: '%' | True
239
+ Token: '/)' | True
240
+
241
+
242
+ Args:
243
+ token (str): The token to check.
244
+
245
+ Returns:
246
+ bool: True if the token should be filtered out, False otherwise.
247
+ """
248
+ pattern = re.compile(
249
+ r"^<.*$|^\s+$|^(?!.*\d)(?!▁)[^\w\s]+$|^_.*$|^Question$|^▁$"
250
  )
251
+ return bool(pattern.match(token))
252
+
253
+ # TODO: Would be nice to @lru_cache this method.
254
+ def get_query_embeddings_and_token_map(
255
+ self, query: str
256
+ ) -> Tuple[torch.Tensor, dict]:
257
+ """
258
+ Retrieves query embeddings and a token index map.
259
+
260
+ Args:
261
+ query (str): The query string.
262
+
263
+ Returns:
264
+ Tuple[torch.Tensor, dict]: Query embeddings and token index map.
265
+ """
266
+ inputs = self.processor.process_queries([query]).to(self.model.device)
267
+ with torch.no_grad():
268
+ q_emb = self.model(**inputs).to("cpu")[0]
269
 
270
+ query_tokens = self.processor.tokenizer.tokenize(
271
+ self.processor.decode(inputs.input_ids[0])
 
 
 
 
 
272
  )
273
+ idx_to_token = {idx: token for idx, token in enumerate(query_tokens)}
274
+ return q_emb, idx_to_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backend/stopwords.py CHANGED
@@ -6,6 +6,7 @@ if not spacy.util.is_package("en_core_web_sm"):
6
  spacy.cli.download("en_core_web_sm")
7
  nlp = spacy.load("en_core_web_sm")
8
 
 
9
  # It would be possible to remove bolding for stopwords without removing them from the query,
10
  # but that would require a java plugin which we didn't want to complicate this sample app with.
11
  def filter(text):
@@ -14,4 +15,4 @@ def filter(text):
14
  if len(tokens) == 0:
15
  # if we remove all the words we don't have a query at all, so use the original
16
  return text
17
- return " ".join(tokens)
 
6
  spacy.cli.download("en_core_web_sm")
7
  nlp = spacy.load("en_core_web_sm")
8
 
9
+
10
  # It would be possible to remove bolding for stopwords without removing them from the query,
11
  # but that would require a java plugin which we didn't want to complicate this sample app with.
12
  def filter(text):
 
15
  if len(tokens) == 0:
16
  # if we remove all the words we don't have a query at all, so use the original
17
  return text
18
+ return " ".join(tokens)
backend/vespa_app.py CHANGED
@@ -7,9 +7,10 @@ import torch
7
  from dotenv import load_dotenv
8
  from vespa.application import Vespa
9
  from vespa.io import VespaQueryResponse
10
- from .colpali import should_filter_token
11
  import backend.stopwords
12
 
 
13
  class VespaQueryClient:
14
  MAX_QUERY_TERMS = 64
15
  VESPA_SCHEMA_NAME = "pdf_page"
@@ -364,7 +365,7 @@ class VespaQueryClient:
364
  fields_to_add = [
365
  f"sim_map_{token}_{idx}"
366
  for idx, token in idx_to_token.items()
367
- if not should_filter_token(token)
368
  ]
369
  for child in result["root"]["children"]:
370
  for sim_map_key in fields_to_add:
 
7
  from dotenv import load_dotenv
8
  from vespa.application import Vespa
9
  from vespa.io import VespaQueryResponse
10
+ from .colpali import SimMapGenerator
11
  import backend.stopwords
12
 
13
+
14
  class VespaQueryClient:
15
  MAX_QUERY_TERMS = 64
16
  VESPA_SCHEMA_NAME = "pdf_page"
 
365
  fields_to_add = [
366
  f"sim_map_{token}_{idx}"
367
  for idx, token in idx_to_token.items()
368
+ if not SimMapGenerator.should_filter_token(token)
369
  ]
370
  for child in result["root"]["children"]:
371
  for sim_map_key in fields_to_add:
frontend/app.py CHANGED
@@ -1,7 +1,7 @@
1
  from typing import Optional
2
  from urllib.parse import quote_plus
3
 
4
- from fasthtml.components import H1, H2, Div, Form, Img, NotStr, P, Span, H3, Br
5
  from fasthtml.xtend import A, Script
6
  from lucide_fasthtml import Lucide
7
  from shad4fast import Badge, Button, Input, Label, RadioGroup, RadioGroupItem, Separator
@@ -154,7 +154,7 @@ def SearchBox(with_border=False, query_value="", ranking_value="nn+colpali"):
154
  name="query",
155
  value=query_value,
156
  id="search-input",
157
- cls="text-base pl-10 border-transparent ring-offset-transparent ring-0 focus-visible:ring-transparent awesomplete",
158
  data_list="#suggestions",
159
  style="font-size: 1rem",
160
  autofocus=True,
@@ -366,7 +366,23 @@ def SimMapButtonPoll(query_id, idx, token, token_idx):
366
  )
367
 
368
 
369
- def SearchResult(results: list, query_id: Optional[str] = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  if not results:
371
  return Div(
372
  P(
@@ -376,10 +392,13 @@ def SearchResult(results: list, query_id: Optional[str] = None):
376
  cls="grid p-10",
377
  )
378
 
 
379
  # Otherwise, display the search results
380
  result_items = []
381
  for idx, result in enumerate(results):
382
  fields = result["fields"] # Extract the 'fields' part of each result
 
 
383
  blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
384
 
385
  sim_map_fields = {
@@ -472,7 +491,7 @@ def SearchResult(results: list, query_id: Optional[str] = None):
472
  Div(
473
  Img(
474
  src=blur_image_base64,
475
- hx_get=f"/full_image?docid={fields['id']}&query_id={query_id}&idx={idx}",
476
  style="backdrop-filter: blur(5px);",
477
  hx_trigger="load",
478
  hx_swap="outerHTML",
@@ -493,9 +512,12 @@ def SearchResult(results: list, query_id: Optional[str] = None):
493
  ),
494
  Div(
495
  Div(
496
- P(
497
- "Page " + str(fields["page_number"]),
498
- cls="text-foreground font-mono bold text-sm",
 
 
 
499
  ),
500
  cls="flex items-center justify-end",
501
  ),
@@ -504,7 +526,10 @@ def SearchResult(results: list, query_id: Optional[str] = None):
504
  Div(
505
  Div(
506
  Div(
507
- H3("Dynamic summary", cls="text-base font-semibold"),
 
 
 
508
  P(
509
  NotStr(fields.get("snippet", "")),
510
  cls="text-highlight text-muted-foreground",
@@ -517,23 +542,28 @@ def SearchResult(results: list, query_id: Optional[str] = None):
517
  Div(
518
  Div(
519
  Div(
520
- H3("Full text", cls="text-base font-semibold"),
 
 
 
521
  Div(
522
  P(
523
  NotStr(fields.get("text", "")),
524
  cls="text-highlight text-muted-foreground",
525
  ),
526
- Br()
527
  ),
528
  cls="grid grid-rows-[auto_0px] content-start gap-y-3",
529
  ),
530
  id=f"result-text-full-{idx}",
531
  cls="grid gap-y-3 p-8 border border-dashed",
532
  ),
533
- Div(cls="absolute inset-x-0 bottom-0 bg-gradient-to-t from-white dark:from-slate-900 pt-[7%]"),
534
- cls="relative grid"
 
 
535
  ),
536
- cls="grid grid-rows-[1fr_1fr] gap-y-8 p-8 text-sm",
537
  ),
538
  cls="grid bg-background",
539
  ),
@@ -545,11 +575,13 @@ def SearchResult(results: list, query_id: Optional[str] = None):
545
  id=f"image-text-columns-{idx}",
546
  cls="relative grid grid-cols-1 border-t grid-image-text-columns",
547
  ),
548
- cls="grid grid-cols-1 grid-rows-[auto_1fr]",
549
  ),
550
  )
551
 
552
- return Div(
 
 
553
  *result_items,
554
  image_swapping,
555
  toggle_text_content,
@@ -559,22 +591,37 @@ def SearchResult(results: list, query_id: Optional[str] = None):
559
  )
560
 
561
 
562
- def ChatResult(query_id: str, query: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  return Div(
564
  Div("AI-response (Gemini-8B)", cls="text-xl font-semibold p-5"),
565
  Div(
566
  Div(
567
- Div(
568
- LoadingSkeleton(),
569
- hx_ext="sse",
570
- sse_connect=f"/get-message?query_id={query_id}&query={quote_plus(query)}",
571
- sse_swap="message",
572
- sse_close="close",
573
- hx_swap="innerHTML",
574
- ),
575
  ),
576
  id="chat-messages",
577
  cls="overflow-auto min-h-0 grid items-end px-5",
578
  ),
 
579
  cls="h-full grid grid-rows-[auto_1fr_auto] min-h-0 gap-3",
580
  )
 
1
  from typing import Optional
2
  from urllib.parse import quote_plus
3
 
4
+ from fasthtml.components import H1, H2, H3, Br, Div, Form, Img, NotStr, P, Span
5
  from fasthtml.xtend import A, Script
6
  from lucide_fasthtml import Lucide
7
  from shad4fast import Badge, Button, Input, Label, RadioGroup, RadioGroupItem, Separator
 
154
  name="query",
155
  value=query_value,
156
  id="search-input",
157
+ cls="text-base pl-10 border-transparent ring-offset-transparent ring-0 focus-visible:ring-transparent bg-white dark:bg-background awesomplete",
158
  data_list="#suggestions",
159
  style="font-size: 1rem",
160
  autofocus=True,
 
366
  )
367
 
368
 
369
+ def SearchInfo(search_time, total_count):
370
+ return (
371
+ Div(
372
+ NotStr(
373
+ f"<span>Found <strong>{total_count}</strong> results in <strong>{search_time}</strong> seconds.</span>"
374
+ ),
375
+ cls="grid bg-background border-t text-sm text-center p-3",
376
+ ),
377
+ )
378
+
379
+
380
+ def SearchResult(
381
+ results: list,
382
+ query: str, query_id: Optional[str] = None,
383
+ search_time: float = 0,
384
+ total_count: int = 0,
385
+ ):
386
  if not results:
387
  return Div(
388
  P(
 
392
  cls="grid p-10",
393
  )
394
 
395
+ doc_ids = []
396
  # Otherwise, display the search results
397
  result_items = []
398
  for idx, result in enumerate(results):
399
  fields = result["fields"] # Extract the 'fields' part of each result
400
+ doc_id = fields["id"]
401
+ doc_ids.append(doc_id)
402
  blur_image_base64 = f"data:image/jpeg;base64,{fields['blur_image']}"
403
 
404
  sim_map_fields = {
 
491
  Div(
492
  Img(
493
  src=blur_image_base64,
494
+ hx_get=f"/full_image?doc_id={doc_id}",
495
  style="backdrop-filter: blur(5px);",
496
  hx_trigger="load",
497
  hx_swap="outerHTML",
 
512
  ),
513
  Div(
514
  Div(
515
+ A(
516
+ Lucide(icon="external-link", size="18"),
517
+ f"PDF Source (Page {fields['page_number']})",
518
+ href=f"{fields['url']}#page={fields['page_number'] + 1}",
519
+ target="_blank",
520
+ cls="flex items-center gap-1.5 font-mono bold text-sm",
521
  ),
522
  cls="flex items-center justify-end",
523
  ),
 
526
  Div(
527
  Div(
528
  Div(
529
+ H3(
530
+ "Dynamic summary",
531
+ cls="text-base font-semibold",
532
+ ),
533
  P(
534
  NotStr(fields.get("snippet", "")),
535
  cls="text-highlight text-muted-foreground",
 
542
  Div(
543
  Div(
544
  Div(
545
+ H3(
546
+ "Full text",
547
+ cls="text-base font-semibold",
548
+ ),
549
  Div(
550
  P(
551
  NotStr(fields.get("text", "")),
552
  cls="text-highlight text-muted-foreground",
553
  ),
554
+ Br(),
555
  ),
556
  cls="grid grid-rows-[auto_0px] content-start gap-y-3",
557
  ),
558
  id=f"result-text-full-{idx}",
559
  cls="grid gap-y-3 p-8 border border-dashed",
560
  ),
561
+ Div(
562
+ cls="absolute inset-x-0 bottom-0 bg-gradient-to-t from-[#fcfcfd] dark:from-[#1c2024] pt-[7%]"
563
+ ),
564
+ cls="relative grid",
565
  ),
566
+ cls="grid grid-rows-[1fr_1fr] xl:grid-rows-[1fr_2fr] gap-y-8 p-8 text-sm",
567
  ),
568
  cls="grid bg-background",
569
  ),
 
575
  id=f"image-text-columns-{idx}",
576
  cls="relative grid grid-cols-1 border-t grid-image-text-columns",
577
  ),
578
+ cls="grid grid-cols-1 grid-rows-[auto_auto_1fr]",
579
  ),
580
  )
581
 
582
+ return [
583
+ Div(
584
+ SearchInfo(search_time, total_count),
585
  *result_items,
586
  image_swapping,
587
  toggle_text_content,
 
591
  )
592
 
593
 
594
+ ,
595
+ Div(
596
+ ChatResult(query_id=query_id, query=query, doc_ids=doc_ids),
597
+ hx_swap_oob="true",
598
+ id="chat_messages",
599
+ ),
600
+ ]
601
+
602
+
603
+ def ChatResult(query_id: str, query: str, doc_ids: Optional[list] = None):
604
+ messages = Div(LoadingSkeleton())
605
+
606
+ if doc_ids:
607
+ messages = Div(
608
+ LoadingSkeleton(),
609
+ hx_ext="sse",
610
+ sse_connect=f"/get-message?query_id={query_id}&doc_ids={','.join(doc_ids)}&query={quote_plus(query)}",
611
+ sse_swap="message",
612
+ sse_close="close",
613
+ hx_swap="innerHTML",
614
+ )
615
+
616
  return Div(
617
  Div("AI-response (Gemini-8B)", cls="text-xl font-semibold p-5"),
618
  Div(
619
  Div(
620
+ messages,
 
 
 
 
 
 
 
621
  ),
622
  id="chat-messages",
623
  cls="overflow-auto min-h-0 grid items-end px-5",
624
  ),
625
+ id="chat_messages",
626
  cls="h-full grid grid-rows-[auto_1fr_auto] min-h-0 gap-3",
627
  )
frontend/layout.py CHANGED
@@ -151,7 +151,7 @@ def Links():
151
  )
152
 
153
 
154
- def Layout(*c, **kwargs):
155
  return (
156
  Title("Visual Retrieval ColPali"),
157
  Body(
@@ -162,6 +162,7 @@ def Layout(*c, **kwargs):
162
  ),
163
  *c,
164
  **kwargs,
 
165
  cls="grid grid-rows-[minmax(0,55px)_minmax(0,1fr)] min-h-0",
166
  ),
167
  layout_script,
 
151
  )
152
 
153
 
154
+ def Layout(*c, is_home=False, **kwargs):
155
  return (
156
  Title("Visual Retrieval ColPali"),
157
  Body(
 
162
  ),
163
  *c,
164
  **kwargs,
165
+ data_is_home=str(is_home).lower(),
166
  cls="grid grid-rows-[minmax(0,55px)_minmax(0,1fr)] min-h-0",
167
  ),
168
  layout_script,
globals.css CHANGED
@@ -5,58 +5,57 @@
5
 
6
  @layer base {
7
  :root {
8
- --background: 0 0% 100%;
9
- --foreground: 222.2 84% 4.9%;
10
- --card: 0 0% 100%;
11
- --card-foreground: 222.2 84% 4.9%;
12
- --popover: 0 0% 100%;
13
- --popover-foreground: 222.2 84% 4.9%;
14
- --primary: 222.2 47.4% 11.2%;
15
- --primary-foreground: 210 40% 98%;
16
- --secondary: 210 40% 96.1%;
17
- --secondary-foreground: 222.2 47.4% 11.2%;
18
- --muted: 210 40% 96.1%;
19
- --muted-foreground: 215.4 16.3% 26.9%;
20
- --accent: 210 40% 96.1%;
21
- --accent-foreground: 222.2 47.4% 11.2%;
22
- --destructive: 0 84.2% 60.2%;
23
- --destructive-foreground: 210 40% 98%;
24
- --border: 214.3 31.8% 81.4%;
25
- --input: 214.3 31.8% 81.4%;
26
- --ring: 222.2 84% 4.9%;
27
- --radius: 0.5rem;
28
- --chart-1: 12 76% 61%;
29
- --chart-2: 173 58% 39%;
30
- --chart-3: 197 37% 24%;
31
- --chart-4: 43 74% 66%;
32
- --chart-5: 27 87% 67%;
33
  }
34
 
35
  .dark {
36
- --background: 222.2 84% 4.9%;
37
- --foreground: 210 40% 98%;
38
- --card: 222.2 84% 4.9%;
39
- --card-foreground: 210 40% 98%;
40
- --popover: 222.2 84% 4.9%;
41
- --popover-foreground: 210 40% 98%;
42
- --primary: 210 40% 98%;
43
- --primary-foreground: 222.2 47.4% 11.2%;
44
- --secondary: 217.2 32.6% 17.5%;
45
- --secondary-foreground: 210 40% 98%;
46
- --muted: 217.2 32.6% 17.5%;
47
- --muted-foreground: 215 20.2% 85.1%;
48
- --accent: 217.2 32.6% 17.5%;
49
- --accent-foreground: 210 40% 98%;
50
- --destructive: 0 62.8% 30.6%;
51
- --destructive-foreground: 210 40% 98%;
52
- --border: 217.2 32.6% 27.5%;
53
- --input: 217.2 32.6% 27.5%;
54
- --ring: 212.7 26.8% 83.9;
55
- --chart-1: 220 70% 50%;
56
- --chart-2: 160 60% 45%;
57
- --chart-3: 30 80% 55%;
58
- --chart-4: 280 65% 60%;
59
- --chart-5: 340 75% 55%;
60
  }
61
  }
62
 
@@ -193,6 +192,16 @@ header {
193
  grid-column: 1/-1;
194
  }
195
 
 
 
 
 
 
 
 
 
 
 
196
  main {
197
  overflow: auto;
198
  }
@@ -236,14 +245,19 @@ aside {
236
  }
237
 
238
  .awesomplete > ul {
239
- @apply text-sm space-y-0.5;
240
  margin: 0;
241
  border-top: none;
242
  border-left: 1px solid hsl(var(--input));
243
  border-right: 1px solid hsl(var(--input));
244
  border-bottom: 1px solid hsl(var(--input));
245
  border-radius: 0 0 calc(var(--radius) - 2px) calc(var(--radius) - 2px);
246
- background: hsl(var(--background));
 
 
 
 
 
247
  box-shadow: none;
248
  text-shadow: none;
249
  }
 
5
 
6
  @layer base {
7
  :root {
8
+ --background: 240 20% 99%; /* 1 */
9
+ --foreground: 210 13% 13%; /* 12 */
10
+ --card: 240 20% 99%; /* 1 */
11
+ --card-foreground: 210 13% 13%; /* 12 */
12
+ --popover: 240 20% 99%; /* 1 */
13
+ --popover-foreground: 210 13% 13%; /* 12 */
14
+ --primary: 210 13% 13%; /* 12 */
15
+ --primary-foreground: 240 20% 98%; /* 2 */
16
+ --secondary: 240 11% 95%; /* 3 */
17
+ --secondary-foreground: 210 13% 13%; /* 12 */
18
+ --muted: 240 11% 95%; /* 3 */
19
+ --muted-foreground: 220 6% 40%; /* 11 */
20
+ --accent: 240 11% 95%; /* 3 */
21
+ --accent-foreground: 210 13% 13%; /* 12 */
22
+ --destructive: 358 75% 59%; /* 9 - red */
23
+ --destructive-foreground: 240 20% 98%; /* 2 */
24
+ --border: 240 10% 86%; /* 6 */
25
+ --input: 240 10% 86%; /* 6 */
26
+ --ring: 210 13% 13%; /* 12 */
27
+ --chart-1: 10 78% 54%; /* 9 - tomato */
28
+ --chart-2: 173 80% 36%; /* 9 - teal */
29
+ --chart-3: 206 100% 50%; /* 9 - blue */
30
+ --chart-4: 42 100% 62%; /* 9 - amber */
31
+ --chart-5: 23 93% 53%; /* 9 - orange */
 
32
  }
33
 
34
  .dark {
35
+ --background: 240 6% 7%; /* 1 */
36
+ --foreground: 220 9% 94%; /* 12 */
37
+ --card: 240 6% 7%; /* 1 */
38
+ --card-foreground: 220 9% 94%; /* 12 */
39
+ --popover: 240 6% 7%; /* 1 */
40
+ --popover-foreground: 220 9% 94%; /* 12 */
41
+ --primary: 220 9% 94%; /* 12 */
42
+ --primary-foreground: 220 6% 10%; /* 2 */
43
+ --secondary: 225 6% 14%; /* 3 */
44
+ --secondary-foreground: 220 9% 94%; /* 12 */
45
+ --muted: 225 6% 14%; /* 3 */
46
+ --muted-foreground: 216 7% 71%; /* 11 */
47
+ --accent: 225 6% 14%; /* 3 */
48
+ --accent-foreground: 220 9% 94%; /* 12 */
49
+ --destructive: 358 75% 59%; /* 9 - red */
50
+ --destructive-foreground: 220 9% 94%; /* 12 */
51
+ --border: 213 8% 23%; /* 6 */
52
+ --input: 213 8% 23%; /* 6 */
53
+ --ring: 220 9% 94%; /* 12 */
54
+ --chart-1: 10 78% 54%; /* 9 - tomato */
55
+ --chart-2: 173 80% 36%; /* 9 - teal */
56
+ --chart-3: 206 100% 50%; /* 9 - blue */
57
+ --chart-4: 42 100% 62%; /* 9 - amber */
58
+ --chart-5: 23 93% 53%; /* 9 - orange */
59
  }
60
  }
61
 
 
192
  grid-column: 1/-1;
193
  }
194
 
195
+ body {
196
+ &[data-is-home="true"] {
197
+ background: radial-gradient(circle at 50% 100%, #fcfcfd, #fcfcfd, #fdfdfe, #fdfdfe, #fefefe, #fefefe, #ffffff, #ffffff);
198
+
199
+ .dark & {
200
+ background: radial-gradient(circle at 50% 50%, #272a2d, #242629, #212326, #1e1f22, #1b1c1e, #18181b, #151517, #111113);
201
+ }
202
+ }
203
+ }
204
+
205
  main {
206
  overflow: auto;
207
  }
 
245
  }
246
 
247
  .awesomplete > ul {
248
+ @apply text-sm space-y-1;
249
  margin: 0;
250
  border-top: none;
251
  border-left: 1px solid hsl(var(--input));
252
  border-right: 1px solid hsl(var(--input));
253
  border-bottom: 1px solid hsl(var(--input));
254
  border-radius: 0 0 calc(var(--radius) - 2px) calc(var(--radius) - 2px);
255
+ background: white;
256
+
257
+ .dark & {
258
+ background: hsl(var(--background));
259
+ }
260
+
261
  box-shadow: none;
262
  text-shadow: none;
263
  }
icons.py CHANGED
@@ -1 +1 @@
1
- ICONS = {"chevrons-right": "<path d=\"m6 17 5-5-5-5\"></path><path d=\"m13 17 5-5-5-5\"></path>", "moon": "<path d=\"M12 3a6 6 0 0 0 9 9 9 9 0 1 1-9-9Z\"></path>", "sun": "<circle cx=\"12\" cy=\"12\" r=\"4\"></circle><path d=\"M12 2v2\"></path><path d=\"M12 20v2\"></path><path d=\"m4.93 4.93 1.41 1.41\"></path><path d=\"m17.66 17.66 1.41 1.41\"></path><path d=\"M2 12h2\"></path><path d=\"M20 12h2\"></path><path d=\"m6.34 17.66-1.41 1.41\"></path><path d=\"m19.07 4.93-1.41 1.41\"></path>", "github": "<path d=\"M15 22v-4a4.8 4.8 0 0 0-1-3.5c3 0 6-2 6-5.5.08-1.25-.27-2.48-1-3.5.28-1.15.28-2.35 0-3.5 0 0-1 0-3 1.5-2.64-.5-5.36-.5-8 0C6 2 5 2 5 2c-.3 1.15-.3 2.35 0 3.5A5.403 5.403 0 0 0 4 9c0 3.5 3 5.5 6 5.5-.39.49-.68 1.05-.85 1.65-.17.6-.22 1.23-.15 1.85v4\"></path><path d=\"M9 18c-4.51 2-5-2-7-2\"></path>", "slack": "<rect height=\"8\" rx=\"1.5\" width=\"3\" x=\"13\" y=\"2\"></rect><path d=\"M19 8.5V10h1.5A1.5 1.5 0 1 0 19 8.5\"></path><rect height=\"8\" rx=\"1.5\" width=\"3\" x=\"8\" y=\"14\"></rect><path d=\"M5 15.5V14H3.5A1.5 1.5 0 1 0 5 15.5\"></path><rect height=\"3\" rx=\"1.5\" width=\"8\" x=\"14\" y=\"13\"></rect><path d=\"M15.5 19H14v1.5a1.5 1.5 0 1 0 1.5-1.5\"></path><rect height=\"3\" rx=\"1.5\" width=\"8\" x=\"2\" y=\"8\"></rect><path d=\"M8.5 5H10V3.5A1.5 1.5 0 1 0 8.5 5\"></path>", "settings": "<path d=\"M12.22 2h-.44a2 2 0 0 0-2 2v.18a2 2 0 0 1-1 1.73l-.43.25a2 2 0 0 1-2 0l-.15-.08a2 2 0 0 0-2.73.73l-.22.38a2 2 0 0 0 .73 2.73l.15.1a2 2 0 0 1 1 1.72v.51a2 2 0 0 1-1 1.74l-.15.09a2 2 0 0 0-.73 2.73l.22.38a2 2 0 0 0 2.73.73l.15-.08a2 2 0 0 1 2 0l.43.25a2 2 0 0 1 1 1.73V20a2 2 0 0 0 2 2h.44a2 2 0 0 0 2-2v-.18a2 2 0 0 1 1-1.73l.43-.25a2 2 0 0 1 2 0l.15.08a2 2 0 0 0 2.73-.73l.22-.39a2 2 0 0 0-.73-2.73l-.15-.08a2 2 0 0 1-1-1.74v-.5a2 2 0 0 1 1-1.74l.15-.09a2 2 0 0 0 .73-2.73l-.22-.38a2 2 0 0 0-2.73-.73l-.15.08a2 2 0 0 1-2 0l-.43-.25a2 2 0 0 1-1-1.73V4a2 2 0 0 0-2-2z\"></path><circle cx=\"12\" cy=\"12\" r=\"3\"></circle>", "arrow-right": "<path d=\"M5 12h14\"></path><path d=\"m12 5 7 7-7 7\"></path>", "search": "<circle cx=\"11\" cy=\"11\" r=\"8\"></circle><path d=\"m21 21-4.3-4.3\"></path>", "file-search": "<path d=\"M14 2v4a2 2 0 0 0 2 2h4\"></path><path d=\"M4.268 21a2 2 0 0 0 1.727 1H18a2 2 0 0 0 2-2V7l-5-5H6a2 2 0 0 0-2 2v3\"></path><path d=\"m9 18-1.5-1.5\"></path><circle cx=\"5\" cy=\"14\" r=\"3\"></circle>", "message-circle-question": "<path d=\"M7.9 20A9 9 0 1 0 4 16.1L2 22Z\"></path><path d=\"M9.09 9a3 3 0 0 1 5.83 1c0 2-3 3-3 3\"></path><path d=\"M12 17h.01\"></path>", "text-search": "<path d=\"M21 6H3\"></path><path d=\"M10 12H3\"></path><path d=\"M10 18H3\"></path><circle cx=\"17\" cy=\"15\" r=\"3\"></circle><path d=\"m21 19-1.9-1.9\"></path>", "maximize": "<path d=\"M8 3H5a2 2 0 0 0-2 2v3\"></path><path d=\"M21 8V5a2 2 0 0 0-2-2h-3\"></path><path d=\"M3 16v3a2 2 0 0 0 2 2h3\"></path><path d=\"M16 21h3a2 2 0 0 0 2-2v-3\"></path>", "expand": "<path d=\"m21 21-6-6m6 6v-4.8m0 4.8h-4.8\"></path><path d=\"M3 16.2V21m0 0h4.8M3 21l6-6\"></path><path d=\"M21 7.8V3m0 0h-4.8M21 3l-6 6\"></path><path d=\"M3 7.8V3m0 0h4.8M3 3l6 6\"></path>", "fullscreen": "<path d=\"M3 7V5a2 2 0 0 1 2-2h2\"></path><path d=\"M17 3h2a2 2 0 0 1 2 2v2\"></path><path d=\"M21 17v2a2 2 0 0 1-2 2h-2\"></path><path d=\"M7 21H5a2 2 0 0 1-2-2v-2\"></path><rect height=\"8\" rx=\"1\" width=\"10\" x=\"7\" y=\"8\"></rect>", "images": "<path d=\"M18 22H4a2 2 0 0 1-2-2V6\"></path><path d=\"m22 13-1.296-1.296a2.41 2.41 0 0 0-3.408 0L11 18\"></path><circle cx=\"12\" cy=\"8\" r=\"2\"></circle><rect height=\"16\" rx=\"2\" width=\"16\" x=\"6\" y=\"2\"></rect>", "circle": "<circle cx=\"12\" cy=\"12\" r=\"10\"></circle>", "loader-circle": "<path d=\"M21 12a9 9 0 1 1-6.219-8.56\"></path>", "file-text": "<path d=\"M15 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7Z\"></path><path d=\"M14 2v4a2 2 0 0 0 2 2h4\"></path><path d=\"M10 9H8\"></path><path d=\"M16 13H8\"></path><path d=\"M16 17H8\"></path>", "file-question": "<path d=\"M12 17h.01\"></path><path d=\"M15 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7z\"></path><path d=\"M9.1 9a3 3 0 0 1 5.82 1c0 2-3 3-3 3\"></path>"}
 
1
+ ICONS = {"chevrons-right": "<path d=\"m6 17 5-5-5-5\"></path><path d=\"m13 17 5-5-5-5\"></path>", "moon": "<path d=\"M12 3a6 6 0 0 0 9 9 9 9 0 1 1-9-9Z\"></path>", "sun": "<circle cx=\"12\" cy=\"12\" r=\"4\"></circle><path d=\"M12 2v2\"></path><path d=\"M12 20v2\"></path><path d=\"m4.93 4.93 1.41 1.41\"></path><path d=\"m17.66 17.66 1.41 1.41\"></path><path d=\"M2 12h2\"></path><path d=\"M20 12h2\"></path><path d=\"m6.34 17.66-1.41 1.41\"></path><path d=\"m19.07 4.93-1.41 1.41\"></path>", "github": "<path d=\"M15 22v-4a4.8 4.8 0 0 0-1-3.5c3 0 6-2 6-5.5.08-1.25-.27-2.48-1-3.5.28-1.15.28-2.35 0-3.5 0 0-1 0-3 1.5-2.64-.5-5.36-.5-8 0C6 2 5 2 5 2c-.3 1.15-.3 2.35 0 3.5A5.403 5.403 0 0 0 4 9c0 3.5 3 5.5 6 5.5-.39.49-.68 1.05-.85 1.65-.17.6-.22 1.23-.15 1.85v4\"></path><path d=\"M9 18c-4.51 2-5-2-7-2\"></path>", "slack": "<rect height=\"8\" rx=\"1.5\" width=\"3\" x=\"13\" y=\"2\"></rect><path d=\"M19 8.5V10h1.5A1.5 1.5 0 1 0 19 8.5\"></path><rect height=\"8\" rx=\"1.5\" width=\"3\" x=\"8\" y=\"14\"></rect><path d=\"M5 15.5V14H3.5A1.5 1.5 0 1 0 5 15.5\"></path><rect height=\"3\" rx=\"1.5\" width=\"8\" x=\"14\" y=\"13\"></rect><path d=\"M15.5 19H14v1.5a1.5 1.5 0 1 0 1.5-1.5\"></path><rect height=\"3\" rx=\"1.5\" width=\"8\" x=\"2\" y=\"8\"></rect><path d=\"M8.5 5H10V3.5A1.5 1.5 0 1 0 8.5 5\"></path>", "settings": "<path d=\"M12.22 2h-.44a2 2 0 0 0-2 2v.18a2 2 0 0 1-1 1.73l-.43.25a2 2 0 0 1-2 0l-.15-.08a2 2 0 0 0-2.73.73l-.22.38a2 2 0 0 0 .73 2.73l.15.1a2 2 0 0 1 1 1.72v.51a2 2 0 0 1-1 1.74l-.15.09a2 2 0 0 0-.73 2.73l.22.38a2 2 0 0 0 2.73.73l.15-.08a2 2 0 0 1 2 0l.43.25a2 2 0 0 1 1 1.73V20a2 2 0 0 0 2 2h.44a2 2 0 0 0 2-2v-.18a2 2 0 0 1 1-1.73l.43-.25a2 2 0 0 1 2 0l.15.08a2 2 0 0 0 2.73-.73l.22-.39a2 2 0 0 0-.73-2.73l-.15-.08a2 2 0 0 1-1-1.74v-.5a2 2 0 0 1 1-1.74l.15-.09a2 2 0 0 0 .73-2.73l-.22-.38a2 2 0 0 0-2.73-.73l-.15.08a2 2 0 0 1-2 0l-.43-.25a2 2 0 0 1-1-1.73V4a2 2 0 0 0-2-2z\"></path><circle cx=\"12\" cy=\"12\" r=\"3\"></circle>", "arrow-right": "<path d=\"M5 12h14\"></path><path d=\"m12 5 7 7-7 7\"></path>", "search": "<circle cx=\"11\" cy=\"11\" r=\"8\"></circle><path d=\"m21 21-4.3-4.3\"></path>", "file-search": "<path d=\"M14 2v4a2 2 0 0 0 2 2h4\"></path><path d=\"M4.268 21a2 2 0 0 0 1.727 1H18a2 2 0 0 0 2-2V7l-5-5H6a2 2 0 0 0-2 2v3\"></path><path d=\"m9 18-1.5-1.5\"></path><circle cx=\"5\" cy=\"14\" r=\"3\"></circle>", "message-circle-question": "<path d=\"M7.9 20A9 9 0 1 0 4 16.1L2 22Z\"></path><path d=\"M9.09 9a3 3 0 0 1 5.83 1c0 2-3 3-3 3\"></path><path d=\"M12 17h.01\"></path>", "text-search": "<path d=\"M21 6H3\"></path><path d=\"M10 12H3\"></path><path d=\"M10 18H3\"></path><circle cx=\"17\" cy=\"15\" r=\"3\"></circle><path d=\"m21 19-1.9-1.9\"></path>", "maximize": "<path d=\"M8 3H5a2 2 0 0 0-2 2v3\"></path><path d=\"M21 8V5a2 2 0 0 0-2-2h-3\"></path><path d=\"M3 16v3a2 2 0 0 0 2 2h3\"></path><path d=\"M16 21h3a2 2 0 0 0 2-2v-3\"></path>", "expand": "<path d=\"m21 21-6-6m6 6v-4.8m0 4.8h-4.8\"></path><path d=\"M3 16.2V21m0 0h4.8M3 21l6-6\"></path><path d=\"M21 7.8V3m0 0h-4.8M21 3l-6 6\"></path><path d=\"M3 7.8V3m0 0h4.8M3 3l6 6\"></path>", "fullscreen": "<path d=\"M3 7V5a2 2 0 0 1 2-2h2\"></path><path d=\"M17 3h2a2 2 0 0 1 2 2v2\"></path><path d=\"M21 17v2a2 2 0 0 1-2 2h-2\"></path><path d=\"M7 21H5a2 2 0 0 1-2-2v-2\"></path><rect height=\"8\" rx=\"1\" width=\"10\" x=\"7\" y=\"8\"></rect>", "images": "<path d=\"M18 22H4a2 2 0 0 1-2-2V6\"></path><path d=\"m22 13-1.296-1.296a2.41 2.41 0 0 0-3.408 0L11 18\"></path><circle cx=\"12\" cy=\"8\" r=\"2\"></circle><rect height=\"16\" rx=\"2\" width=\"16\" x=\"6\" y=\"2\"></rect>", "circle": "<circle cx=\"12\" cy=\"12\" r=\"10\"></circle>", "loader-circle": "<path d=\"M21 12a9 9 0 1 1-6.219-8.56\"></path>", "file-text": "<path d=\"M15 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7Z\"></path><path d=\"M14 2v4a2 2 0 0 0 2 2h4\"></path><path d=\"M10 9H8\"></path><path d=\"M16 13H8\"></path><path d=\"M16 17H8\"></path>", "file-question": "<path d=\"M12 17h.01\"></path><path d=\"M15 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V7z\"></path><path d=\"M9.1 9a3 3 0 0 1 5.82 1c0 2-3 3-3 3\"></path>", "external-link": "<path d=\"M15 3h6v6\"></path><path d=\"M10 14 21 3\"></path><path d=\"M18 13v6a2 2 0 0 1-2 2H5a2 2 0 0 1-2-2V8a2 2 0 0 1 2-2h6\"></path>"}
main.py CHANGED
@@ -1,36 +1,37 @@
1
  import asyncio
 
2
  import os
3
  import time
4
- from pathlib import Path
5
- from concurrent.futures import ThreadPoolExecutor
6
  import uuid
 
 
 
7
  import google.generativeai as genai
 
8
  from fasthtml.common import (
 
9
  Div,
 
 
10
  Img,
 
 
11
  Main,
12
  P,
13
- Script,
14
- Link,
15
- fast_app,
16
- HighlightJS,
17
- FileResponse,
18
  RedirectResponse,
19
- Aside,
20
  StreamingResponse,
21
- JSONResponse,
22
  serve,
23
  )
 
24
  from shad4fast import ShadHead
25
  from vespa.application import Vespa
26
- import base64
27
- from fastcore.parallel import threaded
28
- from PIL import Image
29
 
30
- from backend.colpali import get_query_embeddings_and_token_map, gen_similarity_maps
31
- from backend.modelmanager import ModelManager
32
  from backend.vespa_app import VespaQueryClient
33
  from frontend.app import (
 
34
  ChatResult,
35
  Home,
36
  Search,
@@ -38,7 +39,6 @@ from frontend.app import (
38
  SearchResult,
39
  SimMapButtonPoll,
40
  SimMapButtonReady,
41
- AboutThisDemo,
42
  )
43
  from frontend.layout import Layout
44
 
@@ -90,10 +90,10 @@ thread_pool = ThreadPoolExecutor()
90
  genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
91
  GEMINI_SYSTEM_PROMPT = """If the user query is a question, try your best to answer it based on the provided images.
92
  If the user query can not be interpreted as a question, or if the answer to the query can not be inferred from the images,
93
- answer with the exact phrase "I am sorry, I do not have enough information in the image to answer your question.".
94
  Your response should be HTML formatted, but only simple tags, such as <b>. <p>, <i>, <br> <ul> and <li> are allowed. No HTML tables.
95
  This means that newlines will be replaced with <br> tags, bold text will be enclosed in <b> tags, and so on.
96
- But, you should NOT include backticks (`) or HTML tags in your response.
97
  """
98
  gemini_model = genai.GenerativeModel(
99
  "gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
@@ -107,7 +107,7 @@ os.makedirs(SIM_MAP_DIR, exist_ok=True)
107
 
108
  @app.on_event("startup")
109
  def load_model_on_startup():
110
- app.manager = ModelManager.get_instance()
111
  return
112
 
113
 
@@ -131,7 +131,7 @@ def serve_static(filepath: str):
131
  def get(session):
132
  if "session_id" not in session:
133
  session["session_id"] = str(uuid.uuid4())
134
- return Layout(Main(Home()))
135
 
136
 
137
  @rt("/about-this-demo")
@@ -140,19 +140,16 @@ def get():
140
 
141
 
142
  @rt("/search")
143
- def get(request):
144
- # Extract the 'query' and 'ranking' parameters from the URL
145
- query_value = request.query_params.get("query", "").strip()
146
- ranking_value = request.query_params.get("ranking", "nn+colpali")
147
- print("/search: Fetching results for ranking_value:", ranking_value)
148
 
149
  # Always render the SearchBox first
150
- if not query_value:
151
  # Show SearchBox and a message for missing query
152
  return Layout(
153
  Main(
154
  Div(
155
- SearchBox(query_value=query_value, ranking_value=ranking_value),
156
  Div(
157
  P(
158
  "No query provided. Please enter a query.",
@@ -165,35 +162,17 @@ def get(request):
165
  )
166
  )
167
  # Generate a unique query_id based on the query and ranking value
168
- query_id = generate_query_id(query_value, ranking_value)
169
  # Show the loading message if a query is provided
170
  return Layout(
171
  Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
172
  Aside(
173
- ChatResult(query_id=query_id, query=query_value),
174
  cls="border-t border-l hidden md:block",
175
  ),
176
  ) # Show SearchBox and Loading message initially
177
 
178
 
179
- @rt("/fetch_results2")
180
- def get(query: str, ranking: str):
181
- # 1. Get the results from Vespa (without sim_maps and full_images)
182
- # Call search-endpoint in Vespa sync.
183
-
184
- # 2. Kick off tasks to fetch sim_maps and full_images
185
- # Sim maps - call search endpoint async.
186
- # (A) New rank_profile that does not calculate sim_maps.
187
- # (A) Make vespa endpoints take select_fields as a parameter.
188
- # One sim map per image per token.
189
- # the filename query_id_result_idx_token_idx.png
190
- # Full image. based on the doc_id.
191
- # Each of these tasks saves to disk.
192
- # Need a cleanup task to delete old files.
193
- # Polling endpoints for sim_maps and full_images checks if file exists and returns it.
194
- pass
195
-
196
-
197
  @rt("/fetch_results")
198
  async def get(session, request, query: str, ranking: str):
199
  if "hx-request" not in request.headers:
@@ -203,9 +182,10 @@ async def get(session, request, query: str, ranking: str):
203
  query_id = generate_query_id(query, ranking)
204
  print(f"Query id in /fetch_results: {query_id}")
205
  # Run the embedding and query against Vespa app
206
- model = app.manager.model
207
- processor = app.manager.processor
208
- q_embs, idx_to_token = get_query_embeddings_and_token_map(processor, model, query)
 
209
 
210
  start = time.perf_counter()
211
  # Fetch real search results from Vespa
@@ -219,15 +199,20 @@ async def get(session, request, query: str, ranking: str):
219
  print(
220
  f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
221
  )
 
 
 
222
  search_results = vespa_app.results_to_search_results(result, idx_to_token)
 
223
  get_and_store_sim_maps(
224
  query_id=query_id,
225
  query=query,
226
  q_embs=q_embs,
227
  ranking=ranking,
228
  idx_to_token=idx_to_token,
 
229
  )
230
- return SearchResult(search_results, query_id)
231
 
232
 
233
  def get_results_children(result):
@@ -247,7 +232,9 @@ async def poll_vespa_keepalive():
247
 
248
 
249
  @threaded
250
- def get_and_store_sim_maps(query_id, query: str, q_embs, ranking, idx_to_token):
 
 
251
  ranking_sim = ranking + "_sim"
252
  vespa_sim_maps = vespa_app.get_sim_maps_from_query(
253
  query=query,
@@ -255,9 +242,7 @@ def get_and_store_sim_maps(query_id, query: str, q_embs, ranking, idx_to_token):
255
  ranking=ranking_sim,
256
  idx_to_token=idx_to_token,
257
  )
258
- img_paths = [
259
- IMG_DIR / f"{query_id}_{idx}.jpg" for idx in range(len(vespa_sim_maps))
260
- ]
261
  # All images should be downloaded, but best to wait 5 secs
262
  max_wait = 5
263
  start_time = time.time()
@@ -269,10 +254,7 @@ def get_and_store_sim_maps(query_id, query: str, q_embs, ranking, idx_to_token):
269
  if not all([os.path.exists(img_path) for img_path in img_paths]):
270
  print(f"Images not ready in 5 seconds for query_id: {query_id}")
271
  return False
272
- sim_map_generator = gen_similarity_maps(
273
- model=app.manager.model,
274
- processor=app.manager.processor,
275
- device=app.manager.device,
276
  query=query,
277
  query_embs=q_embs,
278
  token_idx_map=idx_to_token,
@@ -312,17 +294,17 @@ async def get_sim_map(query_id: str, idx: int, token: str, token_idx: int):
312
 
313
 
314
  @app.get("/full_image")
315
- async def full_image(docid: str, query_id: str, idx: int):
316
  """
317
  Endpoint to get the full quality image for a given result id.
318
  """
319
- img_path = IMG_DIR / f"{query_id}_{idx}.jpg"
320
  if not os.path.exists(img_path):
321
- image_data = await vespa_app.get_full_image_from_vespa(docid)
322
  # image data is base 64 encoded string. Save it to disk as jpg.
323
  with open(img_path, "wb") as f:
324
  f.write(base64.b64decode(image_data))
325
- print(f"Full image saved to disk for query_id: {query_id}, idx: {idx}")
326
  else:
327
  with open(img_path, "rb") as f:
328
  image_data = base64.b64encode(f.read()).decode("utf-8")
@@ -334,8 +316,9 @@ async def full_image(docid: str, query_id: str, idx: int):
334
 
335
 
336
  @rt("/suggestions")
337
- async def get_suggestions(request):
338
- query = request.query_params.get("query", "").lower().strip()
 
339
 
340
  if query:
341
  suggestions = await vespa_app.get_suggestions(query)
@@ -345,15 +328,20 @@ async def get_suggestions(request):
345
  return JSONResponse({"suggestions": []})
346
 
347
 
348
- async def message_generator(query_id: str, query: str):
349
- images = []
 
350
  num_images = 3 # Number of images before firing chat request
351
  max_wait = 10 # seconds
352
  start_time = time.time()
353
  # Check if full images are ready on disk
354
- while len(images) < num_images and time.time() - start_time < max_wait:
 
 
 
355
  for idx in range(num_images):
356
- if not os.path.exists(IMG_DIR / f"{query_id}_{idx}.jpg"):
 
357
  print(
358
  f"Message generator: Full image not ready for query_id: {query_id}, idx: {idx}"
359
  )
@@ -362,12 +350,14 @@ async def message_generator(query_id: str, query: str):
362
  print(
363
  f"Message generator: image ready for query_id: {query_id}, idx: {idx}"
364
  )
365
- images.append(Image.open(IMG_DIR / f"{query_id}_{idx}.jpg"))
366
  await asyncio.sleep(0.2)
 
 
367
  # yield message with number of images ready
368
- yield f"event: message\ndata: Generating response based on {len(images)} images.\n\n"
369
  if not images:
370
- yield "event: message\ndata: I am sorry, I do not have enough information in the image to answer your question.\n\n"
371
  yield "event: close\ndata: \n\n"
372
  return
373
 
@@ -388,9 +378,9 @@ async def message_generator(query_id: str, query: str):
388
 
389
 
390
  @app.get("/get-message")
391
- async def get_message(query_id: str, query: str):
392
  return StreamingResponse(
393
- message_generator(query_id=query_id, query=query),
394
  media_type="text/event-stream",
395
  )
396
 
 
1
  import asyncio
2
+ import base64
3
  import os
4
  import time
 
 
5
  import uuid
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ from pathlib import Path
8
+
9
  import google.generativeai as genai
10
+ from fastcore.parallel import threaded
11
  from fasthtml.common import (
12
+ Aside,
13
  Div,
14
+ FileResponse,
15
+ HighlightJS,
16
  Img,
17
+ JSONResponse,
18
+ Link,
19
  Main,
20
  P,
 
 
 
 
 
21
  RedirectResponse,
22
+ Script,
23
  StreamingResponse,
24
+ fast_app,
25
  serve,
26
  )
27
+ from PIL import Image
28
  from shad4fast import ShadHead
29
  from vespa.application import Vespa
 
 
 
30
 
31
+ from backend.colpali import SimMapGenerator
 
32
  from backend.vespa_app import VespaQueryClient
33
  from frontend.app import (
34
+ AboutThisDemo,
35
  ChatResult,
36
  Home,
37
  Search,
 
39
  SearchResult,
40
  SimMapButtonPoll,
41
  SimMapButtonReady,
 
42
  )
43
  from frontend.layout import Layout
44
 
 
90
  genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
91
  GEMINI_SYSTEM_PROMPT = """If the user query is a question, try your best to answer it based on the provided images.
92
  If the user query can not be interpreted as a question, or if the answer to the query can not be inferred from the images,
93
+ answer with the exact phrase "I am sorry, I can't find enough relevant information on these pages to answer your question.".
94
  Your response should be HTML formatted, but only simple tags, such as <b>. <p>, <i>, <br> <ul> and <li> are allowed. No HTML tables.
95
  This means that newlines will be replaced with <br> tags, bold text will be enclosed in <b> tags, and so on.
96
+ Do NOT include backticks (`) in your response. Only simple HTML tags and text.
97
  """
98
  gemini_model = genai.GenerativeModel(
99
  "gemini-1.5-flash-8b", system_instruction=GEMINI_SYSTEM_PROMPT
 
107
 
108
  @app.on_event("startup")
109
  def load_model_on_startup():
110
+ app.sim_map_generator = SimMapGenerator()
111
  return
112
 
113
 
 
131
  def get(session):
132
  if "session_id" not in session:
133
  session["session_id"] = str(uuid.uuid4())
134
+ return Layout(Main(Home()), is_home=True)
135
 
136
 
137
  @rt("/about-this-demo")
 
140
 
141
 
142
  @rt("/search")
143
+ def get(request, query: str = "", ranking: str = "nn+colpali"):
144
+ print("/search: Fetching results for ranking_value:", ranking)
 
 
 
145
 
146
  # Always render the SearchBox first
147
+ if not query:
148
  # Show SearchBox and a message for missing query
149
  return Layout(
150
  Main(
151
  Div(
152
+ SearchBox(query_value=query, ranking_value=ranking),
153
  Div(
154
  P(
155
  "No query provided. Please enter a query.",
 
162
  )
163
  )
164
  # Generate a unique query_id based on the query and ranking value
165
+ query_id = generate_query_id(query, ranking)
166
  # Show the loading message if a query is provided
167
  return Layout(
168
  Main(Search(request), data_overlayscrollbars_initialize=True, cls="border-t"),
169
  Aside(
170
+ ChatResult(query_id=query_id, query=query),
171
  cls="border-t border-l hidden md:block",
172
  ),
173
  ) # Show SearchBox and Loading message initially
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  @rt("/fetch_results")
177
  async def get(session, request, query: str, ranking: str):
178
  if "hx-request" not in request.headers:
 
182
  query_id = generate_query_id(query, ranking)
183
  print(f"Query id in /fetch_results: {query_id}")
184
  # Run the embedding and query against Vespa app
185
+
186
+ q_embs, idx_to_token = app.sim_map_generator.get_query_embeddings_and_token_map(
187
+ query
188
+ )
189
 
190
  start = time.perf_counter()
191
  # Fetch real search results from Vespa
 
199
  print(
200
  f"Search results fetched in {end - start:.2f} seconds, Vespa says searchtime was {result['timing']['searchtime']} seconds"
201
  )
202
+ search_time = result["timing"]["searchtime"]
203
+ total_count = result["root"]["fields"]["totalCount"]
204
+
205
  search_results = vespa_app.results_to_search_results(result, idx_to_token)
206
+
207
  get_and_store_sim_maps(
208
  query_id=query_id,
209
  query=query,
210
  q_embs=q_embs,
211
  ranking=ranking,
212
  idx_to_token=idx_to_token,
213
+ doc_ids=[result["fields"]["id"] for result in search_results],
214
  )
215
+ return SearchResult(search_results, query, query_id, search_time, total_count)
216
 
217
 
218
  def get_results_children(result):
 
232
 
233
 
234
  @threaded
235
+ def get_and_store_sim_maps(
236
+ query_id, query: str, q_embs, ranking, idx_to_token, doc_ids
237
+ ):
238
  ranking_sim = ranking + "_sim"
239
  vespa_sim_maps = vespa_app.get_sim_maps_from_query(
240
  query=query,
 
242
  ranking=ranking_sim,
243
  idx_to_token=idx_to_token,
244
  )
245
+ img_paths = [IMG_DIR / f"{doc_id}.jpg" for doc_id in doc_ids]
 
 
246
  # All images should be downloaded, but best to wait 5 secs
247
  max_wait = 5
248
  start_time = time.time()
 
254
  if not all([os.path.exists(img_path) for img_path in img_paths]):
255
  print(f"Images not ready in 5 seconds for query_id: {query_id}")
256
  return False
257
+ sim_map_generator = app.sim_map_generator.gen_similarity_maps(
 
 
 
258
  query=query,
259
  query_embs=q_embs,
260
  token_idx_map=idx_to_token,
 
294
 
295
 
296
  @app.get("/full_image")
297
+ async def full_image(doc_id: str):
298
  """
299
  Endpoint to get the full quality image for a given result id.
300
  """
301
+ img_path = IMG_DIR / f"{doc_id}.jpg"
302
  if not os.path.exists(img_path):
303
+ image_data = await vespa_app.get_full_image_from_vespa(doc_id)
304
  # image data is base 64 encoded string. Save it to disk as jpg.
305
  with open(img_path, "wb") as f:
306
  f.write(base64.b64decode(image_data))
307
+ print(f"Full image saved to disk for doc_id: {doc_id}")
308
  else:
309
  with open(img_path, "rb") as f:
310
  image_data = base64.b64encode(f.read()).decode("utf-8")
 
316
 
317
 
318
  @rt("/suggestions")
319
+ async def get_suggestions(query: str = ""):
320
+ """Endpoint to get suggestions as user types in the search box"""
321
+ query = query.lower().strip()
322
 
323
  if query:
324
  suggestions = await vespa_app.get_suggestions(query)
 
328
  return JSONResponse({"suggestions": []})
329
 
330
 
331
+ async def message_generator(query_id: str, query: str, doc_ids: list):
332
+ """Generator function to yield SSE messages for chat response"""
333
+ images = {}
334
  num_images = 3 # Number of images before firing chat request
335
  max_wait = 10 # seconds
336
  start_time = time.time()
337
  # Check if full images are ready on disk
338
+ while (
339
+ len(images) < min(num_images, len(doc_ids))
340
+ and time.time() - start_time < max_wait
341
+ ):
342
  for idx in range(num_images):
343
+ image_filename = IMG_DIR / f"{doc_ids[idx]}.jpg"
344
+ if not os.path.exists(image_filename):
345
  print(
346
  f"Message generator: Full image not ready for query_id: {query_id}, idx: {idx}"
347
  )
 
350
  print(
351
  f"Message generator: image ready for query_id: {query_id}, idx: {idx}"
352
  )
353
+ images[image_filename] = Image.open(image_filename)
354
  await asyncio.sleep(0.2)
355
+
356
+ images = list(images.values())
357
  # yield message with number of images ready
358
+ yield f"event: message\ndata: Generating response based on {len(images)} images...\n\n"
359
  if not images:
360
+ yield "event: message\ndata: Failed to send images to Gemini-8B!\n\n"
361
  yield "event: close\ndata: \n\n"
362
  return
363
 
 
378
 
379
 
380
  @app.get("/get-message")
381
+ async def get_message(query_id: str, query: str, doc_ids: str):
382
  return StreamingResponse(
383
+ message_generator(query_id=query_id, query=query, doc_ids=doc_ids.split(",")),
384
  media_type="text/event-stream",
385
  )
386
 
output.css CHANGED
@@ -555,58 +555,105 @@ video {
555
  }
556
 
557
  :root {
558
- --background: 0 0% 100%;
559
- --foreground: 222.2 84% 4.9%;
560
- --card: 0 0% 100%;
561
- --card-foreground: 222.2 84% 4.9%;
562
- --popover: 0 0% 100%;
563
- --popover-foreground: 222.2 84% 4.9%;
564
- --primary: 222.2 47.4% 11.2%;
565
- --primary-foreground: 210 40% 98%;
566
- --secondary: 210 40% 96.1%;
567
- --secondary-foreground: 222.2 47.4% 11.2%;
568
- --muted: 210 40% 96.1%;
569
- --muted-foreground: 215.4 16.3% 26.9%;
570
- --accent: 210 40% 96.1%;
571
- --accent-foreground: 222.2 47.4% 11.2%;
572
- --destructive: 0 84.2% 60.2%;
573
- --destructive-foreground: 210 40% 98%;
574
- --border: 214.3 31.8% 81.4%;
575
- --input: 214.3 31.8% 81.4%;
576
- --ring: 222.2 84% 4.9%;
577
- --radius: 0.5rem;
578
- --chart-1: 12 76% 61%;
579
- --chart-2: 173 58% 39%;
580
- --chart-3: 197 37% 24%;
581
- --chart-4: 43 74% 66%;
582
- --chart-5: 27 87% 67%;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  }
584
 
585
  .dark {
586
- --background: 222.2 84% 4.9%;
587
- --foreground: 210 40% 98%;
588
- --card: 222.2 84% 4.9%;
589
- --card-foreground: 210 40% 98%;
590
- --popover: 222.2 84% 4.9%;
591
- --popover-foreground: 210 40% 98%;
592
- --primary: 210 40% 98%;
593
- --primary-foreground: 222.2 47.4% 11.2%;
594
- --secondary: 217.2 32.6% 17.5%;
595
- --secondary-foreground: 210 40% 98%;
596
- --muted: 217.2 32.6% 17.5%;
597
- --muted-foreground: 215 20.2% 85.1%;
598
- --accent: 217.2 32.6% 17.5%;
599
- --accent-foreground: 210 40% 98%;
600
- --destructive: 0 62.8% 30.6%;
601
- --destructive-foreground: 210 40% 98%;
602
- --border: 217.2 32.6% 27.5%;
603
- --input: 217.2 32.6% 27.5%;
604
- --ring: 212.7 26.8% 83.9;
605
- --chart-1: 220 70% 50%;
606
- --chart-2: 160 60% 45%;
607
- --chart-3: 30 80% 55%;
608
- --chart-4: 280 65% 60%;
609
- --chart-5: 340 75% 55%;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
  }
611
 
612
  :root:has(.no-bg-scroll) {
@@ -1134,6 +1181,10 @@ body {
1134
  grid-template-rows: minmax(0,55px) minmax(0,1fr);
1135
  }
1136
 
 
 
 
 
1137
  .flex-col {
1138
  flex-direction: column;
1139
  }
@@ -1248,6 +1299,12 @@ body {
1248
  margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
1249
  }
1250
 
 
 
 
 
 
 
1251
  .self-stretch {
1252
  align-self: stretch;
1253
  }
@@ -1407,6 +1464,11 @@ body {
1407
  background-color: hsl(var(--secondary));
1408
  }
1409
 
 
 
 
 
 
1410
  .bg-gradient-to-r {
1411
  background-image: linear-gradient(to right, var(--tw-gradient-stops));
1412
  }
@@ -1415,15 +1477,15 @@ body {
1415
  background-image: linear-gradient(to top, var(--tw-gradient-stops));
1416
  }
1417
 
1418
- .from-black {
1419
- --tw-gradient-from: #000 var(--tw-gradient-from-position);
1420
- --tw-gradient-to: rgb(0 0 0 / 0) var(--tw-gradient-to-position);
1421
  --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
1422
  }
1423
 
1424
- .from-white {
1425
- --tw-gradient-from: #fff var(--tw-gradient-from-position);
1426
- --tw-gradient-to: rgb(255 255 255 / 0) var(--tw-gradient-to-position);
1427
  --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
1428
  }
1429
 
@@ -2084,6 +2146,15 @@ header {
2084
  grid-column: 1/-1;
2085
  }
2086
 
 
 
 
 
 
 
 
 
 
2087
  main {
2088
  overflow: auto;
2089
  }
@@ -2139,8 +2210,8 @@ aside {
2139
 
2140
  .awesomplete > ul > :not([hidden]) ~ :not([hidden]) {
2141
  --tw-space-y-reverse: 0;
2142
- margin-top: calc(0.125rem * calc(1 - var(--tw-space-y-reverse)));
2143
- margin-bottom: calc(0.125rem * var(--tw-space-y-reverse));
2144
  }
2145
 
2146
  .awesomplete > ul {
@@ -2152,7 +2223,10 @@ aside {
2152
  border-right: 1px solid hsl(var(--input));
2153
  border-bottom: 1px solid hsl(var(--input));
2154
  border-radius: 0 0 calc(var(--radius) - 2px) calc(var(--radius) - 2px);
2155
- background: hsl(var(--background));
 
 
 
2156
  box-shadow: none;
2157
  text-shadow: none;
2158
  }
@@ -2700,6 +2774,12 @@ aside {
2700
  }
2701
  }
2702
 
 
 
 
 
 
 
2703
  .dark\:block:where(.dark, .dark *) {
2704
  display: block;
2705
  }
@@ -2716,9 +2796,13 @@ aside {
2716
  border-color: hsl(var(--destructive));
2717
  }
2718
 
2719
- .dark\:from-slate-900:where(.dark, .dark *) {
2720
- --tw-gradient-from: #0f172a var(--tw-gradient-from-position);
2721
- --tw-gradient-to: rgb(15 23 42 / 0) var(--tw-gradient-to-position);
 
 
 
 
2722
  --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
2723
  }
2724
 
 
555
  }
556
 
557
  :root {
558
+ --background: 240 20% 99%;
559
+ /* 1 */
560
+ --foreground: 210 13% 13%;
561
+ /* 12 */
562
+ --card: 240 20% 99%;
563
+ /* 1 */
564
+ --card-foreground: 210 13% 13%;
565
+ /* 12 */
566
+ --popover: 240 20% 99%;
567
+ /* 1 */
568
+ --popover-foreground: 210 13% 13%;
569
+ /* 12 */
570
+ --primary: 210 13% 13%;
571
+ /* 12 */
572
+ --primary-foreground: 240 20% 98%;
573
+ /* 2 */
574
+ --secondary: 240 11% 95%;
575
+ /* 3 */
576
+ --secondary-foreground: 210 13% 13%;
577
+ /* 12 */
578
+ --muted: 240 11% 95%;
579
+ /* 3 */
580
+ --muted-foreground: 220 6% 40%;
581
+ /* 11 */
582
+ --accent: 240 11% 95%;
583
+ /* 3 */
584
+ --accent-foreground: 210 13% 13%;
585
+ /* 12 */
586
+ --destructive: 358 75% 59%;
587
+ /* 9 - red */
588
+ --destructive-foreground: 240 20% 98%;
589
+ /* 2 */
590
+ --border: 240 10% 86%;
591
+ /* 6 */
592
+ --input: 240 10% 86%;
593
+ /* 6 */
594
+ --ring: 210 13% 13%;
595
+ /* 12 */
596
+ --chart-1: 10 78% 54%;
597
+ /* 9 - tomato */
598
+ --chart-2: 173 80% 36%;
599
+ /* 9 - teal */
600
+ --chart-3: 206 100% 50%;
601
+ /* 9 - blue */
602
+ --chart-4: 42 100% 62%;
603
+ /* 9 - amber */
604
+ --chart-5: 23 93% 53%;
605
+ /* 9 - orange */
606
  }
607
 
608
  .dark {
609
+ --background: 240 6% 7%;
610
+ /* 1 */
611
+ --foreground: 220 9% 94%;
612
+ /* 12 */
613
+ --card: 240 6% 7%;
614
+ /* 1 */
615
+ --card-foreground: 220 9% 94%;
616
+ /* 12 */
617
+ --popover: 240 6% 7%;
618
+ /* 1 */
619
+ --popover-foreground: 220 9% 94%;
620
+ /* 12 */
621
+ --primary: 220 9% 94%;
622
+ /* 12 */
623
+ --primary-foreground: 220 6% 10%;
624
+ /* 2 */
625
+ --secondary: 225 6% 14%;
626
+ /* 3 */
627
+ --secondary-foreground: 220 9% 94%;
628
+ /* 12 */
629
+ --muted: 225 6% 14%;
630
+ /* 3 */
631
+ --muted-foreground: 216 7% 71%;
632
+ /* 11 */
633
+ --accent: 225 6% 14%;
634
+ /* 3 */
635
+ --accent-foreground: 220 9% 94%;
636
+ /* 12 */
637
+ --destructive: 358 75% 59%;
638
+ /* 9 - red */
639
+ --destructive-foreground: 220 9% 94%;
640
+ /* 12 */
641
+ --border: 213 8% 23%;
642
+ /* 6 */
643
+ --input: 213 8% 23%;
644
+ /* 6 */
645
+ --ring: 220 9% 94%;
646
+ /* 12 */
647
+ --chart-1: 10 78% 54%;
648
+ /* 9 - tomato */
649
+ --chart-2: 173 80% 36%;
650
+ /* 9 - teal */
651
+ --chart-3: 206 100% 50%;
652
+ /* 9 - blue */
653
+ --chart-4: 42 100% 62%;
654
+ /* 9 - amber */
655
+ --chart-5: 23 93% 53%;
656
+ /* 9 - orange */
657
  }
658
 
659
  :root:has(.no-bg-scroll) {
 
1181
  grid-template-rows: minmax(0,55px) minmax(0,1fr);
1182
  }
1183
 
1184
+ .grid-rows-\[auto_auto_1fr\] {
1185
+ grid-template-rows: auto auto 1fr;
1186
+ }
1187
+
1188
  .flex-col {
1189
  flex-direction: column;
1190
  }
 
1299
  margin-bottom: calc(0.5rem * var(--tw-space-y-reverse));
1300
  }
1301
 
1302
+ .space-x-1 > :not([hidden]) ~ :not([hidden]) {
1303
+ --tw-space-x-reverse: 0;
1304
+ margin-right: calc(0.25rem * var(--tw-space-x-reverse));
1305
+ margin-left: calc(0.25rem * calc(1 - var(--tw-space-x-reverse)));
1306
+ }
1307
+
1308
  .self-stretch {
1309
  align-self: stretch;
1310
  }
 
1464
  background-color: hsl(var(--secondary));
1465
  }
1466
 
1467
+ .bg-white {
1468
+ --tw-bg-opacity: 1;
1469
+ background-color: rgb(255 255 255 / var(--tw-bg-opacity));
1470
+ }
1471
+
1472
  .bg-gradient-to-r {
1473
  background-image: linear-gradient(to right, var(--tw-gradient-stops));
1474
  }
 
1477
  background-image: linear-gradient(to top, var(--tw-gradient-stops));
1478
  }
1479
 
1480
+ .from-\[\#fcfcfd\] {
1481
+ --tw-gradient-from: #fcfcfd var(--tw-gradient-from-position);
1482
+ --tw-gradient-to: rgb(252 252 253 / 0) var(--tw-gradient-to-position);
1483
  --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
1484
  }
1485
 
1486
+ .from-black {
1487
+ --tw-gradient-from: #000 var(--tw-gradient-from-position);
1488
+ --tw-gradient-to: rgb(0 0 0 / 0) var(--tw-gradient-to-position);
1489
  --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
1490
  }
1491
 
 
2146
  grid-column: 1/-1;
2147
  }
2148
 
2149
+ body {
2150
+ &[data-is-home="true"] {
2151
+ background: radial-gradient(circle at 50% 100%, #fcfcfd, #fcfcfd, #fdfdfe, #fdfdfe, #fefefe, #fefefe, #ffffff, #ffffff);
2152
+ .dark & {
2153
+ background: radial-gradient(circle at 50% 50%, #272a2d, #242629, #212326, #1e1f22, #1b1c1e, #18181b, #151517, #111113);
2154
+ }
2155
+ }
2156
+ }
2157
+
2158
  main {
2159
  overflow: auto;
2160
  }
 
2210
 
2211
  .awesomplete > ul > :not([hidden]) ~ :not([hidden]) {
2212
  --tw-space-y-reverse: 0;
2213
+ margin-top: calc(0.25rem * calc(1 - var(--tw-space-y-reverse)));
2214
+ margin-bottom: calc(0.25rem * var(--tw-space-y-reverse));
2215
  }
2216
 
2217
  .awesomplete > ul {
 
2223
  border-right: 1px solid hsl(var(--input));
2224
  border-bottom: 1px solid hsl(var(--input));
2225
  border-radius: 0 0 calc(var(--radius) - 2px) calc(var(--radius) - 2px);
2226
+ background: white;
2227
+ .dark & {
2228
+ background: hsl(var(--background));
2229
+ }
2230
  box-shadow: none;
2231
  text-shadow: none;
2232
  }
 
2774
  }
2775
  }
2776
 
2777
+ @media (min-width: 1280px) {
2778
+ .xl\:grid-rows-\[1fr_2fr\] {
2779
+ grid-template-rows: 1fr 2fr;
2780
+ }
2781
+ }
2782
+
2783
  .dark\:block:where(.dark, .dark *) {
2784
  display: block;
2785
  }
 
2796
  border-color: hsl(var(--destructive));
2797
  }
2798
 
2799
+ .dark\:bg-background:where(.dark, .dark *) {
2800
+ background-color: hsl(var(--background));
2801
+ }
2802
+
2803
+ .dark\:from-\[\#1c2024\]:where(.dark, .dark *) {
2804
+ --tw-gradient-from: #1c2024 var(--tw-gradient-from-position);
2805
+ --tw-gradient-to: rgb(28 32 36 / 0) var(--tw-gradient-to-position);
2806
  --tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to);
2807
  }
2808
 
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  # This file was autogenerated by uv via the following command:
2
- # uv pip compile pyproject.toml -o requirements.txt
3
  accelerate==0.34.2
4
  # via peft
5
  aiohappyeyeballs==2.4.3
 
1
  # This file was autogenerated by uv via the following command:
2
+ # uv pip compile pyproject.toml -o src/requirements.txt
3
  accelerate==0.34.2
4
  # via peft
5
  aiohappyeyeballs==2.4.3
static/.DS_Store CHANGED
Binary files a/static/.DS_Store and b/static/.DS_Store differ