KenjieDec commited on
Commit
f4d8a87
1 Parent(s): ef928a1
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -5,11 +5,13 @@ import gradio as gr
5
  import os
6
  import cv2
7
 
8
- def inference(file, mask, af):
9
  im = cv2.imread(file, cv2.IMREAD_COLOR)
10
  cv2.imwrite(os.path.join("input.png"), im)
11
 
12
  from rembg import remove
 
 
13
 
14
  input_path = 'input.png'
15
  output_path = 'output.png'
@@ -17,7 +19,15 @@ def inference(file, mask, af):
17
  with open(input_path, 'rb') as i:
18
  with open(output_path, 'wb') as o:
19
  input = i.read()
20
- output = remove(input, alpha_matting_erode_size = af, only_mask = (True if mask == "Mask only" else False))
 
 
 
 
 
 
 
 
21
  o.write(output)
22
  return os.path.join("output.png")
23
 
@@ -37,14 +47,25 @@ gr.Interface(
37
  "Mask only"
38
  ],
39
  type="value",
40
- default="Alpha matting",
41
  label="Choices"
42
- )
 
 
 
 
 
 
 
 
 
 
 
43
  ],
44
  gr.outputs.Image(type="file", label="Output"),
45
  title=title,
46
  description=description,
47
  article=article,
48
- examples=[["lion.png", 10, "Default"], ["girl.jpg", 10, "Default"]],
49
  enable_queue=True
50
  ).launch()
 
5
  import os
6
  import cv2
7
 
8
+ def inference(file, af, mask, model):
9
  im = cv2.imread(file, cv2.IMREAD_COLOR)
10
  cv2.imwrite(os.path.join("input.png"), im)
11
 
12
  from rembg import remove
13
+ from rembg.session_base import BaseSession
14
+ from rembg.session_factory import new_session
15
 
16
  input_path = 'input.png'
17
  output_path = 'output.png'
 
19
  with open(input_path, 'rb') as i:
20
  with open(output_path, 'wb') as o:
21
  input = i.read()
22
+ sessions: dict[str, BaseSession] = {}
23
+ output = remove(
24
+ input,
25
+ session=sessions.setdefault(
26
+ model, new_session(model)
27
+ ),
28
+ alpha_matting_erode_size = af,
29
+ only_mask = (True if mask == "Mask only" else False)
30
+ )
31
  o.write(output)
32
  return os.path.join("output.png")
33
 
 
47
  "Mask only"
48
  ],
49
  type="value",
50
+ default="Default",
51
  label="Choices"
52
+ ),
53
+ gr.inputs.Dropdown([
54
+ "u2net",
55
+ "u2netp",
56
+ "u2net_human_seg",
57
+ "u2net_cloth_seg",
58
+ "silueta"
59
+ ],
60
+ type="value",
61
+ default="u2net",
62
+ label="Models"
63
+ ),
64
  ],
65
  gr.outputs.Image(type="file", label="Output"),
66
  title=title,
67
  description=description,
68
  article=article,
69
+ examples=[["lion.png", 10, "Default", "u2net"], ["girl.jpg", 10, "Default", "u2net"]],
70
  enable_queue=True
71
  ).launch()