Pclanglais commited on
Commit
21257a3
1 Parent(s): fb02ef9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -62
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import transformers
2
  import re
3
- from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
4
  from vllm import LLM, SamplingParams
5
  import torch
6
  import gradio as gr
@@ -8,34 +8,39 @@ import json
8
  import os
9
  import shutil
10
  import requests
11
- import chromadb
12
- import difflib
13
  import pandas as pd
14
- from chromadb.config import Settings
15
- from chromadb.utils import embedding_functions
16
 
17
  # Define the device
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
- model_name = "Pclanglais/ocronos2"
 
 
21
 
22
- llm = LLM(model_name, max_model_len=8128)
 
 
 
 
23
 
24
-
25
- #CSS for references formatting
 
26
  css = """
 
27
  .generation {
28
- margin-left:2em;
29
- margin-right:2em;
30
- size:1.2em;
31
  }
32
  :target {
33
  background-color: #CCF3DF;
34
  }
35
  .source {
36
- float:left;
37
- max-width:17%;
38
- margin-left:2%;
39
  }
40
  .tooltip {
41
  position: relative;
@@ -43,7 +48,6 @@ css = """
43
  font-variant-position: super;
44
  color: #97999b;
45
  }
46
-
47
  .tooltip:hover::after {
48
  content: attr(data-text);
49
  position: absolute;
@@ -61,7 +65,6 @@ css = """
61
  display: block;
62
  box-shadow: 0 4px 8px rgba(0,0,0,0.1);
63
  }
64
- /* New styles for diff */
65
  .deleted {
66
  background-color: #ffcccb;
67
  text-decoration: line-through;
@@ -69,75 +72,186 @@ css = """
69
  .inserted {
70
  background-color: #90EE90;
71
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """
73
 
74
- #Curtesy of claude
75
  def generate_html_diff(old_text, new_text):
76
  d = difflib.Differ()
77
  diff = list(d.compare(old_text.split(), new_text.split()))
78
-
79
  html_diff = []
80
  for word in diff:
81
- if word.startswith(' '):
82
  html_diff.append(word[2:])
83
  elif word.startswith('+ '):
84
  html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
85
- # We're not adding anything for words that start with '- '
86
-
87
  return ' '.join(html_diff)
88
 
89
- # Class to encapsulate the Falcon chatbot
90
- class MistralChatBot:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
92
  self.system_prompt = system_prompt
93
 
94
- def predict(self, user_message):
95
  sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
96
  detailed_prompt = f"### TEXT ###\n{user_message}\n\n### CORRECTION ###\n"
97
- print(detailed_prompt)
98
  prompts = [detailed_prompt]
99
- outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
100
  generated_text = outputs[0].outputs[0].text
101
-
102
- # Generate HTML diff
103
  html_diff = generate_html_diff(user_message, generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- generated_text = '<h2 style="text-align:center">Réponse</h3>\n<div class="generation">' + html_diff + "</div>"
106
- return generated_text
107
 
108
- # Create the Falcon chatbot instance
109
- mistral_bot = MistralChatBot()
110
 
111
  # Define the Gradio interface
112
- title = "Correction d'OCR"
113
- description = "Un outil expérimental de correction d'OCR basé sur des modèles de langue"
114
- examples = [
115
- [
116
- "Qui peut bénéficier de l'AIP?", # user_message
117
- 0.7 # temperature
118
- ]
119
- ]
120
-
121
- additional_inputs=[
122
- gr.Slider(
123
- label="Température",
124
- value=0.2, # Default value
125
- minimum=0.05,
126
- maximum=1.0,
127
- step=0.05,
128
- interactive=True,
129
- info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté",
130
- ),
131
- ]
132
-
133
- demo = gr.Blocks()
134
-
135
- with gr.Blocks(theme='JohnSmith9982/small_and_pretty', css=css) as demo:
136
- gr.HTML("""<h1 style="text-align:center">Correction d'OCR</h1>""")
137
- text_input = gr.Textbox(label="Votre texte.", type="text", lines=1)
138
- text_button = gr.Button("Corriger l'OCR")
139
- text_output = gr.HTML(label="Le texte corrigé")
140
- text_button.click(mistral_bot.predict, inputs=text_input, outputs=[text_output])
141
 
142
  if __name__ == "__main__":
143
  demo.queue().launch()
 
1
  import transformers
2
  import re
3
+ from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM, pipeline
4
  from vllm import LLM, SamplingParams
5
  import torch
6
  import gradio as gr
 
8
  import os
9
  import shutil
10
  import requests
 
 
11
  import pandas as pd
12
+ import difflib
 
13
 
14
  # Define the device
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
+ # OCR Correction Model
18
+ ocr_model_name = "Pclanglais/ocronos2"
19
+ ocr_llm = LLM(ocr_model_name, max_model_len=8128)
20
 
21
+ # Editorial Segmentation Model
22
+ editorial_model = "PleIAs/Estienne"
23
+ token_classifier = pipeline(
24
+ "token-classification", model=editorial_model, aggregation_strategy="simple", device=device
25
+ )
26
 
27
+ tokenizer = AutoTokenizer.from_pretrained(editorial_model, model_max_length=512)
28
+
29
+ # CSS for formatting
30
  css = """
31
+ <style>
32
  .generation {
33
+ margin-left: 2em;
34
+ margin-right: 2em;
35
+ font-size: 1.2em;
36
  }
37
  :target {
38
  background-color: #CCF3DF;
39
  }
40
  .source {
41
+ float: left;
42
+ max-width: 17%;
43
+ margin-left: 2%;
44
  }
45
  .tooltip {
46
  position: relative;
 
48
  font-variant-position: super;
49
  color: #97999b;
50
  }
 
51
  .tooltip:hover::after {
52
  content: attr(data-text);
53
  position: absolute;
 
65
  display: block;
66
  box-shadow: 0 4px 8px rgba(0,0,0,0.1);
67
  }
 
68
  .deleted {
69
  background-color: #ffcccb;
70
  text-decoration: line-through;
 
72
  .inserted {
73
  background-color: #90EE90;
74
  }
75
+ .manuscript {
76
+ display: flex;
77
+ margin-bottom: 10px;
78
+ align-items: baseline;
79
+ }
80
+ .annotation {
81
+ width: 15%;
82
+ padding-right: 20px;
83
+ color: grey !important;
84
+ font-style: italic;
85
+ text-align: right;
86
+ }
87
+ .content {
88
+ width: 80%;
89
+ }
90
+ h2 {
91
+ margin: 0;
92
+ font-size: 1.5em;
93
+ }
94
+ .title-content h2 {
95
+ font-weight: bold;
96
+ }
97
+ .bibliography-content {
98
+ color: darkgreen !important;
99
+ margin-top: -5px;
100
+ }
101
+ .paratext-content {
102
+ color: #a4a4a4 !important;
103
+ margin-top: -5px;
104
+ }
105
+ </style>
106
  """
107
 
108
+ # Helper functions
109
  def generate_html_diff(old_text, new_text):
110
  d = difflib.Differ()
111
  diff = list(d.compare(old_text.split(), new_text.split()))
 
112
  html_diff = []
113
  for word in diff:
114
+ if word.startswith(' '):
115
  html_diff.append(word[2:])
116
  elif word.startswith('+ '):
117
  html_diff.append(f'<span style="background-color: #90EE90;">{word[2:]}</span>')
 
 
118
  return ' '.join(html_diff)
119
 
120
+ def preprocess_text(text):
121
+ text = re.sub(r'<[^>]+>', '', text)
122
+ text = re.sub(r'\n', ' ', text)
123
+ text = re.sub(r'\s+', ' ', text)
124
+ return text.strip()
125
+
126
+ def split_text(text, max_tokens=500):
127
+ parts = text.split("\n")
128
+ chunks = []
129
+ current_chunk = ""
130
+
131
+ for part in parts:
132
+ if current_chunk:
133
+ temp_chunk = current_chunk + "\n" + part
134
+ else:
135
+ temp_chunk = part
136
+
137
+ num_tokens = len(tokenizer.tokenize(temp_chunk))
138
+
139
+ if num_tokens <= max_tokens:
140
+ current_chunk = temp_chunk
141
+ else:
142
+ if current_chunk:
143
+ chunks.append(current_chunk)
144
+ current_chunk = part
145
+
146
+ if current_chunk:
147
+ chunks.append(current_chunk)
148
+
149
+ if len(chunks) == 1 and len(tokenizer.tokenize(chunks[0])) > max_tokens:
150
+ long_text = chunks[0]
151
+ chunks = []
152
+ while len(tokenizer.tokenize(long_text)) > max_tokens:
153
+ split_point = len(long_text) // 2
154
+ while split_point < len(long_text) and not re.match(r'\s', long_text[split_point]):
155
+ split_point += 1
156
+ if split_point >= len(long_text):
157
+ split_point = len(long_text) - 1
158
+ chunks.append(long_text[:split_point].strip())
159
+ long_text = long_text[split_point:].strip()
160
+ if long_text:
161
+ chunks.append(long_text)
162
+
163
+ return chunks
164
+
165
+ def transform_chunks(marianne_segmentation):
166
+ marianne_segmentation = pd.DataFrame(marianne_segmentation)
167
+ marianne_segmentation = marianne_segmentation[marianne_segmentation['entity_group'] != 'separator']
168
+ marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).str.replace('¶', '\n', regex=False)
169
+ marianne_segmentation['word'] = marianne_segmentation['word'].astype(str).apply(preprocess_text)
170
+ marianne_segmentation = marianne_segmentation[marianne_segmentation['word'].notna() & (marianne_segmentation['word'] != '') & (marianne_segmentation['word'] != ' ')]
171
+
172
+ html_output = []
173
+ for _, row in marianne_segmentation.iterrows():
174
+ entity_group = row['entity_group']
175
+ result_entity = "[" + entity_group.capitalize() + "]"
176
+ word = row['word']
177
+
178
+ if entity_group == 'title':
179
+ html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content title-content"><h2>{word}</h2></div></div>')
180
+ elif entity_group == 'bibliography':
181
+ html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content bibliography-content">{word}</div></div>')
182
+ elif entity_group == 'paratext':
183
+ html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content paratext-content">{word}</div></div>')
184
+ else:
185
+ html_output.append(f'<div class="manuscript"><div class="annotation">{result_entity}</div><div class="content">{word}</div></div>')
186
+
187
+ final_html = '\n'.join(html_output)
188
+ return final_html
189
+
190
+ # OCR Correction Class
191
+ class OCRCorrector:
192
  def __init__(self, system_prompt="Le dialogue suivant est une conversation"):
193
  self.system_prompt = system_prompt
194
 
195
+ def correct(self, user_message):
196
  sampling_params = SamplingParams(temperature=0.9, top_p=0.95, max_tokens=4000, presence_penalty=0, stop=["#END#"])
197
  detailed_prompt = f"### TEXT ###\n{user_message}\n\n### CORRECTION ###\n"
 
198
  prompts = [detailed_prompt]
199
+ outputs = ocr_llm.generate(prompts, sampling_params, use_tqdm=False)
200
  generated_text = outputs[0].outputs[0].text
 
 
201
  html_diff = generate_html_diff(user_message, generated_text)
202
+ return generated_text, html_diff
203
+
204
+ # Editorial Segmentation Class
205
+ class EditorialSegmenter:
206
+ def segment(self, text):
207
+ editorial_text = re.sub("\n", " ¶ ", text)
208
+ num_tokens = len(tokenizer.tokenize(editorial_text))
209
+
210
+ if num_tokens > 500:
211
+ batch_prompts = split_text(editorial_text, max_tokens=500)
212
+ else:
213
+ batch_prompts = [editorial_text]
214
+
215
+ out = token_classifier(batch_prompts)
216
+ classified_list = []
217
+ for classification in out:
218
+ df = pd.DataFrame(classification)
219
+ classified_list.append(df)
220
+
221
+ classified_list = pd.concat(classified_list)
222
+ out = transform_chunks(classified_list)
223
+ return out
224
+
225
+ # Combined Processing Class
226
+ class TextProcessor:
227
+ def __init__(self):
228
+ self.ocr_corrector = OCRCorrector()
229
+ self.editorial_segmenter = EditorialSegmenter()
230
+
231
+ def process(self, user_message):
232
+ # Step 1: OCR Correction
233
+ corrected_text, html_diff = self.ocr_corrector.correct(user_message)
234
+
235
+ # Step 2: Editorial Segmentation
236
+ segmented_text = self.editorial_segmenter.segment(corrected_text)
237
+
238
+ # Combine results
239
+ ocr_result = f'<h2 style="text-align:center">OCR Correction</h2>\n<div class="generation">{html_diff}</div>'
240
+ editorial_result = f'<h2 style="text-align:center">Editorial Segmentation</h2>\n<div class="generation">{segmented_text}</div>'
241
 
242
+ final_output = f"{css}{ocr_result}<br><br>{editorial_result}"
243
+ return final_output
244
 
245
+ # Create the TextProcessor instance
246
+ text_processor = TextProcessor()
247
 
248
  # Define the Gradio interface
249
+ with gr.Blocks(theme='JohnSmith9982/small_and_pretty') as demo:
250
+ gr.HTML("""<h1 style="text-align:center">LM Document Processing</h1>""")
251
+ text_input = gr.Textbox(label="Your text", type="text", lines=5)
252
+ process_button = gr.Button("Process Text")
253
+ text_output = gr.HTML(label="Processed text")
254
+ process_button.click(text_processor.process, inputs=text_input, outputs=[text_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  if __name__ == "__main__":
257
  demo.queue().launch()