lnyan commited on
Commit
09c675d
1 Parent(s): cbdeb43

Update files

Browse files
PyPatchMatch/.gitignore CHANGED
@@ -1,4 +1,4 @@
1
- /build/
2
- /*.so
3
- __pycache__
4
- *.py[cod]
 
1
+ /build/
2
+ /*.so
3
+ __pycache__
4
+ *.py[cod]
PyPatchMatch/LICENSE CHANGED
@@ -1,21 +1,21 @@
1
- MIT License
2
-
3
- Copyright (c) 2020 Jiayuan Mao
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jiayuan Mao
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
PyPatchMatch/Makefile CHANGED
@@ -1,54 +1,54 @@
1
- #
2
- # Makefile
3
- # Jiayuan Mao, 2019-01-09 13:59
4
- #
5
-
6
- SRC_DIR = csrc
7
- INC_DIR = csrc
8
- OBJ_DIR = build/obj
9
- TARGET = libpatchmatch.so
10
-
11
- LIB_TARGET = $(TARGET)
12
- INCLUDE_DIR = -I $(SRC_DIR) -I $(INC_DIR)
13
-
14
- CXX = $(ENVIRONMENT_OPTIONS) g++
15
- CXXFLAGS = -std=c++14
16
- CXXFLAGS += -Ofast -ffast-math -w
17
- # CXXFLAGS += -g
18
- CXXFLAGS += $(shell pkg-config --cflags opencv) -fPIC
19
- CXXFLAGS += $(INCLUDE_DIR)
20
- LDFLAGS = $(shell pkg-config --cflags --libs opencv) -shared -fPIC
21
-
22
-
23
- CXXSOURCES = $(shell find $(SRC_DIR)/ -name "*.cpp")
24
- OBJS = $(addprefix $(OBJ_DIR)/,$(CXXSOURCES:.cpp=.o))
25
- DEPFILES = $(OBJS:.o=.d)
26
-
27
- .PHONY: all clean rebuild test
28
-
29
- all: $(LIB_TARGET)
30
-
31
- $(OBJ_DIR)/%.o: %.cpp
32
- @echo "[CC] $< ..."
33
- @$(CXX) -c $< $(CXXFLAGS) -o $@
34
-
35
- $(OBJ_DIR)/%.d: %.cpp
36
- @mkdir -pv $(dir $@)
37
- @echo "[dep] $< ..."
38
- @$(CXX) $(INCLUDE_DIR) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cpp=.o) $(OBJ_DIR)/$(<:.cpp=.d)" "$<" > "$@"
39
-
40
- sinclude $(DEPFILES)
41
-
42
- $(LIB_TARGET): $(OBJS)
43
- @echo "[link] $(LIB_TARGET) ..."
44
- @$(CXX) $(OBJS) -o $@ $(CXXFLAGS) $(LDFLAGS)
45
-
46
- clean:
47
- rm -rf $(OBJ_DIR) $(LIB_TARGET)
48
-
49
- rebuild:
50
- +@make clean
51
- +@make
52
-
53
- # vim:ft=make
54
- #
 
1
+ #
2
+ # Makefile
3
+ # Jiayuan Mao, 2019-01-09 13:59
4
+ #
5
+
6
+ SRC_DIR = csrc
7
+ INC_DIR = csrc
8
+ OBJ_DIR = build/obj
9
+ TARGET = libpatchmatch.so
10
+
11
+ LIB_TARGET = $(TARGET)
12
+ INCLUDE_DIR = -I $(SRC_DIR) -I $(INC_DIR)
13
+
14
+ CXX = $(ENVIRONMENT_OPTIONS) g++
15
+ CXXFLAGS = -std=c++14
16
+ CXXFLAGS += -Ofast -ffast-math -w
17
+ # CXXFLAGS += -g
18
+ CXXFLAGS += $(shell pkg-config --cflags opencv) -fPIC
19
+ CXXFLAGS += $(INCLUDE_DIR)
20
+ LDFLAGS = $(shell pkg-config --cflags --libs opencv) -shared -fPIC
21
+
22
+
23
+ CXXSOURCES = $(shell find $(SRC_DIR)/ -name "*.cpp")
24
+ OBJS = $(addprefix $(OBJ_DIR)/,$(CXXSOURCES:.cpp=.o))
25
+ DEPFILES = $(OBJS:.o=.d)
26
+
27
+ .PHONY: all clean rebuild test
28
+
29
+ all: $(LIB_TARGET)
30
+
31
+ $(OBJ_DIR)/%.o: %.cpp
32
+ @echo "[CC] $< ..."
33
+ @$(CXX) -c $< $(CXXFLAGS) -o $@
34
+
35
+ $(OBJ_DIR)/%.d: %.cpp
36
+ @mkdir -pv $(dir $@)
37
+ @echo "[dep] $< ..."
38
+ @$(CXX) $(INCLUDE_DIR) $(CXXFLAGS) -MM -MT "$(OBJ_DIR)/$(<:.cpp=.o) $(OBJ_DIR)/$(<:.cpp=.d)" "$<" > "$@"
39
+
40
+ sinclude $(DEPFILES)
41
+
42
+ $(LIB_TARGET): $(OBJS)
43
+ @echo "[link] $(LIB_TARGET) ..."
44
+ @$(CXX) $(OBJS) -o $@ $(CXXFLAGS) $(LDFLAGS)
45
+
46
+ clean:
47
+ rm -rf $(OBJ_DIR) $(LIB_TARGET)
48
+
49
+ rebuild:
50
+ +@make clean
51
+ +@make
52
+
53
+ # vim:ft=make
54
+ #
PyPatchMatch/README.md CHANGED
@@ -1,64 +1,64 @@
1
- PatchMatch based Inpainting
2
- =====================================
3
- This library implements the PatchMatch based inpainting algorithm. It provides both C++ and Python interfaces.
4
- This implementation is heavily based on the implementation by Younesse ANDAM:
5
- (younesse-cv/PatchMatch)[https://github.com/younesse-cv/PatchMatch], with some bugs fix.
6
-
7
- Usage
8
- -------------------------------------
9
-
10
- You need to first install OpenCV to compile the C++ libraries. Then, run `make` to compile the
11
- shared library `libpatchmatch.so`.
12
-
13
- For Python users (example available at `examples/py_example.py`)
14
-
15
- ```python
16
- import patch_match
17
-
18
- image = ... # either a numpy ndarray or a PIL Image object.
19
- mask = ... # either a numpy ndarray or a PIL Image object.
20
- result = patch_match.inpaint(image, mask, patch_size=5)
21
- ```
22
-
23
- For C++ users (examples available at `examples/cpp_example.cpp`)
24
-
25
- ```cpp
26
- #include "inpaint.h"
27
-
28
- int main() {
29
- cv::Mat image = ...
30
- cv::Mat mask = ...
31
-
32
- cv::Mat result = Inpainting(image, mask, 5).run();
33
-
34
- return 0;
35
- }
36
- ```
37
-
38
-
39
- README and COPYRIGHT by Younesse ANDAM
40
- -------------------------------------
41
- @Author: Younesse ANDAM
42
-
43
- @Contact: [email protected]
44
-
45
- Description: This project is a personal implementation of an algorithm called PATCHMATCH that restores missing areas in an image.
46
- The algorithm is presented in the following paper
47
- PatchMatch A Randomized Correspondence Algorithm
48
- for Structural Image Editing
49
- by C.Barnes,E.Shechtman,A.Finkelstein and Dan B.Goldman
50
- ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
51
-
52
- For more information please refer to
53
- http://www.cs.princeton.edu/gfx/pubs/Barnes_2009_PAR/index.php
54
-
55
- Copyright (c) 2010-2011
56
-
57
-
58
- Requirements
59
- -------------------------------------
60
-
61
- To run the project you need to install Opencv library and link it to your project.
62
- Opencv can be download it here
63
- http://opencv.org/downloads.html
64
-
 
1
+ PatchMatch based Inpainting
2
+ =====================================
3
+ This library implements the PatchMatch based inpainting algorithm. It provides both C++ and Python interfaces.
4
+ This implementation is heavily based on the implementation by Younesse ANDAM:
5
+ (younesse-cv/PatchMatch)[https://github.com/younesse-cv/PatchMatch], with some bugs fix.
6
+
7
+ Usage
8
+ -------------------------------------
9
+
10
+ You need to first install OpenCV to compile the C++ libraries. Then, run `make` to compile the
11
+ shared library `libpatchmatch.so`.
12
+
13
+ For Python users (example available at `examples/py_example.py`)
14
+
15
+ ```python
16
+ import patch_match
17
+
18
+ image = ... # either a numpy ndarray or a PIL Image object.
19
+ mask = ... # either a numpy ndarray or a PIL Image object.
20
+ result = patch_match.inpaint(image, mask, patch_size=5)
21
+ ```
22
+
23
+ For C++ users (examples available at `examples/cpp_example.cpp`)
24
+
25
+ ```cpp
26
+ #include "inpaint.h"
27
+
28
+ int main() {
29
+ cv::Mat image = ...
30
+ cv::Mat mask = ...
31
+
32
+ cv::Mat result = Inpainting(image, mask, 5).run();
33
+
34
+ return 0;
35
+ }
36
+ ```
37
+
38
+
39
+ README and COPYRIGHT by Younesse ANDAM
40
+ -------------------------------------
41
+ @Author: Younesse ANDAM
42
+
43
+ @Contact: [email protected]
44
+
45
+ Description: This project is a personal implementation of an algorithm called PATCHMATCH that restores missing areas in an image.
46
+ The algorithm is presented in the following paper
47
+ PatchMatch A Randomized Correspondence Algorithm
48
+ for Structural Image Editing
49
+ by C.Barnes,E.Shechtman,A.Finkelstein and Dan B.Goldman
50
+ ACM Transactions on Graphics (Proc. SIGGRAPH), vol.28, aug-2009
51
+
52
+ For more information please refer to
53
+ http://www.cs.princeton.edu/gfx/pubs/Barnes_2009_PAR/index.php
54
+
55
+ Copyright (c) 2010-2011
56
+
57
+
58
+ Requirements
59
+ -------------------------------------
60
+
61
+ To run the project you need to install Opencv library and link it to your project.
62
+ Opencv can be download it here
63
+ http://opencv.org/downloads.html
64
+
PyPatchMatch/csrc/inpaint.cpp CHANGED
@@ -1,234 +1,234 @@
1
- #include <algorithm>
2
- #include <iostream>
3
- #include <opencv2/imgcodecs.hpp>
4
- #include <opencv2/imgproc.hpp>
5
- #include <opencv2/highgui.hpp>
6
-
7
- #include "inpaint.h"
8
-
9
- namespace {
10
- static std::vector<double> kDistance2Similarity;
11
-
12
- void init_kDistance2Similarity() {
13
- double base[11] = {1.0, 0.99, 0.96, 0.83, 0.38, 0.11, 0.02, 0.005, 0.0006, 0.0001, 0};
14
- int length = (PatchDistanceMetric::kDistanceScale + 1);
15
- kDistance2Similarity.resize(length);
16
- for (int i = 0; i < length; ++i) {
17
- double t = (double) i / length;
18
- int j = (int) (100 * t);
19
- int k = j + 1;
20
- double vj = (j < 11) ? base[j] : 0;
21
- double vk = (k < 11) ? base[k] : 0;
22
- kDistance2Similarity[i] = vj + (100 * t - j) * (vk - vj);
23
- }
24
- }
25
-
26
-
27
- inline void _weighted_copy(const MaskedImage &source, int ys, int xs, cv::Mat &target, int yt, int xt, double weight) {
28
- if (source.is_masked(ys, xs)) return;
29
- if (source.is_globally_masked(ys, xs)) return;
30
-
31
- auto source_ptr = source.get_image(ys, xs);
32
- auto target_ptr = target.ptr<double>(yt, xt);
33
-
34
- #pragma unroll
35
- for (int c = 0; c < 3; ++c)
36
- target_ptr[c] += static_cast<double>(source_ptr[c]) * weight;
37
- target_ptr[3] += weight;
38
- }
39
- }
40
-
41
- /**
42
- * This algorithme uses a version proposed by Xavier Philippeau.
43
- */
44
-
45
- Inpainting::Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric)
46
- : m_initial(image, mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
47
- _initialize_pyramid();
48
- }
49
-
50
- Inpainting::Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric)
51
- : m_initial(image, mask, global_mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
52
- _initialize_pyramid();
53
- }
54
-
55
- void Inpainting::_initialize_pyramid() {
56
- auto source = m_initial;
57
- m_pyramid.push_back(source);
58
- while (source.size().height > m_distance_metric->patch_size() && source.size().width > m_distance_metric->patch_size()) {
59
- source = source.downsample();
60
- m_pyramid.push_back(source);
61
- }
62
-
63
- if (kDistance2Similarity.size() == 0) {
64
- init_kDistance2Similarity();
65
- }
66
- }
67
-
68
- cv::Mat Inpainting::run(bool verbose, bool verbose_visualize, unsigned int random_seed) {
69
- srand(random_seed);
70
- const int nr_levels = m_pyramid.size();
71
-
72
- MaskedImage source, target;
73
- for (int level = nr_levels - 1; level >= 0; --level) {
74
- if (verbose) std::cerr << "Inpainting level: " << level << std::endl;
75
-
76
- source = m_pyramid[level];
77
-
78
- if (level == nr_levels - 1) {
79
- target = source.clone();
80
- target.clear_mask();
81
- m_source2target = NearestNeighborField(source, target, m_distance_metric);
82
- m_target2source = NearestNeighborField(target, source, m_distance_metric);
83
- } else {
84
- m_source2target = NearestNeighborField(source, target, m_distance_metric, m_source2target);
85
- m_target2source = NearestNeighborField(target, source, m_distance_metric, m_target2source);
86
- }
87
-
88
- if (verbose) std::cerr << "Initialization done." << std::endl;
89
-
90
- if (verbose_visualize) {
91
- auto visualize_size = m_initial.size();
92
- cv::Mat source_visualize(visualize_size, m_initial.image().type());
93
- cv::resize(source.image(), source_visualize, visualize_size);
94
- cv::imshow("Source", source_visualize);
95
- cv::Mat target_visualize(visualize_size, m_initial.image().type());
96
- cv::resize(target.image(), target_visualize, visualize_size);
97
- cv::imshow("Target", target_visualize);
98
- cv::waitKey(0);
99
- }
100
-
101
- target = _expectation_maximization(source, target, level, verbose);
102
- }
103
-
104
- return target.image();
105
- }
106
-
107
- // EM-Like algorithm (see "PatchMatch" - page 6).
108
- // Returns a double sized target image (unless level = 0).
109
- MaskedImage Inpainting::_expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose) {
110
- const int nr_iters_em = 1 + 2 * level;
111
- const int nr_iters_nnf = static_cast<int>(std::min(7, 1 + level));
112
- const int patch_size = m_distance_metric->patch_size();
113
-
114
- MaskedImage new_source, new_target;
115
-
116
- for (int iter_em = 0; iter_em < nr_iters_em; ++iter_em) {
117
- if (iter_em != 0) {
118
- m_source2target.set_target(new_target);
119
- m_target2source.set_source(new_target);
120
- target = new_target;
121
- }
122
-
123
- if (verbose) std::cerr << "EM Iteration: " << iter_em << std::endl;
124
-
125
- auto size = source.size();
126
- for (int i = 0; i < size.height; ++i) {
127
- for (int j = 0; j < size.width; ++j) {
128
- if (!source.contains_mask(i, j, patch_size)) {
129
- m_source2target.set_identity(i, j);
130
- m_target2source.set_identity(i, j);
131
- }
132
- }
133
- }
134
- if (verbose) std::cerr << " NNF minimization started." << std::endl;
135
- m_source2target.minimize(nr_iters_nnf);
136
- m_target2source.minimize(nr_iters_nnf);
137
- if (verbose) std::cerr << " NNF minimization finished." << std::endl;
138
-
139
- // Instead of upsizing the final target, we build the last target from the next level source image.
140
- // Thus, the final target is less blurry (see "Space-Time Video Completion" - page 5).
141
- bool upscaled = false;
142
- if (level >= 1 && iter_em == nr_iters_em - 1) {
143
- new_source = m_pyramid[level - 1];
144
- new_target = target.upsample(new_source.size().width, new_source.size().height, m_pyramid[level - 1].global_mask());
145
- upscaled = true;
146
- } else {
147
- new_source = m_pyramid[level];
148
- new_target = target.clone();
149
- }
150
-
151
- auto vote = cv::Mat(new_target.size(), CV_64FC4);
152
- vote.setTo(cv::Scalar::all(0));
153
-
154
- // Votes for best patch from NNF Source->Target (completeness) and Target->Source (coherence).
155
- _expectation_step(m_source2target, 1, vote, new_source, upscaled);
156
- if (verbose) std::cerr << " Expectation source to target finished." << std::endl;
157
- _expectation_step(m_target2source, 0, vote, new_source, upscaled);
158
- if (verbose) std::cerr << " Expectation target to source finished." << std::endl;
159
-
160
- // Compile votes and update pixel values.
161
- _maximization_step(new_target, vote);
162
- if (verbose) std::cerr << " Minimization step finished." << std::endl;
163
- }
164
-
165
- return new_target;
166
- }
167
-
168
- // Expectation step: vote for best estimations of each pixel.
169
- void Inpainting::_expectation_step(
170
- const NearestNeighborField &nnf, bool source2target,
171
- cv::Mat &vote, const MaskedImage &source, bool upscaled
172
- ) {
173
- auto source_size = nnf.source_size();
174
- auto target_size = nnf.target_size();
175
- const int patch_size = m_distance_metric->patch_size();
176
-
177
- for (int i = 0; i < source_size.height; ++i) {
178
- for (int j = 0; j < source_size.width; ++j) {
179
- if (nnf.source().is_globally_masked(i, j)) continue;
180
- int yp = nnf.at(i, j, 0), xp = nnf.at(i, j, 1), dp = nnf.at(i, j, 2);
181
- double w = kDistance2Similarity[dp];
182
-
183
- for (int di = -patch_size; di <= patch_size; ++di) {
184
- for (int dj = -patch_size; dj <= patch_size; ++dj) {
185
- int ys = i + di, xs = j + dj, yt = yp + di, xt = xp + dj;
186
- if (!(ys >= 0 && ys < source_size.height && xs >= 0 && xs < source_size.width)) continue;
187
- if (nnf.source().is_globally_masked(ys, xs)) continue;
188
- if (!(yt >= 0 && yt < target_size.height && xt >= 0 && xt < target_size.width)) continue;
189
- if (nnf.target().is_globally_masked(yt, xt)) continue;
190
-
191
- if (!source2target) {
192
- std::swap(ys, yt);
193
- std::swap(xs, xt);
194
- }
195
-
196
- if (upscaled) {
197
- for (int uy = 0; uy < 2; ++uy) {
198
- for (int ux = 0; ux < 2; ++ux) {
199
- _weighted_copy(source, 2 * ys + uy, 2 * xs + ux, vote, 2 * yt + uy, 2 * xt + ux, w);
200
- }
201
- }
202
- } else {
203
- _weighted_copy(source, ys, xs, vote, yt, xt, w);
204
- }
205
- }
206
- }
207
- }
208
- }
209
- }
210
-
211
- // Maximization Step: maximum likelihood of target pixel.
212
- void Inpainting::_maximization_step(MaskedImage &target, const cv::Mat &vote) {
213
- auto target_size = target.size();
214
- for (int i = 0; i < target_size.height; ++i) {
215
- for (int j = 0; j < target_size.width; ++j) {
216
- const double *source_ptr = vote.ptr<double>(i, j);
217
- unsigned char *target_ptr = target.get_mutable_image(i, j);
218
-
219
- if (target.is_globally_masked(i, j)) {
220
- continue;
221
- }
222
-
223
- if (source_ptr[3] > 0) {
224
- unsigned char r = cv::saturate_cast<unsigned char>(source_ptr[0] / source_ptr[3]);
225
- unsigned char g = cv::saturate_cast<unsigned char>(source_ptr[1] / source_ptr[3]);
226
- unsigned char b = cv::saturate_cast<unsigned char>(source_ptr[2] / source_ptr[3]);
227
- target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
228
- } else {
229
- target.set_mask(i, j, 0);
230
- }
231
- }
232
- }
233
- }
234
-
 
1
+ #include <algorithm>
2
+ #include <iostream>
3
+ #include <opencv2/imgcodecs.hpp>
4
+ #include <opencv2/imgproc.hpp>
5
+ #include <opencv2/highgui.hpp>
6
+
7
+ #include "inpaint.h"
8
+
9
+ namespace {
10
+ static std::vector<double> kDistance2Similarity;
11
+
12
+ void init_kDistance2Similarity() {
13
+ double base[11] = {1.0, 0.99, 0.96, 0.83, 0.38, 0.11, 0.02, 0.005, 0.0006, 0.0001, 0};
14
+ int length = (PatchDistanceMetric::kDistanceScale + 1);
15
+ kDistance2Similarity.resize(length);
16
+ for (int i = 0; i < length; ++i) {
17
+ double t = (double) i / length;
18
+ int j = (int) (100 * t);
19
+ int k = j + 1;
20
+ double vj = (j < 11) ? base[j] : 0;
21
+ double vk = (k < 11) ? base[k] : 0;
22
+ kDistance2Similarity[i] = vj + (100 * t - j) * (vk - vj);
23
+ }
24
+ }
25
+
26
+
27
+ inline void _weighted_copy(const MaskedImage &source, int ys, int xs, cv::Mat &target, int yt, int xt, double weight) {
28
+ if (source.is_masked(ys, xs)) return;
29
+ if (source.is_globally_masked(ys, xs)) return;
30
+
31
+ auto source_ptr = source.get_image(ys, xs);
32
+ auto target_ptr = target.ptr<double>(yt, xt);
33
+
34
+ #pragma unroll
35
+ for (int c = 0; c < 3; ++c)
36
+ target_ptr[c] += static_cast<double>(source_ptr[c]) * weight;
37
+ target_ptr[3] += weight;
38
+ }
39
+ }
40
+
41
+ /**
42
+ * This algorithme uses a version proposed by Xavier Philippeau.
43
+ */
44
+
45
+ Inpainting::Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric)
46
+ : m_initial(image, mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
47
+ _initialize_pyramid();
48
+ }
49
+
50
+ Inpainting::Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric)
51
+ : m_initial(image, mask, global_mask), m_distance_metric(metric), m_pyramid(), m_source2target(), m_target2source() {
52
+ _initialize_pyramid();
53
+ }
54
+
55
+ void Inpainting::_initialize_pyramid() {
56
+ auto source = m_initial;
57
+ m_pyramid.push_back(source);
58
+ while (source.size().height > m_distance_metric->patch_size() && source.size().width > m_distance_metric->patch_size()) {
59
+ source = source.downsample();
60
+ m_pyramid.push_back(source);
61
+ }
62
+
63
+ if (kDistance2Similarity.size() == 0) {
64
+ init_kDistance2Similarity();
65
+ }
66
+ }
67
+
68
+ cv::Mat Inpainting::run(bool verbose, bool verbose_visualize, unsigned int random_seed) {
69
+ srand(random_seed);
70
+ const int nr_levels = m_pyramid.size();
71
+
72
+ MaskedImage source, target;
73
+ for (int level = nr_levels - 1; level >= 0; --level) {
74
+ if (verbose) std::cerr << "Inpainting level: " << level << std::endl;
75
+
76
+ source = m_pyramid[level];
77
+
78
+ if (level == nr_levels - 1) {
79
+ target = source.clone();
80
+ target.clear_mask();
81
+ m_source2target = NearestNeighborField(source, target, m_distance_metric);
82
+ m_target2source = NearestNeighborField(target, source, m_distance_metric);
83
+ } else {
84
+ m_source2target = NearestNeighborField(source, target, m_distance_metric, m_source2target);
85
+ m_target2source = NearestNeighborField(target, source, m_distance_metric, m_target2source);
86
+ }
87
+
88
+ if (verbose) std::cerr << "Initialization done." << std::endl;
89
+
90
+ if (verbose_visualize) {
91
+ auto visualize_size = m_initial.size();
92
+ cv::Mat source_visualize(visualize_size, m_initial.image().type());
93
+ cv::resize(source.image(), source_visualize, visualize_size);
94
+ cv::imshow("Source", source_visualize);
95
+ cv::Mat target_visualize(visualize_size, m_initial.image().type());
96
+ cv::resize(target.image(), target_visualize, visualize_size);
97
+ cv::imshow("Target", target_visualize);
98
+ cv::waitKey(0);
99
+ }
100
+
101
+ target = _expectation_maximization(source, target, level, verbose);
102
+ }
103
+
104
+ return target.image();
105
+ }
106
+
107
+ // EM-Like algorithm (see "PatchMatch" - page 6).
108
+ // Returns a double sized target image (unless level = 0).
109
+ MaskedImage Inpainting::_expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose) {
110
+ const int nr_iters_em = 1 + 2 * level;
111
+ const int nr_iters_nnf = static_cast<int>(std::min(7, 1 + level));
112
+ const int patch_size = m_distance_metric->patch_size();
113
+
114
+ MaskedImage new_source, new_target;
115
+
116
+ for (int iter_em = 0; iter_em < nr_iters_em; ++iter_em) {
117
+ if (iter_em != 0) {
118
+ m_source2target.set_target(new_target);
119
+ m_target2source.set_source(new_target);
120
+ target = new_target;
121
+ }
122
+
123
+ if (verbose) std::cerr << "EM Iteration: " << iter_em << std::endl;
124
+
125
+ auto size = source.size();
126
+ for (int i = 0; i < size.height; ++i) {
127
+ for (int j = 0; j < size.width; ++j) {
128
+ if (!source.contains_mask(i, j, patch_size)) {
129
+ m_source2target.set_identity(i, j);
130
+ m_target2source.set_identity(i, j);
131
+ }
132
+ }
133
+ }
134
+ if (verbose) std::cerr << " NNF minimization started." << std::endl;
135
+ m_source2target.minimize(nr_iters_nnf);
136
+ m_target2source.minimize(nr_iters_nnf);
137
+ if (verbose) std::cerr << " NNF minimization finished." << std::endl;
138
+
139
+ // Instead of upsizing the final target, we build the last target from the next level source image.
140
+ // Thus, the final target is less blurry (see "Space-Time Video Completion" - page 5).
141
+ bool upscaled = false;
142
+ if (level >= 1 && iter_em == nr_iters_em - 1) {
143
+ new_source = m_pyramid[level - 1];
144
+ new_target = target.upsample(new_source.size().width, new_source.size().height, m_pyramid[level - 1].global_mask());
145
+ upscaled = true;
146
+ } else {
147
+ new_source = m_pyramid[level];
148
+ new_target = target.clone();
149
+ }
150
+
151
+ auto vote = cv::Mat(new_target.size(), CV_64FC4);
152
+ vote.setTo(cv::Scalar::all(0));
153
+
154
+ // Votes for best patch from NNF Source->Target (completeness) and Target->Source (coherence).
155
+ _expectation_step(m_source2target, 1, vote, new_source, upscaled);
156
+ if (verbose) std::cerr << " Expectation source to target finished." << std::endl;
157
+ _expectation_step(m_target2source, 0, vote, new_source, upscaled);
158
+ if (verbose) std::cerr << " Expectation target to source finished." << std::endl;
159
+
160
+ // Compile votes and update pixel values.
161
+ _maximization_step(new_target, vote);
162
+ if (verbose) std::cerr << " Minimization step finished." << std::endl;
163
+ }
164
+
165
+ return new_target;
166
+ }
167
+
168
+ // Expectation step: vote for best estimations of each pixel.
169
+ void Inpainting::_expectation_step(
170
+ const NearestNeighborField &nnf, bool source2target,
171
+ cv::Mat &vote, const MaskedImage &source, bool upscaled
172
+ ) {
173
+ auto source_size = nnf.source_size();
174
+ auto target_size = nnf.target_size();
175
+ const int patch_size = m_distance_metric->patch_size();
176
+
177
+ for (int i = 0; i < source_size.height; ++i) {
178
+ for (int j = 0; j < source_size.width; ++j) {
179
+ if (nnf.source().is_globally_masked(i, j)) continue;
180
+ int yp = nnf.at(i, j, 0), xp = nnf.at(i, j, 1), dp = nnf.at(i, j, 2);
181
+ double w = kDistance2Similarity[dp];
182
+
183
+ for (int di = -patch_size; di <= patch_size; ++di) {
184
+ for (int dj = -patch_size; dj <= patch_size; ++dj) {
185
+ int ys = i + di, xs = j + dj, yt = yp + di, xt = xp + dj;
186
+ if (!(ys >= 0 && ys < source_size.height && xs >= 0 && xs < source_size.width)) continue;
187
+ if (nnf.source().is_globally_masked(ys, xs)) continue;
188
+ if (!(yt >= 0 && yt < target_size.height && xt >= 0 && xt < target_size.width)) continue;
189
+ if (nnf.target().is_globally_masked(yt, xt)) continue;
190
+
191
+ if (!source2target) {
192
+ std::swap(ys, yt);
193
+ std::swap(xs, xt);
194
+ }
195
+
196
+ if (upscaled) {
197
+ for (int uy = 0; uy < 2; ++uy) {
198
+ for (int ux = 0; ux < 2; ++ux) {
199
+ _weighted_copy(source, 2 * ys + uy, 2 * xs + ux, vote, 2 * yt + uy, 2 * xt + ux, w);
200
+ }
201
+ }
202
+ } else {
203
+ _weighted_copy(source, ys, xs, vote, yt, xt, w);
204
+ }
205
+ }
206
+ }
207
+ }
208
+ }
209
+ }
210
+
211
+ // Maximization Step: maximum likelihood of target pixel.
212
+ void Inpainting::_maximization_step(MaskedImage &target, const cv::Mat &vote) {
213
+ auto target_size = target.size();
214
+ for (int i = 0; i < target_size.height; ++i) {
215
+ for (int j = 0; j < target_size.width; ++j) {
216
+ const double *source_ptr = vote.ptr<double>(i, j);
217
+ unsigned char *target_ptr = target.get_mutable_image(i, j);
218
+
219
+ if (target.is_globally_masked(i, j)) {
220
+ continue;
221
+ }
222
+
223
+ if (source_ptr[3] > 0) {
224
+ unsigned char r = cv::saturate_cast<unsigned char>(source_ptr[0] / source_ptr[3]);
225
+ unsigned char g = cv::saturate_cast<unsigned char>(source_ptr[1] / source_ptr[3]);
226
+ unsigned char b = cv::saturate_cast<unsigned char>(source_ptr[2] / source_ptr[3]);
227
+ target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
228
+ } else {
229
+ target.set_mask(i, j, 0);
230
+ }
231
+ }
232
+ }
233
+ }
234
+
PyPatchMatch/csrc/inpaint.h CHANGED
@@ -1,27 +1,27 @@
1
- #pragma once
2
-
3
- #include <vector>
4
-
5
- #include "masked_image.h"
6
- #include "nnf.h"
7
-
8
- class Inpainting {
9
- public:
10
- Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric);
11
- Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric);
12
- cv::Mat run(bool verbose = false, bool verbose_visualize = false, unsigned int random_seed = 1212);
13
-
14
- private:
15
- void _initialize_pyramid(void);
16
- MaskedImage _expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose);
17
- void _expectation_step(const NearestNeighborField &nnf, bool source2target, cv::Mat &vote, const MaskedImage &source, bool upscaled);
18
- void _maximization_step(MaskedImage &target, const cv::Mat &vote);
19
-
20
- MaskedImage m_initial;
21
- std::vector<MaskedImage> m_pyramid;
22
-
23
- NearestNeighborField m_source2target;
24
- NearestNeighborField m_target2source;
25
- const PatchDistanceMetric *m_distance_metric;
26
- };
27
-
 
1
+ #pragma once
2
+
3
+ #include <vector>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+
8
+ class Inpainting {
9
+ public:
10
+ Inpainting(cv::Mat image, cv::Mat mask, const PatchDistanceMetric *metric);
11
+ Inpainting(cv::Mat image, cv::Mat mask, cv::Mat global_mask, const PatchDistanceMetric *metric);
12
+ cv::Mat run(bool verbose = false, bool verbose_visualize = false, unsigned int random_seed = 1212);
13
+
14
+ private:
15
+ void _initialize_pyramid(void);
16
+ MaskedImage _expectation_maximization(MaskedImage source, MaskedImage target, int level, bool verbose);
17
+ void _expectation_step(const NearestNeighborField &nnf, bool source2target, cv::Mat &vote, const MaskedImage &source, bool upscaled);
18
+ void _maximization_step(MaskedImage &target, const cv::Mat &vote);
19
+
20
+ MaskedImage m_initial;
21
+ std::vector<MaskedImage> m_pyramid;
22
+
23
+ NearestNeighborField m_source2target;
24
+ NearestNeighborField m_target2source;
25
+ const PatchDistanceMetric *m_distance_metric;
26
+ };
27
+
PyPatchMatch/csrc/masked_image.cpp CHANGED
@@ -1,138 +1,138 @@
1
- #include "masked_image.h"
2
- #include <algorithm>
3
- #include <iostream>
4
-
5
- const cv::Size MaskedImage::kDownsampleKernelSize = cv::Size(6, 6);
6
- const int MaskedImage::kDownsampleKernel[6] = {1, 5, 10, 10, 5, 1};
7
-
8
- bool MaskedImage::contains_mask(int y, int x, int patch_size) const {
9
- auto mask_size = size();
10
- for (int dy = -patch_size; dy <= patch_size; ++dy) {
11
- for (int dx = -patch_size; dx <= patch_size; ++dx) {
12
- int yy = y + dy, xx = x + dx;
13
- if (yy >= 0 && yy < mask_size.height && xx >= 0 && xx < mask_size.width) {
14
- if (is_masked(yy, xx) && !is_globally_masked(yy, xx)) return true;
15
- }
16
- }
17
- }
18
- return false;
19
- }
20
-
21
- MaskedImage MaskedImage::downsample() const {
22
- const auto &kernel_size = MaskedImage::kDownsampleKernelSize;
23
- const auto &kernel = MaskedImage::kDownsampleKernel;
24
-
25
- const auto size = this->size();
26
- const auto new_size = cv::Size(size.width / 2, size.height / 2);
27
-
28
- auto ret = MaskedImage(new_size.width, new_size.height);
29
- if (!m_global_mask.empty()) ret.init_global_mask_mat();
30
- for (int y = 0; y < size.height - 1; y += 2) {
31
- for (int x = 0; x < size.width - 1; x += 2) {
32
- int r = 0, g = 0, b = 0, ksum = 0;
33
- bool is_gmasked = true;
34
-
35
- for (int dy = -kernel_size.height / 2 + 1; dy <= kernel_size.height / 2; ++dy) {
36
- for (int dx = -kernel_size.width / 2 + 1; dx <= kernel_size.width / 2; ++dx) {
37
- int yy = y + dy, xx = x + dx;
38
- if (yy >= 0 && yy < size.height && xx >= 0 && xx < size.width) {
39
- if (!is_globally_masked(yy, xx)) {
40
- is_gmasked = false;
41
- }
42
- if (!is_masked(yy, xx)) {
43
- auto source_ptr = get_image(yy, xx);
44
- int k = kernel[kernel_size.height / 2 - 1 + dy] * kernel[kernel_size.width / 2 - 1 + dx];
45
- r += source_ptr[0] * k, g += source_ptr[1] * k, b += source_ptr[2] * k;
46
- ksum += k;
47
- }
48
- }
49
- }
50
- }
51
-
52
- if (ksum > 0) r /= ksum, g /= ksum, b /= ksum;
53
-
54
- if (!m_global_mask.empty()) {
55
- ret.set_global_mask(y / 2, x / 2, is_gmasked);
56
- }
57
- if (ksum > 0) {
58
- auto target_ptr = ret.get_mutable_image(y / 2, x / 2);
59
- target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
60
- ret.set_mask(y / 2, x / 2, 0);
61
- } else {
62
- ret.set_mask(y / 2, x / 2, 1);
63
- }
64
- }
65
- }
66
-
67
- return ret;
68
- }
69
-
70
- MaskedImage MaskedImage::upsample(int new_w, int new_h) const {
71
- const auto size = this->size();
72
- auto ret = MaskedImage(new_w, new_h);
73
- if (!m_global_mask.empty()) ret.init_global_mask_mat();
74
- for (int y = 0; y < new_h; ++y) {
75
- for (int x = 0; x < new_w; ++x) {
76
- int yy = y * size.height / new_h;
77
- int xx = x * size.width / new_w;
78
-
79
- if (is_globally_masked(yy, xx)) {
80
- ret.set_global_mask(y, x, 1);
81
- ret.set_mask(y, x, 1);
82
- } else {
83
- if (!m_global_mask.empty()) ret.set_global_mask(y, x, 0);
84
-
85
- if (is_masked(yy, xx)) {
86
- ret.set_mask(y, x, 1);
87
- } else {
88
- auto source_ptr = get_image(yy, xx);
89
- auto target_ptr = ret.get_mutable_image(y, x);
90
- for (int c = 0; c < 3; ++c)
91
- target_ptr[c] = source_ptr[c];
92
- ret.set_mask(y, x, 0);
93
- }
94
- }
95
- }
96
- }
97
-
98
- return ret;
99
- }
100
-
101
- MaskedImage MaskedImage::upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const {
102
- auto ret = upsample(new_w, new_h);
103
- ret.set_global_mask_mat(new_global_mask);
104
- return ret;
105
- }
106
-
107
- void MaskedImage::compute_image_gradients() {
108
- if (m_image_grad_computed) {
109
- return;
110
- }
111
-
112
- const auto size = m_image.size();
113
- m_image_grady = cv::Mat(size, CV_8UC3);
114
- m_image_gradx = cv::Mat(size, CV_8UC3);
115
- m_image_grady = cv::Scalar::all(0);
116
- m_image_gradx = cv::Scalar::all(0);
117
-
118
- for (int i = 1; i < size.height - 1; ++i) {
119
- const auto *ptr = m_image.ptr<unsigned char>(i, 0);
120
- const auto *ptry1 = m_image.ptr<unsigned char>(i + 1, 0);
121
- const auto *ptry2 = m_image.ptr<unsigned char>(i - 1, 0);
122
- const auto *ptrx1 = m_image.ptr<unsigned char>(i, 0) + 3;
123
- const auto *ptrx2 = m_image.ptr<unsigned char>(i, 0) - 3;
124
- auto *mptry = m_image_grady.ptr<unsigned char>(i, 0);
125
- auto *mptrx = m_image_gradx.ptr<unsigned char>(i, 0);
126
- for (int j = 3; j < size.width * 3 - 3; ++j) {
127
- mptry[j] = (ptry1[j] / 2 - ptry2[j] / 2) + 128;
128
- mptrx[j] = (ptrx1[j] / 2 - ptrx2[j] / 2) + 128;
129
- }
130
- }
131
-
132
- m_image_grad_computed = true;
133
- }
134
-
135
- void MaskedImage::compute_image_gradients() const {
136
- const_cast<MaskedImage *>(this)->compute_image_gradients();
137
- }
138
-
 
1
+ #include "masked_image.h"
2
+ #include <algorithm>
3
+ #include <iostream>
4
+
5
+ const cv::Size MaskedImage::kDownsampleKernelSize = cv::Size(6, 6);
6
+ const int MaskedImage::kDownsampleKernel[6] = {1, 5, 10, 10, 5, 1};
7
+
8
+ bool MaskedImage::contains_mask(int y, int x, int patch_size) const {
9
+ auto mask_size = size();
10
+ for (int dy = -patch_size; dy <= patch_size; ++dy) {
11
+ for (int dx = -patch_size; dx <= patch_size; ++dx) {
12
+ int yy = y + dy, xx = x + dx;
13
+ if (yy >= 0 && yy < mask_size.height && xx >= 0 && xx < mask_size.width) {
14
+ if (is_masked(yy, xx) && !is_globally_masked(yy, xx)) return true;
15
+ }
16
+ }
17
+ }
18
+ return false;
19
+ }
20
+
21
+ MaskedImage MaskedImage::downsample() const {
22
+ const auto &kernel_size = MaskedImage::kDownsampleKernelSize;
23
+ const auto &kernel = MaskedImage::kDownsampleKernel;
24
+
25
+ const auto size = this->size();
26
+ const auto new_size = cv::Size(size.width / 2, size.height / 2);
27
+
28
+ auto ret = MaskedImage(new_size.width, new_size.height);
29
+ if (!m_global_mask.empty()) ret.init_global_mask_mat();
30
+ for (int y = 0; y < size.height - 1; y += 2) {
31
+ for (int x = 0; x < size.width - 1; x += 2) {
32
+ int r = 0, g = 0, b = 0, ksum = 0;
33
+ bool is_gmasked = true;
34
+
35
+ for (int dy = -kernel_size.height / 2 + 1; dy <= kernel_size.height / 2; ++dy) {
36
+ for (int dx = -kernel_size.width / 2 + 1; dx <= kernel_size.width / 2; ++dx) {
37
+ int yy = y + dy, xx = x + dx;
38
+ if (yy >= 0 && yy < size.height && xx >= 0 && xx < size.width) {
39
+ if (!is_globally_masked(yy, xx)) {
40
+ is_gmasked = false;
41
+ }
42
+ if (!is_masked(yy, xx)) {
43
+ auto source_ptr = get_image(yy, xx);
44
+ int k = kernel[kernel_size.height / 2 - 1 + dy] * kernel[kernel_size.width / 2 - 1 + dx];
45
+ r += source_ptr[0] * k, g += source_ptr[1] * k, b += source_ptr[2] * k;
46
+ ksum += k;
47
+ }
48
+ }
49
+ }
50
+ }
51
+
52
+ if (ksum > 0) r /= ksum, g /= ksum, b /= ksum;
53
+
54
+ if (!m_global_mask.empty()) {
55
+ ret.set_global_mask(y / 2, x / 2, is_gmasked);
56
+ }
57
+ if (ksum > 0) {
58
+ auto target_ptr = ret.get_mutable_image(y / 2, x / 2);
59
+ target_ptr[0] = r, target_ptr[1] = g, target_ptr[2] = b;
60
+ ret.set_mask(y / 2, x / 2, 0);
61
+ } else {
62
+ ret.set_mask(y / 2, x / 2, 1);
63
+ }
64
+ }
65
+ }
66
+
67
+ return ret;
68
+ }
69
+
70
+ MaskedImage MaskedImage::upsample(int new_w, int new_h) const {
71
+ const auto size = this->size();
72
+ auto ret = MaskedImage(new_w, new_h);
73
+ if (!m_global_mask.empty()) ret.init_global_mask_mat();
74
+ for (int y = 0; y < new_h; ++y) {
75
+ for (int x = 0; x < new_w; ++x) {
76
+ int yy = y * size.height / new_h;
77
+ int xx = x * size.width / new_w;
78
+
79
+ if (is_globally_masked(yy, xx)) {
80
+ ret.set_global_mask(y, x, 1);
81
+ ret.set_mask(y, x, 1);
82
+ } else {
83
+ if (!m_global_mask.empty()) ret.set_global_mask(y, x, 0);
84
+
85
+ if (is_masked(yy, xx)) {
86
+ ret.set_mask(y, x, 1);
87
+ } else {
88
+ auto source_ptr = get_image(yy, xx);
89
+ auto target_ptr = ret.get_mutable_image(y, x);
90
+ for (int c = 0; c < 3; ++c)
91
+ target_ptr[c] = source_ptr[c];
92
+ ret.set_mask(y, x, 0);
93
+ }
94
+ }
95
+ }
96
+ }
97
+
98
+ return ret;
99
+ }
100
+
101
+ MaskedImage MaskedImage::upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const {
102
+ auto ret = upsample(new_w, new_h);
103
+ ret.set_global_mask_mat(new_global_mask);
104
+ return ret;
105
+ }
106
+
107
+ void MaskedImage::compute_image_gradients() {
108
+ if (m_image_grad_computed) {
109
+ return;
110
+ }
111
+
112
+ const auto size = m_image.size();
113
+ m_image_grady = cv::Mat(size, CV_8UC3);
114
+ m_image_gradx = cv::Mat(size, CV_8UC3);
115
+ m_image_grady = cv::Scalar::all(0);
116
+ m_image_gradx = cv::Scalar::all(0);
117
+
118
+ for (int i = 1; i < size.height - 1; ++i) {
119
+ const auto *ptr = m_image.ptr<unsigned char>(i, 0);
120
+ const auto *ptry1 = m_image.ptr<unsigned char>(i + 1, 0);
121
+ const auto *ptry2 = m_image.ptr<unsigned char>(i - 1, 0);
122
+ const auto *ptrx1 = m_image.ptr<unsigned char>(i, 0) + 3;
123
+ const auto *ptrx2 = m_image.ptr<unsigned char>(i, 0) - 3;
124
+ auto *mptry = m_image_grady.ptr<unsigned char>(i, 0);
125
+ auto *mptrx = m_image_gradx.ptr<unsigned char>(i, 0);
126
+ for (int j = 3; j < size.width * 3 - 3; ++j) {
127
+ mptry[j] = (ptry1[j] / 2 - ptry2[j] / 2) + 128;
128
+ mptrx[j] = (ptrx1[j] / 2 - ptrx2[j] / 2) + 128;
129
+ }
130
+ }
131
+
132
+ m_image_grad_computed = true;
133
+ }
134
+
135
+ void MaskedImage::compute_image_gradients() const {
136
+ const_cast<MaskedImage *>(this)->compute_image_gradients();
137
+ }
138
+
PyPatchMatch/csrc/masked_image.h CHANGED
@@ -1,112 +1,112 @@
1
- #pragma once
2
-
3
- #include <opencv2/core.hpp>
4
-
5
- class MaskedImage {
6
- public:
7
- MaskedImage() : m_image(), m_mask(), m_global_mask(), m_image_grady(), m_image_gradx(), m_image_grad_computed(false) {
8
- // pass
9
- }
10
- MaskedImage(cv::Mat image, cv::Mat mask) : m_image(image), m_mask(mask), m_image_grad_computed(false) {
11
- // pass
12
- }
13
- MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask) : m_image(image), m_mask(mask), m_global_mask(global_mask), m_image_grad_computed(false) {
14
- // pass
15
- }
16
- MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask, cv::Mat grady, cv::Mat gradx, bool grad_computed) :
17
- m_image(image), m_mask(mask), m_global_mask(global_mask),
18
- m_image_grady(grady), m_image_gradx(gradx), m_image_grad_computed(grad_computed) {
19
- // pass
20
- }
21
- MaskedImage(int width, int height) : m_global_mask(), m_image_grady(), m_image_gradx() {
22
- m_image = cv::Mat(cv::Size(width, height), CV_8UC3);
23
- m_image = cv::Scalar::all(0);
24
-
25
- m_mask = cv::Mat(cv::Size(width, height), CV_8U);
26
- m_mask = cv::Scalar::all(0);
27
- }
28
- inline MaskedImage clone() {
29
- return MaskedImage(
30
- m_image.clone(), m_mask.clone(), m_global_mask.clone(),
31
- m_image_grady.clone(), m_image_gradx.clone(), m_image_grad_computed
32
- );
33
- }
34
-
35
- inline cv::Size size() const {
36
- return m_image.size();
37
- }
38
- inline const cv::Mat &image() const {
39
- return m_image;
40
- }
41
- inline const cv::Mat &mask() const {
42
- return m_mask;
43
- }
44
- inline const cv::Mat &global_mask() const {
45
- return m_global_mask;
46
- }
47
- inline const cv::Mat &grady() const {
48
- assert(m_image_grad_computed);
49
- return m_image_grady;
50
- }
51
- inline const cv::Mat &gradx() const {
52
- assert(m_image_grad_computed);
53
- return m_image_gradx;
54
- }
55
-
56
- inline void init_global_mask_mat() {
57
- m_global_mask = cv::Mat(m_mask.size(), CV_8U);
58
- m_global_mask.setTo(cv::Scalar(0));
59
- }
60
- inline void set_global_mask_mat(const cv::Mat &other) {
61
- m_global_mask = other;
62
- }
63
-
64
- inline bool is_masked(int y, int x) const {
65
- return static_cast<bool>(m_mask.at<unsigned char>(y, x));
66
- }
67
- inline bool is_globally_masked(int y, int x) const {
68
- return !m_global_mask.empty() && static_cast<bool>(m_global_mask.at<unsigned char>(y, x));
69
- }
70
- inline void set_mask(int y, int x, bool value) {
71
- m_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
72
- }
73
- inline void set_global_mask(int y, int x, bool value) {
74
- m_global_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
75
- }
76
- inline void clear_mask() {
77
- m_mask.setTo(cv::Scalar(0));
78
- }
79
-
80
- inline const unsigned char *get_image(int y, int x) const {
81
- return m_image.ptr<unsigned char>(y, x);
82
- }
83
- inline unsigned char *get_mutable_image(int y, int x) {
84
- return m_image.ptr<unsigned char>(y, x);
85
- }
86
-
87
- inline unsigned char get_image(int y, int x, int c) const {
88
- return m_image.ptr<unsigned char>(y, x)[c];
89
- }
90
- inline int get_image_int(int y, int x, int c) const {
91
- return static_cast<int>(m_image.ptr<unsigned char>(y, x)[c]);
92
- }
93
-
94
- bool contains_mask(int y, int x, int patch_size) const;
95
- MaskedImage downsample() const;
96
- MaskedImage upsample(int new_w, int new_h) const;
97
- MaskedImage upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const;
98
- void compute_image_gradients();
99
- void compute_image_gradients() const;
100
-
101
- static const cv::Size kDownsampleKernelSize;
102
- static const int kDownsampleKernel[6];
103
-
104
- private:
105
- cv::Mat m_image;
106
- cv::Mat m_mask;
107
- cv::Mat m_global_mask;
108
- cv::Mat m_image_grady;
109
- cv::Mat m_image_gradx;
110
- bool m_image_grad_computed = false;
111
- };
112
-
 
1
+ #pragma once
2
+
3
+ #include <opencv2/core.hpp>
4
+
5
+ class MaskedImage {
6
+ public:
7
+ MaskedImage() : m_image(), m_mask(), m_global_mask(), m_image_grady(), m_image_gradx(), m_image_grad_computed(false) {
8
+ // pass
9
+ }
10
+ MaskedImage(cv::Mat image, cv::Mat mask) : m_image(image), m_mask(mask), m_image_grad_computed(false) {
11
+ // pass
12
+ }
13
+ MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask) : m_image(image), m_mask(mask), m_global_mask(global_mask), m_image_grad_computed(false) {
14
+ // pass
15
+ }
16
+ MaskedImage(cv::Mat image, cv::Mat mask, cv::Mat global_mask, cv::Mat grady, cv::Mat gradx, bool grad_computed) :
17
+ m_image(image), m_mask(mask), m_global_mask(global_mask),
18
+ m_image_grady(grady), m_image_gradx(gradx), m_image_grad_computed(grad_computed) {
19
+ // pass
20
+ }
21
+ MaskedImage(int width, int height) : m_global_mask(), m_image_grady(), m_image_gradx() {
22
+ m_image = cv::Mat(cv::Size(width, height), CV_8UC3);
23
+ m_image = cv::Scalar::all(0);
24
+
25
+ m_mask = cv::Mat(cv::Size(width, height), CV_8U);
26
+ m_mask = cv::Scalar::all(0);
27
+ }
28
+ inline MaskedImage clone() {
29
+ return MaskedImage(
30
+ m_image.clone(), m_mask.clone(), m_global_mask.clone(),
31
+ m_image_grady.clone(), m_image_gradx.clone(), m_image_grad_computed
32
+ );
33
+ }
34
+
35
+ inline cv::Size size() const {
36
+ return m_image.size();
37
+ }
38
+ inline const cv::Mat &image() const {
39
+ return m_image;
40
+ }
41
+ inline const cv::Mat &mask() const {
42
+ return m_mask;
43
+ }
44
+ inline const cv::Mat &global_mask() const {
45
+ return m_global_mask;
46
+ }
47
+ inline const cv::Mat &grady() const {
48
+ assert(m_image_grad_computed);
49
+ return m_image_grady;
50
+ }
51
+ inline const cv::Mat &gradx() const {
52
+ assert(m_image_grad_computed);
53
+ return m_image_gradx;
54
+ }
55
+
56
+ inline void init_global_mask_mat() {
57
+ m_global_mask = cv::Mat(m_mask.size(), CV_8U);
58
+ m_global_mask.setTo(cv::Scalar(0));
59
+ }
60
+ inline void set_global_mask_mat(const cv::Mat &other) {
61
+ m_global_mask = other;
62
+ }
63
+
64
+ inline bool is_masked(int y, int x) const {
65
+ return static_cast<bool>(m_mask.at<unsigned char>(y, x));
66
+ }
67
+ inline bool is_globally_masked(int y, int x) const {
68
+ return !m_global_mask.empty() && static_cast<bool>(m_global_mask.at<unsigned char>(y, x));
69
+ }
70
+ inline void set_mask(int y, int x, bool value) {
71
+ m_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
72
+ }
73
+ inline void set_global_mask(int y, int x, bool value) {
74
+ m_global_mask.at<unsigned char>(y, x) = static_cast<unsigned char>(value);
75
+ }
76
+ inline void clear_mask() {
77
+ m_mask.setTo(cv::Scalar(0));
78
+ }
79
+
80
+ inline const unsigned char *get_image(int y, int x) const {
81
+ return m_image.ptr<unsigned char>(y, x);
82
+ }
83
+ inline unsigned char *get_mutable_image(int y, int x) {
84
+ return m_image.ptr<unsigned char>(y, x);
85
+ }
86
+
87
+ inline unsigned char get_image(int y, int x, int c) const {
88
+ return m_image.ptr<unsigned char>(y, x)[c];
89
+ }
90
+ inline int get_image_int(int y, int x, int c) const {
91
+ return static_cast<int>(m_image.ptr<unsigned char>(y, x)[c]);
92
+ }
93
+
94
+ bool contains_mask(int y, int x, int patch_size) const;
95
+ MaskedImage downsample() const;
96
+ MaskedImage upsample(int new_w, int new_h) const;
97
+ MaskedImage upsample(int new_w, int new_h, const cv::Mat &new_global_mask) const;
98
+ void compute_image_gradients();
99
+ void compute_image_gradients() const;
100
+
101
+ static const cv::Size kDownsampleKernelSize;
102
+ static const int kDownsampleKernel[6];
103
+
104
+ private:
105
+ cv::Mat m_image;
106
+ cv::Mat m_mask;
107
+ cv::Mat m_global_mask;
108
+ cv::Mat m_image_grady;
109
+ cv::Mat m_image_gradx;
110
+ bool m_image_grad_computed = false;
111
+ };
112
+
PyPatchMatch/csrc/nnf.cpp CHANGED
@@ -1,268 +1,268 @@
1
- #include <algorithm>
2
- #include <iostream>
3
- #include <cmath>
4
-
5
- #include "masked_image.h"
6
- #include "nnf.h"
7
-
8
- /**
9
- * Nearest-Neighbor Field (see PatchMatch algorithm).
10
- * This algorithme uses a version proposed by Xavier Philippeau.
11
- *
12
- */
13
-
14
- template <typename T>
15
- T clamp(T value, T min_value, T max_value) {
16
- return std::min(std::max(value, min_value), max_value);
17
- }
18
-
19
- void NearestNeighborField::_randomize_field(int max_retry, bool reset) {
20
- auto this_size = source_size();
21
- for (int i = 0; i < this_size.height; ++i) {
22
- for (int j = 0; j < this_size.width; ++j) {
23
- if (m_source.is_globally_masked(i, j)) continue;
24
-
25
- auto this_ptr = mutable_ptr(i, j);
26
- int distance = reset ? PatchDistanceMetric::kDistanceScale : this_ptr[2];
27
- if (distance < PatchDistanceMetric::kDistanceScale) {
28
- continue;
29
- }
30
-
31
- int i_target = 0, j_target = 0;
32
- for (int t = 0; t < max_retry; ++t) {
33
- i_target = rand() % this_size.height;
34
- j_target = rand() % this_size.width;
35
- if (m_target.is_globally_masked(i_target, j_target)) continue;
36
-
37
- distance = _distance(i, j, i_target, j_target);
38
- if (distance < PatchDistanceMetric::kDistanceScale)
39
- break;
40
- }
41
-
42
- this_ptr[0] = i_target, this_ptr[1] = j_target, this_ptr[2] = distance;
43
- }
44
- }
45
- }
46
-
47
- void NearestNeighborField::_initialize_field_from(const NearestNeighborField &other, int max_retry) {
48
- const auto &this_size = source_size();
49
- const auto &other_size = other.source_size();
50
- double fi = static_cast<double>(this_size.height) / other_size.height;
51
- double fj = static_cast<double>(this_size.width) / other_size.width;
52
-
53
- for (int i = 0; i < this_size.height; ++i) {
54
- for (int j = 0; j < this_size.width; ++j) {
55
- if (m_source.is_globally_masked(i, j)) continue;
56
-
57
- int ilow = static_cast<int>(std::min(i / fi, static_cast<double>(other_size.height - 1)));
58
- int jlow = static_cast<int>(std::min(j / fj, static_cast<double>(other_size.width - 1)));
59
- auto this_value = mutable_ptr(i, j);
60
- auto other_value = other.ptr(ilow, jlow);
61
-
62
- this_value[0] = static_cast<int>(other_value[0] * fi);
63
- this_value[1] = static_cast<int>(other_value[1] * fj);
64
- this_value[2] = _distance(i, j, this_value[0], this_value[1]);
65
- }
66
- }
67
-
68
- _randomize_field(max_retry, false);
69
- }
70
-
71
- void NearestNeighborField::minimize(int nr_pass) {
72
- const auto &this_size = source_size();
73
- while (nr_pass--) {
74
- for (int i = 0; i < this_size.height; ++i)
75
- for (int j = 0; j < this_size.width; ++j) {
76
- if (m_source.is_globally_masked(i, j)) continue;
77
- if (at(i, j, 2) > 0) _minimize_link(i, j, +1);
78
- }
79
- for (int i = this_size.height - 1; i >= 0; --i)
80
- for (int j = this_size.width - 1; j >= 0; --j) {
81
- if (m_source.is_globally_masked(i, j)) continue;
82
- if (at(i, j, 2) > 0) _minimize_link(i, j, -1);
83
- }
84
- }
85
- }
86
-
87
- void NearestNeighborField::_minimize_link(int y, int x, int direction) {
88
- const auto &this_size = source_size();
89
- const auto &this_target_size = target_size();
90
- auto this_ptr = mutable_ptr(y, x);
91
-
92
- // propagation along the y direction.
93
- if (y - direction >= 0 && y - direction < this_size.height && !m_source.is_globally_masked(y - direction, x)) {
94
- int yp = at(y - direction, x, 0) + direction;
95
- int xp = at(y - direction, x, 1);
96
- int dp = _distance(y, x, yp, xp);
97
- if (dp < at(y, x, 2)) {
98
- this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
99
- }
100
- }
101
-
102
- // propagation along the x direction.
103
- if (x - direction >= 0 && x - direction < this_size.width && !m_source.is_globally_masked(y, x - direction)) {
104
- int yp = at(y, x - direction, 0);
105
- int xp = at(y, x - direction, 1) + direction;
106
- int dp = _distance(y, x, yp, xp);
107
- if (dp < at(y, x, 2)) {
108
- this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
109
- }
110
- }
111
-
112
- // random search with a progressive step size.
113
- int random_scale = (std::min(this_target_size.height, this_target_size.width) - 1) / 2;
114
- while (random_scale > 0) {
115
- int yp = this_ptr[0] + (rand() % (2 * random_scale + 1) - random_scale);
116
- int xp = this_ptr[1] + (rand() % (2 * random_scale + 1) - random_scale);
117
- yp = clamp(yp, 0, target_size().height - 1);
118
- xp = clamp(xp, 0, target_size().width - 1);
119
-
120
- if (m_target.is_globally_masked(yp, xp)) {
121
- random_scale /= 2;
122
- }
123
-
124
- int dp = _distance(y, x, yp, xp);
125
- if (dp < at(y, x, 2)) {
126
- this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
127
- }
128
- random_scale /= 2;
129
- }
130
- }
131
-
132
- const int PatchDistanceMetric::kDistanceScale = 65535;
133
- const int PatchSSDDistanceMetric::kSSDScale = 9 * 255 * 255;
134
-
135
- namespace {
136
-
137
- inline int pow2(int i) {
138
- return i * i;
139
- }
140
-
141
- int distance_masked_images(
142
- const MaskedImage &source, int ys, int xs,
143
- const MaskedImage &target, int yt, int xt,
144
- int patch_size
145
- ) {
146
- long double distance = 0;
147
- long double wsum = 0;
148
-
149
- source.compute_image_gradients();
150
- target.compute_image_gradients();
151
-
152
- auto source_size = source.size();
153
- auto target_size = target.size();
154
-
155
- for (int dy = -patch_size; dy <= patch_size; ++dy) {
156
- const int yys = ys + dy, yyt = yt + dy;
157
-
158
- if (yys <= 0 || yys >= source_size.height - 1 || yyt <= 0 || yyt >= target_size.height - 1) {
159
- distance += (long double)(PatchSSDDistanceMetric::kSSDScale) * (2 * patch_size + 1);
160
- wsum += 2 * patch_size + 1;
161
- continue;
162
- }
163
-
164
- const auto *p_si = source.image().ptr<unsigned char>(yys, 0);
165
- const auto *p_ti = target.image().ptr<unsigned char>(yyt, 0);
166
- const auto *p_sm = source.mask().ptr<unsigned char>(yys, 0);
167
- const auto *p_tm = target.mask().ptr<unsigned char>(yyt, 0);
168
-
169
- const unsigned char *p_sgm = nullptr;
170
- const unsigned char *p_tgm = nullptr;
171
- if (!source.global_mask().empty()) {
172
- p_sgm = source.global_mask().ptr<unsigned char>(yys, 0);
173
- p_tgm = target.global_mask().ptr<unsigned char>(yyt, 0);
174
- }
175
-
176
- const auto *p_sgy = source.grady().ptr<unsigned char>(yys, 0);
177
- const auto *p_tgy = target.grady().ptr<unsigned char>(yyt, 0);
178
- const auto *p_sgx = source.gradx().ptr<unsigned char>(yys, 0);
179
- const auto *p_tgx = target.gradx().ptr<unsigned char>(yyt, 0);
180
-
181
- for (int dx = -patch_size; dx <= patch_size; ++dx) {
182
- int xxs = xs + dx, xxt = xt + dx;
183
- wsum += 1;
184
-
185
- if (xxs <= 0 || xxs >= source_size.width - 1 || xxt <= 0 || xxt >= source_size.width - 1) {
186
- distance += PatchSSDDistanceMetric::kSSDScale;
187
- continue;
188
- }
189
-
190
- if (p_sm[xxs] || p_tm[xxt] || (p_sgm && p_sgm[xxs]) || (p_tgm && p_tgm[xxt]) ) {
191
- distance += PatchSSDDistanceMetric::kSSDScale;
192
- continue;
193
- }
194
-
195
- int ssd = 0;
196
- for (int c = 0; c < 3; ++c) {
197
- int s_value = p_si[xxs * 3 + c];
198
- int t_value = p_ti[xxt * 3 + c];
199
- int s_gy = p_sgy[xxs * 3 + c];
200
- int t_gy = p_tgy[xxt * 3 + c];
201
- int s_gx = p_sgx[xxs * 3 + c];
202
- int t_gx = p_tgx[xxt * 3 + c];
203
-
204
- ssd += pow2(static_cast<int>(s_value) - t_value);
205
- ssd += pow2(static_cast<int>(s_gx) - t_gx);
206
- ssd += pow2(static_cast<int>(s_gy) - t_gy);
207
- }
208
- distance += ssd;
209
- }
210
- }
211
-
212
- distance /= (long double)(PatchSSDDistanceMetric::kSSDScale);
213
-
214
- int res = int(PatchDistanceMetric::kDistanceScale * distance / wsum);
215
- if (res < 0 || res > PatchDistanceMetric::kDistanceScale) return PatchDistanceMetric::kDistanceScale;
216
- return res;
217
- }
218
-
219
- }
220
-
221
- int PatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
222
- return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
223
- }
224
-
225
- int DebugPatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
226
- fprintf(stderr, "DebugPatchSSDDistanceMetric: %d %d %d %d\n", source.size().width, source.size().height, m_width, m_height);
227
- return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
228
- }
229
-
230
- int RegularityGuidedPatchDistanceMetricV1::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
231
- double dx = remainder(double(source_x - target_x) / source.size().width, m_dx1);
232
- double dy = remainder(double(source_y - target_y) / source.size().height, m_dy2);
233
-
234
- double score1 = sqrt(dx * dx + dy *dy) / m_scale;
235
- if (score1 < 0 || score1 > 1) score1 = 1;
236
- score1 *= PatchDistanceMetric::kDistanceScale;
237
-
238
- double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
239
- double score = score1 * m_weight + score2 / (1 + m_weight);
240
- return static_cast<int>(score / (1 + m_weight));
241
- }
242
-
243
- int RegularityGuidedPatchDistanceMetricV2::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
244
- if (target_y < 0 || target_y >= target.size().height || target_x < 0 || target_x >= target.size().width)
245
- return PatchDistanceMetric::kDistanceScale;
246
-
247
- int source_scale = m_ijmap.size().height / source.size().height;
248
- int target_scale = m_ijmap.size().height / target.size().height;
249
-
250
- // fprintf(stderr, "RegularityGuidedPatchDistanceMetricV2 %d %d %d %d\n", source_y * source_scale, m_ijmap.size().height, source_x * source_scale, m_ijmap.size().width);
251
-
252
- double score1 = PatchDistanceMetric::kDistanceScale;
253
- if (!source.is_globally_masked(source_y, source_x) && !target.is_globally_masked(target_y, target_x)) {
254
- auto source_ij = m_ijmap.ptr<float>(source_y * source_scale, source_x * source_scale);
255
- auto target_ij = m_ijmap.ptr<float>(target_y * target_scale, target_x * target_scale);
256
-
257
- float di = fabs(source_ij[0] - target_ij[0]); if (di > 0.5) di = 1 - di;
258
- float dj = fabs(source_ij[1] - target_ij[1]); if (dj > 0.5) dj = 1 - dj;
259
- score1 = sqrt(di * di + dj *dj) / 0.707;
260
- if (score1 < 0 || score1 > 1) score1 = 1;
261
- score1 *= PatchDistanceMetric::kDistanceScale;
262
- }
263
-
264
- double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
265
- double score = score1 * m_weight + score2;
266
- return int(score / (1 + m_weight));
267
- }
268
-
 
1
+ #include <algorithm>
2
+ #include <iostream>
3
+ #include <cmath>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+
8
+ /**
9
+ * Nearest-Neighbor Field (see PatchMatch algorithm).
10
+ * This algorithme uses a version proposed by Xavier Philippeau.
11
+ *
12
+ */
13
+
14
+ template <typename T>
15
+ T clamp(T value, T min_value, T max_value) {
16
+ return std::min(std::max(value, min_value), max_value);
17
+ }
18
+
19
+ void NearestNeighborField::_randomize_field(int max_retry, bool reset) {
20
+ auto this_size = source_size();
21
+ for (int i = 0; i < this_size.height; ++i) {
22
+ for (int j = 0; j < this_size.width; ++j) {
23
+ if (m_source.is_globally_masked(i, j)) continue;
24
+
25
+ auto this_ptr = mutable_ptr(i, j);
26
+ int distance = reset ? PatchDistanceMetric::kDistanceScale : this_ptr[2];
27
+ if (distance < PatchDistanceMetric::kDistanceScale) {
28
+ continue;
29
+ }
30
+
31
+ int i_target = 0, j_target = 0;
32
+ for (int t = 0; t < max_retry; ++t) {
33
+ i_target = rand() % this_size.height;
34
+ j_target = rand() % this_size.width;
35
+ if (m_target.is_globally_masked(i_target, j_target)) continue;
36
+
37
+ distance = _distance(i, j, i_target, j_target);
38
+ if (distance < PatchDistanceMetric::kDistanceScale)
39
+ break;
40
+ }
41
+
42
+ this_ptr[0] = i_target, this_ptr[1] = j_target, this_ptr[2] = distance;
43
+ }
44
+ }
45
+ }
46
+
47
+ void NearestNeighborField::_initialize_field_from(const NearestNeighborField &other, int max_retry) {
48
+ const auto &this_size = source_size();
49
+ const auto &other_size = other.source_size();
50
+ double fi = static_cast<double>(this_size.height) / other_size.height;
51
+ double fj = static_cast<double>(this_size.width) / other_size.width;
52
+
53
+ for (int i = 0; i < this_size.height; ++i) {
54
+ for (int j = 0; j < this_size.width; ++j) {
55
+ if (m_source.is_globally_masked(i, j)) continue;
56
+
57
+ int ilow = static_cast<int>(std::min(i / fi, static_cast<double>(other_size.height - 1)));
58
+ int jlow = static_cast<int>(std::min(j / fj, static_cast<double>(other_size.width - 1)));
59
+ auto this_value = mutable_ptr(i, j);
60
+ auto other_value = other.ptr(ilow, jlow);
61
+
62
+ this_value[0] = static_cast<int>(other_value[0] * fi);
63
+ this_value[1] = static_cast<int>(other_value[1] * fj);
64
+ this_value[2] = _distance(i, j, this_value[0], this_value[1]);
65
+ }
66
+ }
67
+
68
+ _randomize_field(max_retry, false);
69
+ }
70
+
71
+ void NearestNeighborField::minimize(int nr_pass) {
72
+ const auto &this_size = source_size();
73
+ while (nr_pass--) {
74
+ for (int i = 0; i < this_size.height; ++i)
75
+ for (int j = 0; j < this_size.width; ++j) {
76
+ if (m_source.is_globally_masked(i, j)) continue;
77
+ if (at(i, j, 2) > 0) _minimize_link(i, j, +1);
78
+ }
79
+ for (int i = this_size.height - 1; i >= 0; --i)
80
+ for (int j = this_size.width - 1; j >= 0; --j) {
81
+ if (m_source.is_globally_masked(i, j)) continue;
82
+ if (at(i, j, 2) > 0) _minimize_link(i, j, -1);
83
+ }
84
+ }
85
+ }
86
+
87
+ void NearestNeighborField::_minimize_link(int y, int x, int direction) {
88
+ const auto &this_size = source_size();
89
+ const auto &this_target_size = target_size();
90
+ auto this_ptr = mutable_ptr(y, x);
91
+
92
+ // propagation along the y direction.
93
+ if (y - direction >= 0 && y - direction < this_size.height && !m_source.is_globally_masked(y - direction, x)) {
94
+ int yp = at(y - direction, x, 0) + direction;
95
+ int xp = at(y - direction, x, 1);
96
+ int dp = _distance(y, x, yp, xp);
97
+ if (dp < at(y, x, 2)) {
98
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
99
+ }
100
+ }
101
+
102
+ // propagation along the x direction.
103
+ if (x - direction >= 0 && x - direction < this_size.width && !m_source.is_globally_masked(y, x - direction)) {
104
+ int yp = at(y, x - direction, 0);
105
+ int xp = at(y, x - direction, 1) + direction;
106
+ int dp = _distance(y, x, yp, xp);
107
+ if (dp < at(y, x, 2)) {
108
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
109
+ }
110
+ }
111
+
112
+ // random search with a progressive step size.
113
+ int random_scale = (std::min(this_target_size.height, this_target_size.width) - 1) / 2;
114
+ while (random_scale > 0) {
115
+ int yp = this_ptr[0] + (rand() % (2 * random_scale + 1) - random_scale);
116
+ int xp = this_ptr[1] + (rand() % (2 * random_scale + 1) - random_scale);
117
+ yp = clamp(yp, 0, target_size().height - 1);
118
+ xp = clamp(xp, 0, target_size().width - 1);
119
+
120
+ if (m_target.is_globally_masked(yp, xp)) {
121
+ random_scale /= 2;
122
+ }
123
+
124
+ int dp = _distance(y, x, yp, xp);
125
+ if (dp < at(y, x, 2)) {
126
+ this_ptr[0] = yp, this_ptr[1] = xp, this_ptr[2] = dp;
127
+ }
128
+ random_scale /= 2;
129
+ }
130
+ }
131
+
132
+ const int PatchDistanceMetric::kDistanceScale = 65535;
133
+ const int PatchSSDDistanceMetric::kSSDScale = 9 * 255 * 255;
134
+
135
+ namespace {
136
+
137
+ inline int pow2(int i) {
138
+ return i * i;
139
+ }
140
+
141
+ int distance_masked_images(
142
+ const MaskedImage &source, int ys, int xs,
143
+ const MaskedImage &target, int yt, int xt,
144
+ int patch_size
145
+ ) {
146
+ long double distance = 0;
147
+ long double wsum = 0;
148
+
149
+ source.compute_image_gradients();
150
+ target.compute_image_gradients();
151
+
152
+ auto source_size = source.size();
153
+ auto target_size = target.size();
154
+
155
+ for (int dy = -patch_size; dy <= patch_size; ++dy) {
156
+ const int yys = ys + dy, yyt = yt + dy;
157
+
158
+ if (yys <= 0 || yys >= source_size.height - 1 || yyt <= 0 || yyt >= target_size.height - 1) {
159
+ distance += (long double)(PatchSSDDistanceMetric::kSSDScale) * (2 * patch_size + 1);
160
+ wsum += 2 * patch_size + 1;
161
+ continue;
162
+ }
163
+
164
+ const auto *p_si = source.image().ptr<unsigned char>(yys, 0);
165
+ const auto *p_ti = target.image().ptr<unsigned char>(yyt, 0);
166
+ const auto *p_sm = source.mask().ptr<unsigned char>(yys, 0);
167
+ const auto *p_tm = target.mask().ptr<unsigned char>(yyt, 0);
168
+
169
+ const unsigned char *p_sgm = nullptr;
170
+ const unsigned char *p_tgm = nullptr;
171
+ if (!source.global_mask().empty()) {
172
+ p_sgm = source.global_mask().ptr<unsigned char>(yys, 0);
173
+ p_tgm = target.global_mask().ptr<unsigned char>(yyt, 0);
174
+ }
175
+
176
+ const auto *p_sgy = source.grady().ptr<unsigned char>(yys, 0);
177
+ const auto *p_tgy = target.grady().ptr<unsigned char>(yyt, 0);
178
+ const auto *p_sgx = source.gradx().ptr<unsigned char>(yys, 0);
179
+ const auto *p_tgx = target.gradx().ptr<unsigned char>(yyt, 0);
180
+
181
+ for (int dx = -patch_size; dx <= patch_size; ++dx) {
182
+ int xxs = xs + dx, xxt = xt + dx;
183
+ wsum += 1;
184
+
185
+ if (xxs <= 0 || xxs >= source_size.width - 1 || xxt <= 0 || xxt >= source_size.width - 1) {
186
+ distance += PatchSSDDistanceMetric::kSSDScale;
187
+ continue;
188
+ }
189
+
190
+ if (p_sm[xxs] || p_tm[xxt] || (p_sgm && p_sgm[xxs]) || (p_tgm && p_tgm[xxt]) ) {
191
+ distance += PatchSSDDistanceMetric::kSSDScale;
192
+ continue;
193
+ }
194
+
195
+ int ssd = 0;
196
+ for (int c = 0; c < 3; ++c) {
197
+ int s_value = p_si[xxs * 3 + c];
198
+ int t_value = p_ti[xxt * 3 + c];
199
+ int s_gy = p_sgy[xxs * 3 + c];
200
+ int t_gy = p_tgy[xxt * 3 + c];
201
+ int s_gx = p_sgx[xxs * 3 + c];
202
+ int t_gx = p_tgx[xxt * 3 + c];
203
+
204
+ ssd += pow2(static_cast<int>(s_value) - t_value);
205
+ ssd += pow2(static_cast<int>(s_gx) - t_gx);
206
+ ssd += pow2(static_cast<int>(s_gy) - t_gy);
207
+ }
208
+ distance += ssd;
209
+ }
210
+ }
211
+
212
+ distance /= (long double)(PatchSSDDistanceMetric::kSSDScale);
213
+
214
+ int res = int(PatchDistanceMetric::kDistanceScale * distance / wsum);
215
+ if (res < 0 || res > PatchDistanceMetric::kDistanceScale) return PatchDistanceMetric::kDistanceScale;
216
+ return res;
217
+ }
218
+
219
+ }
220
+
221
+ int PatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
222
+ return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
223
+ }
224
+
225
+ int DebugPatchSSDDistanceMetric::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
226
+ fprintf(stderr, "DebugPatchSSDDistanceMetric: %d %d %d %d\n", source.size().width, source.size().height, m_width, m_height);
227
+ return distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
228
+ }
229
+
230
+ int RegularityGuidedPatchDistanceMetricV1::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
231
+ double dx = remainder(double(source_x - target_x) / source.size().width, m_dx1);
232
+ double dy = remainder(double(source_y - target_y) / source.size().height, m_dy2);
233
+
234
+ double score1 = sqrt(dx * dx + dy *dy) / m_scale;
235
+ if (score1 < 0 || score1 > 1) score1 = 1;
236
+ score1 *= PatchDistanceMetric::kDistanceScale;
237
+
238
+ double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
239
+ double score = score1 * m_weight + score2 / (1 + m_weight);
240
+ return static_cast<int>(score / (1 + m_weight));
241
+ }
242
+
243
+ int RegularityGuidedPatchDistanceMetricV2::operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const {
244
+ if (target_y < 0 || target_y >= target.size().height || target_x < 0 || target_x >= target.size().width)
245
+ return PatchDistanceMetric::kDistanceScale;
246
+
247
+ int source_scale = m_ijmap.size().height / source.size().height;
248
+ int target_scale = m_ijmap.size().height / target.size().height;
249
+
250
+ // fprintf(stderr, "RegularityGuidedPatchDistanceMetricV2 %d %d %d %d\n", source_y * source_scale, m_ijmap.size().height, source_x * source_scale, m_ijmap.size().width);
251
+
252
+ double score1 = PatchDistanceMetric::kDistanceScale;
253
+ if (!source.is_globally_masked(source_y, source_x) && !target.is_globally_masked(target_y, target_x)) {
254
+ auto source_ij = m_ijmap.ptr<float>(source_y * source_scale, source_x * source_scale);
255
+ auto target_ij = m_ijmap.ptr<float>(target_y * target_scale, target_x * target_scale);
256
+
257
+ float di = fabs(source_ij[0] - target_ij[0]); if (di > 0.5) di = 1 - di;
258
+ float dj = fabs(source_ij[1] - target_ij[1]); if (dj > 0.5) dj = 1 - dj;
259
+ score1 = sqrt(di * di + dj *dj) / 0.707;
260
+ if (score1 < 0 || score1 > 1) score1 = 1;
261
+ score1 *= PatchDistanceMetric::kDistanceScale;
262
+ }
263
+
264
+ double score2 = distance_masked_images(source, source_y, source_x, target, target_y, target_x, m_patch_size);
265
+ double score = score1 * m_weight + score2;
266
+ return int(score / (1 + m_weight));
267
+ }
268
+
PyPatchMatch/csrc/nnf.h CHANGED
@@ -1,133 +1,133 @@
1
- #pragma once
2
-
3
- #include <opencv2/core.hpp>
4
- #include "masked_image.h"
5
-
6
- class PatchDistanceMetric {
7
- public:
8
- PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
9
- virtual ~PatchDistanceMetric() = default;
10
-
11
- inline int patch_size() const { return m_patch_size; }
12
- virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
13
- static const int kDistanceScale;
14
-
15
- protected:
16
- int m_patch_size;
17
- };
18
-
19
- class NearestNeighborField {
20
- public:
21
- NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
22
- // pass
23
- }
24
- NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
25
- : m_source(source), m_target(target), m_distance_metric(metric) {
26
- m_field = cv::Mat(m_source.size(), CV_32SC3);
27
- _randomize_field(max_retry);
28
- }
29
- NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
30
- : m_source(source), m_target(target), m_distance_metric(metric) {
31
- m_field = cv::Mat(m_source.size(), CV_32SC3);
32
- _initialize_field_from(other, max_retry);
33
- }
34
-
35
- const MaskedImage &source() const {
36
- return m_source;
37
- }
38
- const MaskedImage &target() const {
39
- return m_target;
40
- }
41
- inline cv::Size source_size() const {
42
- return m_source.size();
43
- }
44
- inline cv::Size target_size() const {
45
- return m_target.size();
46
- }
47
- inline void set_source(const MaskedImage &source) {
48
- m_source = source;
49
- }
50
- inline void set_target(const MaskedImage &target) {
51
- m_target = target;
52
- }
53
-
54
- inline int *mutable_ptr(int y, int x) {
55
- return m_field.ptr<int>(y, x);
56
- }
57
- inline const int *ptr(int y, int x) const {
58
- return m_field.ptr<int>(y, x);
59
- }
60
-
61
- inline int at(int y, int x, int c) const {
62
- return m_field.ptr<int>(y, x)[c];
63
- }
64
- inline int &at(int y, int x, int c) {
65
- return m_field.ptr<int>(y, x)[c];
66
- }
67
- inline void set_identity(int y, int x) {
68
- auto ptr = mutable_ptr(y, x);
69
- ptr[0] = y, ptr[1] = x, ptr[2] = 0;
70
- }
71
-
72
- void minimize(int nr_pass);
73
-
74
- private:
75
- inline int _distance(int source_y, int source_x, int target_y, int target_x) {
76
- return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
77
- }
78
-
79
- void _randomize_field(int max_retry = 20, bool reset = true);
80
- void _initialize_field_from(const NearestNeighborField &other, int max_retry);
81
- void _minimize_link(int y, int x, int direction);
82
-
83
- MaskedImage m_source;
84
- MaskedImage m_target;
85
- cv::Mat m_field; // { y_target, x_target, distance_scaled }
86
- const PatchDistanceMetric *m_distance_metric;
87
- };
88
-
89
-
90
- class PatchSSDDistanceMetric : public PatchDistanceMetric {
91
- public:
92
- using PatchDistanceMetric::PatchDistanceMetric;
93
- virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
94
- static const int kSSDScale;
95
- };
96
-
97
- class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
98
- public:
99
- DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
100
- virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
101
- protected:
102
- int m_width, m_height;
103
- };
104
-
105
- class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
106
- public:
107
- RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
108
- : PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
109
-
110
- assert(m_dy1 == 0);
111
- assert(m_dx2 == 0);
112
- m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
113
- }
114
- virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
115
-
116
- protected:
117
- double m_dx1, m_dy1, m_dx2, m_dy2;
118
- double m_scale, m_weight;
119
- };
120
-
121
- class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
122
- public:
123
- RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
124
- : PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
125
-
126
- }
127
- virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
128
-
129
- protected:
130
- cv::Mat m_ijmap;
131
- double m_width, m_height, m_weight;
132
- };
133
-
 
1
+ #pragma once
2
+
3
+ #include <opencv2/core.hpp>
4
+ #include "masked_image.h"
5
+
6
+ class PatchDistanceMetric {
7
+ public:
8
+ PatchDistanceMetric(int patch_size) : m_patch_size(patch_size) {}
9
+ virtual ~PatchDistanceMetric() = default;
10
+
11
+ inline int patch_size() const { return m_patch_size; }
12
+ virtual int operator()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const = 0;
13
+ static const int kDistanceScale;
14
+
15
+ protected:
16
+ int m_patch_size;
17
+ };
18
+
19
+ class NearestNeighborField {
20
+ public:
21
+ NearestNeighborField() : m_source(), m_target(), m_field(), m_distance_metric(nullptr) {
22
+ // pass
23
+ }
24
+ NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, int max_retry = 20)
25
+ : m_source(source), m_target(target), m_distance_metric(metric) {
26
+ m_field = cv::Mat(m_source.size(), CV_32SC3);
27
+ _randomize_field(max_retry);
28
+ }
29
+ NearestNeighborField(const MaskedImage &source, const MaskedImage &target, const PatchDistanceMetric *metric, const NearestNeighborField &other, int max_retry = 20)
30
+ : m_source(source), m_target(target), m_distance_metric(metric) {
31
+ m_field = cv::Mat(m_source.size(), CV_32SC3);
32
+ _initialize_field_from(other, max_retry);
33
+ }
34
+
35
+ const MaskedImage &source() const {
36
+ return m_source;
37
+ }
38
+ const MaskedImage &target() const {
39
+ return m_target;
40
+ }
41
+ inline cv::Size source_size() const {
42
+ return m_source.size();
43
+ }
44
+ inline cv::Size target_size() const {
45
+ return m_target.size();
46
+ }
47
+ inline void set_source(const MaskedImage &source) {
48
+ m_source = source;
49
+ }
50
+ inline void set_target(const MaskedImage &target) {
51
+ m_target = target;
52
+ }
53
+
54
+ inline int *mutable_ptr(int y, int x) {
55
+ return m_field.ptr<int>(y, x);
56
+ }
57
+ inline const int *ptr(int y, int x) const {
58
+ return m_field.ptr<int>(y, x);
59
+ }
60
+
61
+ inline int at(int y, int x, int c) const {
62
+ return m_field.ptr<int>(y, x)[c];
63
+ }
64
+ inline int &at(int y, int x, int c) {
65
+ return m_field.ptr<int>(y, x)[c];
66
+ }
67
+ inline void set_identity(int y, int x) {
68
+ auto ptr = mutable_ptr(y, x);
69
+ ptr[0] = y, ptr[1] = x, ptr[2] = 0;
70
+ }
71
+
72
+ void minimize(int nr_pass);
73
+
74
+ private:
75
+ inline int _distance(int source_y, int source_x, int target_y, int target_x) {
76
+ return (*m_distance_metric)(m_source, source_y, source_x, m_target, target_y, target_x);
77
+ }
78
+
79
+ void _randomize_field(int max_retry = 20, bool reset = true);
80
+ void _initialize_field_from(const NearestNeighborField &other, int max_retry);
81
+ void _minimize_link(int y, int x, int direction);
82
+
83
+ MaskedImage m_source;
84
+ MaskedImage m_target;
85
+ cv::Mat m_field; // { y_target, x_target, distance_scaled }
86
+ const PatchDistanceMetric *m_distance_metric;
87
+ };
88
+
89
+
90
+ class PatchSSDDistanceMetric : public PatchDistanceMetric {
91
+ public:
92
+ using PatchDistanceMetric::PatchDistanceMetric;
93
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
94
+ static const int kSSDScale;
95
+ };
96
+
97
+ class DebugPatchSSDDistanceMetric : public PatchDistanceMetric {
98
+ public:
99
+ DebugPatchSSDDistanceMetric(int patch_size, int width, int height) : PatchDistanceMetric(patch_size), m_width(width), m_height(height) {}
100
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
101
+ protected:
102
+ int m_width, m_height;
103
+ };
104
+
105
+ class RegularityGuidedPatchDistanceMetricV1 : public PatchDistanceMetric {
106
+ public:
107
+ RegularityGuidedPatchDistanceMetricV1(int patch_size, double dx1, double dy1, double dx2, double dy2, double weight)
108
+ : PatchDistanceMetric(patch_size), m_dx1(dx1), m_dy1(dy1), m_dx2(dx2), m_dy2(dy2), m_weight(weight) {
109
+
110
+ assert(m_dy1 == 0);
111
+ assert(m_dx2 == 0);
112
+ m_scale = sqrt(m_dx1 * m_dx1 + m_dy2 * m_dy2) / 4;
113
+ }
114
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
115
+
116
+ protected:
117
+ double m_dx1, m_dy1, m_dx2, m_dy2;
118
+ double m_scale, m_weight;
119
+ };
120
+
121
+ class RegularityGuidedPatchDistanceMetricV2 : public PatchDistanceMetric {
122
+ public:
123
+ RegularityGuidedPatchDistanceMetricV2(int patch_size, cv::Mat ijmap, double weight)
124
+ : PatchDistanceMetric(patch_size), m_ijmap(ijmap), m_weight(weight) {
125
+
126
+ }
127
+ virtual int operator ()(const MaskedImage &source, int source_y, int source_x, const MaskedImage &target, int target_y, int target_x) const;
128
+
129
+ protected:
130
+ cv::Mat m_ijmap;
131
+ double m_width, m_height, m_weight;
132
+ };
133
+
PyPatchMatch/csrc/pyinterface.cpp CHANGED
@@ -1,107 +1,107 @@
1
- #include "pyinterface.h"
2
- #include "inpaint.h"
3
-
4
- static unsigned int PM_seed = 1212;
5
- static bool PM_verbose = false;
6
-
7
- int _dtype_py_to_cv(int dtype_py);
8
- int _dtype_cv_to_py(int dtype_cv);
9
- cv::Mat _py_to_cv2(PM_mat_t pymat);
10
- PM_mat_t _cv2_to_py(cv::Mat cvmat);
11
-
12
- void PM_set_random_seed(unsigned int seed) {
13
- PM_seed = seed;
14
- }
15
-
16
- void PM_set_verbose(int value) {
17
- PM_verbose = static_cast<bool>(value);
18
- }
19
-
20
- void PM_free_pymat(PM_mat_t pymat) {
21
- free(pymat.data_ptr);
22
- }
23
-
24
- PM_mat_t PM_inpaint(PM_mat_t source_py, PM_mat_t mask_py, int patch_size) {
25
- cv::Mat source = _py_to_cv2(source_py);
26
- cv::Mat mask = _py_to_cv2(mask_py);
27
- auto metric = PatchSSDDistanceMetric(patch_size);
28
- cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
29
- return _cv2_to_py(result);
30
- }
31
-
32
- PM_mat_t PM_inpaint_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
33
- cv::Mat source = _py_to_cv2(source_py);
34
- cv::Mat mask = _py_to_cv2(mask_py);
35
- cv::Mat ijmap = _py_to_cv2(ijmap_py);
36
-
37
- auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
38
- cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
39
- return _cv2_to_py(result);
40
- }
41
-
42
- PM_mat_t PM_inpaint2(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, int patch_size) {
43
- cv::Mat source = _py_to_cv2(source_py);
44
- cv::Mat mask = _py_to_cv2(mask_py);
45
- cv::Mat global_mask = _py_to_cv2(global_mask_py);
46
-
47
- auto metric = PatchSSDDistanceMetric(patch_size);
48
- cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
49
- return _cv2_to_py(result);
50
- }
51
-
52
- PM_mat_t PM_inpaint2_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
53
- cv::Mat source = _py_to_cv2(source_py);
54
- cv::Mat mask = _py_to_cv2(mask_py);
55
- cv::Mat global_mask = _py_to_cv2(global_mask_py);
56
- cv::Mat ijmap = _py_to_cv2(ijmap_py);
57
-
58
- auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
59
- cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
60
- return _cv2_to_py(result);
61
- }
62
-
63
- int _dtype_py_to_cv(int dtype_py) {
64
- switch (dtype_py) {
65
- case PM_UINT8: return CV_8U;
66
- case PM_INT8: return CV_8S;
67
- case PM_UINT16: return CV_16U;
68
- case PM_INT16: return CV_16S;
69
- case PM_INT32: return CV_32S;
70
- case PM_FLOAT32: return CV_32F;
71
- case PM_FLOAT64: return CV_64F;
72
- }
73
-
74
- return CV_8U;
75
- }
76
-
77
- int _dtype_cv_to_py(int dtype_cv) {
78
- switch (dtype_cv) {
79
- case CV_8U: return PM_UINT8;
80
- case CV_8S: return PM_INT8;
81
- case CV_16U: return PM_UINT16;
82
- case CV_16S: return PM_INT16;
83
- case CV_32S: return PM_INT32;
84
- case CV_32F: return PM_FLOAT32;
85
- case CV_64F: return PM_FLOAT64;
86
- }
87
-
88
- return PM_UINT8;
89
- }
90
-
91
- cv::Mat _py_to_cv2(PM_mat_t pymat) {
92
- int dtype = _dtype_py_to_cv(pymat.dtype);
93
- dtype = CV_MAKETYPE(pymat.dtype, pymat.shape.channels);
94
- return cv::Mat(cv::Size(pymat.shape.width, pymat.shape.height), dtype, pymat.data_ptr).clone();
95
- }
96
-
97
- PM_mat_t _cv2_to_py(cv::Mat cvmat) {
98
- PM_shape_t shape = {cvmat.size().width, cvmat.size().height, cvmat.channels()};
99
- int dtype = _dtype_cv_to_py(cvmat.depth());
100
- size_t dsize = cvmat.total() * cvmat.elemSize();
101
-
102
- void *data_ptr = reinterpret_cast<void *>(malloc(dsize));
103
- memcpy(data_ptr, reinterpret_cast<void *>(cvmat.data), dsize);
104
-
105
- return PM_mat_t {data_ptr, shape, dtype};
106
- }
107
-
 
1
+ #include "pyinterface.h"
2
+ #include "inpaint.h"
3
+
4
+ static unsigned int PM_seed = 1212;
5
+ static bool PM_verbose = false;
6
+
7
+ int _dtype_py_to_cv(int dtype_py);
8
+ int _dtype_cv_to_py(int dtype_cv);
9
+ cv::Mat _py_to_cv2(PM_mat_t pymat);
10
+ PM_mat_t _cv2_to_py(cv::Mat cvmat);
11
+
12
+ void PM_set_random_seed(unsigned int seed) {
13
+ PM_seed = seed;
14
+ }
15
+
16
+ void PM_set_verbose(int value) {
17
+ PM_verbose = static_cast<bool>(value);
18
+ }
19
+
20
+ void PM_free_pymat(PM_mat_t pymat) {
21
+ free(pymat.data_ptr);
22
+ }
23
+
24
+ PM_mat_t PM_inpaint(PM_mat_t source_py, PM_mat_t mask_py, int patch_size) {
25
+ cv::Mat source = _py_to_cv2(source_py);
26
+ cv::Mat mask = _py_to_cv2(mask_py);
27
+ auto metric = PatchSSDDistanceMetric(patch_size);
28
+ cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
29
+ return _cv2_to_py(result);
30
+ }
31
+
32
+ PM_mat_t PM_inpaint_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
33
+ cv::Mat source = _py_to_cv2(source_py);
34
+ cv::Mat mask = _py_to_cv2(mask_py);
35
+ cv::Mat ijmap = _py_to_cv2(ijmap_py);
36
+
37
+ auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
38
+ cv::Mat result = Inpainting(source, mask, &metric).run(PM_verbose, false, PM_seed);
39
+ return _cv2_to_py(result);
40
+ }
41
+
42
+ PM_mat_t PM_inpaint2(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, int patch_size) {
43
+ cv::Mat source = _py_to_cv2(source_py);
44
+ cv::Mat mask = _py_to_cv2(mask_py);
45
+ cv::Mat global_mask = _py_to_cv2(global_mask_py);
46
+
47
+ auto metric = PatchSSDDistanceMetric(patch_size);
48
+ cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
49
+ return _cv2_to_py(result);
50
+ }
51
+
52
+ PM_mat_t PM_inpaint2_regularity(PM_mat_t source_py, PM_mat_t mask_py, PM_mat_t global_mask_py, PM_mat_t ijmap_py, int patch_size, float guide_weight) {
53
+ cv::Mat source = _py_to_cv2(source_py);
54
+ cv::Mat mask = _py_to_cv2(mask_py);
55
+ cv::Mat global_mask = _py_to_cv2(global_mask_py);
56
+ cv::Mat ijmap = _py_to_cv2(ijmap_py);
57
+
58
+ auto metric = RegularityGuidedPatchDistanceMetricV2(patch_size, ijmap, guide_weight);
59
+ cv::Mat result = Inpainting(source, mask, global_mask, &metric).run(PM_verbose, false, PM_seed);
60
+ return _cv2_to_py(result);
61
+ }
62
+
63
+ int _dtype_py_to_cv(int dtype_py) {
64
+ switch (dtype_py) {
65
+ case PM_UINT8: return CV_8U;
66
+ case PM_INT8: return CV_8S;
67
+ case PM_UINT16: return CV_16U;
68
+ case PM_INT16: return CV_16S;
69
+ case PM_INT32: return CV_32S;
70
+ case PM_FLOAT32: return CV_32F;
71
+ case PM_FLOAT64: return CV_64F;
72
+ }
73
+
74
+ return CV_8U;
75
+ }
76
+
77
+ int _dtype_cv_to_py(int dtype_cv) {
78
+ switch (dtype_cv) {
79
+ case CV_8U: return PM_UINT8;
80
+ case CV_8S: return PM_INT8;
81
+ case CV_16U: return PM_UINT16;
82
+ case CV_16S: return PM_INT16;
83
+ case CV_32S: return PM_INT32;
84
+ case CV_32F: return PM_FLOAT32;
85
+ case CV_64F: return PM_FLOAT64;
86
+ }
87
+
88
+ return PM_UINT8;
89
+ }
90
+
91
+ cv::Mat _py_to_cv2(PM_mat_t pymat) {
92
+ int dtype = _dtype_py_to_cv(pymat.dtype);
93
+ dtype = CV_MAKETYPE(pymat.dtype, pymat.shape.channels);
94
+ return cv::Mat(cv::Size(pymat.shape.width, pymat.shape.height), dtype, pymat.data_ptr).clone();
95
+ }
96
+
97
+ PM_mat_t _cv2_to_py(cv::Mat cvmat) {
98
+ PM_shape_t shape = {cvmat.size().width, cvmat.size().height, cvmat.channels()};
99
+ int dtype = _dtype_cv_to_py(cvmat.depth());
100
+ size_t dsize = cvmat.total() * cvmat.elemSize();
101
+
102
+ void *data_ptr = reinterpret_cast<void *>(malloc(dsize));
103
+ memcpy(data_ptr, reinterpret_cast<void *>(cvmat.data), dsize);
104
+
105
+ return PM_mat_t {data_ptr, shape, dtype};
106
+ }
107
+
PyPatchMatch/csrc/pyinterface.h CHANGED
@@ -1,38 +1,38 @@
1
- #include <opencv2/core.hpp>
2
- #include <cstdlib>
3
- #include <cstdio>
4
- #include <cstring>
5
-
6
- extern "C" {
7
-
8
- struct PM_shape_t {
9
- int width, height, channels;
10
- };
11
-
12
- enum PM_dtype_e {
13
- PM_UINT8,
14
- PM_INT8,
15
- PM_UINT16,
16
- PM_INT16,
17
- PM_INT32,
18
- PM_FLOAT32,
19
- PM_FLOAT64,
20
- };
21
-
22
- struct PM_mat_t {
23
- void *data_ptr;
24
- PM_shape_t shape;
25
- int dtype;
26
- };
27
-
28
- void PM_set_random_seed(unsigned int seed);
29
- void PM_set_verbose(int value);
30
-
31
- void PM_free_pymat(PM_mat_t pymat);
32
- PM_mat_t PM_inpaint(PM_mat_t image, PM_mat_t mask, int patch_size);
33
- PM_mat_t PM_inpaint_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t ijmap, int patch_size, float guide_weight);
34
- PM_mat_t PM_inpaint2(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, int patch_size);
35
- PM_mat_t PM_inpaint2_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, PM_mat_t ijmap, int patch_size, float guide_weight);
36
-
37
- } /* extern "C" */
38
-
 
1
+ #include <opencv2/core.hpp>
2
+ #include <cstdlib>
3
+ #include <cstdio>
4
+ #include <cstring>
5
+
6
+ extern "C" {
7
+
8
+ struct PM_shape_t {
9
+ int width, height, channels;
10
+ };
11
+
12
+ enum PM_dtype_e {
13
+ PM_UINT8,
14
+ PM_INT8,
15
+ PM_UINT16,
16
+ PM_INT16,
17
+ PM_INT32,
18
+ PM_FLOAT32,
19
+ PM_FLOAT64,
20
+ };
21
+
22
+ struct PM_mat_t {
23
+ void *data_ptr;
24
+ PM_shape_t shape;
25
+ int dtype;
26
+ };
27
+
28
+ void PM_set_random_seed(unsigned int seed);
29
+ void PM_set_verbose(int value);
30
+
31
+ void PM_free_pymat(PM_mat_t pymat);
32
+ PM_mat_t PM_inpaint(PM_mat_t image, PM_mat_t mask, int patch_size);
33
+ PM_mat_t PM_inpaint_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t ijmap, int patch_size, float guide_weight);
34
+ PM_mat_t PM_inpaint2(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, int patch_size);
35
+ PM_mat_t PM_inpaint2_regularity(PM_mat_t image, PM_mat_t mask, PM_mat_t global_mask, PM_mat_t ijmap, int patch_size, float guide_weight);
36
+
37
+ } /* extern "C" */
38
+
PyPatchMatch/examples/.gitignore CHANGED
@@ -1,2 +1,2 @@
1
- /cpp_example.exe
2
- /images/*recovered.bmp
 
1
+ /cpp_example.exe
2
+ /images/*recovered.bmp
PyPatchMatch/examples/cpp_example.cpp CHANGED
@@ -1,31 +1,31 @@
1
- #include <iostream>
2
- #include <opencv2/imgcodecs.hpp>
3
- #include <opencv2/highgui.hpp>
4
-
5
- #include "masked_image.h"
6
- #include "nnf.h"
7
- #include "inpaint.h"
8
-
9
- int main() {
10
- auto source = cv::imread("./images/forest_pruned.bmp", cv::IMREAD_COLOR);
11
-
12
- auto mask = cv::Mat(source.size(), CV_8UC1);
13
- mask = cv::Scalar::all(0);
14
- for (int i = 0; i < source.size().height; ++i) {
15
- for (int j = 0; j < source.size().width; ++j) {
16
- auto source_ptr = source.ptr<unsigned char>(i, j);
17
- if (source_ptr[0] == 255 && source_ptr[1] == 255 && source_ptr[2] == 255) {
18
- mask.at<unsigned char>(i, j) = 1;
19
- }
20
- }
21
- }
22
-
23
- auto metric = PatchSSDDistanceMetric(3);
24
- auto result = Inpainting(source, mask, &metric).run(true, true);
25
- // cv::imwrite("./images/forest_recovered.bmp", result);
26
- // cv::imshow("Result", result);
27
- // cv::waitKey();
28
-
29
- return 0;
30
- }
31
-
 
1
+ #include <iostream>
2
+ #include <opencv2/imgcodecs.hpp>
3
+ #include <opencv2/highgui.hpp>
4
+
5
+ #include "masked_image.h"
6
+ #include "nnf.h"
7
+ #include "inpaint.h"
8
+
9
+ int main() {
10
+ auto source = cv::imread("./images/forest_pruned.bmp", cv::IMREAD_COLOR);
11
+
12
+ auto mask = cv::Mat(source.size(), CV_8UC1);
13
+ mask = cv::Scalar::all(0);
14
+ for (int i = 0; i < source.size().height; ++i) {
15
+ for (int j = 0; j < source.size().width; ++j) {
16
+ auto source_ptr = source.ptr<unsigned char>(i, j);
17
+ if (source_ptr[0] == 255 && source_ptr[1] == 255 && source_ptr[2] == 255) {
18
+ mask.at<unsigned char>(i, j) = 1;
19
+ }
20
+ }
21
+ }
22
+
23
+ auto metric = PatchSSDDistanceMetric(3);
24
+ auto result = Inpainting(source, mask, &metric).run(true, true);
25
+ // cv::imwrite("./images/forest_recovered.bmp", result);
26
+ // cv::imshow("Result", result);
27
+ // cv::waitKey();
28
+
29
+ return 0;
30
+ }
31
+
PyPatchMatch/examples/cpp_example_run.sh CHANGED
@@ -1,18 +1,18 @@
1
- #! /bin/bash
2
- #
3
- # cpp_example_run.sh
4
- # Copyright (C) 2020 Jiayuan Mao <[email protected]>
5
- #
6
- # Distributed under terms of the MIT license.
7
- #
8
-
9
- set -x
10
-
11
- CFLAGS="-std=c++14 -O2 $(pkg-config --cflags opencv)"
12
- LDFLAGS="$(pkg-config --libs opencv)"
13
- g++ $CFLAGS cpp_example.cpp -I../csrc/ -L../ -lpatchmatch $LDFLAGS -o cpp_example.exe
14
-
15
- export DYLD_LIBRARY_PATH=../:$DYLD_LIBRARY_PATH # For macOS
16
- export LD_LIBRARY_PATH=../:$LD_LIBRARY_PATH # For Linux
17
- time ./cpp_example.exe
18
-
 
1
+ #! /bin/bash
2
+ #
3
+ # cpp_example_run.sh
4
+ # Copyright (C) 2020 Jiayuan Mao <[email protected]>
5
+ #
6
+ # Distributed under terms of the MIT license.
7
+ #
8
+
9
+ set -x
10
+
11
+ CFLAGS="-std=c++14 -O2 $(pkg-config --cflags opencv)"
12
+ LDFLAGS="$(pkg-config --libs opencv)"
13
+ g++ $CFLAGS cpp_example.cpp -I../csrc/ -L../ -lpatchmatch $LDFLAGS -o cpp_example.exe
14
+
15
+ export DYLD_LIBRARY_PATH=../:$DYLD_LIBRARY_PATH # For macOS
16
+ export LD_LIBRARY_PATH=../:$LD_LIBRARY_PATH # For Linux
17
+ time ./cpp_example.exe
18
+
PyPatchMatch/examples/py_example.py CHANGED
@@ -1,21 +1,21 @@
1
- #! /usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # File : test.py
4
- # Author : Jiayuan Mao
5
- # Email : [email protected]
6
- # Date : 01/09/2020
7
- #
8
- # Distributed under terms of the MIT license.
9
-
10
- from PIL import Image
11
-
12
- import sys
13
- sys.path.insert(0, '../')
14
- import patch_match
15
-
16
-
17
- if __name__ == '__main__':
18
- source = Image.open('./images/forest_pruned.bmp')
19
- result = patch_match.inpaint(source, patch_size=3)
20
- Image.fromarray(result).save('./images/forest_recovered.bmp')
21
-
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ from PIL import Image
11
+
12
+ import sys
13
+ sys.path.insert(0, '../')
14
+ import patch_match
15
+
16
+
17
+ if __name__ == '__main__':
18
+ source = Image.open('./images/forest_pruned.bmp')
19
+ result = patch_match.inpaint(source, patch_size=3)
20
+ Image.fromarray(result).save('./images/forest_recovered.bmp')
21
+
PyPatchMatch/examples/py_example_global_mask.py CHANGED
@@ -1,27 +1,27 @@
1
- #! /usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # File : test.py
4
- # Author : Jiayuan Mao
5
- # Email : [email protected]
6
- # Date : 01/09/2020
7
- #
8
- # Distributed under terms of the MIT license.
9
-
10
- import numpy as np
11
- from PIL import Image
12
-
13
- import sys
14
- sys.path.insert(0, '../')
15
- import patch_match
16
-
17
-
18
- if __name__ == '__main__':
19
- patch_match.set_verbose(True)
20
- source = Image.open('./images/forest_pruned.bmp')
21
- source = np.array(source)
22
- source[:100, :100] = 255
23
- global_mask = np.zeros_like(source[..., 0])
24
- global_mask[:100, :100] = 1
25
- result = patch_match.inpaint(source, global_mask=global_mask, patch_size=3)
26
- Image.fromarray(result).save('./images/forest_recovered.bmp')
27
-
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import numpy as np
11
+ from PIL import Image
12
+
13
+ import sys
14
+ sys.path.insert(0, '../')
15
+ import patch_match
16
+
17
+
18
+ if __name__ == '__main__':
19
+ patch_match.set_verbose(True)
20
+ source = Image.open('./images/forest_pruned.bmp')
21
+ source = np.array(source)
22
+ source[:100, :100] = 255
23
+ global_mask = np.zeros_like(source[..., 0])
24
+ global_mask[:100, :100] = 1
25
+ result = patch_match.inpaint(source, global_mask=global_mask, patch_size=3)
26
+ Image.fromarray(result).save('./images/forest_recovered.bmp')
27
+
PyPatchMatch/patch_match.py CHANGED
@@ -1,201 +1,263 @@
1
- #! /usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
- # File : patch_match.py
4
- # Author : Jiayuan Mao
5
- # Email : [email protected]
6
- # Date : 01/09/2020
7
- #
8
- # Distributed under terms of the MIT license.
9
-
10
- import ctypes
11
- import os.path as osp
12
- from typing import Optional, Union
13
-
14
- import numpy as np
15
- from PIL import Image
16
-
17
- try:
18
- # If the Jacinle library (https://github.com/vacancy/Jacinle) is present, use its auto_travis feature.
19
- from jacinle.jit.cext import auto_travis
20
- auto_travis(__file__, required_files=['*.so'])
21
- except ImportError as e:
22
- # Otherwise, fall back to the subprocess.
23
- import subprocess
24
- print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
25
- subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True)
26
-
27
-
28
- __all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
29
-
30
-
31
- class CShapeT(ctypes.Structure):
32
- _fields_ = [
33
- ('width', ctypes.c_int),
34
- ('height', ctypes.c_int),
35
- ('channels', ctypes.c_int),
36
- ]
37
-
38
-
39
- class CMatT(ctypes.Structure):
40
- _fields_ = [
41
- ('data_ptr', ctypes.c_void_p),
42
- ('shape', CShapeT),
43
- ('dtype', ctypes.c_int)
44
- ]
45
-
46
-
47
- PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so'))
48
-
49
- PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
50
- PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
51
- PMLIB.PM_free_pymat.argtypes = [CMatT]
52
- PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
53
- PMLIB.PM_inpaint.restype = CMatT
54
- PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
55
- PMLIB.PM_inpaint_regularity.restype = CMatT
56
- PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
57
- PMLIB.PM_inpaint2.restype = CMatT
58
- PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
59
- PMLIB.PM_inpaint2_regularity.restype = CMatT
60
-
61
-
62
- def set_random_seed(seed: int):
63
- PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
64
-
65
-
66
- def set_verbose(verbose: bool):
67
- PMLIB.PM_set_verbose(ctypes.c_int(verbose))
68
-
69
-
70
- def inpaint(
71
- image: Union[np.ndarray, Image.Image],
72
- mask: Optional[Union[np.ndarray, Image.Image]] = None,
73
- *,
74
- global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
75
- patch_size: int = 15
76
- ) -> np.ndarray:
77
- """
78
- PatchMatch based inpainting proposed in:
79
-
80
- PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
81
- C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
82
- SIGGRAPH 2009
83
-
84
- Args:
85
- image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
86
- mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
87
- If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
88
- global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
89
- patch_size (int): the patch size for the inpainting algorithm.
90
-
91
- Return:
92
- result (np.ndarray): the repaired image, of the same size as the input image.
93
- """
94
-
95
- if isinstance(image, Image.Image):
96
- image = np.array(image)
97
- image = np.ascontiguousarray(image)
98
- assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
99
-
100
- if mask is None:
101
- mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
102
- mask = np.ascontiguousarray(mask)
103
- else:
104
- mask = _canonize_mask_array(mask)
105
-
106
- if global_mask is None:
107
- ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
108
- else:
109
- global_mask = _canonize_mask_array(global_mask)
110
- ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
111
-
112
- ret_npmat = pymat_to_np(ret_pymat)
113
- PMLIB.PM_free_pymat(ret_pymat)
114
-
115
- return ret_npmat
116
-
117
-
118
- def inpaint_regularity(
119
- image: Union[np.ndarray, Image.Image],
120
- mask: Optional[Union[np.ndarray, Image.Image]],
121
- ijmap: np.ndarray,
122
- *,
123
- global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
124
- patch_size: int = 15, guide_weight: float = 0.25
125
- ) -> np.ndarray:
126
- if isinstance(image, Image.Image):
127
- image = np.array(image)
128
- image = np.ascontiguousarray(image)
129
-
130
- assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
131
- ijmap = np.ascontiguousarray(ijmap)
132
-
133
- assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
134
- if mask is None:
135
- mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
136
- mask = np.ascontiguousarray(mask)
137
- else:
138
- mask = _canonize_mask_array(mask)
139
-
140
-
141
- if global_mask is None:
142
- ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
143
- else:
144
- global_mask = _canonize_mask_array(global_mask)
145
- ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
146
-
147
- ret_npmat = pymat_to_np(ret_pymat)
148
- PMLIB.PM_free_pymat(ret_pymat)
149
-
150
- return ret_npmat
151
-
152
-
153
- def _canonize_mask_array(mask):
154
- if isinstance(mask, Image.Image):
155
- mask = np.array(mask)
156
- if mask.ndim == 2 and mask.dtype == 'uint8':
157
- mask = mask[..., np.newaxis]
158
- assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
159
- return np.ascontiguousarray(mask)
160
-
161
-
162
- dtype_pymat_to_ctypes = [
163
- ctypes.c_uint8,
164
- ctypes.c_int8,
165
- ctypes.c_uint16,
166
- ctypes.c_int16,
167
- ctypes.c_int32,
168
- ctypes.c_float,
169
- ctypes.c_double,
170
- ]
171
-
172
-
173
- dtype_np_to_pymat = {
174
- 'uint8': 0,
175
- 'int8': 1,
176
- 'uint16': 2,
177
- 'int16': 3,
178
- 'int32': 4,
179
- 'float32': 5,
180
- 'float64': 6,
181
- }
182
-
183
-
184
- def np_to_pymat(npmat):
185
- assert npmat.ndim == 3
186
- return CMatT(
187
- ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
188
- CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
189
- dtype_np_to_pymat[str(npmat.dtype)]
190
- )
191
-
192
-
193
- def pymat_to_np(pymat):
194
- npmat = np.ctypeslib.as_array(
195
- ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
196
- (pymat.shape.height, pymat.shape.width, pymat.shape.channels)
197
- )
198
- ret = np.empty(npmat.shape, npmat.dtype)
199
- ret[:] = npmat
200
- return ret
201
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : patch_match.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 01/09/2020
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ import ctypes
11
+ import os.path as osp
12
+ from typing import Optional, Union
13
+
14
+ import numpy as np
15
+ from PIL import Image
16
+
17
+
18
+ import os
19
+ if os.name!="nt":
20
+ # Otherwise, fall back to the subprocess.
21
+ import subprocess
22
+ print('Compiling and loading c extensions from "{}".'.format(osp.realpath(osp.dirname(__file__))))
23
+ # subprocess.check_call(['./travis.sh'], cwd=osp.dirname(__file__))
24
+ subprocess.check_call("make clean && make", cwd=osp.dirname(__file__), shell=True)
25
+
26
+
27
+ __all__ = ['set_random_seed', 'set_verbose', 'inpaint', 'inpaint_regularity']
28
+
29
+
30
+ class CShapeT(ctypes.Structure):
31
+ _fields_ = [
32
+ ('width', ctypes.c_int),
33
+ ('height', ctypes.c_int),
34
+ ('channels', ctypes.c_int),
35
+ ]
36
+
37
+
38
+ class CMatT(ctypes.Structure):
39
+ _fields_ = [
40
+ ('data_ptr', ctypes.c_void_p),
41
+ ('shape', CShapeT),
42
+ ('dtype', ctypes.c_int)
43
+ ]
44
+
45
+ import tempfile
46
+ from urllib.request import urlopen, Request
47
+ import shutil
48
+ from pathlib import Path
49
+ from tqdm import tqdm
50
+
51
+ def download_url_to_file(url, dst, hash_prefix=None, progress=True):
52
+ r"""Download object at the given URL to a local path.
53
+
54
+ Args:
55
+ url (string): URL of the object to download
56
+ dst (string): Full path where object will be saved, e.g. ``/tmp/temporary_file``
57
+ hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with ``hash_prefix``.
58
+ Default: None
59
+ progress (bool, optional): whether or not to display a progress bar to stderr
60
+ Default: True
61
+ https://pytorch.org/docs/stable/_modules/torch/hub.html#load_state_dict_from_url
62
+ """
63
+ file_size = None
64
+ req = Request(url)
65
+ u = urlopen(req)
66
+ meta = u.info()
67
+ if hasattr(meta, 'getheaders'):
68
+ content_length = meta.getheaders("Content-Length")
69
+ else:
70
+ content_length = meta.get_all("Content-Length")
71
+ if content_length is not None and len(content_length) > 0:
72
+ file_size = int(content_length[0])
73
+
74
+ # We deliberately save it in a temp file and move it after
75
+ # download is complete. This prevents a local working checkpoint
76
+ # being overridden by a broken download.
77
+ dst = os.path.expanduser(dst)
78
+ dst_dir = os.path.dirname(dst)
79
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
80
+
81
+ try:
82
+ with tqdm(total=file_size, disable=not progress,
83
+ unit='B', unit_scale=True, unit_divisor=1024) as pbar:
84
+ while True:
85
+ buffer = u.read(8192)
86
+ if len(buffer) == 0:
87
+ break
88
+ f.write(buffer)
89
+ pbar.update(len(buffer))
90
+
91
+ f.close()
92
+ shutil.move(f.name, dst)
93
+ finally:
94
+ f.close()
95
+ if os.path.exists(f.name):
96
+ os.remove(f.name)
97
+
98
+ if os.name!="nt":
99
+ PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.so'))
100
+ else:
101
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
102
+ download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll",dst=osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
103
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
104
+ download_url_to_file(url="https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll",dst=osp.join(osp.dirname(__file__), 'opencv_world460.dll'))
105
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'libpatchmatch.dll')):
106
+ print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/libpatchmatch.dll and put it into the PyPatchMatch folder")
107
+ if not os.path.exists(osp.join(osp.dirname(__file__), 'opencv_world460.dll')):
108
+ print("[Dependency Missing] Please download https://github.com/lkwq007/PyPatchMatch/releases/download/v0.1/opencv_world460.dll and put it into the PyPatchMatch folder")
109
+ PMLIB = ctypes.CDLL(osp.join(osp.dirname(__file__), 'libpatchmatch.dll'))
110
+
111
+ PMLIB.PM_set_random_seed.argtypes = [ctypes.c_uint]
112
+ PMLIB.PM_set_verbose.argtypes = [ctypes.c_int]
113
+ PMLIB.PM_free_pymat.argtypes = [CMatT]
114
+ PMLIB.PM_inpaint.argtypes = [CMatT, CMatT, ctypes.c_int]
115
+ PMLIB.PM_inpaint.restype = CMatT
116
+ PMLIB.PM_inpaint_regularity.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
117
+ PMLIB.PM_inpaint_regularity.restype = CMatT
118
+ PMLIB.PM_inpaint2.argtypes = [CMatT, CMatT, CMatT, ctypes.c_int]
119
+ PMLIB.PM_inpaint2.restype = CMatT
120
+ PMLIB.PM_inpaint2_regularity.argtypes = [CMatT, CMatT, CMatT, CMatT, ctypes.c_int, ctypes.c_float]
121
+ PMLIB.PM_inpaint2_regularity.restype = CMatT
122
+
123
+
124
+ def set_random_seed(seed: int):
125
+ PMLIB.PM_set_random_seed(ctypes.c_uint(seed))
126
+
127
+
128
+ def set_verbose(verbose: bool):
129
+ PMLIB.PM_set_verbose(ctypes.c_int(verbose))
130
+
131
+
132
+ def inpaint(
133
+ image: Union[np.ndarray, Image.Image],
134
+ mask: Optional[Union[np.ndarray, Image.Image]] = None,
135
+ *,
136
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
137
+ patch_size: int = 15
138
+ ) -> np.ndarray:
139
+ """
140
+ PatchMatch based inpainting proposed in:
141
+
142
+ PatchMatch : A Randomized Correspondence Algorithm for Structural Image Editing
143
+ C.Barnes, E.Shechtman, A.Finkelstein and Dan B.Goldman
144
+ SIGGRAPH 2009
145
+
146
+ Args:
147
+ image (Union[np.ndarray, Image.Image]): the input image, should be 3-channel RGB/BGR.
148
+ mask (Union[np.array, Image.Image], optional): the mask of the hole(s) to be filled, should be 1-channel.
149
+ If not provided (None), the algorithm will treat all purely white pixels as the holes (255, 255, 255).
150
+ global_mask (Union[np.array, Image.Image], optional): the target mask of the output image.
151
+ patch_size (int): the patch size for the inpainting algorithm.
152
+
153
+ Return:
154
+ result (np.ndarray): the repaired image, of the same size as the input image.
155
+ """
156
+
157
+ if isinstance(image, Image.Image):
158
+ image = np.array(image)
159
+ image = np.ascontiguousarray(image)
160
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
161
+
162
+ if mask is None:
163
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
164
+ mask = np.ascontiguousarray(mask)
165
+ else:
166
+ mask = _canonize_mask_array(mask)
167
+
168
+ if global_mask is None:
169
+ ret_pymat = PMLIB.PM_inpaint(np_to_pymat(image), np_to_pymat(mask), ctypes.c_int(patch_size))
170
+ else:
171
+ global_mask = _canonize_mask_array(global_mask)
172
+ ret_pymat = PMLIB.PM_inpaint2(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), ctypes.c_int(patch_size))
173
+
174
+ ret_npmat = pymat_to_np(ret_pymat)
175
+ PMLIB.PM_free_pymat(ret_pymat)
176
+
177
+ return ret_npmat
178
+
179
+
180
+ def inpaint_regularity(
181
+ image: Union[np.ndarray, Image.Image],
182
+ mask: Optional[Union[np.ndarray, Image.Image]],
183
+ ijmap: np.ndarray,
184
+ *,
185
+ global_mask: Optional[Union[np.ndarray, Image.Image]] = None,
186
+ patch_size: int = 15, guide_weight: float = 0.25
187
+ ) -> np.ndarray:
188
+ if isinstance(image, Image.Image):
189
+ image = np.array(image)
190
+ image = np.ascontiguousarray(image)
191
+
192
+ assert isinstance(ijmap, np.ndarray) and ijmap.ndim == 3 and ijmap.shape[2] == 3 and ijmap.dtype == 'float32'
193
+ ijmap = np.ascontiguousarray(ijmap)
194
+
195
+ assert image.ndim == 3 and image.shape[2] == 3 and image.dtype == 'uint8'
196
+ if mask is None:
197
+ mask = (image == (255, 255, 255)).all(axis=2, keepdims=True).astype('uint8')
198
+ mask = np.ascontiguousarray(mask)
199
+ else:
200
+ mask = _canonize_mask_array(mask)
201
+
202
+
203
+ if global_mask is None:
204
+ ret_pymat = PMLIB.PM_inpaint_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
205
+ else:
206
+ global_mask = _canonize_mask_array(global_mask)
207
+ ret_pymat = PMLIB.PM_inpaint2_regularity(np_to_pymat(image), np_to_pymat(mask), np_to_pymat(global_mask), np_to_pymat(ijmap), ctypes.c_int(patch_size), ctypes.c_float(guide_weight))
208
+
209
+ ret_npmat = pymat_to_np(ret_pymat)
210
+ PMLIB.PM_free_pymat(ret_pymat)
211
+
212
+ return ret_npmat
213
+
214
+
215
+ def _canonize_mask_array(mask):
216
+ if isinstance(mask, Image.Image):
217
+ mask = np.array(mask)
218
+ if mask.ndim == 2 and mask.dtype == 'uint8':
219
+ mask = mask[..., np.newaxis]
220
+ assert mask.ndim == 3 and mask.shape[2] == 1 and mask.dtype == 'uint8'
221
+ return np.ascontiguousarray(mask)
222
+
223
+
224
+ dtype_pymat_to_ctypes = [
225
+ ctypes.c_uint8,
226
+ ctypes.c_int8,
227
+ ctypes.c_uint16,
228
+ ctypes.c_int16,
229
+ ctypes.c_int32,
230
+ ctypes.c_float,
231
+ ctypes.c_double,
232
+ ]
233
+
234
+
235
+ dtype_np_to_pymat = {
236
+ 'uint8': 0,
237
+ 'int8': 1,
238
+ 'uint16': 2,
239
+ 'int16': 3,
240
+ 'int32': 4,
241
+ 'float32': 5,
242
+ 'float64': 6,
243
+ }
244
+
245
+
246
+ def np_to_pymat(npmat):
247
+ assert npmat.ndim == 3
248
+ return CMatT(
249
+ ctypes.cast(npmat.ctypes.data, ctypes.c_void_p),
250
+ CShapeT(npmat.shape[1], npmat.shape[0], npmat.shape[2]),
251
+ dtype_np_to_pymat[str(npmat.dtype)]
252
+ )
253
+
254
+
255
+ def pymat_to_np(pymat):
256
+ npmat = np.ctypeslib.as_array(
257
+ ctypes.cast(pymat.data_ptr, ctypes.POINTER(dtype_pymat_to_ctypes[pymat.dtype])),
258
+ (pymat.shape.height, pymat.shape.width, pymat.shape.channels)
259
+ )
260
+ ret = np.empty(npmat.shape, npmat.dtype)
261
+ ret[:] = npmat
262
+ return ret
263
+
PyPatchMatch/travis.sh CHANGED
@@ -1,9 +1,9 @@
1
- #! /bin/bash
2
- #
3
- # travis.sh
4
- # Copyright (C) 2020 Jiayuan Mao <[email protected]>
5
- #
6
- # Distributed under terms of the MIT license.
7
- #
8
-
9
- make clean && make
 
1
+ #! /bin/bash
2
+ #
3
+ # travis.sh
4
+ # Copyright (C) 2020 Jiayuan Mao <[email protected]>
5
+ #
6
+ # Distributed under terms of the MIT license.
7
+ #
8
+
9
+ make clean && make
config.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ shortcut:
2
+ clear: Escape
3
+ load: Ctrl+o
4
+ save: Ctrl+s
5
+ export: Ctrl+e
6
+ upload: Ctrl+u
7
+ selection: 1
8
+ canvas: 2
9
+ eraser: 3
10
+ outpaint: d
11
+ accept: a
12
+ cancel: c
13
+ retry: r
14
+ prev: q
15
+ next: e
16
+ zoom_in: z
17
+ zoom_out: x
18
+ random_seed: s
convert_checkpoint.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py
16
+ """ Conversion script for the LDM checkpoints. """
17
+
18
+ import argparse
19
+ import os
20
+
21
+ import torch
22
+
23
+
24
+ try:
25
+ from omegaconf import OmegaConf
26
+ except ImportError:
27
+ raise ImportError(
28
+ "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
29
+ )
30
+
31
+ from diffusers import (
32
+ AutoencoderKL,
33
+ DDIMScheduler,
34
+ LDMTextToImagePipeline,
35
+ LMSDiscreteScheduler,
36
+ PNDMScheduler,
37
+ StableDiffusionPipeline,
38
+ UNet2DConditionModel,
39
+ )
40
+ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import LDMBertConfig, LDMBertModel
41
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
42
+ from transformers import AutoFeatureExtractor, BertTokenizerFast, CLIPTextModel, CLIPTokenizer
43
+
44
+
45
+ def shave_segments(path, n_shave_prefix_segments=1):
46
+ """
47
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
48
+ """
49
+ if n_shave_prefix_segments >= 0:
50
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
51
+ else:
52
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
53
+
54
+
55
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
56
+ """
57
+ Updates paths inside resnets to the new naming scheme (local renaming)
58
+ """
59
+ mapping = []
60
+ for old_item in old_list:
61
+ new_item = old_item.replace("in_layers.0", "norm1")
62
+ new_item = new_item.replace("in_layers.2", "conv1")
63
+
64
+ new_item = new_item.replace("out_layers.0", "norm2")
65
+ new_item = new_item.replace("out_layers.3", "conv2")
66
+
67
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
68
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
69
+
70
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
71
+
72
+ mapping.append({"old": old_item, "new": new_item})
73
+
74
+ return mapping
75
+
76
+
77
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
78
+ """
79
+ Updates paths inside resnets to the new naming scheme (local renaming)
80
+ """
81
+ mapping = []
82
+ for old_item in old_list:
83
+ new_item = old_item
84
+
85
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
86
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
87
+
88
+ mapping.append({"old": old_item, "new": new_item})
89
+
90
+ return mapping
91
+
92
+
93
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
94
+ """
95
+ Updates paths inside attentions to the new naming scheme (local renaming)
96
+ """
97
+ mapping = []
98
+ for old_item in old_list:
99
+ new_item = old_item
100
+
101
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
102
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
103
+
104
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
105
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
106
+
107
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
108
+
109
+ mapping.append({"old": old_item, "new": new_item})
110
+
111
+ return mapping
112
+
113
+
114
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
115
+ """
116
+ Updates paths inside attentions to the new naming scheme (local renaming)
117
+ """
118
+ mapping = []
119
+ for old_item in old_list:
120
+ new_item = old_item
121
+
122
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
123
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
124
+
125
+ new_item = new_item.replace("q.weight", "query.weight")
126
+ new_item = new_item.replace("q.bias", "query.bias")
127
+
128
+ new_item = new_item.replace("k.weight", "key.weight")
129
+ new_item = new_item.replace("k.bias", "key.bias")
130
+
131
+ new_item = new_item.replace("v.weight", "value.weight")
132
+ new_item = new_item.replace("v.bias", "value.bias")
133
+
134
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
135
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
136
+
137
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
138
+
139
+ mapping.append({"old": old_item, "new": new_item})
140
+
141
+ return mapping
142
+
143
+
144
+ def assign_to_checkpoint(
145
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
146
+ ):
147
+ """
148
+ This does the final conversion step: take locally converted weights and apply a global renaming
149
+ to them. It splits attention layers, and takes into account additional replacements
150
+ that may arise.
151
+
152
+ Assigns the weights to the new checkpoint.
153
+ """
154
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
155
+
156
+ # Splits the attention layers into three variables.
157
+ if attention_paths_to_split is not None:
158
+ for path, path_map in attention_paths_to_split.items():
159
+ old_tensor = old_checkpoint[path]
160
+ channels = old_tensor.shape[0] // 3
161
+
162
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
163
+
164
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
165
+
166
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
167
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
168
+
169
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
170
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
171
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
172
+
173
+ for path in paths:
174
+ new_path = path["new"]
175
+
176
+ # These have already been assigned
177
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
178
+ continue
179
+
180
+ # Global renaming happens here
181
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
182
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
183
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
184
+
185
+ if additional_replacements is not None:
186
+ for replacement in additional_replacements:
187
+ new_path = new_path.replace(replacement["old"], replacement["new"])
188
+
189
+ # proj_attn.weight has to be converted from conv 1D to linear
190
+ if "proj_attn.weight" in new_path:
191
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
192
+ else:
193
+ checkpoint[new_path] = old_checkpoint[path["old"]]
194
+
195
+
196
+ def conv_attn_to_linear(checkpoint):
197
+ keys = list(checkpoint.keys())
198
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
199
+ for key in keys:
200
+ if ".".join(key.split(".")[-2:]) in attn_keys:
201
+ if checkpoint[key].ndim > 2:
202
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
203
+ elif "proj_attn.weight" in key:
204
+ if checkpoint[key].ndim > 2:
205
+ checkpoint[key] = checkpoint[key][:, :, 0]
206
+
207
+
208
+ def create_unet_diffusers_config(original_config):
209
+ """
210
+ Creates a config for the diffusers based on the config of the LDM model.
211
+ """
212
+ unet_params = original_config.model.params.unet_config.params
213
+
214
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
215
+
216
+ down_block_types = []
217
+ resolution = 1
218
+ for i in range(len(block_out_channels)):
219
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
220
+ down_block_types.append(block_type)
221
+ if i != len(block_out_channels) - 1:
222
+ resolution *= 2
223
+
224
+ up_block_types = []
225
+ for i in range(len(block_out_channels)):
226
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
227
+ up_block_types.append(block_type)
228
+ resolution //= 2
229
+
230
+ config = dict(
231
+ sample_size=unet_params.image_size,
232
+ in_channels=unet_params.in_channels,
233
+ out_channels=unet_params.out_channels,
234
+ down_block_types=tuple(down_block_types),
235
+ up_block_types=tuple(up_block_types),
236
+ block_out_channels=tuple(block_out_channels),
237
+ layers_per_block=unet_params.num_res_blocks,
238
+ cross_attention_dim=unet_params.context_dim,
239
+ attention_head_dim=unet_params.num_heads,
240
+ )
241
+
242
+ return config
243
+
244
+
245
+ def create_vae_diffusers_config(original_config):
246
+ """
247
+ Creates a config for the diffusers based on the config of the LDM model.
248
+ """
249
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
250
+ _ = original_config.model.params.first_stage_config.params.embed_dim
251
+
252
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
253
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
254
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
255
+
256
+ config = dict(
257
+ sample_size=vae_params.resolution,
258
+ in_channels=vae_params.in_channels,
259
+ out_channels=vae_params.out_ch,
260
+ down_block_types=tuple(down_block_types),
261
+ up_block_types=tuple(up_block_types),
262
+ block_out_channels=tuple(block_out_channels),
263
+ latent_channels=vae_params.z_channels,
264
+ layers_per_block=vae_params.num_res_blocks,
265
+ )
266
+ return config
267
+
268
+
269
+ def create_diffusers_schedular(original_config):
270
+ schedular = DDIMScheduler(
271
+ num_train_timesteps=original_config.model.params.timesteps,
272
+ beta_start=original_config.model.params.linear_start,
273
+ beta_end=original_config.model.params.linear_end,
274
+ beta_schedule="scaled_linear",
275
+ )
276
+ return schedular
277
+
278
+
279
+ def create_ldm_bert_config(original_config):
280
+ bert_params = original_config.model.parms.cond_stage_config.params
281
+ config = LDMBertConfig(
282
+ d_model=bert_params.n_embed,
283
+ encoder_layers=bert_params.n_layer,
284
+ encoder_ffn_dim=bert_params.n_embed * 4,
285
+ )
286
+ return config
287
+
288
+
289
+ def convert_ldm_unet_checkpoint(checkpoint, config):
290
+ """
291
+ Takes a state dict and a config, and returns a converted checkpoint.
292
+ """
293
+
294
+ # extract state_dict for UNet
295
+ unet_state_dict = {}
296
+ unet_key = "model.diffusion_model."
297
+ keys = list(checkpoint.keys())
298
+ for key in keys:
299
+ if key.startswith(unet_key):
300
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
301
+
302
+ new_checkpoint = {}
303
+
304
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
305
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
306
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
307
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
308
+
309
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
310
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
311
+
312
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
313
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
314
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
315
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
316
+
317
+ # Retrieves the keys for the input blocks only
318
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
319
+ input_blocks = {
320
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
321
+ for layer_id in range(num_input_blocks)
322
+ }
323
+
324
+ # Retrieves the keys for the middle blocks only
325
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
326
+ middle_blocks = {
327
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
328
+ for layer_id in range(num_middle_blocks)
329
+ }
330
+
331
+ # Retrieves the keys for the output blocks only
332
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
333
+ output_blocks = {
334
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
335
+ for layer_id in range(num_output_blocks)
336
+ }
337
+
338
+ for i in range(1, num_input_blocks):
339
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
340
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
341
+
342
+ resnets = [
343
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
344
+ ]
345
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
346
+
347
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
348
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
349
+ f"input_blocks.{i}.0.op.weight"
350
+ )
351
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
352
+ f"input_blocks.{i}.0.op.bias"
353
+ )
354
+
355
+ paths = renew_resnet_paths(resnets)
356
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
357
+ assign_to_checkpoint(
358
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
359
+ )
360
+
361
+ if len(attentions):
362
+ paths = renew_attention_paths(attentions)
363
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
364
+ assign_to_checkpoint(
365
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
366
+ )
367
+
368
+ resnet_0 = middle_blocks[0]
369
+ attentions = middle_blocks[1]
370
+ resnet_1 = middle_blocks[2]
371
+
372
+ resnet_0_paths = renew_resnet_paths(resnet_0)
373
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
374
+
375
+ resnet_1_paths = renew_resnet_paths(resnet_1)
376
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
377
+
378
+ attentions_paths = renew_attention_paths(attentions)
379
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
380
+ assign_to_checkpoint(
381
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
382
+ )
383
+
384
+ for i in range(num_output_blocks):
385
+ block_id = i // (config["layers_per_block"] + 1)
386
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
387
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
388
+ output_block_list = {}
389
+
390
+ for layer in output_block_layers:
391
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
392
+ if layer_id in output_block_list:
393
+ output_block_list[layer_id].append(layer_name)
394
+ else:
395
+ output_block_list[layer_id] = [layer_name]
396
+
397
+ if len(output_block_list) > 1:
398
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
399
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
400
+
401
+ resnet_0_paths = renew_resnet_paths(resnets)
402
+ paths = renew_resnet_paths(resnets)
403
+
404
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
405
+ assign_to_checkpoint(
406
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
407
+ )
408
+
409
+ if ["conv.weight", "conv.bias"] in output_block_list.values():
410
+ index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
411
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
412
+ f"output_blocks.{i}.{index}.conv.weight"
413
+ ]
414
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
415
+ f"output_blocks.{i}.{index}.conv.bias"
416
+ ]
417
+
418
+ # Clear attentions as they have been attributed above.
419
+ if len(attentions) == 2:
420
+ attentions = []
421
+
422
+ if len(attentions):
423
+ paths = renew_attention_paths(attentions)
424
+ meta_path = {
425
+ "old": f"output_blocks.{i}.1",
426
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
427
+ }
428
+ assign_to_checkpoint(
429
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
430
+ )
431
+ else:
432
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
433
+ for path in resnet_0_paths:
434
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
435
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
436
+
437
+ new_checkpoint[new_path] = unet_state_dict[old_path]
438
+
439
+ return new_checkpoint
440
+
441
+
442
+ def convert_ldm_vae_checkpoint(checkpoint, config):
443
+ # extract state dict for VAE
444
+ vae_state_dict = {}
445
+ vae_key = "first_stage_model."
446
+ keys = list(checkpoint.keys())
447
+ for key in keys:
448
+ if key.startswith(vae_key):
449
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
450
+
451
+ new_checkpoint = {}
452
+
453
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
454
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
455
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
456
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
457
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
458
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
459
+
460
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
461
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
462
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
463
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
464
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
465
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
466
+
467
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
468
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
469
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
470
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
471
+
472
+ # Retrieves the keys for the encoder down blocks only
473
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
474
+ down_blocks = {
475
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
476
+ }
477
+
478
+ # Retrieves the keys for the decoder up blocks only
479
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
480
+ up_blocks = {
481
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
482
+ }
483
+
484
+ for i in range(num_down_blocks):
485
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
486
+
487
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
488
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
489
+ f"encoder.down.{i}.downsample.conv.weight"
490
+ )
491
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
492
+ f"encoder.down.{i}.downsample.conv.bias"
493
+ )
494
+
495
+ paths = renew_vae_resnet_paths(resnets)
496
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
497
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
498
+
499
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
500
+ num_mid_res_blocks = 2
501
+ for i in range(1, num_mid_res_blocks + 1):
502
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
503
+
504
+ paths = renew_vae_resnet_paths(resnets)
505
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
506
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
507
+
508
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
509
+ paths = renew_vae_attention_paths(mid_attentions)
510
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
511
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
512
+ conv_attn_to_linear(new_checkpoint)
513
+
514
+ for i in range(num_up_blocks):
515
+ block_id = num_up_blocks - 1 - i
516
+ resnets = [
517
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
518
+ ]
519
+
520
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
521
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
522
+ f"decoder.up.{block_id}.upsample.conv.weight"
523
+ ]
524
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
525
+ f"decoder.up.{block_id}.upsample.conv.bias"
526
+ ]
527
+
528
+ paths = renew_vae_resnet_paths(resnets)
529
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
530
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
531
+
532
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
533
+ num_mid_res_blocks = 2
534
+ for i in range(1, num_mid_res_blocks + 1):
535
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
536
+
537
+ paths = renew_vae_resnet_paths(resnets)
538
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
539
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
540
+
541
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
542
+ paths = renew_vae_attention_paths(mid_attentions)
543
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
544
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
545
+ conv_attn_to_linear(new_checkpoint)
546
+ return new_checkpoint
547
+
548
+
549
+ def convert_ldm_bert_checkpoint(checkpoint, config):
550
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
551
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
552
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
553
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
554
+
555
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
556
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
557
+
558
+ def _copy_linear(hf_linear, pt_linear):
559
+ hf_linear.weight = pt_linear.weight
560
+ hf_linear.bias = pt_linear.bias
561
+
562
+ def _copy_layer(hf_layer, pt_layer):
563
+ # copy layer norms
564
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
565
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
566
+
567
+ # copy attn
568
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
569
+
570
+ # copy MLP
571
+ pt_mlp = pt_layer[1][1]
572
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
573
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
574
+
575
+ def _copy_layers(hf_layers, pt_layers):
576
+ for i, hf_layer in enumerate(hf_layers):
577
+ if i != 0:
578
+ i += i
579
+ pt_layer = pt_layers[i : i + 2]
580
+ _copy_layer(hf_layer, pt_layer)
581
+
582
+ hf_model = LDMBertModel(config).eval()
583
+
584
+ # copy embeds
585
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
586
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
587
+
588
+ # copy layer norm
589
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
590
+
591
+ # copy hidden layers
592
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
593
+
594
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
595
+
596
+ return hf_model
597
+
598
+
599
+ def convert_ldm_clip_checkpoint(checkpoint):
600
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
601
+
602
+ keys = list(checkpoint.keys())
603
+
604
+ text_model_dict = {}
605
+
606
+ for key in keys:
607
+ if key.startswith("cond_stage_model.transformer"):
608
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
609
+
610
+ text_model.load_state_dict(text_model_dict)
611
+
612
+ return text_model
613
+
614
+ import os
615
+ def convert_checkpoint(checkpoint_path, inpainting=False):
616
+ parser = argparse.ArgumentParser()
617
+
618
+ parser.add_argument(
619
+ "--checkpoint_path", default=checkpoint_path, type=str, help="Path to the checkpoint to convert."
620
+ )
621
+ # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
622
+ parser.add_argument(
623
+ "--original_config_file",
624
+ default=None,
625
+ type=str,
626
+ help="The YAML config file corresponding to the original architecture.",
627
+ )
628
+ parser.add_argument(
629
+ "--scheduler_type",
630
+ default="pndm",
631
+ type=str,
632
+ help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
633
+ )
634
+ parser.add_argument("--dump_path", default=None, type=str, help="Path to the output model.")
635
+
636
+ args = parser.parse_args([])
637
+ if args.original_config_file is None:
638
+ if inpainting:
639
+ args.original_config_file = "./models/v1-inpainting-inference.yaml"
640
+ else:
641
+ args.original_config_file = "./models/v1-inference.yaml"
642
+
643
+ original_config = OmegaConf.load(args.original_config_file)
644
+ checkpoint = torch.load(args.checkpoint_path)["state_dict"]
645
+
646
+ num_train_timesteps = original_config.model.params.timesteps
647
+ beta_start = original_config.model.params.linear_start
648
+ beta_end = original_config.model.params.linear_end
649
+ if args.scheduler_type == "pndm":
650
+ scheduler = PNDMScheduler(
651
+ beta_end=beta_end,
652
+ beta_schedule="scaled_linear",
653
+ beta_start=beta_start,
654
+ num_train_timesteps=num_train_timesteps,
655
+ skip_prk_steps=True,
656
+ )
657
+ elif args.scheduler_type == "lms":
658
+ scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
659
+ elif args.scheduler_type == "ddim":
660
+ scheduler = DDIMScheduler(
661
+ beta_start=beta_start,
662
+ beta_end=beta_end,
663
+ beta_schedule="scaled_linear",
664
+ clip_sample=False,
665
+ set_alpha_to_one=False,
666
+ )
667
+ else:
668
+ raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")
669
+
670
+ # Convert the UNet2DConditionModel model.
671
+ unet_config = create_unet_diffusers_config(original_config)
672
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
673
+
674
+ unet = UNet2DConditionModel(**unet_config)
675
+ unet.load_state_dict(converted_unet_checkpoint)
676
+
677
+ # Convert the VAE model.
678
+ vae_config = create_vae_diffusers_config(original_config)
679
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
680
+
681
+ vae = AutoencoderKL(**vae_config)
682
+ vae.load_state_dict(converted_vae_checkpoint)
683
+
684
+ # Convert the text model.
685
+ text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
686
+ if text_model_type == "FrozenCLIPEmbedder":
687
+ text_model = convert_ldm_clip_checkpoint(checkpoint)
688
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
689
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
690
+ feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
691
+ pipe = StableDiffusionPipeline(
692
+ vae=vae,
693
+ text_encoder=text_model,
694
+ tokenizer=tokenizer,
695
+ unet=unet,
696
+ scheduler=scheduler,
697
+ safety_checker=safety_checker,
698
+ feature_extractor=feature_extractor,
699
+ )
700
+ else:
701
+ text_config = create_ldm_bert_config(original_config)
702
+ text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
703
+ tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
704
+ pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
705
+
706
+ return pipe
css/w2ui.min.css ADDED
The diff for this file is too large to render. See raw diff
 
js/fabric.min.js ADDED
The diff for this file is too large to render. See raw diff
 
js/keyboard.js ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ window.my_setup_keyboard=setInterval(function(){
3
+ let app=document.querySelector("gradio-app");
4
+ app=app.shadowRoot??app;
5
+ let frame=app.querySelector("#sdinfframe").contentWindow;
6
+ console.log("Check iframe...");
7
+ if(frame.setup_shortcut)
8
+ {
9
+ frame.setup_shortcut(json);
10
+ clearInterval(window.my_setup_keyboard);
11
+ }
12
+ }, 1000);
13
+ var config=JSON.parse(json);
14
+ var key_map={};
15
+ Object.keys(config.shortcut).forEach(k=>{
16
+ key_map[config.shortcut[k]]=k;
17
+ });
18
+ document.addEventListener("keydown", e => {
19
+ if(e.target.tagName!="INPUT"&&e.target.tagName!="GRADIO-APP"&&e.target.tagName!="TEXTAREA")
20
+ {
21
+ let key=e.key;
22
+ if(e.ctrlKey)
23
+ {
24
+ key="Ctrl+"+e.key;
25
+ if(key in key_map)
26
+ {
27
+ e.preventDefault();
28
+ }
29
+ }
30
+ let app=document.querySelector("gradio-app");
31
+ app=app.shadowRoot??app;
32
+ let frame=app.querySelector("#sdinfframe").contentDocument;
33
+ frame.dispatchEvent(
34
+ new KeyboardEvent("keydown", {key: e.key, ctrlKey: e.ctrlKey})
35
+ );
36
+ }
37
+ })
js/mode.js CHANGED
@@ -1,6 +1,6 @@
1
- function(mode){
2
- let app=document.querySelector("gradio-app");
3
- let frame=app.querySelector("#sdinfframe").contentWindow;
4
- frame.postMessage(["mode", mode], "*");
5
- return mode;
6
  }
 
1
+ function(mode){
2
+ let app=document.querySelector("gradio-app").shadowRoot;
3
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
4
+ frame.querySelector("#mode").value=mode;
5
+ return mode;
6
  }
js/outpaint.js CHANGED
@@ -1,31 +1,23 @@
1
- function(a){
2
- if(!window.my_observe_outpaint)
3
- {
4
- console.log("setup outpaint here");
5
- window.my_observe_outpaint = new MutationObserver(function (event) {
6
- console.log(event);
7
- let app=document.querySelector("gradio-app");
8
- let frame=app.querySelector("#sdinfframe").contentWindow;
9
- var str=document.querySelector("gradio-app").querySelector("#output textarea").value;
10
- frame.postMessage(["outpaint", str], "*");
11
- });
12
- window.my_observe_outpaint_target=document.querySelector("gradio-app").querySelector("#output span")
13
- window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
14
- attributes: false,
15
- subtree: true,
16
- childList: true,
17
- characterData: true
18
- });
19
- window.addEventListener("message", function(e){
20
- if(e.data[0]=="transfer")
21
- {
22
- document.querySelector("gradio-app").querySelector("#input textarea").value=e.data[1];
23
- document.querySelector("gradio-app").querySelector("#proceed").click();
24
- }
25
- });
26
- }
27
- let app=document.querySelector("gradio-app");
28
- let frame=app.querySelector("#sdinfframe").contentWindow;
29
- frame.postMessage(["transfer"],"*")
30
- return a;
31
  }
 
1
+ function(a){
2
+ if(!window.my_observe_outpaint)
3
+ {
4
+ console.log("setup outpaint here");
5
+ window.my_observe_outpaint = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ let app=document.querySelector("gradio-app");
8
+ app=app.shadowRoot??app;
9
+ let frame=app.querySelector("#sdinfframe").contentWindow;
10
+ frame.postMessage(["outpaint", ""], "*");
11
+ });
12
+ var app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ window.my_observe_outpaint_target=app.querySelector("#output span");
15
+ window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16
+ attributes: false,
17
+ subtree: true,
18
+ childList: true,
19
+ characterData: true
20
+ });
21
+ }
22
+ return a;
 
 
 
 
 
 
 
 
23
  }
js/proceed.js CHANGED
@@ -1,22 +1,42 @@
1
- function(sel_buffer_str,
2
- prompt_text,
3
- strength,
4
- guidance,
5
- step,
6
- resize_check,
7
- fill_mode,
8
- enable_safety,
9
- state){
10
- sel_buffer = document.querySelector("gradio-app").querySelector("#input textarea").value;
11
- return [
12
- sel_buffer,
13
- prompt_text,
14
- strength,
15
- guidance,
16
- step,
17
- resize_check,
18
- fill_mode,
19
- enable_safety,
20
- state
21
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  }
 
1
+ function(sel_buffer_str,
2
+ prompt_text,
3
+ negative_prompt_text,
4
+ strength,
5
+ guidance,
6
+ step,
7
+ resize_check,
8
+ fill_mode,
9
+ enable_safety,
10
+ use_correction,
11
+ enable_img2img,
12
+ use_seed,
13
+ seed_val,
14
+ generate_num,
15
+ scheduler,
16
+ scheduler_eta,
17
+ state){
18
+ let app=document.querySelector("gradio-app");
19
+ app=app.shadowRoot??app;
20
+ sel_buffer=app.querySelector("#input textarea").value;
21
+ let use_correction_bak=false;
22
+ ({resize_check,enable_safety,use_correction_bak,enable_img2img,use_seed,seed_val}=window.config_obj);
23
+ return [
24
+ sel_buffer,
25
+ prompt_text,
26
+ negative_prompt_text,
27
+ strength,
28
+ guidance,
29
+ step,
30
+ resize_check,
31
+ fill_mode,
32
+ enable_safety,
33
+ use_correction,
34
+ enable_img2img,
35
+ use_seed,
36
+ seed_val,
37
+ generate_num,
38
+ scheduler,
39
+ scheduler_eta,
40
+ state,
41
+ ]
42
  }
js/setup.js CHANGED
@@ -1,22 +1,28 @@
1
- function(token_val, width, height, size){
2
- let app=document.querySelector("gradio-app");
3
- app.querySelector("#sdinfframe").style.height=height+"px";
4
- let frame=app.querySelector("#sdinfframe").contentWindow.document;
5
- if(frame.querySelector("#setup").value=="0")
6
- {
7
- window.my_setup=setInterval(function(){
8
- let frame=document.querySelector("gradio-app").querySelector("#sdinfframe").contentWindow.document;
9
- console.log("Check PyScript...")
10
- if(frame.querySelector("#setup").value=="1")
11
- {
12
- frame.querySelector("#draw").click();
13
- clearInterval(window.my_setup);
14
- }
15
- },100)
16
- }
17
- else
18
- {
19
- frame.querySelector("#draw").click();
20
- }
21
- return [token_val, width, height, size];
 
 
 
 
 
 
22
  }
 
1
+ function(token_val, width, height, size, model_choice, model_path){
2
+ let app=document.querySelector("gradio-app");
3
+ app=app.shadowRoot??app;
4
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
5
+ // app.querySelector("#setup_row").style.display="none";
6
+ app.querySelector("#model_path_input").style.display="none";
7
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
8
+
9
+ if(frame.querySelector("#setup").value=="0")
10
+ {
11
+ window.my_setup=setInterval(function(){
12
+ let app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ let frame=app.querySelector("#sdinfframe").contentWindow.document;
15
+ console.log("Check PyScript...")
16
+ if(frame.querySelector("#setup").value=="1")
17
+ {
18
+ frame.querySelector("#draw").click();
19
+ clearInterval(window.my_setup);
20
+ }
21
+ }, 100)
22
+ }
23
+ else
24
+ {
25
+ frame.querySelector("#draw").click();
26
+ }
27
+ return [token_val, width, height, size, model_choice, model_path];
28
  }
js/toolbar.js ADDED
@@ -0,0 +1,581 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://rawgit.com/vitmalina/w2ui/master/dist/w2ui.es6.min.js"
2
+ // import { w2ui,w2toolbar,w2field,query,w2alert, w2utils,w2confirm} from "https://cdn.jsdelivr.net/gh/vitmalina/w2ui@master/dist/w2ui.es6.min.js"
3
+
4
+ // https://stackoverflow.com/questions/36280818/how-to-convert-file-to-base64-in-javascript
5
+ function getBase64(file) {
6
+ var reader = new FileReader();
7
+ reader.readAsDataURL(file);
8
+ reader.onload = function () {
9
+ add_image(reader.result);
10
+ // console.log(reader.result);
11
+ };
12
+ reader.onerror = function (error) {
13
+ console.log("Error: ", error);
14
+ };
15
+ }
16
+
17
+ function getText(file) {
18
+ var reader = new FileReader();
19
+ reader.readAsText(file);
20
+ reader.onload = function () {
21
+ window.postMessage(["load",reader.result],"*")
22
+ // console.log(reader.result);
23
+ };
24
+ reader.onerror = function (error) {
25
+ console.log("Error: ", error);
26
+ };
27
+ }
28
+
29
+ document.querySelector("#upload_file").addEventListener("change", (event)=>{
30
+ console.log(event);
31
+ let file = document.querySelector("#upload_file").files[0];
32
+ getBase64(file);
33
+ })
34
+
35
+ document.querySelector("#upload_state").addEventListener("change", (event)=>{
36
+ console.log(event);
37
+ let file = document.querySelector("#upload_state").files[0];
38
+ getText(file);
39
+ })
40
+
41
+ open_setting = function() {
42
+ if (!w2ui.foo) {
43
+ new w2form({
44
+ name: "foo",
45
+ style: "border: 0px; background-color: transparent;",
46
+ fields: [{
47
+ field: "canvas_width",
48
+ type: "int",
49
+ required: true,
50
+ html: {
51
+ label: "Canvas Width"
52
+ }
53
+ },
54
+ {
55
+ field: "canvas_height",
56
+ type: "int",
57
+ required: true,
58
+ html: {
59
+ label: "Canvas Height"
60
+ }
61
+ },
62
+ ],
63
+ record: {
64
+ canvas_width: 1200,
65
+ canvas_height: 600,
66
+ },
67
+ actions: {
68
+ Save() {
69
+ this.validate();
70
+ let record = this.getCleanRecord();
71
+ window.postMessage(["resize",record.canvas_width,record.canvas_height],"*");
72
+ w2popup.close();
73
+ },
74
+ custom: {
75
+ text: "Cancel",
76
+ style: "text-transform: uppercase",
77
+ onClick(event) {
78
+ w2popup.close();
79
+ }
80
+ }
81
+ }
82
+ });
83
+ }
84
+ w2popup.open({
85
+ title: "Form in a Popup",
86
+ body: "<div id='form' style='width: 100%; height: 100%;''></div>",
87
+ style: "padding: 15px 0px 0px 0px",
88
+ width: 500,
89
+ height: 280,
90
+ showMax: true,
91
+ async onToggle(event) {
92
+ await event.complete
93
+ w2ui.foo.resize();
94
+ }
95
+ })
96
+ .then((event) => {
97
+ w2ui.foo.render("#form")
98
+ });
99
+ }
100
+
101
+ var button_lst=["clear", "load", "save", "export", "upload", "selection", "canvas", "eraser", "outpaint", "accept", "cancel", "retry", "prev", "current", "next", "eraser_size_btn", "eraser_size", "resize_selection", "scale", "zoom_in", "zoom_out", "help"];
102
+ var upload_button_lst=['clear', 'load', 'save', "upload", 'export', 'outpaint', 'resize_selection', 'help', "setting"];
103
+ var resize_button_lst=['clear', 'load', 'save', "upload", 'export', "selection", "canvas", "eraser", 'outpaint', 'resize_selection',"zoom_in", "zoom_out", 'help', "setting"];
104
+ var outpaint_button_lst=['clear', 'load', 'save', "canvas", "eraser", "upload", 'export', 'resize_selection', "zoom_in", "zoom_out",'help', "setting"];
105
+ var outpaint_result_lst=["accept", "cancel", "retry", "prev", "current", "next"];
106
+ var outpaint_result_func_lst=["accept", "retry", "prev", "current", "next"];
107
+
108
+ function check_button(id,text="",checked=true,tooltip="")
109
+ {
110
+ return { type: "check", id: id, text: text, icon: checked?"fa-solid fa-square-check":"fa-regular fa-square", checked: checked, tooltip: tooltip };
111
+ }
112
+
113
+ var toolbar=new w2toolbar({
114
+ box: "#toolbar",
115
+ name: "toolbar",
116
+ tooltip: "top",
117
+ items: [
118
+ { type: "button", id: "clear", text: "Reset", tooltip: "Reset Canvas", icon: "fa-solid fa-rectangle-xmark" },
119
+ { type: "break" },
120
+ { type: "button", id: "load", tooltip: "Load Canvas", icon: "fa-solid fa-file-import" },
121
+ { type: "button", id: "save", tooltip: "Save Canvas", icon: "fa-solid fa-file-export" },
122
+ { type: "button", id: "export", tooltip: "Export Image", icon: "fa-solid fa-floppy-disk" },
123
+ { type: "break" },
124
+ { type: "button", id: "upload", text: "Upload Image", icon: "fa-solid fa-upload" },
125
+ { type: "break" },
126
+ { type: "radio", id: "selection", group: "1", tooltip: "Selection", icon: "fa-solid fa-arrows-up-down-left-right", checked: true },
127
+ { type: "radio", id: "canvas", group: "1", tooltip: "Canvas", icon: "fa-solid fa-image" },
128
+ { type: "radio", id: "eraser", group: "1", tooltip: "Eraser", icon: "fa-solid fa-eraser" },
129
+ { type: "break" },
130
+ { type: "button", id: "outpaint", text: "Outpaint", tooltip: "Run Outpainting", icon: "fa-solid fa-brush" },
131
+ { type: "break" },
132
+ { type: "button", id: "accept", text: "Accept", tooltip: "Accept current result", icon: "fa-solid fa-check", hidden: true, disable:true,},
133
+ { type: "button", id: "cancel", text: "Cancel", tooltip: "Cancel current outpainting/error", icon: "fa-solid fa-ban", hidden: true},
134
+ { type: "button", id: "retry", text: "Retry", tooltip: "Retry", icon: "fa-solid fa-rotate", hidden: true, disable:true,},
135
+ { type: "button", id: "prev", tooltip: "Prev Result", icon: "fa-solid fa-caret-left", hidden: true, disable:true,},
136
+ { type: "html", id: "current", hidden: true, disable:true,
137
+ async onRefresh(event) {
138
+ await event.complete
139
+ let fragment = query.html(`
140
+ <div class="w2ui-tb-text">
141
+ <div class="w2ui-tb-count">
142
+ <span>${this.sel_value ?? "1/1"}</span>
143
+ </div> </div>`)
144
+ query(this.box).find("#tb_toolbar_item_current").append(fragment)
145
+ }
146
+ },
147
+ { type: "button", id: "next", tooltip: "Next Result", icon: "fa-solid fa-caret-right", hidden: true,disable:true,},
148
+ { type: "button", id: "add_image", text: "Add Image", icon: "fa-solid fa-file-circle-plus", hidden: true,disable:true,},
149
+ { type: "button", id: "delete_image", text: "Delete Image", icon: "fa-solid fa-trash-can", hidden: true,disable:true,},
150
+ { type: "button", id: "confirm", text: "Confirm", icon: "fa-solid fa-check", hidden: true,disable:true,},
151
+ { type: "button", id: "cancel_overlay", text: "Cancel", icon: "fa-solid fa-ban", hidden: true,disable:true,},
152
+ { type: "break" },
153
+ { type: "spacer" },
154
+ { type: "break" },
155
+ { type: "button", id: "eraser_size_btn", tooltip: "Eraser Size", text:"Size", icon: "fa-solid fa-eraser", hidden: true, count: 32},
156
+ { type: "html", id: "eraser_size", hidden: true,
157
+ async onRefresh(event) {
158
+ await event.complete
159
+ // let fragment = query.html(`
160
+ // <input type="number" size="${this.eraser_size ? this.eraser_size.length:"2"}" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
161
+ // <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">`)
162
+ let fragment = query.html(`
163
+ <input type="range" style="margin: 0px 3px; padding: 4px;" min="8" max="${this.eraser_max ?? "256"}" value="${this.eraser_size ?? "32"}">
164
+ `)
165
+ fragment.filter("input").on("change", event => {
166
+ this.eraser_size = event.target.value;
167
+ window.overlay.freeDrawingBrush.width=this.eraser_size;
168
+ this.setCount("eraser_size_btn", event.target.value);
169
+ window.postMessage(["eraser_size", event.target.value],"*")
170
+ this.refresh();
171
+ })
172
+ query(this.box).find("#tb_toolbar_item_eraser_size").append(fragment)
173
+ }
174
+ },
175
+ // { type: "button", id: "resize_eraser", tooltip: "Resize Eraser", icon: "fa-solid fa-sliders" },
176
+ { type: "button", id: "resize_selection", text: "Resize Selection", tooltip: "Resize Selection", icon: "fa-solid fa-expand" },
177
+ { type: "break" },
178
+ { type: "html", id: "scale",
179
+ async onRefresh(event) {
180
+ await event.complete
181
+ let fragment = query.html(`
182
+ <div class="">
183
+ <div style="padding: 4px; border: 1px solid silver">
184
+ <span>${this.scale_value ?? "100%"}</span>
185
+ </div></div>`)
186
+ query(this.box).find("#tb_toolbar_item_scale").append(fragment)
187
+ }
188
+ },
189
+ { type: "button", id: "zoom_in", tooltip: "Zoom In", icon: "fa-solid fa-magnifying-glass-plus" },
190
+ { type: "button", id: "zoom_out", tooltip: "Zoom Out", icon: "fa-solid fa-magnifying-glass-minus" },
191
+ { type: "break" },
192
+ { type: "button", id: "help", tooltip: "Help", icon: "fa-solid fa-circle-info" },
193
+ { type: "new-line"},
194
+ { type: "button", id: "setting", text: "Canvas Setting", tooltip: "Resize Canvas Here", icon: "fa-solid fa-sliders" },
195
+ { type: "break" },
196
+ check_button("enable_img2img","Enable Img2Img",false),
197
+ // check_button("use_correction","Photometric Correction",false),
198
+ check_button("resize_check","Resize Small Input",true),
199
+ check_button("enable_safety","Enable Safety Checker",true),
200
+ check_button("square_selection","Square Selection Only",false),
201
+ {type: "break"},
202
+ check_button("use_seed","Use Seed:",false),
203
+ { type: "html", id: "seed_val",
204
+ async onRefresh(event) {
205
+ await event.complete
206
+ let fragment = query.html(`
207
+ <input type="number" style="margin: 0px 3px; padding: 4px; width:100px;" value="${this.config_obj.seed_val ?? "0"}">`)
208
+ fragment.filter("input").on("change", event => {
209
+ this.config_obj.seed_val = event.target.value;
210
+ parent.config_obj=this.config_obj;
211
+ this.refresh();
212
+ })
213
+ query(this.box).find("#tb_toolbar_item_seed_val").append(fragment)
214
+ }
215
+ },
216
+ { type: "button", id: "random_seed", tooltip: "Set a random seed", icon: "fa-solid fa-dice" },
217
+ ],
218
+ onClick(event) {
219
+ switch(event.target){
220
+ case "setting":
221
+ open_setting();
222
+ break;
223
+ case "upload":
224
+ this.upload_mode=true
225
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
226
+ this.click("canvas");
227
+ this.click("selection");
228
+ this.show("confirm","cancel_overlay","add_image","delete_image");
229
+ this.enable("confirm","cancel_overlay","add_image","delete_image");
230
+ this.disable(...upload_button_lst);
231
+ query("#upload_file").click();
232
+ if(this.upload_tip)
233
+ {
234
+ this.upload_tip=false;
235
+ w2utils.notify("Note that only visible images will be added to canvas",{timeout:10000,where:query("#container")})
236
+ }
237
+ break;
238
+ case "resize_selection":
239
+ this.resize_mode=true;
240
+ this.disable(...resize_button_lst);
241
+ this.enable("confirm","cancel_overlay");
242
+ this.show("confirm","cancel_overlay");
243
+ window.postMessage(["resize_selection",""],"*");
244
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
245
+ break;
246
+ case "confirm":
247
+ if(this.upload_mode)
248
+ {
249
+ export_image();
250
+ }
251
+ else
252
+ {
253
+ let sel_box=this.selection_box;
254
+ window.postMessage(["resize_selection",sel_box.x,sel_box.y,sel_box.width,sel_box.height],"*");
255
+ }
256
+ case "cancel_overlay":
257
+ end_overlay();
258
+ this.hide("confirm","cancel_overlay","add_image","delete_image");
259
+ if(this.upload_mode){
260
+ this.enable(...upload_button_lst);
261
+ }
262
+ else
263
+ {
264
+ this.enable(...resize_button_lst);
265
+ window.postMessage(["resize_selection","",""],"*");
266
+ if(event.target=="cancel_overlay")
267
+ {
268
+ this.selection_box=this.selection_box_bak;
269
+ }
270
+ }
271
+ if(this.selection_box)
272
+ {
273
+ this.setCount("resize_selection",`${Math.floor(this.selection_box.width/8)*8}x${Math.floor(this.selection_box.height/8)*8}`);
274
+ }
275
+ this.disable("confirm","cancel_overlay","add_image","delete_image");
276
+ this.upload_mode=false;
277
+ this.resize_mode=false;
278
+ this.click("selection");
279
+ break;
280
+ case "add_image":
281
+ query("#upload_file").click();
282
+ break;
283
+ case "delete_image":
284
+ let active_obj = window.overlay.getActiveObject();
285
+ if(active_obj)
286
+ {
287
+ window.overlay.remove(active_obj);
288
+ window.overlay.renderAll();
289
+ }
290
+ else
291
+ {
292
+ w2utils.notify("You need to select an image first",{error:true,timeout:2000,where:query("#container")})
293
+ }
294
+ break;
295
+ case "load":
296
+ query("#upload_state").click();
297
+ this.selection_box=null;
298
+ this.setCount("resize_selection","");
299
+ break;
300
+ case "next":
301
+ case "prev":
302
+ window.postMessage(["outpaint", "", event.target], "*");
303
+ break;
304
+ case "outpaint":
305
+ this.click("selection");
306
+ this.disable(...outpaint_button_lst);
307
+ this.show(...outpaint_result_lst);
308
+ if(this.outpaint_tip)
309
+ {
310
+ this.outpaint_tip=false;
311
+ w2utils.notify("The canvas stays locked until you accept/cancel current outpainting",{timeout:10000,where:query("#container")})
312
+ }
313
+ document.querySelector("#container").style.pointerEvents="none";
314
+ case "retry":
315
+ this.disable(...outpaint_result_func_lst);
316
+ window.postMessage(["transfer",""],"*")
317
+ break;
318
+ case "accept":
319
+ case "cancel":
320
+ this.hide(...outpaint_result_lst);
321
+ this.disable(...outpaint_result_func_lst);
322
+ this.enable(...outpaint_button_lst);
323
+ document.querySelector("#container").style.pointerEvents="auto";
324
+ window.postMessage(["click", event.target],"*");
325
+ let app=parent.document.querySelector("gradio-app");
326
+ app=app.shadowRoot??app;
327
+ app.querySelector("#cancel").click();
328
+ break;
329
+ case "eraser":
330
+ case "selection":
331
+ case "canvas":
332
+ if(event.target=="eraser")
333
+ {
334
+ this.show("eraser_size","eraser_size_btn");
335
+ window.overlay.freeDrawingBrush.width=this.eraser_size;
336
+ window.overlay.isDrawingMode = true;
337
+ }
338
+ else
339
+ {
340
+ this.hide("eraser_size","eraser_size_btn");
341
+ window.overlay.isDrawingMode = false;
342
+ }
343
+ if(this.upload_mode)
344
+ {
345
+ if(event.target=="canvas")
346
+ {
347
+ window.postMessage(["mode", event.target],"*")
348
+ document.querySelector("#overlay_container").style.pointerEvents="none";
349
+ document.querySelector("#overlay_container").style.opacity = 0.5;
350
+ }
351
+ else
352
+ {
353
+ document.querySelector("#overlay_container").style.pointerEvents="auto";
354
+ document.querySelector("#overlay_container").style.opacity = 1.0;
355
+ }
356
+ }
357
+ else
358
+ {
359
+ window.postMessage(["mode", event.target],"*")
360
+ }
361
+ break;
362
+ case "help":
363
+ w2popup.open({
364
+ title: "Document",
365
+ body: "Usage: <a href='https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md' target='_blank'>https://github.com/lkwq007/stablediffusion-infinity/blob/master/docs/usage.md</a>"
366
+ })
367
+ break;
368
+ case "clear":
369
+ w2confirm("Reset canvas?").yes(() => {
370
+ window.postMessage(["click", event.target],"*");
371
+ }).no(() => {})
372
+ break;
373
+ case "random_seed":
374
+ this.config_obj.seed_val=Math.floor(Math.random() * 3000000000);
375
+ parent.config_obj=this.config_obj;
376
+ this.refresh();
377
+ break;
378
+ case "enable_img2img":
379
+ case "use_correction":
380
+ case "resize_check":
381
+ case "enable_safety":
382
+ case "use_seed":
383
+ case "square_selection":
384
+ let target=this.get(event.target);
385
+ target.icon=target.checked?"fa-regular fa-square":"fa-solid fa-square-check";
386
+ this.config_obj[event.target]=!target.checked;
387
+ parent.config_obj=this.config_obj;
388
+ this.refresh();
389
+ break;
390
+ case "save":
391
+ case "export":
392
+ ask_filename(event.target);
393
+ break;
394
+ default:
395
+ // clear, save, export, outpaint, retry
396
+ // break, save, export, accept, retry, outpaint
397
+ window.postMessage(["click", event.target],"*")
398
+ }
399
+ console.log("Target: "+ event.target, event)
400
+ }
401
+ })
402
+ window.w2ui=w2ui;
403
+ w2ui.toolbar.config_obj={
404
+ resize_check: true,
405
+ enable_safety: true,
406
+ use_correction: false,
407
+ enable_img2img: false,
408
+ use_seed: false,
409
+ seed_val: 0,
410
+ square_selection: false,
411
+ };
412
+ w2ui.toolbar.outpaint_tip=true;
413
+ w2ui.toolbar.upload_tip=true;
414
+ window.update_count=function(cur,total){
415
+ w2ui.toolbar.sel_value=`${cur}/${total}`;
416
+ w2ui.toolbar.refresh();
417
+ }
418
+ window.update_eraser=function(val,max_val){
419
+ w2ui.toolbar.eraser_size=`${val}`;
420
+ w2ui.toolbar.eraser_max=`${max_val}`;
421
+ w2ui.toolbar.setCount("eraser_size_btn", `${val}`);
422
+ w2ui.toolbar.refresh();
423
+ }
424
+ window.update_scale=function(val){
425
+ w2ui.toolbar.scale_value=`${val}`;
426
+ w2ui.toolbar.refresh();
427
+ }
428
+ window.enable_result_lst=function(){
429
+ w2ui.toolbar.enable(...outpaint_result_lst);
430
+ }
431
+ function onObjectScaled(e)
432
+ {
433
+ let object = e.target;
434
+ if(object.isType("rect"))
435
+ {
436
+ let width=object.getScaledWidth();
437
+ let height=object.getScaledHeight();
438
+ object.scale(1);
439
+ width=Math.max(Math.min(width,window.overlay.width-object.left),256);
440
+ height=Math.max(Math.min(height,window.overlay.height-object.top),256);
441
+ let l=Math.max(Math.min(object.left,window.overlay.width-width-object.strokeWidth),0);
442
+ let t=Math.max(Math.min(object.top,window.overlay.height-height-object.strokeWidth),0);
443
+ if(window.w2ui.toolbar.config_obj.square_selection)
444
+ {
445
+ let max_val = Math.min(Math.max(width,height),window.overlay.width,window.overlay.height);
446
+ width=max_val;
447
+ height=max_val;
448
+ }
449
+ object.set({ width: width, height: height, left:l,top:t})
450
+ window.w2ui.toolbar.selection_box={width: width, height: height, x:object.left, y:object.top};
451
+ window.w2ui.toolbar.setCount("resize_selection",`${Math.floor(width/8)*8}x${Math.floor(height/8)*8}`);
452
+ window.w2ui.toolbar.refresh();
453
+ }
454
+ }
455
+ function onObjectMoved(e)
456
+ {
457
+ let object = e.target;
458
+ if(object.isType("rect"))
459
+ {
460
+ let l=Math.max(Math.min(object.left,window.overlay.width-object.width-object.strokeWidth),0);
461
+ let t=Math.max(Math.min(object.top,window.overlay.height-object.height-object.strokeWidth),0);
462
+ object.set({left:l,top:t});
463
+ window.w2ui.toolbar.selection_box={width: object.width, height: object.height, x:object.left, y:object.top};
464
+ }
465
+ }
466
+ window.setup_overlay=function(width,height)
467
+ {
468
+ if(window.overlay)
469
+ {
470
+ window.overlay.setDimensions({width:width,height:height});
471
+ let app=parent.document.querySelector("gradio-app");
472
+ app=app.shadowRoot??app;
473
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
474
+ document.querySelector("#container").style.height= height+"px";
475
+ document.querySelector("#container").style.width = width+"px";
476
+ }
477
+ else
478
+ {
479
+ canvas=new fabric.Canvas("overlay_canvas");
480
+ canvas.setDimensions({width:width,height:height});
481
+ let app=parent.document.querySelector("gradio-app");
482
+ app=app.shadowRoot??app;
483
+ app.querySelector("#sdinfframe").style.height=80+Number(height)+"px";
484
+ canvas.freeDrawingBrush = new fabric.EraserBrush(canvas);
485
+ canvas.on("object:scaling", onObjectScaled);
486
+ canvas.on("object:moving", onObjectMoved);
487
+ window.overlay=canvas;
488
+ }
489
+ document.querySelector("#overlay_container").style.pointerEvents="none";
490
+ }
491
+ window.update_overlay=function(width,height)
492
+ {
493
+ window.overlay.setDimensions({width:width,height:height},{backstoreOnly:true});
494
+ // document.querySelector("#overlay_container").style.pointerEvents="none";
495
+ }
496
+ window.adjust_selection=function(x,y,width,height)
497
+ {
498
+ var rect = new fabric.Rect({
499
+ left: x,
500
+ top: y,
501
+ fill: "rgba(0,0,0,0)",
502
+ strokeWidth: 3,
503
+ stroke: "rgba(0,0,0,0.7)",
504
+ cornerColor: "red",
505
+ cornerStrokeColor: "red",
506
+ borderColor: "rgba(255, 0, 0, 1.0)",
507
+ width: width,
508
+ height: height,
509
+ lockRotation: true,
510
+ });
511
+ rect.setControlsVisibility({ mtr: false });
512
+ window.overlay.add(rect);
513
+ window.overlay.setActiveObject(window.overlay.item(0));
514
+ window.w2ui.toolbar.selection_box={width: width, height: height, x:x, y:y};
515
+ window.w2ui.toolbar.selection_box_bak={width: width, height: height, x:x, y:y};
516
+ }
517
+ function add_image(url)
518
+ {
519
+ fabric.Image.fromURL(url,function(img){
520
+ window.overlay.add(img);
521
+ window.overlay.setActiveObject(img);
522
+ },{left:100,top:100});
523
+ }
524
+ function export_image()
525
+ {
526
+ data=window.overlay.toDataURL();
527
+ document.querySelector("#upload_content").value=data;
528
+ window.postMessage(["upload",""],"*");
529
+ end_overlay();
530
+ }
531
+ function end_overlay()
532
+ {
533
+ window.overlay.clear();
534
+ document.querySelector("#overlay_container").style.opacity = 1.0;
535
+ document.querySelector("#overlay_container").style.pointerEvents="none";
536
+ }
537
+ function ask_filename(target)
538
+ {
539
+ w2prompt({
540
+ label: "Enter filename",
541
+ value: `outpaint_${((new Date(Date.now() -(new Date()).getTimezoneOffset() * 60000))).toISOString().replace("T","_").replace(/[^0-9_]/g, "").substring(0,15)}`,
542
+ })
543
+ .change((event) => {
544
+ console.log("change", event.detail.originalEvent.target.value);
545
+ })
546
+ .ok((event) => {
547
+ console.log("value=", event.detail.value);
548
+ window.postMessage(["click",target,event.detail.value],"*");
549
+ })
550
+ .cancel((event) => {
551
+ console.log("cancel");
552
+ });
553
+ }
554
+
555
+ document.querySelector("#container").addEventListener("wheel",(e)=>{e.preventDefault()})
556
+ window.setup_shortcut=function(json)
557
+ {
558
+ var config=JSON.parse(json);
559
+ var key_map={};
560
+ Object.keys(config.shortcut).forEach(k=>{
561
+ key_map[config.shortcut[k]]=k;
562
+ })
563
+ document.addEventListener("keydown",(e)=>{
564
+ if(e.target.tagName!="INPUT")
565
+ {
566
+ let key=e.key;
567
+ if(e.ctrlKey)
568
+ {
569
+ key="Ctrl+"+e.key;
570
+ if(key in key_map)
571
+ {
572
+ e.preventDefault();
573
+ }
574
+ }
575
+ if(key in key_map)
576
+ {
577
+ w2ui.toolbar.click(key_map[key]);
578
+ }
579
+ }
580
+ })
581
+ }
js/upload.js CHANGED
@@ -1,20 +1,19 @@
1
- function(a,b){
2
- if(!window.my_observe_upload)
3
- {
4
- console.log("setup upload here");
5
- window.my_observe_upload = new MutationObserver(function (event) {
6
- console.log(event);
7
- var frame=document.querySelector("gradio-app").querySelector("#sdinfframe").contentWindow;
8
- var str=document.querySelector("gradio-app").querySelector("#upload textarea").value;
9
- frame.postMessage(["upload", str], "*");
10
- });
11
- window.my_observe_upload_target = document.querySelector("gradio-app").querySelector("#upload span");
12
- window.my_observe_upload.observe(window.my_observe_upload_target, {
13
- attributes: false,
14
- subtree: true,
15
- childList: true,
16
- characterData: true
17
- });
18
- }
19
- return [a,b];
20
  }
 
1
+ function(a,b){
2
+ if(!window.my_observe_upload)
3
+ {
4
+ console.log("setup upload here");
5
+ window.my_observe_upload = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ var frame=document.querySelector("gradio-app").shadowRoot.querySelector("#sdinfframe").contentWindow.document;
8
+ frame.querySelector("#upload").click();
9
+ });
10
+ window.my_observe_upload_target = document.querySelector("gradio-app").shadowRoot.querySelector("#upload span");
11
+ window.my_observe_upload.observe(window.my_observe_upload_target, {
12
+ attributes: false,
13
+ subtree: true,
14
+ childList: true,
15
+ characterData: true
16
+ });
17
+ }
18
+ return [a,b];
 
19
  }
js/w2ui.min.js ADDED
The diff for this file is too large to render. See raw diff
 
js/xss.js ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ var setup_outpaint=function(){
2
+ if(!window.my_observe_outpaint)
3
+ {
4
+ console.log("setup outpaint here");
5
+ window.my_observe_outpaint = new MutationObserver(function (event) {
6
+ console.log(event);
7
+ let app=document.querySelector("gradio-app");
8
+ app=app.shadowRoot??app;
9
+ let frame=app.querySelector("#sdinfframe").contentWindow;
10
+ frame.postMessage(["outpaint", ""], "*");
11
+ });
12
+ var app=document.querySelector("gradio-app");
13
+ app=app.shadowRoot??app;
14
+ window.my_observe_outpaint_target=app.querySelector("#output span");
15
+ window.my_observe_outpaint.observe(window.my_observe_outpaint_target, {
16
+ attributes: false,
17
+ subtree: true,
18
+ childList: true,
19
+ characterData: true
20
+ });
21
+ }
22
+ };
23
+ window.config_obj={
24
+ resize_check: true,
25
+ enable_safety: true,
26
+ use_correction: false,
27
+ enable_img2img: false,
28
+ use_seed: false,
29
+ seed_val: 0,
30
+ };
31
+ setup_outpaint();
models/v1-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-04
3
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: crossattn
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ use_ema: False
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 10000 ]
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 4
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
models/v1-inpainting-inference.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 7.5e-05
3
+ target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ image_size: 64
13
+ channels: 4
14
+ cond_stage_trainable: false # Note: different from the one we trained before
15
+ conditioning_key: hybrid # important
16
+ monitor: val/loss_simple_ema
17
+ scale_factor: 0.18215
18
+ finetune_keys: null
19
+
20
+ scheduler_config: # 10000 warmup steps
21
+ target: ldm.lr_scheduler.LambdaLinearScheduler
22
+ params:
23
+ warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch
24
+ cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25
+ f_start: [ 1.e-6 ]
26
+ f_max: [ 1. ]
27
+ f_min: [ 1. ]
28
+
29
+ unet_config:
30
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31
+ params:
32
+ image_size: 32 # unused
33
+ in_channels: 9 # 4 data + 4 downscaled image + 1 mask
34
+ out_channels: 4
35
+ model_channels: 320
36
+ attention_resolutions: [ 4, 2, 1 ]
37
+ num_res_blocks: 2
38
+ channel_mult: [ 1, 2, 4, 4 ]
39
+ num_heads: 8
40
+ use_spatial_transformer: True
41
+ transformer_depth: 1
42
+ context_dim: 768
43
+ use_checkpoint: True
44
+ legacy: False
45
+
46
+ first_stage_config:
47
+ target: ldm.models.autoencoder.AutoencoderKL
48
+ params:
49
+ embed_dim: 4
50
+ monitor: val/rec_loss
51
+ ddconfig:
52
+ double_z: true
53
+ z_channels: 4
54
+ resolution: 256
55
+ in_channels: 3
56
+ out_ch: 3
57
+ ch: 128
58
+ ch_mult:
59
+ - 1
60
+ - 2
61
+ - 4
62
+ - 4
63
+ num_res_blocks: 2
64
+ attn_resolutions: []
65
+ dropout: 0.0
66
+ lossconfig:
67
+ target: torch.nn.Identity
68
+
69
+ cond_stage_config:
70
+ target: ldm.modules.encoders.modules.FrozenCLIPEmbedder