Spaces:
Runtime error
Runtime error
/* | |
* Copyright 2021 Google LLC | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
namespace csrblocksparse { | |
constexpr int kGenericSIMDWidth = 4; | |
// TODO(b/188702959): Rename arguments to match gru_gates.h. | |
template <typename GRUStateType, typename GRUMatMulOutType, typename QR_W_Type, | |
typename SampleType, ARInputsMode kInputsMode, | |
bool SplitGates = false> | |
void GoThroughGates(int start, int end, const QR_W_Type* qr_ptr, | |
const GRUMatMulOutType* gru_gates_ptr, | |
const GRUMatMulOutType* gru_gates_other_ptr, | |
const GRUMatMulOutType* conditioning_ptr, | |
GRUStateType* gru_h_ptr, const QR_W_Type* w_hat, | |
int proj_size, const SampleType* coarse_at_sminus1, | |
const SampleType* fine_at_sminus1, | |
const SampleType* coarse_at_s = nullptr) { | |
float qr_cell = 0.0f, reset, update, cell; | |
for (int i = start; i < end; ++i) { | |
if (kInputsMode == ARInputsMode::k0ARInputs) { | |
reset = static_cast<float>(gru_gates_ptr[i]); | |
update = static_cast<float>(gru_gates_ptr[proj_size + i]); | |
} else { | |
float qr_c_reset = static_cast<float>(qr_ptr[2 * i + 0]); | |
float qr_f_reset = static_cast<float>(qr_ptr[2 * i + 1]); | |
float qr_c_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 0]); | |
float qr_f_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 1]); | |
float qr_c_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 0]); | |
float qr_f_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 1]); | |
float w_hat_i_reset = 0.0f; | |
float w_hat_i_update = 0.0f; | |
float w_hat_i_cell = 0.0f; | |
if (kInputsMode == ARInputsMode::k3ARInputs) { | |
w_hat_i_reset = static_cast<float>(w_hat[i]); | |
w_hat_i_update = static_cast<float>(w_hat[proj_size + i]); | |
w_hat_i_cell = static_cast<float>(w_hat[2 * proj_size + i]); | |
} | |
float coarse = static_cast<float>(coarse_at_sminus1[0]); | |
float fine = static_cast<float>(fine_at_sminus1[0]); | |
reset = qr_c_reset * coarse + qr_f_reset * fine; | |
update = qr_c_update * coarse + qr_f_update * fine; | |
qr_cell = qr_c_cell * coarse + qr_f_cell * fine; | |
if (kInputsMode == ARInputsMode::k3ARInputs) { | |
float coarse = static_cast<float>(coarse_at_s[0]); | |
reset += w_hat_i_reset * coarse; | |
update += w_hat_i_update * coarse; | |
qr_cell += w_hat_i_cell * coarse; | |
} | |
reset += static_cast<float>(gru_gates_ptr[i]); | |
update += static_cast<float>(gru_gates_ptr[proj_size + i]); | |
} | |
cell = static_cast<float>(gru_gates_ptr[2 * proj_size + i]); | |
if (SplitGates) { | |
reset += static_cast<float>(gru_gates_other_ptr[i]); | |
update += static_cast<float>(gru_gates_other_ptr[proj_size + i]); | |
cell += static_cast<float>(gru_gates_other_ptr[2 * proj_size + i]); | |
} | |
float reset_conditioning = static_cast<float>(conditioning_ptr[i]); | |
float update_conditioning = | |
static_cast<float>(conditioning_ptr[proj_size + i]); | |
float cell_conditioning = | |
static_cast<float>(conditioning_ptr[2 * proj_size + i]); | |
reset = fast_sigmoid(reset + reset_conditioning); | |
update = fast_sigmoid(update + update_conditioning); | |
float hbar = fast_tanh(qr_cell + reset * cell + cell_conditioning); | |
int h_index = i; | |
float prev_h = static_cast<float>(gru_h_ptr[h_index]); | |
float diff = prev_h - hbar; | |
float new_h = hbar + diff * update; | |
gru_h_ptr[h_index] = static_cast<GRUStateType>(new_h); | |
} | |
} | |
} // namespace csrblocksparse | |