MasonCriner / apex /apex.patch
MasonCrinr's picture
Upload 331 files
8026e91
diff --git a/csrc/fused_adam_cuda_kernel.cu b/csrc/fused_adam_cuda_kernel.cu
index 34f7aa2..95581d1 100644
--- a/csrc/fused_adam_cuda_kernel.cu
+++ b/csrc/fused_adam_cuda_kernel.cu
@@ -19,8 +19,8 @@ typedef enum{
template <typename T, typename GRAD_T>
__global__ void adam_cuda_kernel(
- T* __restrict__ p,
- GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
+ GRAD_T* __restrict__ p,
+ T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
T* __restrict__ m,
T* __restrict__ v,
const GRAD_T * __restrict__ g,
@@ -50,7 +50,7 @@ __global__ void adam_cuda_kernel(
else // Mode 1
denom = sqrtf(v[j]) + eps;
float update = (m[j]/denom) + (decay*p[j]);
- p[j] = p[j] - (step_size*update);
+ p[j] = (GRAD_T) (p[j] - (step_size*update));
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
}
}
@@ -93,14 +93,14 @@ void fused_adam_cuda(
if (g.scalar_type() == at::ScalarType::Half) {
//all other values should be fp32 for half gradients
- AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
+// AT_ASSERTM(p.scalar_type() == at::ScalarType::Float, "expected parameter to be of float type");
//dispatch is done on the gradient type
using namespace at; // prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
using accscalar_t = at::acc_type<scalar_t_0, true>;
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
- p.data<accscalar_t>(),
- p_copy.numel() ? p_copy.data<scalar_t_0>() : NULL,
+ p.data<scalar_t_0>(),
+ NULL, //don't output p_copy for fp32, it's wasted write
m.data<accscalar_t>(),
v.data<accscalar_t>(),
g.data<scalar_t_0>(),