|
import copy |
|
from typing import Any, Dict |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from transformers import AutoConfig, AutoModel, PretrainedConfig |
|
|
|
from fim.models.blocks import AModel, ModelFactory, RNNEncoder, TransformerEncoder |
|
from fim.models.utils import create_matrix_from_off_diagonal, create_padding_mask, get_off_diagonal_elements |
|
from fim.utils.helper import create_class_instance |
|
|
|
|
|
class FIMMJPConfig(PretrainedConfig): |
|
model_type = "fimmjp" |
|
|
|
def __init__( |
|
self, |
|
n_states: int = 2, |
|
use_adjacency_matrix: bool = False, |
|
ts_encoder: dict = None, |
|
pos_encodings: dict = None, |
|
path_attention: dict = None, |
|
intensity_matrix_decoder: dict = None, |
|
initial_distribution_decoder: dict = None, |
|
**kwargs, |
|
): |
|
self.n_states = n_states |
|
self.use_adjacency_matrix = use_adjacency_matrix |
|
self.ts_encoder = ts_encoder |
|
self.pos_encodings = pos_encodings |
|
self.path_attention = path_attention |
|
self.intensity_matrix_decoder = intensity_matrix_decoder |
|
self.initial_distribution_decoder = initial_distribution_decoder |
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
class FIMMJP(AModel): |
|
""" |
|
FIMMJP: A Neural Recognition Model for Zero-Shot Inference of Markov Jump Processes |
|
This class implements a neural recognition model for zero-shot inference of Markov jump processes (MJPs) on bounded state spaces from noisy and sparse observations. The methodology is based on the following paper: |
|
Markov jump processes are continuous-time stochastic processes which describe dynamical systems evolving in discrete state spaces. These processes find wide application in the natural sciences and machine learning, but their inference is known to be far from trivial. In this work we introduce a methodology for zero-shot inference of Markov jump processes (MJPs), on bounded state spaces, from noisy and sparse observations, which consists of two components. First, a broad probability distribution over families of MJPs, as well as over possible observation times and noise mechanisms, with which we simulate a synthetic dataset of hidden MJPs and their noisy observations. Second, a neural recognition model that processes subsets of the simulated observations, and that is trained to output the initial condition and rate matrix of the target MJP in a supervised way. We empirically demonstrate that one and the same (pretrained) recognition model can infer, in a zero-shot fashion, hidden MJPs evolving in state spaces of different dimensionalities. Specifically, we infer MJPs which describe (i) discrete flashing ratchet systems, which are a type of Brownian motors, and the conformational dynamics in (ii) molecular simulations, (iii) experimental ion channel data and (iv) simple protein folding models. What is more, we show that our model performs on par with state-of-the-art models which are trained on the target datasets. |
|
|
|
It is model from the paper:"Foundation Inference Models for Markov Jump Processes" --- https://arxiv.org/abs/2406.06419. |
|
Attributes: |
|
n_states (int): Number of states in the Markov jump process. |
|
use_adjacency_matrix (bool): Whether to use an adjacency matrix. |
|
ts_encoder (dict | TransformerEncoder): Time series encoder. |
|
pos_encodings (dict | SineTimeEncoding): Positional encodings. |
|
path_attention (dict | nn.Module): Path attention mechanism. |
|
intensity_matrix_decoder (dict | nn.Module): Decoder for the intensity matrix. |
|
initial_distribution_decoder (dict | nn.Module): Decoder for the initial distribution. |
|
gaussian_nll (nn.GaussianNLLLoss): Gaussian negative log-likelihood loss. |
|
init_cross_entropy (nn.CrossEntropyLoss): Cross-entropy loss for initial distribution. |
|
|
|
Methods: |
|
forward(x: dict[str, Tensor], schedulers: dict = None, step: int = None) -> dict: |
|
Forward pass of the model. |
|
__decode(h: Tensor) -> tuple[Tensor, Tensor]: |
|
Decode the hidden representation to obtain the intensity matrix and initial condition. |
|
__encode(x: Tensor, obs_grid_normalized: Tensor, obs_values_one_hot: Tensor) -> Tensor: |
|
Encode the input observations to obtain the hidden representation. |
|
__denormalize_offdiag_mean_logvar(norm_constants: Tensor, pred_offdiag_im_mean_logvar: Tensor) -> tuple[Tensor, Tensor]: |
|
Denormalize the predicted off-diagonal mean and log-variance. |
|
__normalize_obs_grid(obs_grid: Tensor) -> tuple[Tensor, Tensor]: |
|
Normalize the observation grid. |
|
loss(pred_im: Tensor, pred_logvar_im: Tensor, pred_init_cond: Tensor, target_im: Tensor, target_init_cond: Tensor, adjaceny_matrix: Tensor, normalization_constants: Tensor, schedulers: dict = None, step: int = None) -> dict: |
|
Compute the loss for the model. |
|
new_stats() -> dict: |
|
Initialize new statistics. |
|
metric(y: Any, y_target: Any) -> Dict: |
|
Compute the metric for the model. |
|
""" |
|
|
|
config_class = FIMMJPConfig |
|
|
|
def __init__(self, config: FIMMJPConfig, **kwargs): |
|
super().__init__(config, **kwargs) |
|
self.n_states = config.n_states |
|
self.use_adjacency_matrix = config.use_adjacency_matrix |
|
self.ts_encoder = config.ts_encoder |
|
self.total_offdiagonal_transitions = self.n_states**2 - self.n_states |
|
|
|
self.__create_modules() |
|
|
|
self.gaussian_nll = nn.GaussianNLLLoss(full=True, reduction="none") |
|
self.init_cross_entropy = nn.CrossEntropyLoss(reduction="none") |
|
|
|
def __create_modules(self): |
|
pos_encodings = copy.deepcopy(self.config.pos_encodings) |
|
ts_encoder = copy.deepcopy(self.config.ts_encoder) |
|
path_attention = copy.deepcopy(self.config.path_attention) |
|
intensity_matrix_decoder = copy.deepcopy(self.config.intensity_matrix_decoder) |
|
initial_distribution_decoder = copy.deepcopy(self.config.initial_distribution_decoder) |
|
|
|
if ts_encoder["name"] == "fim.models.blocks.base.TransformerEncoder": |
|
pos_encodings["out_features"] -= self.n_states |
|
self.pos_encodings = create_class_instance(pos_encodings.pop("name"), pos_encodings) |
|
|
|
ts_encoder["in_features"] = self.n_states + self.pos_encodings.out_features |
|
self.ts_encoder = create_class_instance(ts_encoder.pop("name"), ts_encoder) |
|
|
|
self.path_attention = create_class_instance(path_attention.pop("name"), path_attention) |
|
|
|
in_features = intensity_matrix_decoder.get( |
|
"in_features", self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.use_adjacency_matrix else 1) |
|
) |
|
intensity_matrix_decoder["in_features"] = in_features |
|
intensity_matrix_decoder["out_features"] = 2 * self.total_offdiagonal_transitions |
|
self.intensity_matrix_decoder = create_class_instance(intensity_matrix_decoder.pop("name"), intensity_matrix_decoder) |
|
|
|
in_features = initial_distribution_decoder.get( |
|
"in_features", self.ts_encoder.out_features + ((self.total_offdiagonal_transitions + 1) if self.use_adjacency_matrix else 1) |
|
) |
|
initial_distribution_decoder["in_features"] = in_features |
|
initial_distribution_decoder["out_features"] = self.n_states |
|
self.initial_distribution_decoder = create_class_instance(initial_distribution_decoder.pop("name"), initial_distribution_decoder) |
|
|
|
def forward(self, x: dict[str, Tensor], n_states: int = None, schedulers: dict = None, step: int = None) -> dict: |
|
""" |
|
Forward pass for the model. |
|
|
|
Args: |
|
x (dict[str, Tensor]): A dictionary containing the input tensors: |
|
- "observation_grid": Tensor representing the observation grid. |
|
- "observation_values": Tensor representing the observation values. |
|
- "seq_lengths": Tensor representing the sequence lengths. |
|
- Optional keys: |
|
- "time_normalization_factors": Tensor representing the time normalization factors. |
|
- Optional keys for loss calculation: |
|
- "intensity_matrices": Tensor representing the intensity matrices. |
|
- "initial_distributions": Tensor representing the initial distributions. |
|
- "adjacency_matrices": Tensor representing the adjacency matrices. |
|
schedulers (dict, optional): A dictionary of schedulers for the training process. Default is None. |
|
step (int, optional): The current step in the training process. Default is None. |
|
Returns: |
|
dict: A dictionary containing the following keys: |
|
- "im": Tensor representing the intensity matrix. |
|
- "intensity_matrices_variance": Tensor representing the log variance of the intensity matrix. |
|
- "initial_condition": Tensor representing the initial conditions. |
|
- "losses" (optional): Tensor representing the calculated losses, if the required keys are present in `x`. |
|
""" |
|
|
|
obs_grid = x["observation_grid"] |
|
if "time_normalization_factors" not in x: |
|
norm_constants, obs_grid = self.__normalize_obs_grid(obs_grid) |
|
x["time_normalization_factors"] = norm_constants |
|
x["observation_grid_normalized"] = obs_grid |
|
else: |
|
norm_constants = x["time_normalization_factors"] |
|
x["observation_grid_normalized"] = obs_grid |
|
|
|
x["observation_values_one_hot"] = torch.nn.functional.one_hot(x["observation_values"].long().squeeze(-1), num_classes=self.n_states) |
|
|
|
h = self.__encode(x) |
|
pred_offdiag_im_mean_logvar, init_cond = self.__decode(h) |
|
|
|
pred_offdiag_im_mean, pred_offdiag_im_logvar = self.__denormalize_offdiag_mean_logstd(norm_constants, pred_offdiag_im_mean_logvar) |
|
|
|
out = { |
|
"intensity_matrices": create_matrix_from_off_diagonal( |
|
pred_offdiag_im_mean, self.n_states, mode="negative_sum_row", n_states=self.n_states if n_states is None else n_states |
|
), |
|
"intensity_matrices_variance": create_matrix_from_off_diagonal( |
|
torch.exp(pred_offdiag_im_logvar), |
|
self.n_states, |
|
mode="negative_sum_row", |
|
n_states=self.n_states if n_states is None else n_states, |
|
), |
|
"initial_condition": init_cond, |
|
} |
|
if "intensity_matrices" in x and "initial_distributions" in x: |
|
out["losses"] = self.loss( |
|
pred_offdiag_im_mean, pred_offdiag_im_logvar, init_cond, x, norm_constants.view(-1, 1), schedulers, step |
|
) |
|
|
|
return out |
|
|
|
def __decode(self, h: Tensor) -> tuple[Tensor, Tensor]: |
|
pred_offdiag_logmean_logstd = self.intensity_matrix_decoder(h) |
|
init_cond = self.initial_distribution_decoder(h) |
|
return pred_offdiag_logmean_logstd, init_cond |
|
|
|
def __encode(self, x: dict[str, Tensor]) -> Tensor: |
|
obs_grid_normalized = x["observation_grid_normalized"] |
|
obs_values_one_hot = x["observation_values_one_hot"] |
|
B, P, L = obs_grid_normalized.shape[:3] |
|
pos_enc = self.pos_encodings(obs_grid_normalized) |
|
path = torch.cat([pos_enc, obs_values_one_hot], dim=-1) |
|
if isinstance(self.ts_encoder, TransformerEncoder): |
|
padding_mask = create_padding_mask(x["seq_lengths"].view(B * P), L) |
|
padding_mask[:, 0] = True |
|
h = self.ts_encoder(path.view(B * P, L, -1), padding_mask)[:, 1, :].view(B, P, -1) |
|
if isinstance(self.path_attention, nn.MultiheadAttention): |
|
h = self.path_attention(h, h, h)[0][:, -1] |
|
else: |
|
h = self.path_attention(h, h, h) |
|
elif isinstance(self.ts_encoder, RNNEncoder): |
|
h = self.ts_encoder(path.view(B * P, L, -1), x["seq_lengths"].view(B * P)) |
|
last_observation = x["seq_lengths"].view(B * P) - 1 |
|
h = h[torch.arange(B * P), last_observation].view(B, P, -1) |
|
h = self.path_attention(h, h, h) |
|
|
|
h = torch.cat([h, torch.ones(B, 1).to(h.device) / 100.0 * P], dim=-1) |
|
if self.use_adjacency_matrix: |
|
h = torch.cat([h, get_off_diagonal_elements(x["adjacency_matrix"])], dim=-1) |
|
return h |
|
|
|
def __denormalize_offdiag_mean_logstd(self, norm_constants: Tensor, pred_offdiag_im_logmean_logstd: Tensor) -> tuple[Tensor, Tensor]: |
|
pred_offdiag_im_logmean, pred_offdiag_im_logstd = pred_offdiag_im_logmean_logstd.chunk(2, dim=-1) |
|
pred_offdiag_im_mean = torch.exp(pred_offdiag_im_logmean) / norm_constants.view(-1, 1) |
|
pred_offdiag_im_logstd = pred_offdiag_im_logstd - torch.log(norm_constants.view(-1, 1)) |
|
return pred_offdiag_im_mean, pred_offdiag_im_logstd |
|
|
|
def __normalize_obs_grid(self, obs_grid: Tensor) -> tuple[Tensor, Tensor]: |
|
norm_constants = obs_grid.amax(dim=[-3, -2, -1]) |
|
obs_grid_normalized = obs_grid / norm_constants.view(-1, 1, 1, 1) |
|
return norm_constants, obs_grid_normalized |
|
|
|
def loss( |
|
self, |
|
pred_im: Tensor, |
|
pred_logstd_im: Tensor, |
|
pred_init_cond: Tensor, |
|
target: dict, |
|
normalization_constants: Tensor, |
|
schedulers: dict = None, |
|
step: int = None, |
|
) -> dict: |
|
target_im = target["intensity_matrices"] |
|
target_init_cond = target["initial_distributions"] |
|
adjaceny_matrix = target["adjacency_matrices"] |
|
target_mean = get_off_diagonal_elements(target_im) |
|
P = target["observation_grid"].shape[1] |
|
adjaceny_matrix = get_off_diagonal_elements(adjaceny_matrix) |
|
target_init_cond = torch.argmax(target_init_cond, dim=-1).long() |
|
pred_im_std = torch.exp(pred_logstd_im) |
|
loss_gauss = adjaceny_matrix * self.gaussian_nll(pred_im, target_mean, torch.pow(pred_im_std, 2)) |
|
loss_gauss = loss_gauss.sum() / (adjaceny_matrix.sum() + 1e-8) |
|
loss_initial = self.init_cross_entropy(pred_init_cond, target_init_cond).mean() |
|
zero_entries = 1.0 - adjaceny_matrix |
|
loss_missing_link = normalization_constants * zero_entries * (torch.pow(pred_im, 2) + torch.pow(pred_im_std, 2)) |
|
loss_missing_link = loss_missing_link.sum() / (zero_entries.sum() + 1e-8) |
|
rmse_loss = torch.sqrt(torch.mean((target_mean - pred_im) ** 2)) |
|
|
|
gaus_cons = schedulers.get("gauss_nll")(step) if schedulers else torch.tensor(1.0) |
|
init_cons = schedulers.get("init_cross_entropy")(step) if schedulers else torch.tensor(1.0) |
|
missing_link_cons = schedulers.get("missing_link")(step) if schedulers else torch.tensor(1.0) |
|
gaus_cons = gaus_cons.to(self.device) |
|
init_cons = init_cons.to(self.device) |
|
missing_link_cons = missing_link_cons.to(self.device) |
|
|
|
loss = gaus_cons * loss_gauss + init_cons * loss_initial + missing_link_cons * loss_missing_link |
|
|
|
return { |
|
"loss": loss, |
|
"loss_gauss": loss_gauss, |
|
"loss_initial": loss_initial, |
|
"loss_missing_link": loss_missing_link, |
|
"rmse_loss": rmse_loss, |
|
"beta_gauss_nll": gaus_cons, |
|
"beta_init_cross_entropy": init_cons, |
|
"beta_missing_link": missing_link_cons, |
|
"number_of_paths": torch.tensor(P, device=self.device), |
|
} |
|
|
|
def metric(self, y: Any, y_target: Any) -> Dict: |
|
return super().metric(y, y_target) |
|
|
|
|
|
ModelFactory.register(FIMMJPConfig.model_type, FIMMJP) |
|
AutoConfig.register(FIMMJPConfig.model_type, FIMMJPConfig) |
|
AutoModel.register(FIMMJPConfig, FIMMJP) |
|
|