Ron Au commited on
Commit
81ea62c
1 Parent(s): b1e2dc7
Files changed (7) hide show
  1. .gitignore +4 -0
  2. Pipfile +17 -0
  3. Pipfile.lock +0 -0
  4. README.md +2 -4
  5. app.py +48 -0
  6. modules/sprites.py +266 -0
  7. requirements.txt +80 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ !**/.gitkeep
2
+
3
+ cache/*
4
+ output/*
Pipfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+ diffusers = { version = "==0.7.*", extras = ["torch"] }
8
+ gradio = "==3.9.*"
9
+ scipy = "==1.9.*"
10
+ torch = "==1.13.*"
11
+ torchvision = "==0.14.*"
12
+ transformers = "==4.24.*"
13
+
14
+ [dev-packages]
15
+
16
+ [requires]
17
+ python_version = "3.10"
Pipfile.lock ADDED
The diff for this file is too large to render. See raw diff
 
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Sd Spritesheets
3
- emoji: 🌖
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
@@ -8,5 +8,3 @@ sdk_version: 3.9
8
  app_file: app.py
9
  pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Stable Diffusion Sprite Sheets
3
+ emoji: 🚶‍♀️
4
  colorFrom: purple
5
  colorTo: gray
6
  sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
  ---
 
 
app.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import time
4
+ import gradio as gr
5
+ from modules.sprites import generate_sides, build_gifs, build_spritesheet
6
+
7
+
8
+ def generate(prompt, thresh):
9
+ timestamp = int(time.time())
10
+
11
+ sides = generate_sides(prompt, 3)[0]
12
+ spritesheet = build_spritesheet(sides, prompt, timestamp=timestamp, thresh=thresh)[0]
13
+
14
+ filepaths = build_gifs(sides, prompt, save=True, timestamp=timestamp, thresh=thresh)[1]
15
+
16
+ return spritesheet, filepaths[0], filepaths[1], filepaths[2], filepaths[3]
17
+
18
+
19
+ demo = gr.Blocks()
20
+
21
+ with demo:
22
+ gr.Markdown("""
23
+ # Stable Diffusion Sprite Sheets
24
+
25
+ Generate a sprite sheet of pixel art character sides and their corresponding walk animations! Checkpoint by [Onodofthenorth](https://huggingface.co/Onodofthenorth/SD_PixelArt_SpriteSheet_Generator). Sprites are 32x32 pixels scaled up to 96x96. NSFW content replaced with blank sprites.
26
+ """)
27
+
28
+ with gr.Row():
29
+ with gr.Column():
30
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter text prompt")
31
+ threshold = gr.Slider(label="Background removal threshold", placeholder="Tweak how strong the background removal is", minimum=0, maximum=255, value=128)
32
+
33
+ button = gr.Button("Generate")
34
+
35
+ with gr.Box():
36
+ with gr.Row():
37
+ spritesheet = gr.Image(label="Sprite Sheet")
38
+
39
+ with gr.Row():
40
+ front = gr.Image(label="Front")
41
+ back = gr.Image(label="Back")
42
+ left = gr.Image(label="Left")
43
+ right = gr.Image(label="Right")
44
+
45
+ button.click(fn=generate, inputs=[prompt, threshold], outputs=[spritesheet, front, back, left, right])
46
+
47
+ demo.queue()
48
+ demo.launch(show_api=False)
modules/sprites.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import re
4
+ import time
5
+ import random
6
+ import torch
7
+ from typing import Final, List, Optional, Tuple, cast
8
+
9
+ from PIL import Image, ImageDraw, ImageEnhance
10
+ from PIL.Image import Image as PILImage
11
+ from diffusers import StableDiffusionPipeline
12
+
13
+ model_id: Final = "Onodofthenorth/SD_PixelArt_SpriteSheet_Generator"
14
+ pipe = StableDiffusionPipeline.from_pretrained(
15
+ model_id, torch_dtype=torch.float16, cache_dir="cache"
16
+ )
17
+ pipe = pipe.to("cuda")
18
+
19
+ sprite_sides: Final = {
20
+ "front": "PixelArtFSS",
21
+ "right": "PixelArtRSS",
22
+ "back": "PixelArtBSS",
23
+ "left": "PixelArtLSS",
24
+ }
25
+
26
+
27
+ def torchGenerator(seed: Optional[int], max: int = 1024) -> Tuple[torch.Generator, int]:
28
+ seed = seed or random.randrange(0, max)
29
+
30
+ return torch.Generator("cuda").manual_seed(seed), seed
31
+
32
+
33
+ def generate(
34
+ prompt: str,
35
+ sfw_retries: int = 1,
36
+ seed: Optional[int] = None,
37
+ ) -> PILImage:
38
+ """
39
+ Generate a sprite image from a text description.
40
+
41
+ Return a blank image if the model fails to generate a safe image.
42
+ """
43
+
44
+ generator = torchGenerator(seed)[0]
45
+ image: PILImage | None = None
46
+
47
+ for _ in range(sfw_retries):
48
+ pipe_output = pipe(prompt, generator=generator, width=512, height=512)
49
+ image = pipe_output.images[0]
50
+
51
+ if not pipe_output.nsfw_content_detected[0]:
52
+ break
53
+
54
+ rand_seed = seed
55
+
56
+ while rand_seed == seed:
57
+ print(f"Regenerating `{prompt}` with different seed.")
58
+
59
+ rand_seed = random.randrange(0, 1024)
60
+ generator = torchGenerator(rand_seed)[0]
61
+
62
+ return cast(PILImage, image)
63
+
64
+
65
+ def generate_sides(
66
+ prompt: str, sfw_retries: int = 1, sides: dict[str, str] = sprite_sides
67
+ ) -> Tuple[dict[str, PILImage], str]:
68
+ """
69
+ Generate sprite images from a text description of different sides.
70
+
71
+ If both left and right side specified, duplicate and flip left side as the right side
72
+ """
73
+
74
+ print(f"Generating sprites for `{prompt}`")
75
+
76
+ seed = random.randrange(0, 1024)
77
+ sprites = {}
78
+
79
+ # If both left and right side specified, duplicate and flip left side as the right side
80
+ for side, label in sides.items():
81
+ if side == "right" and "left" in sides and "right" in sides:
82
+ continue
83
+
84
+ sprites[side] = generate(f"({prompt}) [nsfw] [photograph] {label}", sfw_retries, seed)
85
+
86
+ if "left" in sides and "right" in sides:
87
+ sprites["right"] = sprites["left"].transpose(Image.Transpose.FLIP_LEFT_RIGHT)
88
+
89
+ return sprites, prompt
90
+
91
+
92
+ def clean_sprite(
93
+ image: PILImage,
94
+ size: Tuple[int, int] = (192, 192),
95
+ sharpness: float = 1.5,
96
+ thresh: int = 128,
97
+ rescaling: Optional[int] = None,
98
+ ) -> PILImage:
99
+ """
100
+ Process image to be more sprite-like.
101
+
102
+ `rescale` will first scale down by value, then up to specified size.
103
+ """
104
+
105
+ width, height = image.size
106
+ sharpener = ImageEnhance.Sharpness(image)
107
+
108
+ image = sharpener.enhance(sharpness)
109
+ image = image.convert("RGBA")
110
+ ImageDraw.floodfill(image, (0, 0), (255, 255, 255, 0), thresh=thresh)
111
+
112
+ if type(rescaling) is int:
113
+ image = image.resize(
114
+ (int(width / rescaling), int(height / rescaling)),
115
+ resample=Image.Resampling.NEAREST,
116
+ )
117
+
118
+ image = image.resize(size, resample=Image.Resampling.NEAREST)
119
+
120
+ return image
121
+
122
+
123
+ def split_sprites(image: PILImage, size: Tuple[int, int] = (96, 96)) -> List[PILImage]:
124
+ """Split sprite image into individual sides."""
125
+
126
+ width, height = image.size
127
+ w, h = size
128
+
129
+ # fmt: off
130
+ frames = [
131
+ image.crop((
132
+ 0,
133
+ int(h / 2),
134
+ int(width / 4),
135
+ int(height * 0.75),
136
+ )),
137
+ image.crop((
138
+ int(width / 4),
139
+ int(h / 2),
140
+ int(width / 4) * 2,
141
+ int(height * 0.75),
142
+ )),
143
+ image.crop((
144
+ int(width / 4) * 2,
145
+ int(h / 2),
146
+ int(width / 4) * 3,
147
+ int(height * 0.75),
148
+ )),
149
+ image.crop((
150
+ int(width / 4) * 3,
151
+ int(h / 2),
152
+ width,
153
+ int(height * 0.75),
154
+ )),
155
+ ]
156
+ # fmt: on
157
+
158
+ new_canvas = Image.new("RGBA", size, (255, 255, 255, 0))
159
+
160
+ for i in range(len(frames)):
161
+ canvas = new_canvas.copy()
162
+ canvas.paste(frames[i], (int(w / 4), 0, int(w * 0.75), h))
163
+ frames[i] = canvas
164
+
165
+ return frames
166
+
167
+
168
+ def build_spritesheet(
169
+ images: dict[str, PILImage],
170
+ text: str = "sd_pixelart",
171
+ sprite_size: Tuple[int, int] = (96, 96),
172
+ dir: str = "output",
173
+ save: bool = False,
174
+ timestamp: Optional[int] = None,
175
+ thresh: int = 128,
176
+ ) -> Tuple[PILImage, str | None]:
177
+ """
178
+ Build sprite sheet from sides.
179
+
180
+ 1. Clean and scale each image
181
+ 2. Split each image into individual frames
182
+ 3. Create a new spritesheet canvas for all sides[frames]
183
+ 4. Paste each individial frame onto canvas
184
+ """
185
+
186
+ frames = {}
187
+ width, height = sprite_size
188
+ text = re.sub(r"[^\w()[\]_-]", "", text)
189
+ filepath = None
190
+
191
+ for side, image in images.items():
192
+ image = clean_sprite(image, (width * 2, height * 2), thresh=thresh)
193
+ frames[side] = split_sprites(image, sprite_size)
194
+
195
+ canvas = Image.new(
196
+ "RGBA",
197
+ (width * len(frames["front"]), height * len(frames)),
198
+ (255, 255, 255, 0),
199
+ )
200
+
201
+ for j in range(len(frames["front"])):
202
+ for k, side in enumerate(frames):
203
+ canvas.paste(
204
+ frames[side][j],
205
+ (
206
+ j * width,
207
+ k * height,
208
+ j * width + width,
209
+ k * height + height,
210
+ ),
211
+ )
212
+
213
+ spritesheet = io.BytesIO()
214
+ canvas.save(spritesheet, "PNG")
215
+
216
+ if save:
217
+ timestamp = timestamp or int(time.time())
218
+ filepath = os.path.join(dir, f"{timestamp}_{text}.png")
219
+ canvas.save(filepath)
220
+
221
+ return Image.open(spritesheet), filepath
222
+
223
+
224
+ def build_gifs(
225
+ images: dict[str, PILImage],
226
+ text: str = "sd_spritesheet",
227
+ dir: str = "output",
228
+ duration: int | List[int] | Tuple[int, ...] = (300, 450, 300, 450),
229
+ save: bool = False,
230
+ timestamp: Optional[int] = None,
231
+ thresh: int = 128,
232
+ ) -> Tuple[dict[str, List[PILImage]], List[str] | None]:
233
+ """Build animated GIFs from side frames."""
234
+
235
+ gifs = {}
236
+ text = re.sub(r"[^\w()[\]_-]", "", text)
237
+ filepaths = [] if save else None
238
+
239
+ for side, image in images.items():
240
+ image = clean_sprite(image, thresh=thresh)
241
+ frames = split_sprites(image)
242
+
243
+ gif = io.BytesIO()
244
+
245
+ options = {
246
+ "fp": gif,
247
+ "format": "GIF",
248
+ "save_all": True,
249
+ "append_images": frames[1:],
250
+ "disposal": 3,
251
+ "duration": duration,
252
+ "loop": 0,
253
+ }
254
+
255
+ frames[0].save(**options)
256
+ gifs[side] = Image.open(gif)
257
+
258
+ if save:
259
+ timestamp = timestamp or int(time.time())
260
+ filepath = os.path.join(dir, f"{timestamp}_{text}_{side}.gif")
261
+ filepaths.append(filepath)
262
+
263
+ options.update({"fp": filepath})
264
+ frames[0].save(**options)
265
+
266
+ return gifs, filepaths
requirements.txt ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -i https://pypi.org/simple
2
+ accelerate==0.14.0
3
+ aiohttp==3.8.3
4
+ aiosignal==1.3.1
5
+ anyio==3.6.2
6
+ async-timeout==4.0.2
7
+ attrs==22.1.0
8
+ bcrypt==4.0.1
9
+ certifi==2022.9.24
10
+ cffi==1.15.1
11
+ charset-normalizer==2.1.1
12
+ click==8.1.3
13
+ contourpy==1.0.6
14
+ cryptography==38.0.3
15
+ cycler==0.11.0
16
+ diffusers==0.7.2
17
+ fastapi==0.86.0
18
+ ffmpy==0.3.0
19
+ filelock==3.8.0
20
+ fonttools==4.38.0
21
+ frozenlist==1.3.3
22
+ fsspec==2022.11.0
23
+ gradio==3.9.1
24
+ h11==0.12.0
25
+ httpcore==0.15.0
26
+ httpx==0.23.0
27
+ huggingface-hub==0.10.1
28
+ idna==3.4
29
+ importlib-metadata==5.0.0
30
+ jinja2==3.1.2
31
+ kiwisolver==1.4.4
32
+ linkify-it-py==1.0.3
33
+ markdown-it-py==2.1.0
34
+ markupsafe==2.1.1
35
+ matplotlib==3.6.2
36
+ mdit-py-plugins==0.3.1
37
+ mdurl==0.1.2
38
+ multidict==6.0.2
39
+ numpy==1.23.4
40
+ nvidia-cublas-cu11==11.10.3.66
41
+ nvidia-cuda-nvrtc-cu11==11.7.99
42
+ nvidia-cuda-runtime-cu11==11.7.99
43
+ nvidia-cudnn-cu11==8.5.0.96
44
+ orjson==3.8.1
45
+ packaging==21.3
46
+ pandas==1.5.1
47
+ paramiko==2.12.0
48
+ pillow==9.3.0
49
+ psutil==5.9.4
50
+ pycparser==2.21
51
+ pycryptodome==3.15.0
52
+ pydantic==1.10.2
53
+ pydub==0.25.1
54
+ pynacl==1.5.0
55
+ pyparsing==3.0.9
56
+ python-dateutil==2.8.2
57
+ python-multipart==0.0.5
58
+ pytz==2022.6
59
+ pyyaml==6.0
60
+ regex==2022.10.31
61
+ requests==2.28.1
62
+ rfc3986==1.5.0
63
+ scipy==1.9.3
64
+ setuptools==65.5.1
65
+ six==1.16.0
66
+ sniffio==1.3.0
67
+ starlette==0.20.4
68
+ tokenizers==0.13.2
69
+ torch==1.13.0
70
+ torchvision==0.14.0
71
+ tqdm==4.64.1
72
+ transformers==4.24.0
73
+ typing-extensions==4.4.0
74
+ uc-micro-py==1.0.1
75
+ urllib3==1.26.12
76
+ uvicorn==0.19.0
77
+ websockets==10.4
78
+ wheel==0.38.4
79
+ yarl==1.8.1
80
+ zipp==3.10.0