import streamlit as st from transformers import AutoTokenizer, T5ForConditionalGeneration import post_ocr # Sidebar information info = '''Welcome to the demo of the [swedish-ocr-correction](https://huggingface.co/viklofg/swedish-ocr-correction) model. Enter or upload OCR output and the model will attempt to correct it. :clock2: Slow generation? Try a shorter input. ''' # Example inputs examples = { 'Examples': None, 'Example 1': 'En Gosse fur plats nu genast ! inetallyrkc, JU 83 Drottninggatan.', 'Example 2': '— Storartad gåfva till Göteborgs Museum. Den i HandelstidniDgens g&rdagsnnmmer omtalade hvalfisken, sorn fångats i Frölnndaviken, har i dag af hr brukspatronen James Dickson blifvit inköpt för 1,500 rdr och skänkt till härvarande Museum.', 'Example 3': 'Sn underlig race att ſtudera, desfa uppſinnare! utropar en Londontidnings fronifôr. Wet ni hur ſtort antalet är af patenter, ſom ſiſtlidet är utfärdades i British Patent Office? Jo, 14,000 ſty>en !! Det kan man ju fkalla en rif rd! Fjorton tuſen uppfinninnar! Herre Gud, hwilfet märkrwoärdigt tidehrvarf wi lefroa i!' } # Load model @st.cache_resource def load_model(): return T5ForConditionalGeneration.from_pretrained('KBLab/swedish-ocr-correction') model = load_model() # Load tokenizer @st.cache_resource def load_tokenizer(): return AutoTokenizer.from_pretrained('google/byt5-small') tokenizer = load_tokenizer() # Set model and tokenizer post_ocr.set_model(model, tokenizer) # Title st.title(':memo: Swedish OCR correction') # Input and output areas tab1, tab2 = st.tabs(["Text input", "From file"]) # Initialize session states if 'inputs' not in st.session_state: st.session_state.inputs = {'tab1': None, 'tab2': None} if 'outputs' not in st.session_state: st.session_state.outputs = {'tab1': None, 'tab2': None} # Sidebar (info) with st.sidebar: st.header('About') st.markdown(info) def handle_input(input_, id_): """Generate and display output""" # Put everything output-related in a bordered container with st.container(border=True): st.caption('Output') # Only update the output if the input has been updated if input_ and st.session_state.inputs[id_] != input_: st.session_state.inputs[id_] = input_ with st.spinner('Generating...'): output = post_ocr.process(input_) st.session_state.outputs[id_] = output # This container is needed to display the `show changes` toggle # after the output text container = st.container() st.divider() show_changes = st.toggle('Show changes', key=f'show_changes_{id_}') with container: # Display output output = st.session_state.outputs[id_] if output is not None: st.write(post_ocr.diff(input_, output) if show_changes else output) # Manual entry tab with tab1: col1, col2 = st.columns([4, 1]) with col2: example_title = st.selectbox('Examples', options=examples, label_visibility='collapsed') with col1: text = st.text_area( label='Input text', value=examples[example_title], height=200, label_visibility='collapsed', placeholder='Enter OCR generated text or choose an example') if text is not None: handle_input(text, 'tab1') # File upload tab with tab2: uploaded_file = st.file_uploader('Choose a file', type='.txt') # Display file content if uploaded_file is not None: file_content = uploaded_file.getvalue().decode('utf-8') text = st.text_area('File content', value=file_content, height=300) handle_input(text, 'tab2')