ai-safety-chatty / embedding_model.py
jeevan
recommit
bc453aa
raw
history blame
No virus
2.15 kB
import tiktoken
import os
from langchain_openai import OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
import torch
from transformers import AutoModel, AutoTokenizer
from transformers import AutoModel, AutoTokenizer
from langchain_huggingface import HuggingFaceEmbeddings
# def get_embeddings_model_bge_base_en_v1_5():
# model_name = "BAAI/bge-base-en-v1.5"
# model_kwargs = {'device': 'cpu'}
# encode_kwargs = {'normalize_embeddings': False}
# embedding_model = HuggingFaceBgeEmbeddings(
# model_name=model_name,
# model_kwargs=model_kwargs,
# encode_kwargs=encode_kwargs
# )
# return embedding_model
# def get_embeddings_model_bge_en_icl():
# model_name = "BAAI/bge-en-icl"
# model_kwargs = {'device': 'cpu'}
# encode_kwargs = {'normalize_embeddings': False}
# embedding_model = HuggingFaceBgeEmbeddings(
# model_name=model_name,
# model_kwargs=model_kwargs,
# encode_kwargs=encode_kwargs
# )
# return embedding_model , 4096
# def get_embeddings_model_bge_large_en():
# model_name = "BAAI/bge-large-en"
# model_kwargs = {'device': 'cpu'}
# encode_kwargs = {'normalize_embeddings': False}
# embedding_model = HuggingFaceBgeEmbeddings(
# model_name=model_name,
# model_kwargs=model_kwargs,
# encode_kwargs=encode_kwargs
# )
# return embedding_model
def get_embeddings_openai_text_3_large():
embedding_model = OpenAIEmbeddings(model="text-embedding-3-large")
dimension = 3072
return embedding_model,dimension
# def get_embeddings_snowflake_arctic_embed_l():
# current_dir = os.path.dirname(os.path.realpath(__file__))
# model_name = "Snowflake/snowflake-arctic-embed-l"
# tokenizer = AutoTokenizer.from_pretrained(f"{current_dir}/cache/tokenizer/{model_name}")
# model = AutoModel.from_pretrained(f"{current_dir}/cache/model/{model_name}")
# return model,1024
def get_embeddings_snowflake_arctic_embed_l():
embedding_model = HuggingFaceEmbeddings(model_name="Snowflake/snowflake-arctic-embed-l")
return embedding_model,1024