|
import os |
|
import gradio as gr |
|
from collections import OrderedDict |
|
from PIL import Image, ImageDraw, ImageFont |
|
from io import BytesIO |
|
import time |
|
import tempfile |
|
import PyPDF2 |
|
import pdf2image |
|
from datasets import load_dataset |
|
|
|
MAX_PAGES = 50 |
|
MAX_PDF_SIZE = 100000000 |
|
MIN_WIDTH, MIN_HEIGHT = 150, 150 |
|
|
|
|
|
def equal_image_grid(images): |
|
def compute_grid(n, max_cols=6): |
|
equalDivisor = int(n**0.5) |
|
cols = min(equalDivisor, max_cols) |
|
rows = equalDivisor |
|
if rows * cols >= n: |
|
return rows, cols |
|
cols += 1 |
|
if rows * cols >= n: |
|
return rows, cols |
|
while rows * cols < n: |
|
rows += 1 |
|
return rows, cols |
|
|
|
|
|
rows, cols = compute_grid(len(images)) |
|
|
|
|
|
images = [im for im in images if (im.height > 0) and (im.width > 0)] |
|
|
|
min_width = min(im.width for im in images) |
|
images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images] |
|
|
|
w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images]) |
|
|
|
grid = Image.new("RGB", size=(cols * w, rows * h)) |
|
grid_w, grid_h = grid.size |
|
|
|
for i, img in enumerate(images): |
|
grid.paste(img, box=(i % cols * w, i // cols * h)) |
|
return grid |
|
|
|
|
|
def add_pagenumbers(im_list, height_scale=40): |
|
def add_pagenumber(image, i): |
|
width, height = image.size |
|
draw = ImageDraw.Draw(image) |
|
fontsize = int((width * height) ** (0.5) / height_scale) |
|
font = ImageFont.truetype("Arial.ttf", fontsize) |
|
margin = int(2 * fontsize) |
|
draw.text( |
|
(width - margin, height - margin), |
|
str(i + 1), |
|
fill="#D00917", |
|
font=font, |
|
spacing=4, |
|
align="right", |
|
) |
|
|
|
for i, image in enumerate(im_list): |
|
add_pagenumber(image, i) |
|
|
|
|
|
def pdf_to_grid(pdf_path): |
|
reader = PyPDF2.PdfReader(pdf_path) |
|
reached_page_limit = False |
|
images = [] |
|
try: |
|
for p, page in enumerate(reader.pages): |
|
if reached_page_limit: |
|
break |
|
for image in page.images: |
|
im = Image.open(BytesIO(image.data)) |
|
if im.width < MIN_WIDTH and im.height < MIN_HEIGHT: |
|
continue |
|
images.append(im) |
|
except Exception as e: |
|
print(f"{pdf_path} PyPDF get_images {e}") |
|
images = pdf2image.convert_from_bytes(pdf_path) |
|
|
|
|
|
|
|
|
|
if len(images) == 0: |
|
return None |
|
add_pagenumbers(images) |
|
return equal_image_grid(images) |
|
|
|
|
|
def main(dataset, label): |
|
|
|
timestamp = time.time() |
|
seed = int(timestamp * 1000) % 1000000 |
|
|
|
try: |
|
shuffled_dataset = DATASETS[dataset].shuffle(buffer_size=10, seed=seed) |
|
except: |
|
shuffled_dataset = DATASETS[dataset].shuffle(seed=seed) |
|
|
|
|
|
for sample in shuffled_dataset: |
|
label_column = "label" if "label" in sample else "labels" |
|
filelabel = _CLASSES[sample[label_column]] |
|
if label and filelabel != label: |
|
continue |
|
pdf_path = sample["file"] |
|
grid = pdf_to_grid(BytesIO(pdf_path)) |
|
if grid is None: |
|
continue |
|
PDF = tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) |
|
with PDF as tmp_file: |
|
|
|
tmp_file.write(pdf_path) |
|
return filelabel, grid, tmp_file.name |
|
|
|
|
|
_CLASSES = [ |
|
"letter", |
|
"form", |
|
"email", |
|
"handwritten", |
|
"advertisement", |
|
"scientific report", |
|
"scientific publication", |
|
"specification", |
|
"file folder", |
|
"news article", |
|
"budget", |
|
"invoice", |
|
"presentation", |
|
"questionnaire", |
|
"resume", |
|
"memo", |
|
"", |
|
] |
|
|
|
|
|
DATASETS = OrderedDict( |
|
{ |
|
|
|
"rvl_cdip_N": load_dataset("bdpc/rvl_cdip_n_mp", split="test"), |
|
} |
|
) |
|
|
|
meta_cats = {"dataset": ["rvl_cdip", "rvl_cdip_N"], "label": _CLASSES} |
|
sliders = [gr.Dropdown(choices=choices, value=choices[-1], label=label) for label, choices in meta_cats.items()] |
|
slider_defaults = [sliders[0].value, None] |
|
|
|
|
|
|
|
|
|
outputs = [ |
|
gr.Textbox(label="label"), |
|
gr.Image(label="image grid of PDF"), |
|
gr.File(label="PDF"), |
|
] |
|
|
|
DESCRIPTION = """ |
|
Visualize PDF samples from multi-page (PDF) document classification datasets @ https://huggingface.co/datasets/bdpc |
|
|
|
- **dataset**: dataset name |
|
- **label**: label name |
|
|
|
The first time that the app is launched, it will download the datasets, which can take a few minutes. |
|
For fastest response, choose the rvl_cdip_N dataset, which is considerably smaller to iterate over. |
|
""" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=main, |
|
inputs=sliders, |
|
outputs=outputs, |
|
description=DESCRIPTION, |
|
title="Beyond Document Page Classification: Examples", |
|
) |
|
iface.launch(share=True) |
|
|