zyznull commited on
Commit
2d7b768
1 Parent(s): 0119b51

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +75 -0
README.md CHANGED
@@ -30,3 +30,78 @@ transformers>=4.39.2
30
  flash_attn>=2.5.6
31
  ```
32
  ## Usage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  flash_attn>=2.5.6
31
  ```
32
  ## Usage
33
+
34
+ Get Dense Embeddings with Transformers
35
+ ```
36
+ # Requires transformers>=4.36.0
37
+
38
+ import torch.nn.functional as F
39
+ from transformers import AutoModel, AutoTokenizer
40
+
41
+ input_texts = [
42
+ "what is the capital of China?",
43
+ "how to implement quick sort in python?",
44
+ "北京",
45
+ "快排算法介绍"
46
+ ]
47
+
48
+ model_path = 'Alibaba-NLP/gte-multilingual-base'
49
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
50
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
51
+
52
+ # Tokenize the input texts
53
+ batch_dict = tokenizer(input_texts, max_length=8192, padding=True, truncation=True, return_tensors='pt')
54
+
55
+ outputs = model(**batch_dict)
56
+
57
+ dimension=768 # The output dimension of the output embedding, should be in [128, 768]
58
+ embeddings = outputs.last_hidden_state[:, 0][:dimension]
59
+
60
+ embeddings = F.normalize(embeddings, p=2, dim=1)
61
+ scores = (embeddings[:1] @ embeddings[1:].T) * 100
62
+ print(scores.tolist())
63
+ ```
64
+
65
+ Use with sentence-transformers
66
+ ```
67
+ from sentence_transformers import SentenceTransformer
68
+ from sentence_transformers.util import cos_sim
69
+
70
+ input_texts = [
71
+ "what is the capital of China?",
72
+ "how to implement quick sort in python?",
73
+ "北京",
74
+ "快排算法介绍"
75
+ ]
76
+
77
+ model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base', trust_remote_code=True)
78
+ embeddings = model.encode(input_texts)
79
+ ```
80
+
81
+ Use with custom code to get dense embeddigns and sparse token weights
82
+ ```
83
+ # You can find the gte_embeddings.py in https://huggingface.co/Alibaba-NLP/gte-multilingual-base/blob/main/scripts/gte_embedding.py
84
+ from gte_embeddings import GTEEmbeddidng
85
+
86
+ model_path = 'Alibaba-NLP/gte-multilingual-base'
87
+ model = GTEEmbeddidng(model_path)
88
+ query = "中国的首都在哪儿"
89
+
90
+ docs = [
91
+ "what is the capital of China?",
92
+ "how to implement quick sort in python?",
93
+ "北京",
94
+ "快排算法介绍"
95
+ ]
96
+
97
+ embs = model.encode(docs, return_dense=True,return_sparse=True)
98
+ print('dense_embeddings vecs', embs['dense_embeddings'])
99
+ print('token_weights', embs['token_weights'])
100
+ pairs = [(query, doc) for doc in docs]
101
+ dense_scores = model.compute_scores(pairs, dense_weight=1.0, sparse_weight=0.0)
102
+ sparse_scores = model.compute_scores(pairs, dense_weight=0.0, sparse_weight=1.0)
103
+ hybird_scores = model.compute_scores(pairs, dense_weight=1.0, sparse_weight=0.3)
104
+ print('dense_scores', dense_scores)
105
+ print('sparse_scores', sparse_scores)
106
+ print('hybird_scores', hybird_scores)
107
+ ```