ginipick commited on
Commit
9e77c17
1 Parent(s): d9943ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -4
app.py CHANGED
@@ -22,8 +22,8 @@ from torch import Tensor, nn
22
  from transformers import CLIPTextModel, CLIPTokenizer
23
  from transformers import T5EncoderModel, T5Tokenizer
24
  # from optimum.quanto import freeze, qfloat8, quantize
25
-
26
-
27
  # ---------------- Encoders ----------------
28
 
29
 
@@ -759,6 +759,12 @@ def generate_image(
759
  do_img2img, init_image, image2image_strength, resize_img,
760
  progress=gr.Progress(track_tqdm=True),
761
  ):
 
 
 
 
 
 
762
  if seed == 0:
763
  seed = int(random.random() * 1000000)
764
 
@@ -809,8 +815,9 @@ def generate_image(
809
  x = rearrange(x[0], "c h w -> h w c")
810
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
811
 
812
- return img, seed
813
-
 
814
  css = """
815
  footer {
816
  visibility: hidden;
 
22
  from transformers import CLIPTextModel, CLIPTokenizer
23
  from transformers import T5EncoderModel, T5Tokenizer
24
  # from optimum.quanto import freeze, qfloat8, quantize
25
+ from transformers import pipeline
26
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
27
  # ---------------- Encoders ----------------
28
 
29
 
 
759
  do_img2img, init_image, image2image_strength, resize_img,
760
  progress=gr.Progress(track_tqdm=True),
761
  ):
762
+ translated_prompt = prompt
763
+ if any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in prompt):
764
+ translated_prompt = translator(prompt, max_length=512)[0]['translation_text']
765
+ print(f"Translated prompt: {translated_prompt}")
766
+ prompt = translated_prompt
767
+
768
  if seed == 0:
769
  seed = int(random.random() * 1000000)
770
 
 
815
  x = rearrange(x[0], "c h w -> h w c")
816
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
817
 
818
+
819
+ return img, seed, translated_prompt
820
+
821
  css = """
822
  footer {
823
  visibility: hidden;