kanji_lookup / database.py
etrotta's picture
Change the vector database used and embed the embeddings within the program
63a1db6
raw
history blame contribute delete
843 Bytes
import torch
import lancedb
from lancedb.pydantic import LanceModel
import pydantic
# import time
from config import lancedb_location
db = lancedb.connect(lancedb_location)
table = db.open_table("kanji")
class SearchResult(LanceModel):
kanji: str
distance: float = pydantic.Field(validation_alias=pydantic.AliasChoices('distance', '_distance'))
def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[SearchResult]:
# start = time.perf_counter()
results = (
table
.search(query_vector.numpy(), vector_column_name="vector", query_type="vector")
.limit(limit)
# .to_pydantic(SearchResult) # type: ignore
.to_list()
)
# end = time.perf_counter()
# print(f"Searched in {end - start:.3f}")
return [SearchResult.model_validate(result) for result in results]