tablecell-htr / train_trocr.py
MikkoLipsanen's picture
Update train_trocr.py
dc41f2f verified
raw
history blame contribute delete
No virus
6.68 kB
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
import argparse
from evaluate import load
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW
import torchvision.transforms as transforms
from augments import RandAug, RandRotate
parser = argparse.ArgumentParser('arguments for the code')
parser.add_argument('--root_path', type=str, default="",
help='Root path to data files.')
parser.add_argument('--tr_data_path', type=str, default="/path/to/train_data.csv",
help='Path to .csv file containing the training data.')
parser.add_argument('--val_data_path', type=str, default="/path/to/val_data.csv",
help='Path to .csv file containing the validation data.')
parser.add_argument('--output_path', type=str, default="./output/path/",
help='Path for saving training results.')
parser.add_argument('--model_path', type=str, default="/model/path/",
help='Path to trocr model')
parser.add_argument('--processor_path', type=str, default="/processor/path/",
help='Path to trocr processor')
parser.add_argument('--epochs', type=int, default=15,
help='Training epochs.')
parser.add_argument('--batch_size', type=int, default=16,
help='Training epochs.')
parser.add_argument('--device', type=str, default="cuda:0",
help='Device used for training.')
parser.add_argument('--augment', type=int, default=0,
help='Defines if image augmentations are used during training.')
args = parser.parse_args()
# Initialize processor and model
processor = TrOCRProcessor.from_pretrained(args.processor_path)
model = VisionEncoderDecoderModel.from_pretrained(args.model_path)
model.to(args.device)
# Initialize metrics
cer_metric = load("cer")
wer_metric = load("wer")
# Load train and validation data to dataframes
train_df = pd.read_csv(args.tr_data_path)
val_df = pd.read_csv(args.val_data_path)
# Reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)
# Torch dataset
class TextlineDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=128, augment=False):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
self.augment = augment
self.augmentator = RandAug()
self.rotator = RandRotate()
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df['file_name'][idx]
text = self.df['text'][idx]
# prepare image (i.e. resize + normalize)
image = Image.open(self.root_dir + file_name).convert("RGB")
if self.augment:
image = self.augmentator(image)
pixel_values = self.processor(image, return_tensors="pt").pixel_values
# add labels (input_ids) by encoding the text
labels = self.processor.tokenizer(text,
padding="max_length", truncation=True,
max_length=self.max_target_length).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
#encoding = {"pixel_values": pixel_values.squeeze(0),"labels":labels}
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
return encoding
# Create train and validation datasets
train_dataset = TextlineDataset(root_dir=args.root_path,
df=train_df,
processor=processor,
augment=args.augment)
eval_dataset = TextlineDataset(root_dir=args.root_path,
df=val_df,
processor=processor,
augment=False)
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(eval_dataset))
# Define model configuration
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 1
# Set arguments for model training
# For all argumenst see https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
eval_strategy="epoch",
save_strategy="epoch",
logging_strategy="steps",
logging_steps=50,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
load_best_model_at_end=True,
metric_for_best_model='cer',
greater_is_better=False,
fp16=True,
num_train_epochs=args.epochs,
save_total_limit=1,
output_dir=args.output_path,
optim='adamw_torch'
)
# Function for computing CER and WER metrics for the prediction results
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
return {"cer": cer, "wer": wer}
# Instantiate trainer
# For all parameters see: https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.image_processor,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
# Train the model
trainer.train()
#trainer.train(resume_from_checkpoint = True)
model.save_pretrained(args.output_path)
processor.save_pretrained(args.output_path + "/processor")