import pandas as pd import time import random from sentence_transformers import SentenceTransformer from pymilvus import connections, DataType, FieldSchema, CollectionSchema, Collection, utility import configparser from tqdm import tqdm # Initialize SentenceTransformer model for embeddings embedding_model = SentenceTransformer(model_name_or_path="bert-base-uncased") # Read molecule names from CSV csv_path = 'molecules-small.csv' df = pd.read_csv(csv_path) max_name_length = 256 molecules = df['cmpdname'].tolist() for i, molecule in enumerate(molecules): if len(molecule) > max_name_length: molecules[i] = molecule[:max_name_length] cids = df['cid'].tolist() # Encode embeddings for each molecule embeddings_list = [] for molecule in tqdm(molecules, desc="Generating Embeddings"): embeddings = embedding_model.encode(molecule) embeddings_list.append(embeddings) cfp = configparser.RawConfigParser() cfp.read('config.ini') milvus_uri = cfp.get('example', 'uri') token = cfp.get('example', 'token') connections.connect("default", uri=milvus_uri, token=token) print(f"Connecting to DB: {milvus_uri}") # Define collection name and dimensionality of embeddings collection_name = 'molecule_embeddings' check_collection = utility.has_collection(collection_name) if check_collection: drop_result = utility.drop_collection(collection_name) print("Success!") dim = 768 # Adjust based on the dimensionality of your embeddings # Define collection schema molecule_cid = FieldSchema(name="molecule_cid", dtype=DataType.INT64, description="cid", is_primary = True) molecule_name = FieldSchema(name="molecule_name", dtype=DataType.VARCHAR, max_length=256, description="name") molecule_embeddings = FieldSchema(name="molecule_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim) schema = CollectionSchema(fields=[molecule_cid, molecule_name, molecule_embeddings], auto_id=False, description="my first collection!") print(f"Creating example collection: {collection_name}") collection = Collection(name=collection_name, schema=schema) print(f"Schema: {schema}") print("Success!") batch_size = 1000 total_rt = 0 start = 0 print(f"Inserting {len(embeddings_list)} entities... ") for i in tqdm(range(0, len(embeddings_list), batch_size), desc="Inserting Embeddings"): batch_embeddings = embeddings_list[i:i + batch_size] batch_molecules = molecules[i:i + batch_size] batch_cids = cids[i:i + batch_size] entities = [batch_cids, batch_molecules, batch_embeddings] start += batch_size t0 = time.time() ins_resp = collection.insert(entities) ins_rt = time.time() - t0 total_rt += ins_rt print(f"Succeed in inserting {len(embeddings_list)} entities in {round(total_rt, 4)} seconds!") # Flush collection print("Flushing collection...") collection.flush() # Build index index_params = {"index_type": "AUTOINDEX", "metric_type": "L2", "params": {}} print("Building index...") collection.create_index(field_name='molecule_embedding', index_params=index_params) collection.load() # Example search nq = 1 search_params = {"metric_type": "L2"} topk = 5 search_vec = [[random.random() for _ in range(dim)] for _ in range(nq)] print(f"Searching vector: {search_vec}") results = collection.search(search_vec, anns_field='molecule_embedding', param=search_params, limit=topk) print(f"Search results: {results}") # Disconnect from Milvus server connections.disconnect("default") print("Disconnected from Milvus server.")