File size: 3,896 Bytes
9c00f5c
 
 
 
8be0786
 
9c00f5c
 
80597e4
 
 
 
9c00f5c
80597e4
9c00f5c
 
 
 
 
 
 
80597e4
 
 
9c00f5c
 
 
8be0786
9c00f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80597e4
 
9c00f5c
 
 
 
 
 
 
 
 
8be0786
9c00f5c
 
 
 
 
 
 
 
8be0786
80597e4
 
 
9c00f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8be0786
80597e4
8be0786
 
9c00f5c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#!/usr/bin/env python

from __future__ import annotations

import gradio as gr

from model import AppModel

DESCRIPTION = '# <a href="https://github.com/THUDM/CogView2">CogView2</a> (text2image)'
NOTES = '''
- This app is adapted from <a href="https://github.com/hysts/CogView2_demo">https://github.com/hysts/CogView2_demo</a>. It would be recommended to use the repo if you want to run the app yourself.
- [This Space](https://huggingface.co/spaces/chinhon/translation_eng2ch) is used for translation from English to Chinese.
'''
FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=THUDM.CogView2" />'


def set_example_text(example: list) -> dict:
    return gr.Textbox.update(value=example[0])


def main():
    only_first_stage = True
    max_inference_batch_size = 4
    model = AppModel(max_inference_batch_size, only_first_stage)

    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(DESCRIPTION)

        with gr.Row():
            with gr.Column():
                with gr.Group():
                    text = gr.Textbox(label='Input Text')
                    translate = gr.Checkbox(label='Translate to Chinese',
                                            value=False)
                    style = gr.Dropdown(choices=[
                        'mainbody',
                        'photo',
                        'flat',
                        'comics',
                        'oil',
                        'sketch',
                        'isometric',
                        'chinese',
                        'watercolor',
                    ],
                                        label='Style')
                    seed = gr.Slider(0,
                                     100000,
                                     step=1,
                                     value=1234,
                                     label='Seed')
                    only_first_stage = gr.Checkbox(
                        label='Only First Stage',
                        value=only_first_stage,
                        visible=not only_first_stage)
                    num_images = gr.Slider(1,
                                           16,
                                           step=1,
                                           value=8,
                                           label='Number of Images')
                    with open('samples.txt') as f:
                        samples = [[line.strip()] for line in f.readlines()]
                    examples = gr.Dataset(components=[text], samples=samples)
                    run_button = gr.Button('Run')

            with gr.Column():
                with gr.Group():
                    translated_text = gr.Textbox(label='Translated Text')
                    with gr.Tabs():
                        with gr.TabItem('Output (Grid View)'):
                            result_grid = gr.Image(show_label=False)
                        with gr.TabItem('Output (Gallery)'):
                            result_gallery = gr.Gallery(show_label=False)

        gr.Markdown(NOTES)
        gr.Markdown(FOOTER)

        run_button.click(fn=model.run_with_translation,
                         inputs=[
                             text,
                             translate,
                             style,
                             seed,
                             only_first_stage,
                             num_images,
                         ],
                         outputs=[
                             translated_text,
                             result_grid,
                             result_gallery,
                         ])
        examples.click(fn=set_example_text,
                       inputs=examples,
                       outputs=examples.components)

    demo.launch(enable_queue=True)


if __name__ == '__main__':
    main()