DmitrMakeev
commited on
Commit
•
78ae62d
1
Parent(s):
8b2a131
Upload 5 files
Browse files- sync_batchnorm/__init__.py +12 -0
- sync_batchnorm/batchnorm.py +315 -0
- sync_batchnorm/comm.py +137 -0
- sync_batchnorm/replicate.py +94 -0
- sync_batchnorm/unittest.py +29 -0
sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
18 |
+
|
19 |
+
from .comm import SyncMaster
|
20 |
+
|
21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
22 |
+
|
23 |
+
|
24 |
+
def _sum_ft(tensor):
|
25 |
+
"""sum over the first and last dimention"""
|
26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
27 |
+
|
28 |
+
|
29 |
+
def _unsqueeze_ft(tensor):
|
30 |
+
"""add new dementions at the front and the tail"""
|
31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
32 |
+
|
33 |
+
|
34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
36 |
+
|
37 |
+
|
38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
|
40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
41 |
+
|
42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
43 |
+
|
44 |
+
self._is_parallel = False
|
45 |
+
self._parallel_id = None
|
46 |
+
self._slave_pipe = None
|
47 |
+
|
48 |
+
def forward(self, input):
|
49 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
50 |
+
if not (self._is_parallel and self.training):
|
51 |
+
return F.batch_norm(
|
52 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
53 |
+
self.training, self.momentum, self.eps)
|
54 |
+
|
55 |
+
# Resize the input to (B, C, -1).
|
56 |
+
input_shape = input.size()
|
57 |
+
input = input.view(input.size(0), self.num_features, -1)
|
58 |
+
|
59 |
+
# Compute the sum and square-sum.
|
60 |
+
sum_size = input.size(0) * input.size(2)
|
61 |
+
input_sum = _sum_ft(input)
|
62 |
+
input_ssum = _sum_ft(input ** 2)
|
63 |
+
|
64 |
+
# Reduce-and-broadcast the statistics.
|
65 |
+
if self._parallel_id == 0:
|
66 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
67 |
+
else:
|
68 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
69 |
+
|
70 |
+
# Compute the output.
|
71 |
+
if self.affine:
|
72 |
+
# MJY:: Fuse the multiplication for speed.
|
73 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
74 |
+
else:
|
75 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
76 |
+
|
77 |
+
# Reshape it.
|
78 |
+
return output.view(input_shape)
|
79 |
+
|
80 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
81 |
+
self._is_parallel = True
|
82 |
+
self._parallel_id = copy_id
|
83 |
+
|
84 |
+
# parallel_id == 0 means master device.
|
85 |
+
if self._parallel_id == 0:
|
86 |
+
ctx.sync_master = self._sync_master
|
87 |
+
else:
|
88 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
89 |
+
|
90 |
+
def _data_parallel_master(self, intermediates):
|
91 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
92 |
+
|
93 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
94 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
95 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
96 |
+
|
97 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
98 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
99 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
100 |
+
|
101 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
102 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
103 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
104 |
+
|
105 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
106 |
+
|
107 |
+
outputs = []
|
108 |
+
for i, rec in enumerate(intermediates):
|
109 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
110 |
+
|
111 |
+
return outputs
|
112 |
+
|
113 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
114 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
115 |
+
also maintains the moving average on the master device."""
|
116 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
117 |
+
mean = sum_ / size
|
118 |
+
sumvar = ssum - sum_ * mean
|
119 |
+
unbias_var = sumvar / (size - 1)
|
120 |
+
bias_var = sumvar / size
|
121 |
+
|
122 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
123 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
124 |
+
|
125 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
126 |
+
|
127 |
+
|
128 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
129 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
130 |
+
mini-batch.
|
131 |
+
|
132 |
+
.. math::
|
133 |
+
|
134 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
135 |
+
|
136 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
137 |
+
standard-deviation are reduced across all devices during training.
|
138 |
+
|
139 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
140 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
141 |
+
the statistics only on that device, which accelerated the computation and
|
142 |
+
is also easy to implement, but the statistics might be inaccurate.
|
143 |
+
Instead, in this synchronized version, the statistics will be computed
|
144 |
+
over all training samples distributed on multiple devices.
|
145 |
+
|
146 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
147 |
+
as the built-in PyTorch implementation.
|
148 |
+
|
149 |
+
The mean and standard-deviation are calculated per-dimension over
|
150 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
151 |
+
of size C (where C is the input size).
|
152 |
+
|
153 |
+
During training, this layer keeps a running estimate of its computed mean
|
154 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
155 |
+
|
156 |
+
During evaluation, this running mean/variance is used for normalization.
|
157 |
+
|
158 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
159 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
160 |
+
|
161 |
+
Args:
|
162 |
+
num_features: num_features from an expected input of size
|
163 |
+
`batch_size x num_features [x width]`
|
164 |
+
eps: a value added to the denominator for numerical stability.
|
165 |
+
Default: 1e-5
|
166 |
+
momentum: the value used for the running_mean and running_var
|
167 |
+
computation. Default: 0.1
|
168 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
169 |
+
affine parameters. Default: ``True``
|
170 |
+
|
171 |
+
Shape:
|
172 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
173 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
174 |
+
|
175 |
+
Examples:
|
176 |
+
>>> # With Learnable Parameters
|
177 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
178 |
+
>>> # Without Learnable Parameters
|
179 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
180 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
181 |
+
>>> output = m(input)
|
182 |
+
"""
|
183 |
+
|
184 |
+
def _check_input_dim(self, input):
|
185 |
+
if input.dim() != 2 and input.dim() != 3:
|
186 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
187 |
+
.format(input.dim()))
|
188 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
189 |
+
|
190 |
+
|
191 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
192 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
193 |
+
of 3d inputs
|
194 |
+
|
195 |
+
.. math::
|
196 |
+
|
197 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
198 |
+
|
199 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
200 |
+
standard-deviation are reduced across all devices during training.
|
201 |
+
|
202 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
203 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
204 |
+
the statistics only on that device, which accelerated the computation and
|
205 |
+
is also easy to implement, but the statistics might be inaccurate.
|
206 |
+
Instead, in this synchronized version, the statistics will be computed
|
207 |
+
over all training samples distributed on multiple devices.
|
208 |
+
|
209 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
210 |
+
as the built-in PyTorch implementation.
|
211 |
+
|
212 |
+
The mean and standard-deviation are calculated per-dimension over
|
213 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
214 |
+
of size C (where C is the input size).
|
215 |
+
|
216 |
+
During training, this layer keeps a running estimate of its computed mean
|
217 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
218 |
+
|
219 |
+
During evaluation, this running mean/variance is used for normalization.
|
220 |
+
|
221 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
222 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
223 |
+
|
224 |
+
Args:
|
225 |
+
num_features: num_features from an expected input of
|
226 |
+
size batch_size x num_features x height x width
|
227 |
+
eps: a value added to the denominator for numerical stability.
|
228 |
+
Default: 1e-5
|
229 |
+
momentum: the value used for the running_mean and running_var
|
230 |
+
computation. Default: 0.1
|
231 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
232 |
+
affine parameters. Default: ``True``
|
233 |
+
|
234 |
+
Shape:
|
235 |
+
- Input: :math:`(N, C, H, W)`
|
236 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
237 |
+
|
238 |
+
Examples:
|
239 |
+
>>> # With Learnable Parameters
|
240 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
241 |
+
>>> # Without Learnable Parameters
|
242 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
243 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
244 |
+
>>> output = m(input)
|
245 |
+
"""
|
246 |
+
|
247 |
+
def _check_input_dim(self, input):
|
248 |
+
if input.dim() != 4:
|
249 |
+
raise ValueError('expected 4D input (got {}D input)'
|
250 |
+
.format(input.dim()))
|
251 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
252 |
+
|
253 |
+
|
254 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
255 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
256 |
+
of 4d inputs
|
257 |
+
|
258 |
+
.. math::
|
259 |
+
|
260 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
261 |
+
|
262 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
263 |
+
standard-deviation are reduced across all devices during training.
|
264 |
+
|
265 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
266 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
267 |
+
the statistics only on that device, which accelerated the computation and
|
268 |
+
is also easy to implement, but the statistics might be inaccurate.
|
269 |
+
Instead, in this synchronized version, the statistics will be computed
|
270 |
+
over all training samples distributed on multiple devices.
|
271 |
+
|
272 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
273 |
+
as the built-in PyTorch implementation.
|
274 |
+
|
275 |
+
The mean and standard-deviation are calculated per-dimension over
|
276 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
277 |
+
of size C (where C is the input size).
|
278 |
+
|
279 |
+
During training, this layer keeps a running estimate of its computed mean
|
280 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
281 |
+
|
282 |
+
During evaluation, this running mean/variance is used for normalization.
|
283 |
+
|
284 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
285 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
286 |
+
or Spatio-temporal BatchNorm
|
287 |
+
|
288 |
+
Args:
|
289 |
+
num_features: num_features from an expected input of
|
290 |
+
size batch_size x num_features x depth x height x width
|
291 |
+
eps: a value added to the denominator for numerical stability.
|
292 |
+
Default: 1e-5
|
293 |
+
momentum: the value used for the running_mean and running_var
|
294 |
+
computation. Default: 0.1
|
295 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
296 |
+
affine parameters. Default: ``True``
|
297 |
+
|
298 |
+
Shape:
|
299 |
+
- Input: :math:`(N, C, D, H, W)`
|
300 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
301 |
+
|
302 |
+
Examples:
|
303 |
+
>>> # With Learnable Parameters
|
304 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
305 |
+
>>> # Without Learnable Parameters
|
306 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
307 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
308 |
+
>>> output = m(input)
|
309 |
+
"""
|
310 |
+
|
311 |
+
def _check_input_dim(self, input):
|
312 |
+
if input.dim() != 5:
|
313 |
+
raise ValueError('expected 5D input (got {}D input)'
|
314 |
+
.format(input.dim()))
|
315 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|
sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from torch.autograd import Variable
|
15 |
+
|
16 |
+
|
17 |
+
def as_numpy(v):
|
18 |
+
if isinstance(v, Variable):
|
19 |
+
v = v.data
|
20 |
+
return v.cpu().numpy()
|
21 |
+
|
22 |
+
|
23 |
+
class TorchTestCase(unittest.TestCase):
|
24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
26 |
+
self.assertTrue(
|
27 |
+
np.allclose(npa, npb, atol=atol),
|
28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
29 |
+
)
|