AlhitawiMohammed22 commited on
Commit
9002e70
1 Parent(s): 8c112b4

text generation model

Browse files
Files changed (1) hide show
  1. trocr.py +30 -0
trocr.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
+
5
+
6
+ device = "cuda" if torch.cuda.is_available() else "cpu"
7
+
8
+
9
+ class IAMDataset(Dataset):
10
+ def __init__(self, crops, processor):
11
+ self.crops = crops
12
+ self.processor = processor
13
+
14
+ def __len__(self):
15
+ return len(self.crops)
16
+
17
+ def __getitem__(self, idx):
18
+ crp = self.crops[idx]
19
+ pixel_values = self.processor(crp, return_tensors="pt").pixel_values
20
+ encoding = {"pixel_values": pixel_values.squeeze()}
21
+ return encoding
22
+
23
+ def get_processor_model(checkpoint:str):
24
+ rec_processor = TrOCRProcessor.from_pretrained('trocr_printed_processor/')
25
+ rec_model = VisionEncoderDecoderModel.from_pretrained('trocr_printed_model/')
26
+ rec_model.config.eos_token_id = 2
27
+ rec_model.config.pad_token_id = 2
28
+ rec_model.to(device)
29
+ rec_model.eval()
30
+ return rec_processor, rec_model