Daryl Lim
Update app.py
2709a97
"""
This module provides an interface for summarizing text using the BART Large CNN model.
The interface allows users to upload a PDF of a research paper.
The user will receive a generated summary of the research paper's abstract.
"""
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pypdf
import re
# Load tokenizer with fast processing enabled
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn", use_fast=True)
# Load model with bf16 for optimized memory usage
model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/bart-large-cnn", torch_dtype=torch.bfloat16
)
def extract_abstract(pdf_path):
"""
Extract the text from the abstract section in a PDF file.
Args:
pdf_path (str): Path to the PDF file.
Returns:
str: The extracted abstract text. Returns an empty string if no abstract is found.
"""
try:
with open(pdf_path, "rb") as f:
reader = pypdf.PdfReader(f)
# Extract text from the first two pages of the PDF
text = reader.pages[0].extract_text(orientations=(0)) + \
reader.pages[1].extract_text(orientations=(0))
# Remove new lines for clean text extraction
text = text.replace("\n", "")
except FileNotFoundError:
print(f"File not found: {pdf_path}")
return ""
except Exception as e:
print(f"Error reading PDF file: {e}")
return ""
# Search for the abstract section
abstract_regex = re.compile(r"Abstract|ABSTRACT", re.IGNORECASE)
abstract_match = re.search(abstract_regex, text)
if not abstract_match:
return ""
ABSTRACT_OFFSET = 8 # Length of the word "Abstract"
abstract_start = abstract_match.start() + ABSTRACT_OFFSET
# Search for the introduction section to determine where the abstract ends
introduction_regex = re.compile(r"Introduction|ntroduction|INTRODUCTION|NTRODUCTION")
intro_match = re.search(introduction_regex, text[abstract_start:])
if not intro_match:
return "" # Return empty string if no introduction section is found
abstract_end = intro_match.start() + abstract_start
return text[abstract_start:abstract_end]
@spaces.GPU
def summarize_abstract(pdf_path):
"""
Generate a summary of the text from the abstract section in a PDF file.
Args:
pdf_path (str): The path to the PDF file.
Returns:
str: The generated summary of the abstrat, or an empty string if no abstract is found.
"""
try:
# Extract abstract text from the PDF
abstract_text = extract_abstract(pdf_path)
if not abstract_text:
print("No abstract found in the PDF.")
return ""
# Tokenize the abstract text for summarization
tokenized_input = tokenizer(
abstract_text,
truncation=True,
max_length=130,
return_tensors="pt"
)
# Generate a summary prediction using the model
summary_ids = model.generate(**tokenized_input)
# Decode the generated summary
summary = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
return summary[0] if summary else ""
except FileNotFoundError:
print(f"File not found: {pdf_path}")
return ""
except Exception as e:
print(f"Error generating summary: {e}")
return ""
# Define the Gradio interface
demo = gr.Interface(
fn=summarize_abstract,
inputs=[gr.File(label="Upload PDF")],
outputs=[gr.Textbox(label="Abstract Summary")],
title="BART Large CNN Abstract Summarizer",
description="Upload a research paper in PDF format to generate a summary of its abstract."
)
# Launch the Gradio interface
demo.launch()