import transformers import re from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch import gradio as gr import difflib import spaces from concurrent.futures import ThreadPoolExecutor import os # OCR Correction Model model_name = "PleIAs/OCRonos-Vintage" device = "cuda" if torch.cuda.is_available() else "cpu" # Load pre-trained model and tokenizer model = GPT2LMHeadModel.from_pretrained(model_name).to(device) tokenizer = GPT2Tokenizer.from_pretrained(model_name) # CSS for formatting css = """ """ def generate_html_diff(old_text, new_text): d = difflib.Differ() diff = list(d.compare(old_text.split(), new_text.split())) html_diff = [] for word in diff: if word.startswith(' '): html_diff.append(word[2:]) elif word.startswith('+ '): html_diff.append(f'{word[2:]}') return ' '.join(html_diff) def split_text(text, max_tokens=400): tokens = tokenizer.tokenize(text) chunks = [] current_chunk = [] for token in tokens: current_chunk.append(token) if len(current_chunk) >= max_tokens: chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) current_chunk = [] if current_chunk: chunks.append(tokenizer.convert_tokens_to_string(current_chunk)) return chunks @spaces.GPU def ocr_correction(prompt, max_new_tokens=600, num_threads=os.cpu_count()): prompt = f"""### Text ###\n{prompt}\n\n\n### Correction ###\n""" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device) output = model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, top_k=50, num_return_sequences=1, do_sample=False ) result = tokenizer.decode(output[0], skip_special_tokens=True) return result.split("### Correction ###")[1].strip() def process_text(user_message): chunks = split_text(user_message) corrected_chunks = [] for chunk in chunks: corrected_chunk = ocr_correction(chunk) corrected_chunks.append(corrected_chunk) corrected_text = ' '.join(corrected_chunks) html_diff = generate_html_diff(user_message, corrected_text) ocr_result = f'