Correct the output dtype of rmsnorm_func

#13
by ag0 - opened

Currently the output dtype of rmsnorm_func is not the same as the input dtype, I'm not sure if this is the intended behaviour but this looks like a bug.

How to reproduce:

import torch

hidden_size = 8

hidden_states = torch.rand((4, hidden_size), dtype=torch.float16)
weight = torch.ones(hidden_size, dtype=torch.float32)
variance_epsilon = torch.tensor(1e-6)

def rmsnorm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return weight * hidden_states.to(input_dtype)

print('input', hidden_states.dtype)
print('output', rmsnorm_func(hidden_states, weight, variance_epsilon).dtype)

Result:

input torch.float16
output torch.float32

With this PR:

input torch.float16
output torch.float16
Together org

Thanks @ag0 !! @juewang can you look into this one?

Ce

Together org

LGTM

juewang changed pull request status to merged

Sign up or log in to comment