philipp-zettl commited on
Commit
1ccde3b
1 Parent(s): ddc0abc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -63
app.py CHANGED
@@ -158,6 +158,7 @@ def find_best_parameters(eval_data, model, tokenizer, max_length=85):
158
  4: [2],
159
  6: [2], # 6x3 == 4x2
160
  8: [2], # 8x4 == 6x3 == 4x2
 
161
  10: [2], # 10x5 == 8x4 == 6x3 == 4x2
162
  }
163
 
@@ -249,7 +250,9 @@ def gen(content, temperature_qg=0.5, temperature_qa=0.75, num_return_sequences_q
249
  )
250
 
251
  if optimize_questions:
252
- q_params = find_best_parameters(list(chain.from_iterable(question)), qg_model, tokenizer, max_length=max_length)
 
 
253
 
254
  question = run_model(
255
  inputs,
@@ -308,69 +311,89 @@ def create_file_download(qnas):
308
  return 'qnas.tsv'
309
 
310
 
311
- with gr.Blocks(css='.hidden_input {display: none;}') as demo:
312
- with gr.Row(equal_height=True):
313
- gr.Markdown(
314
- """
315
- # QA-Generator
316
- A combination of fine-tuned flan-T5(-small) models chained into sequence
317
- to generate:
318
-
319
- a) a versatile set of questions
320
- b) an accurate set of matching answers
321
-
322
- according to a given piece of text content.
323
-
324
- The idea is simple:
325
-
326
- 1. Add your content
327
- 2. Select the amount of questions you want to generate
328
- 2.2 (optional) Select the amount of answers you want to generate per goven question
329
- 3. Press generate
330
- 4. ???
331
- 5. Profit
332
-
333
- If you're satisfied with the generated data set, you can export it as TSV
334
- to edit or import it into your favourite tool.
335
- """)
336
- with gr.Row(equal_height=True):
337
- with gr.Group("Content"):
338
- content = gr.Textbox(label='Content', lines=15, placeholder='Enter text here', max_lines=10_000)
339
- with gr.Group("Settings"):
340
- temperature_qg = gr.Slider(label='Temperature QG', value=0.2, minimum=0, maximum=1, step=0.01)
341
- temperature_qa = gr.Slider(label='Temperature QA', value=0.5, minimum=0, maximum=1, step=0.01)
342
- max_length = gr.Number(label='Max Length', value=85, minimum=1, step=1, maximum=512)
343
- num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
344
- num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
345
- seed = gr.Number(label="seed", value=42069)
346
- optimize_questions = gr.Checkbox(label="Optimize questions?", value=False)
347
-
348
- with gr.Row():
349
- gen_btn = gr.Button("Generate")
350
-
351
- @gr.render(
352
- inputs=[
353
- content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
354
- max_length, seed, optimize_questions
355
- ],
356
- triggers=[gen_btn.click]
357
- )
358
- def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa, max_length, seed, optimize_questions):
359
- qnas = gen(
360
- content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
361
- max_length, seed, optimize_questions
362
- )
363
- df = gr.Dataframe(
364
- value=[u.values() for u in qnas],
365
- headers=['Question', 'Answer'],
366
- col_count=2,
367
- wrap=True
 
 
 
 
 
 
368
  )
369
- pd_df = pd.DataFrame([u.values() for u in qnas], columns=['Question', 'Answer'])
370
-
371
- download = gr.DownloadButton(label='Download (without headers)', value=create_file_download(pd_df))
372
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
 
375
  demo.queue()
376
- demo.launch()
 
158
  4: [2],
159
  6: [2], # 6x3 == 4x2
160
  8: [2], # 8x4 == 6x3 == 4x2
161
+ 9: [3],
162
  10: [2], # 10x5 == 8x4 == 6x3 == 4x2
163
  }
164
 
 
250
  )
251
 
252
  if optimize_questions:
253
+ q_params = find_best_parameters(
254
+ list(chain.from_iterable(question)), qg_model, tokenizer, max_length=max_length
255
+ )
256
 
257
  question = run_model(
258
  inputs,
 
311
  return 'qnas.tsv'
312
 
313
 
314
+ with gr.Blocks() as demo:
315
+ with gr.Tab(label='Description'):
316
+ with gr.Row(equal_height=True):
317
+ gr.Markdown(
318
+ """
319
+ # QA-Generator
320
+ A combination of fine-tuned flan-T5(-small) models chained into sequence
321
+ to generate:
322
+
323
+ a) a versatile set of questions
324
+ b) an accurate set of matching answers
325
+
326
+ according to a given piece of text content.
327
+ The idea is simple:
328
+
329
+ 1. Add your content
330
+ 2. Select the amount of questions you want to generate
331
+ 2.2 (optional) Select the amount of answers you want to generate per goven question
332
+ 3. Press generate
333
+ 4. ???
334
+ 5. Profit
335
+ If you're satisfied with the generated data set, you can export it as TSV
336
+ to edit or import it into your favourite tool.
337
+ """)
338
+ with gr.Row(equal_height=True):
339
+ with gr.Accordion(label='Optimization', open=False):
340
+ gr.Markdown("""
341
+ For optimization of the question generation we apply the following combined score:
342
+
343
+ $$\\text{combined} = \\text{dist1} + \\text{dist2} - \\text{fluency} + \\text{contextual} - \\text{jsd}$$
344
+
345
+ Here's a brief explanation of each component:
346
+
347
+ 1. **dist1 and dist2**: These represent the diversity of the generated outputs. dist1 measures the ratio of unique unigrams to total unigrams, and dist2 measures the ratio of unique bigrams to total bigrams. <u>**Higher values indicate more diverse outputs.**</u>
348
+
349
+ 2. **fluency**: This is the perplexity of the generated outputs, which measures how well the outputs match the language model's expectations. <u>**Lower values indicate better fluency.**</u>
350
+
351
+ 3. **contextual**: This measures the similarity between the input and generated outputs using embedding similarity. <u>**Higher values indicate better contextual relevance.**</u>
352
+
353
+ 4. **jsd**: This is the Jensen-Shannon Divergence between the n-gram distributions of the generated outputs and the reference data. <u>**Lower values indicate greater similarity between distributions.**</u>
354
+ """, latex_delimiters=[{'display': False, 'left': '$$', 'right': '$$'}])
355
+ with gr.Tab(label='QA Generator'):
356
+ with gr.Row(equal_height=True):
357
+ with gr.Group("Content"):
358
+ content = gr.Textbox(label='Content', lines=15, placeholder='Enter text here', max_lines=10_000)
359
+ with gr.Group("Settings"):
360
+ temperature_qg = gr.Slider(label='Diversity Penalty QG', value=0.2, minimum=0, maximum=1, step=0.01)
361
+ temperature_qa = gr.Slider(label='Diversity Penalty QA', value=0.5, minimum=0, maximum=1, step=0.01)
362
+ max_length = gr.Number(label='Max Length', value=85, minimum=1, step=1, maximum=512)
363
+ num_return_sequences_qg = gr.Number(label='Number Questions', value=max_questions, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
364
+ num_return_sequences_qa = gr.Number(label="Number Answers", value=max_answers, minimum=1, step=1, maximum=max(max_questions, max_elem_value))
365
+ seed = gr.Number(label="seed", value=42069)
366
+ optimize_questions = gr.Checkbox(label="Optimize questions?", value=False)
367
+
368
+ with gr.Row():
369
+ gen_btn = gr.Button("Generate")
370
+
371
+ @gr.render(
372
+ inputs=[
373
+ content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
374
+ max_length, seed, optimize_questions
375
+ ],
376
+ triggers=[gen_btn.click]
377
  )
378
+ def render_results(content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa, max_length, seed, optimize_questions):
379
+ if not content.strip():
380
+ raise gr.Error('Please enter some content to generate questions and answers.')
381
+ qnas = gen(
382
+ content, temperature_qg, temperature_qa, num_return_sequences_qg, num_return_sequences_qa,
383
+ max_length, seed, optimize_questions
384
+ )
385
+ df = gr.Dataframe(
386
+ value=[u.values() for u in qnas],
387
+ headers=['Question', 'Answer'],
388
+ col_count=2,
389
+ wrap=True
390
+ )
391
+ pd_df = pd.DataFrame([u.values() for u in qnas], columns=['Question', 'Answer'])
392
+
393
+ download = gr.DownloadButton(label='Download (without headers)', value=create_file_download(pd_df))
394
+
395
+ content.change(lambda x: x.strip(), content)
396
 
397
 
398
  demo.queue()
399
+ demo.launch()