thefreeham commited on
Commit
73203ad
1 Parent(s): c82c2e4

Create app.py.bck

Browse files
Files changed (1) hide show
  1. app.py.bck +77 -0
app.py.bck ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import os
4
+ from pathlib import Path
5
+ from io import BytesIO
6
+ import time
7
+
8
+ from flask import Flask, request, jsonify
9
+ from flask_cors import CORS, cross_origin
10
+ from consts import IMAGES_OUTPUT_DIR
11
+ from utils import parse_arg_boolean, parse_arg_dalle_version
12
+ from consts import ModelSize
13
+
14
+
15
+ import gradio as gr
16
+
17
+ def greet(name):
18
+ return "Hello " + name + "!!"
19
+
20
+ iface = gr.Interface(fn=greet, inputs="text", outputs="text")
21
+ iface.launch()
22
+
23
+
24
+ app = Flask(__name__)
25
+ CORS(app)
26
+ print("--> Starting DALL-E Server. This might take up to two minutes.")
27
+
28
+ from dalle_model import DalleModel
29
+ dalle_model = None
30
+
31
+ parser = argparse.ArgumentParser(description = "A DALL-E app to turn your textual prompts into visionary delights")
32
+ parser.add_argument("--port", type=int, default=8000, help = "backend port")
33
+ parser.add_argument("--model_version", type = parse_arg_dalle_version, default = ModelSize.MINI, help = "Mini, Mega, or Mega_full")
34
+ parser.add_argument("--save_to_disk", type = parse_arg_boolean, default = False, help = "Should save generated images to disk")
35
+ args = parser.parse_args()
36
+
37
+ @app.route("/dalle", methods=["POST"])
38
+ @cross_origin()
39
+ def generate_images_api():
40
+ json_data = request.get_json(force=True)
41
+ text_prompt = json_data["text"]
42
+ num_images = json_data["num_images"]
43
+ generated_imgs = dalle_model.generate_images(text_prompt, num_images)
44
+
45
+ generated_images = []
46
+ if args.save_to_disk:
47
+ dir_name = os.path.join(IMAGES_OUTPUT_DIR,f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
48
+ Path(dir_name).mkdir(parents=True, exist_ok=True)
49
+
50
+ for idx, img in enumerate(generated_imgs):
51
+ if args.save_to_disk:
52
+ img.save(os.path.join(dir_name, f'{idx}.jpeg'), format="JPEG")
53
+
54
+ buffered = BytesIO()
55
+ img.save(buffered, format="JPEG")
56
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
57
+ generated_images.append(img_str)
58
+
59
+ print(f"Created {num_images} images from text prompt [{text_prompt}]")
60
+ return jsonify(generated_images)
61
+
62
+
63
+ @app.route("/", methods=["GET"])
64
+ @cross_origin()
65
+ def health_check():
66
+ return jsonify(success=True)
67
+
68
+
69
+ with app.app_context():
70
+ dalle_model = DalleModel(args.model_version)
71
+ dalle_model.generate_images("warm-up", 1)
72
+ print("--> DALL-E Server is up and running!")
73
+ print(f"--> Model selected - DALL-E {args.model_version}")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ app.run(host="0.0.0.0", port=args.port, debug=False)