|
import pandas as pd |
|
import torch |
|
from torch.utils.data import Dataset |
|
from PIL import Image |
|
import argparse |
|
from evaluate import load |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator, AdamW |
|
import torchvision.transforms as transforms |
|
from augments import RandAug, RandRotate |
|
|
|
parser = argparse.ArgumentParser('arguments for the code') |
|
|
|
parser.add_argument('--root_path', type=str, default="", |
|
help='Root path to data files.') |
|
parser.add_argument('--tr_data_path', type=str, default="/path/to/train_data.csv", |
|
help='Path to .csv file containing the training data.') |
|
parser.add_argument('--val_data_path', type=str, default="/path/to/val_data.csv", |
|
help='Path to .csv file containing the validation data.') |
|
parser.add_argument('--output_path', type=str, default="./output/path/", |
|
help='Path for saving training results.') |
|
parser.add_argument('--model_path', type=str, default="/model/path/", |
|
help='Path to trocr model') |
|
parser.add_argument('--processor_path', type=str, default="/processor/path/", |
|
help='Path to trocr processor') |
|
parser.add_argument('--epochs', type=int, default=15, |
|
help='Training epochs.') |
|
parser.add_argument('--batch_size', type=int, default=16, |
|
help='Training epochs.') |
|
parser.add_argument('--device', type=str, default="cuda:0", |
|
help='Device used for training.') |
|
parser.add_argument('--augment', type=int, default=0, |
|
help='Defines if image augmentations are used during training.') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
processor = TrOCRProcessor.from_pretrained(args.processor_path) |
|
model = VisionEncoderDecoderModel.from_pretrained(args.model_path) |
|
model.to(args.device) |
|
|
|
|
|
cer_metric = load("cer") |
|
wer_metric = load("wer") |
|
|
|
|
|
train_df = pd.read_csv(args.tr_data_path) |
|
val_df = pd.read_csv(args.val_data_path) |
|
|
|
|
|
train_df.reset_index(drop=True, inplace=True) |
|
val_df.reset_index(drop=True, inplace=True) |
|
|
|
|
|
class TextlineDataset(Dataset): |
|
def __init__(self, root_dir, df, processor, max_target_length=128, augment=False): |
|
self.root_dir = root_dir |
|
self.df = df |
|
self.processor = processor |
|
self.max_target_length = max_target_length |
|
self.augment = augment |
|
self.augmentator = RandAug() |
|
self.rotator = RandRotate() |
|
|
|
def __len__(self): |
|
return len(self.df) |
|
|
|
def __getitem__(self, idx): |
|
|
|
file_name = self.df['file_name'][idx] |
|
text = self.df['text'][idx] |
|
|
|
|
|
image = Image.open(self.root_dir + file_name).convert("RGB") |
|
|
|
if self.augment: |
|
image = self.augmentator(image) |
|
|
|
pixel_values = self.processor(image, return_tensors="pt").pixel_values |
|
|
|
|
|
labels = self.processor.tokenizer(text, |
|
padding="max_length", truncation=True, |
|
max_length=self.max_target_length).input_ids |
|
|
|
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels] |
|
|
|
encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)} |
|
return encoding |
|
|
|
|
|
train_dataset = TextlineDataset(root_dir=args.root_path, |
|
df=train_df, |
|
processor=processor, |
|
augment=args.augment) |
|
|
|
eval_dataset = TextlineDataset(root_dir=args.root_path, |
|
df=val_df, |
|
processor=processor, |
|
augment=False) |
|
|
|
print("Number of training examples:", len(train_dataset)) |
|
print("Number of validation examples:", len(eval_dataset)) |
|
|
|
|
|
|
|
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id |
|
model.config.pad_token_id = processor.tokenizer.pad_token_id |
|
|
|
model.config.vocab_size = model.config.decoder.vocab_size |
|
|
|
model.config.eos_token_id = processor.tokenizer.sep_token_id |
|
model.config.max_length = 64 |
|
model.config.early_stopping = True |
|
model.config.no_repeat_ngram_size = 3 |
|
model.config.length_penalty = 2.0 |
|
model.config.num_beams = 1 |
|
|
|
|
|
|
|
training_args = Seq2SeqTrainingArguments( |
|
predict_with_generate=True, |
|
eval_strategy="epoch", |
|
save_strategy="epoch", |
|
logging_strategy="steps", |
|
logging_steps=50, |
|
per_device_train_batch_size=args.batch_size, |
|
per_device_eval_batch_size=args.batch_size, |
|
load_best_model_at_end=True, |
|
metric_for_best_model='cer', |
|
greater_is_better=False, |
|
fp16=True, |
|
num_train_epochs=args.epochs, |
|
save_total_limit=1, |
|
output_dir=args.output_path, |
|
optim='adamw_torch' |
|
) |
|
|
|
|
|
def compute_metrics(pred): |
|
labels_ids = pred.label_ids |
|
pred_ids = pred.predictions |
|
|
|
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True) |
|
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id |
|
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True) |
|
|
|
cer = cer_metric.compute(predictions=pred_str, references=label_str) |
|
wer = wer_metric.compute(predictions=pred_str, references=label_str) |
|
|
|
return {"cer": cer, "wer": wer} |
|
|
|
|
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
model=model, |
|
tokenizer=processor.image_processor, |
|
args=training_args, |
|
compute_metrics=compute_metrics, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
data_collator=default_data_collator, |
|
) |
|
|
|
|
|
trainer.train() |
|
|
|
model.save_pretrained(args.output_path) |
|
processor.save_pretrained(args.output_path + "/processor") |