from text_extractor import TextExtractor from tqdm import tqdm from transformers import PegasusForConditionalGeneration, PegasusTokenizer from transformers import pipeline from mdutils.mdutils import MdUtils from pathlib import Path import gradio as gr import fitz import torch import copy import os FILENAME = "" preprocess = TextExtractor() model_name = "sshleifer/distill-pegasus-cnn-16-4" device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = PegasusTokenizer.from_pretrained(model_name, max_length=500) model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) def summarize(slides): generated_slides = copy.deepcopy(slides) for page, contents in tqdm(generated_slides.items()): for idx, (tag, content) in enumerate(contents): if tag.startswith('p'): try: input = tokenizer(content, truncation=True, padding="longest", return_tensors="pt").to(device) tensor = model.generate(**input) summary = tokenizer.batch_decode(tensor, skip_special_tokens=True)[0] contents[idx] = (tag, summary) except Exception as e: print(e) print("Summarization Fails") return generated_slides def convert2markdown(generate_slides): mdFile = MdUtils(file_name=FILENAME, title=f'{FILENAME} Presentation') for k, v in generate_slides.items(): mdFile.new_paragraph('---') for section in v: tag = section[0] content = section[1] if tag.startswith('h'): mdFile.new_header(level=int(tag[1]), title=content) if tag == 'p': contents = content.split('') for content in contents: mdFile.new_paragraph(content) mdFile.create_md_file() return f"{FILENAME}.md" def inference(document): global FILENAME print(document) doc = fitz.open(document) FILENAME = Path(doc.name).stem font_counts, styles = preprocess.get_font_info(doc, granularity=False) size_tag = preprocess.get_font_tags(font_counts, styles) texts = preprocess.assign_tags(doc, size_tag) slides = preprocess.get_slides(texts) generated_slides = summarize(slides) markdown_path = convert2markdown(generated_slides) # with open(markdown_path, 'rt') as f: # markdown_str = f.read() return markdown_path with gr.Blocks() as demo: inp = gr.File(file_types=['pdf']) out = gr.File(label="Markdown File") # out = gr.Textbox(label="Markdown Content") inference_btn = gr.Button("Summarized PDF") inference_btn.click(fn=inference, inputs=inp, outputs=out, show_progress=True, api_name="summarize") demo.launch()