|
|
|
|
|
|
|
|
|
@@ -19,8 +19,8 @@ typedef enum{ |
|
|
|
template <typename T, typename GRAD_T> |
|
__global__ void adam_cuda_kernel( |
|
- T* __restrict__ p, |
|
- GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed |
|
+ GRAD_T* __restrict__ p, |
|
+ T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed |
|
T* __restrict__ m, |
|
T* __restrict__ v, |
|
const GRAD_T * __restrict__ g, |
|
@@ -50,7 +50,7 @@ __global__ void adam_cuda_kernel( |
|
else // Mode 1 |
|
denom = sqrtf(v[j]) + eps; |
|
float update = (m[j]/denom) + (decay*p[j]); |
|
- p[j] = p[j] - (step_size*update); |
|
+ p[j] = (GRAD_T) (p[j] - (step_size*update)); |
|
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j]; |
|
} |
|
} |
|
@@ -93,14 +93,14 @@ void fused_adam_cuda( |
|
|
|
if (g.scalar_type() == at::ScalarType::Half) { |
|
//all other values should be fp32 for half gradients |
|
- AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); |
|
+// AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type"); |
|
//dispatch is done on the gradient type |
|
using namespace at; // prevents "toString is undefined" errors |
|
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel", |
|
using accscalar_t = at::acc_type<scalar_t_0, true>; |
|
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>( |
|
- p.data<accscalar_t>(), |
|
- p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL, |
|
+ p.data<scalar_t_0>(), |
|
+ NULL, //don't output p_copy for fp32, it's wasted write |
|
m.data<accscalar_t>(), |
|
v.data<accscalar_t>(), |
|
g.data<scalar_t_0>(), |
|
|