Update README.md
Browse files
README.md
CHANGED
@@ -30,3 +30,78 @@ transformers>=4.39.2
|
|
30 |
flash_attn>=2.5.6
|
31 |
```
|
32 |
## Usage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
flash_attn>=2.5.6
|
31 |
```
|
32 |
## Usage
|
33 |
+
|
34 |
+
Get Dense Embeddings with Transformers
|
35 |
+
```
|
36 |
+
# Requires transformers>=4.36.0
|
37 |
+
|
38 |
+
import torch.nn.functional as F
|
39 |
+
from transformers import AutoModel, AutoTokenizer
|
40 |
+
|
41 |
+
input_texts = [
|
42 |
+
"what is the capital of China?",
|
43 |
+
"how to implement quick sort in python?",
|
44 |
+
"北京",
|
45 |
+
"快排算法介绍"
|
46 |
+
]
|
47 |
+
|
48 |
+
model_path = 'Alibaba-NLP/gte-multilingual-base'
|
49 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
50 |
+
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
|
51 |
+
|
52 |
+
# Tokenize the input texts
|
53 |
+
batch_dict = tokenizer(input_texts, max_length=8192, padding=True, truncation=True, return_tensors='pt')
|
54 |
+
|
55 |
+
outputs = model(**batch_dict)
|
56 |
+
|
57 |
+
dimension=768 # The output dimension of the output embedding, should be in [128, 768]
|
58 |
+
embeddings = outputs.last_hidden_state[:, 0][:dimension]
|
59 |
+
|
60 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
61 |
+
scores = (embeddings[:1] @ embeddings[1:].T) * 100
|
62 |
+
print(scores.tolist())
|
63 |
+
```
|
64 |
+
|
65 |
+
Use with sentence-transformers
|
66 |
+
```
|
67 |
+
from sentence_transformers import SentenceTransformer
|
68 |
+
from sentence_transformers.util import cos_sim
|
69 |
+
|
70 |
+
input_texts = [
|
71 |
+
"what is the capital of China?",
|
72 |
+
"how to implement quick sort in python?",
|
73 |
+
"北京",
|
74 |
+
"快排算法介绍"
|
75 |
+
]
|
76 |
+
|
77 |
+
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
|
78 |
+
embeddings = model.encode(input_texts)
|
79 |
+
```
|
80 |
+
|
81 |
+
Use with custom code to get dense embeddigns and sparse token weights
|
82 |
+
```
|
83 |
+
# You can find the gte_embeddings.py in https://huggingface.co/Alibaba-NLP/gte-multilingual-base/blob/main/scripts/gte_embedding.py
|
84 |
+
from gte_embeddings import GTEEmbeddidng
|
85 |
+
|
86 |
+
model_path = 'Alibaba-NLP/gte-multilingual-base'
|
87 |
+
model = GTEEmbeddidng(model_path)
|
88 |
+
query = "中国的首都在哪儿"
|
89 |
+
|
90 |
+
docs = [
|
91 |
+
"what is the capital of China?",
|
92 |
+
"how to implement quick sort in python?",
|
93 |
+
"北京",
|
94 |
+
"快排算法介绍"
|
95 |
+
]
|
96 |
+
|
97 |
+
embs = model.encode(docs, return_dense=True,return_sparse=True)
|
98 |
+
print('dense_embeddings vecs', embs['dense_embeddings'])
|
99 |
+
print('token_weights', embs['token_weights'])
|
100 |
+
pairs = [(query, doc) for doc in docs]
|
101 |
+
dense_scores = model.compute_scores(pairs, dense_weight=1.0, sparse_weight=0.0)
|
102 |
+
sparse_scores = model.compute_scores(pairs, dense_weight=0.0, sparse_weight=1.0)
|
103 |
+
hybird_scores = model.compute_scores(pairs, dense_weight=1.0, sparse_weight=0.3)
|
104 |
+
print('dense_scores', dense_scores)
|
105 |
+
print('sparse_scores', sparse_scores)
|
106 |
+
print('hybird_scores', hybird_scores)
|
107 |
+
```
|