Model max_seq_length

#6
by shuyuej - opened

Should we specifically set model.max_seq_length =512?

StellaEncoder org

Hi, the model was trained on the data with 512 length. I recommend model.max_seq_length <= 1024.

@infgrad Thank you very much for your reply! I really appreciate it!

I did a quick test and found that the similarity between query and docs is pretty low.
For example, the Cosine similarity between What is apple? and Apple is a kind of fruit. is only 0.24.

from sentence_transformers import SentenceTransformer

query_prompt_name = "s2s_query"  # or s2p_query 
queries = [
    "What is apple?",
    "What is the treatment of dementia?",
]

# docs do not need any prompts
docs = [
    "Apple is a kind of fruit.",
    "Thank you, Representative Waxman, for taking the time to speak with me today about your health policy work. This is Bridget Keene with JAMA.",
]

# !The default dimension is 1024, if you need other dimensions, please clone the model and modify `modules.json`
# to replace `2_Dense_1024` with another dimension, e.g. `2_Dense_256` or `2_Dense_8192` !
model = SentenceTransformer("infgrad/stella_en_1.5B_v5")
query_embeddings = model.encode(queries, prompt_name=query_prompt_name)
doc_embeddings = model.encode(docs)

print(query_embeddings.shape, doc_embeddings.shape)
similarities = model.similarity(query_embeddings, doc_embeddings)
print(similarities)

The result is as follows,

(2, 1024) (2, 1024)
tensor([[0.2433, 0.2636],
        [0.2337, 0.2678]])

Could you please help me check if there is any wrong in the codes?

Thank you very much in advance!

Best regards,

Shuyue
July 17th, 2024

StellaEncoder org

Hi, this is really weird!
Here is my codes:

from sentence_transformers import SentenceTransformer

if __name__ == "__main__":
    query_prompt_name = "s2s_query"  # or s2p_query
    queries = [
        "What is apple?",
        "What is the treatment of dementia?",
    ]

    # docs do not need any prompts
    docs = [
        "Apple is a kind of fruit.",
        "Thank you, Representative Waxman, for taking the time to speak with me today about your health policy work. This is Bridget Keene with JAMA.",
    ]

    # !The default dimension is 1024, if you need other dimensions, please clone the model and modify `modules.json`
    # to replace `2_Dense_1024` with another dimension, e.g. `2_Dense_256` or `2_Dense_8192` !
    model = SentenceTransformer("MODEL_PATH", trust_remote_code=True)
    query_embeddings = model.encode(queries, prompt_name=query_prompt_name)
    doc_embeddings = model.encode(docs)

    print(query_embeddings.shape, doc_embeddings.shape)
    similarities = model.similarity(query_embeddings, doc_embeddings)
    print(similarities)

The output is:

(2, 1024) (2, 1024)
tensor([[0.6056, 0.1705],
        [0.1453, 0.1893]])

Please update your model and try again.

If you the program has any warnings, please let me know.

Best regards,

StellaEncoder org

@shuyuej
Here is my environment:

accelerate                        0.31.0
aiofiles                          23.2.1
aiohttp                           3.9.5
aiosignal                         1.3.1
altair                            5.2.0
annotated-types                   0.6.0
anyio                             4.3.0
asttokens                         2.0.5
async-timeout                     4.0.3
attrs                             23.2.0
beautifulsoup4                    4.12.3
beir                              2.0.0
bitsandbytes                      0.43.0
cachetools                        5.3.3
certifi                           2024.2.2
charset-normalizer                3.3.2
click                             8.1.7
cloudpickle                       3.0.0
cmake                             3.29.0.1
colorama                          0.4.6
comm                              0.2.2
contourpy                         1.2.0
cupy-cuda12x                      12.1.0
cycler                            0.12.1
datasets                          2.20.0
debugpy                           1.6.7
decorator                         5.1.1
deepspeed                         0.14.4
dill                              0.3.8
diskcache                         5.6.3
distro                            1.9.0
dnspython                         2.6.1
docker-pycreds                    0.4.0
einops                            0.8.0
elasticsearch                     7.9.1
et-xmlfile                        1.1.0
eval_type_backport                0.2.0
exceptiongroup                    1.2.0
executing                         0.8.3
faiss-cpu                         1.8.0
fastapi                           0.110.0
fastrlock                         0.8.2
ffmpy                             0.3.2
filelock                          3.13.1
flash-attn                        2.5.9.post1
fonttools                         4.49.0
frozenlist                        1.4.1
fsspec                            2024.2.0
gitdb                             4.0.11
GitPython                         3.1.43
google                            3.0.0
GPUtil                            1.4.0
gradio                            4.36.1
gradio_client                     1.0.1
h11                               0.14.0
hjson                             3.1.0
httpcore                          1.0.4
httptools                         0.6.1
httpx                             0.27.0
huggingface-hub                   0.23.4
idna                              3.6
importlib_metadata                7.1.0
importlib_resources               6.3.0
interegular                       0.3.3
ipykernel                         6.29.3
ipython                           8.20.0
jedi                              0.18.1
jieba                             0.42.1
Jinja2                            3.1.3
jiojio                            1.2.5
jionlp                            1.5.7
joblib                            1.3.2
jsonlines                         4.0.0
jsonschema                        4.21.1
jsonschema-specifications         2023.12.1
jupyter_client                    8.6.1
jupyter_core                      5.7.2
kiwisolver                        1.4.5
lark                              1.1.9
llvmlite                          0.42.0
lm-format-enforcer                0.10.1
loguru                            0.7.2
markdown-it-py                    3.0.0
markdown_to_json                  2.1.1
MarkupSafe                        2.1.5
matplotlib                        3.8.3
matplotlib-inline                 0.1.6
mdurl                             0.1.2
mpmath                            1.3.0
msgpack                           1.0.8
mteb                              1.12.79
multidict                         6.0.5
multiprocess                      0.70.16
nest_asyncio                      1.6.0
networkx                          3.2.1
ninja                             1.11.1.1
numba                             0.59.0
numpy                             1.26.4
nvidia-cublas-cu12                12.1.3.1
nvidia-cuda-cupti-cu12            12.1.105
nvidia-cuda-nvrtc-cu12            12.1.105
nvidia-cuda-runtime-cu12          12.1.105
nvidia-cudnn-cu12                 8.9.2.26
nvidia-cufft-cu12                 11.0.2.54
nvidia-curand-cu12                10.3.2.106
nvidia-cusolver-cu12              11.4.5.107
nvidia-cusparse-cu12              12.1.0.106
nvidia-ml-py                      12.550.52
nvidia-nccl-cu12                  2.20.5
nvidia-nvjitlink-cu12             12.4.99
nvidia-nvtx-cu12                  12.1.105
openai                            1.14.0
openpyxl                          3.1.2
orjson                            3.9.15
outlines                          0.0.46
packaging                         24.0
pandas                            2.2.1
parso                             0.8.3
peft                              0.9.0
pexpect                           4.8.0
pillow                            10.2.0
pip                               23.3.1
platformdirs                      4.2.0
polars                            0.20.31
prometheus_client                 0.20.0
prometheus-fastapi-instrumentator 7.0.0
prompt-toolkit                    3.0.43
protobuf                          5.26.0
psutil                            5.9.8
ptyprocess                        0.7.0
pure-eval                         0.2.2
py                                1.11.0
py-cpuinfo                        9.0.0
pyairports                        2.1.1
pyarrow                           16.1.0
pyarrow-hotfix                    0.6
pycountry                         24.6.1
pycryptodome                      3.9.9
pydantic                          2.6.4
pydantic_core                     2.16.3
pydub                             0.25.1
Pygments                          2.15.1
PyJWT                             2.8.0
pymongo                           4.8.0
pynvml                            11.5.0
pyparsing                         3.1.2
python-dateutil                   2.9.0
python-dotenv                     1.0.1
python-multipart                  0.0.9
pytrec-eval                       0.5
pytrec-eval-terrier               0.5.6
pytz                              2020.5
PyYAML                            6.0.1
pyzmq                             25.1.2
ray                               2.9.3
referencing                       0.33.0
regex                             2023.12.25
requests                          2.32.3
retry                             0.9.2
rich                              13.7.1
rjieba                            0.1.11
roformer                          0.4.3
rpds-py                           0.18.0
ruff                              0.3.2
safetensors                       0.4.2
scikit-learn                      1.4.1.post1
scipy                             1.12.0
semantic-version                  2.10.0
sentence-transformers             3.0.1
sentencepiece                     0.2.0
sentry-sdk                        2.7.1
setproctitle                      1.3.3
setuptools                        68.2.2
shellingham                       1.5.4
six                               1.16.0
smmap                             5.0.1
sniffio                           1.3.1
soupsieve                         2.5
stack-data                        0.2.0
starlette                         0.36.3
sympy                             1.12
threadpoolctl                     3.3.0
tiktoken                          0.6.0
tokenizers                        0.19.1
tomlkit                           0.12.0
toolz                             0.12.1
torch                             2.3.0
torchvision                       0.18.0
tornado                           6.4
tqdm                              4.66.4
traitlets                         5.7.1
transformers                      4.42.3
triton                            2.3.0
typer                             0.12.3
typing_extensions                 4.10.0
tzdata                            2024.1
urllib3                           2.2.1
uvicorn                           0.28.0
uvloop                            0.19.0
vllm                              0.5.1
vllm-flash-attn                   2.5.9
vllm-nccl-cu12                    2.18.1.0.1.0
volcengine                        1.0.133
volcengine-python-sdk             1.0.86
wandb                             0.17.3
watchfiles                        0.21.0
wcwidth                           0.2.5
websockets                        11.0.3
wheel                             0.41.2
xformers                          0.0.26.post1
xxhash                            3.4.1
yarl                              1.9.4
zhconv                            1.4.3
zhipuai                           2.0.1
zipfile36                         0.1.3
zipp                              3.17.0

@infgrad Got it! Thank you very much! I really appreciate it!

Did you figure out what the issue was? I'm pretty curious.

  • Tom Aarsen

Sign up or log in to comment