how to finetune
#9
by
Yudum
- opened
Thank you for the multilingual reranker. How can we fine-tune this reranker to improve its performance on a specific language?
You can use sentence_transformers library to fine-tune it, here is the code to do it
import math
import json
from torch.utils.data import DataLoader
from sentence_transformers import InputExample, LoggingHandler
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator
import logging
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
logger = logging.getLogger(__name__)
train_batch_size = 4
num_epochs = 3
model_id = "Alibaba-NLP/gte-multilingual-reranker-base"
model_save_path = "gte-multilingual-reranker-base-new"
model = CrossEncoder(model_id, num_labels=1, trust_remote_code=True)
data_dir = "training_data"
train_data = list()
for d in json.loads(open(f"{data_dir}/train.json").read()):
train_data.append(InputExample(texts=[d["sentence1"], d["sentence1"]], label=d["label"]))
dev_data = list()
for d in json.loads(open(f"{data_dir}/dev.json").read()):
dev_data.append(InputExample(texts=[d["sentence1"], d["sentence1"]], label=d["label"]))
test_data = list()
for d in json.loads(open(f"{data_dir}/test.json").read()):
test_data.append(InputExample(texts=[d["sentence1"], d["sentence1"]], label=d["label"]))
logger.info(f"Total Train Data: {len(train_data)}, Dev Data: {len(dev_data)}, Test Data: {len(test_data)}")
logger.info(f"Training the model")
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=train_batch_size)
evaluator = CECorrelationEvaluator.from_input_examples(dev_data, name="sts-dev")
# Configure the training
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1)
logger.info(f"Warmup-steps: {warmup_steps}")
# Train the model
model.fit(
train_dataloader=train_dataloader,
evaluator=evaluator,
epochs=num_epochs,
warmup_steps=warmup_steps,
output_path=model_save_path,
use_amp=True
)
# Load model and eval on test set
model = CrossEncoder(model_save_path, trust_remote_code=True)
evaluator = CECorrelationEvaluator.from_input_examples(test_data, name="sts-test")
logger.info(evaluator(model))```
Is there any sample dataset which we can reference for creating our own dataset?
You can check this: https://sbert.net/docs/cross_encoder/training/examples.html