PatchTSMixer
개요
PatchTSMixer 모델은 Vijay Ekambaram, Arindam Jati, Nam Nguyen, Phanwadee Sinthong, Jayant Kalagnanam이 제안한 TSMixer: 다변량 시계열 예측을 위한 경량 MLP-Mixer 모델이라는 논문에서 소개되었습니다.
PatchTSMixer는 MLP-Mixer 아키텍처를 기반으로 한 경량 시계열 모델링 접근법입니다. 허깅페이스 구현에서는 PatchTSMixer의 기능을 제공하여 패치, 채널, 숨겨진 특성 간의 경량 혼합을 쉽게 수행하여 효과적인 다변량 시계열 모델링을 가능하게 합니다. 또한 간단한 게이트 어텐션부터 사용자 정의된 더 복잡한 셀프 어텐션 블록까지 다양한 어텐션 메커니즘을 지원합니다. 이 모델은 사전 훈련될 수 있으며 이후 예측, 분류, 회귀와 같은 다양한 다운스트림 작업에 사용될 수 있습니다.
해당 논문의 초록입니다:
TSMixer는 패치 처리된 시계열의 다변량 예측 및 표현 학습을 위해 설계된 다층 퍼셉트론(MLP) 모듈로만 구성된 경량 신경망 아키텍처입니다. 우리의 모델은 컴퓨터 비전 분야에서 MLP-Mixer 모델의 성공에서 영감을 받았습니다. 우리는 Vision MLP-Mixer를 시계열에 적용하는 데 따르는 과제를 보여주고, 정확도를 향상시키기 위해 경험적으로 검증된 구성 요소들을 도입합니다. 여기에는 계층 구조 및 채널 상관관계와 같은 시계열 특성을 명시적으로 모델링하기 위해 MLP-Mixer 백본에 온라인 조정 헤드를 부착하는 새로운 설계 패러다임이 포함됩니다. 또한 기존 패치 채널 혼합 방법의 일반적인 문제인 노이즈가 있는 채널 상호작용을 효과적으로 처리하고 다양한 데이터셋에 걸쳐 일반화하기 위한 하이브리드 채널 모델링 접근법을 제안합니다. 추가로, 중요한 특성을 우선시하기 위해 백본에 간단한 게이트 주의 메커니즘을 도입합니다. 이러한 경량 구성 요소들을 통합함으로써, 우리는 단순한 MLP 구조의 학습 능력을 크게 향상시켜 최소한의 컴퓨팅 사용으로 복잡한 트랜스포머 모델들을 능가하는 성능을 달성합니다. 더욱이, TSMixer의 모듈식 설계는 감독 학습과 마스크 자기 감독 학습 방법 모두와 호환되어 시계열 기초 모델의 유망한 구성 요소가 됩니다. TSMixer는 예측 작업에서 최첨단 MLP 및 트랜스포머 모델들을 상당한 차이(8-60%)로 능가합니다. 또한 최신의 강력한 Patch-Transformer 모델 벤치마크들을 메모리와 실행 시간을 크게 줄이면서(2-3배) 성능 면에서도 앞섭니다(1-2%).
이 모델은 ajati, vijaye12, gsinthong, namctin, wmgifford, kashif가 기여했습니다.
사용 예
아래의 코드 스니펫은 PatchTSMixer 모델을 무작위로 초기화하는 방법을 보여줍니다. PatchTSMixer 모델은 Trainer API와 호환됩니다.
from transformers import PatchTSMixerConfig, PatchTSMixerForPrediction
from transformers import Trainer, TrainingArguments,
config = PatchTSMixerConfig(context_length = 512, prediction_length = 96)
model = PatchTSMixerForPrediction(config)
trainer = Trainer(model=model, args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset)
trainer.train()
results = trainer.evaluate(test_dataset)
사용 팁
이 모델은 시계열 분류와 시계열 회귀에도 사용될 수 있습니다. 각각PatchTSMixerForTimeSeriesClassification와 PatchTSMixerForRegression 클래스를 참조하세요.
자료
- PatchTSMixer를 자세히 설명하는 블로그 포스트는 여기에서 찾을 수 있습니다 이곳. 이 블로그는 Google Colab에서도 열어볼 수 있습니다.
PatchTSMixerConfig
class transformers.PatchTSMixerConfig
< source >( context_length: int = 32 patch_length: int = 8 num_input_channels: int = 1 patch_stride: int = 8 num_parallel_samples: int = 100 d_model: int = 8 expansion_factor: int = 2 num_layers: int = 3 dropout: float = 0.2 mode: str = 'common_channel' gated_attn: bool = True norm_mlp: str = 'LayerNorm' self_attn: bool = False self_attn_heads: int = 1 use_positional_encoding: bool = False positional_encoding_type: str = 'sincos' scaling: Union = 'std' loss: str = 'mse' init_std: float = 0.02 post_init: bool = False norm_eps: float = 1e-05 mask_type: str = 'random' random_mask_ratio: float = 0.5 num_forecast_mask_patches: Union = [2] mask_value: int = 0 masked_loss: bool = True channel_consistent_masking: bool = True unmasked_channel_indices: Optional = None head_dropout: float = 0.2 distribution_output: str = 'student_t' prediction_length: int = 16 prediction_channel_indices: list = None num_targets: int = 3 output_range: list = None head_aggregation: str = 'max_pool' **kwargs )
Parameters
- context_length (
int
, optional, defaults to 32) — The context/history length for the input sequence. - patch_length (
int
, optional, defaults to 8) — The patch length for the input sequence. - num_input_channels (
int
, optional, defaults to 1) — Number of input variates. For Univariate, set it to 1. - patch_stride (
int
, optional, defaults to 8) — Determines the overlap between two consecutive patches. Set it to patch_length (or greater), if we want non-overlapping patches. - num_parallel_samples (
int
, optional, defaults to 100) — The number of samples to generate in parallel for probabilistic forecast. - d_model (
int
, optional, defaults to 8) — Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-5X of patch_length). Larger value indicates more complex model. - expansion_factor (
int
, optional, defaults to 2) — Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model. - num_layers (
int
, optional, defaults to 3) — Number of layers to use. Recommended range is 3-15. Larger value indicates more complex model. - dropout (
float
, optional, defaults to 0.2) — The dropout probability thePatchTSMixer
backbone. Recommended range is 0.2-0.7 - mode (
str
, optional, defaults to"common_channel"
) — Mixer Mode. Determines how to process the channels. Allowed values: “common_channel”, “mix_channel”. In “common_channel” mode, we follow Channel-independent modelling with no explicit channel-mixing. Channel mixing happens in an implicit manner via shared weights across channels. (preferred first approach) In “mix_channel” mode, we follow explicit channel-mixing in addition to patch and feature mixer. (preferred approach when channel correlations are very important to model) - gated_attn (
bool
, optional, defaults toTrue
) — Enable Gated Attention. - norm_mlp (
str
, optional, defaults to"LayerNorm"
) — Normalization layer (BatchNorm or LayerNorm). - self_attn (
bool
, optional, defaults toFalse
) — Enable Tiny self attention across patches. This can be enabled when the output of Vanilla PatchTSMixer with gated attention is not satisfactory. Enabling this leads to explicit pair-wise attention and modelling across patches. - self_attn_heads (
int
, optional, defaults to 1) — Number of self-attention heads. Works only whenself_attn
is set toTrue
. - use_positional_encoding (
bool
, optional, defaults toFalse
) — Enable the use of positional embedding for the tiny self-attention layers. Works only whenself_attn
is set toTrue
. - positional_encoding_type (
str
, optional, defaults to"sincos"
) — Positional encodings. Options"random"
and"sincos"
are supported. Works only whenuse_positional_encoding
is set toTrue
- scaling (
string
orbool
, optional, defaults to"std"
) — Whether to scale the input targets via “mean” scaler, “std” scaler or no scaler ifNone
. IfTrue
, the scaler is set to “mean”. - loss (
string
, optional, defaults to"mse"
) — The loss function for the model corresponding to thedistribution_output
head. For parametric distributions it is the negative log likelihood (“nll”) and for point estimates it is the mean squared error “mse”. - init_std (
float
, optional, defaults to 0.02) — The standard deviation of the truncated normal weight initialization distribution. - post_init (
bool
, optional, defaults toFalse
) — Whether to use custom weight initialization fromtransformers
library, or the default initialization inPyTorch
. Setting it toFalse
performsPyTorch
weight initialization. - norm_eps (
float
, optional, defaults to 1e-05) — A value added to the denominator for numerical stability of normalization. - mask_type (
str
, optional, defaults to"random"
) — Type of masking to use for Masked Pretraining mode. Allowed values are “random”, “forecast”. In Random masking, points are masked randomly. In Forecast masking, points are masked towards the end. - random_mask_ratio (
float
, optional, defaults to 0.5) — Masking ratio to use whenmask_type
israndom
. Higher value indicates more masking. - num_forecast_mask_patches (
int
orlist
, optional, defaults to[2]
) — Number of patches to be masked at the end of each batch sample. If it is an integer, all the samples in the batch will have the same number of masked patches. If it is a list, samples in the batch will be randomly masked by numbers defined in the list. This argument is only used for forecast pretraining. - mask_value (
float
, optional, defaults to0.0
) — Mask value to use. - masked_loss (
bool
, optional, defaults toTrue
) — Whether to compute pretraining loss only at the masked portions, or on the entire output. - channel_consistent_masking (
bool
, optional, defaults toTrue
) — When true, masking will be same across all channels of a timeseries. Otherwise, masking positions will vary across channels. - unmasked_channel_indices (
list
, optional) — Channels that are not masked during pretraining. - head_dropout (
float
, optional, defaults to 0.2) — The dropout probability thePatchTSMixer
head. - distribution_output (
string
, optional, defaults to"student_t"
) — The distribution emission head for the model when loss is “nll”. Could be either “student_t”, “normal” or “negative_binomial”. - prediction_length (
int
, optional, defaults to 16) — Number of time steps to forecast for a forecasting task. Also known as the Forecast Horizon. - prediction_channel_indices (
list
, optional) — List of channel indices to forecast. If None, forecast all channels. Target data is expected to have all channels and we explicitly filter the channels in prediction and target before loss computation. - num_targets (
int
, optional, defaults to 3) — Number of targets (dimensionality of the regressed variable) for a regression task. - output_range (
list
, optional) — Output range to restrict for the regression task. Defaults to None. - head_aggregation (
str
, optional, defaults to"max_pool"
) — Aggregation mode to enable for classification or regression task. Allowed values areNone
, “use_last”, “max_pool”, “avg_pool”.
This is the configuration class to store the configuration of a PatchTSMixerModel. It is used to instantiate a PatchTSMixer model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the PatchTSMixer ibm/patchtsmixer-etth1-pretrain architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Example:
>>> from transformers import PatchTSMixerConfig, PatchTSMixerModel
>>> # Initializing a default PatchTSMixer configuration
>>> configuration = PatchTSMixerConfig()
>>> # Randomly initializing a model (with random weights) from the configuration
>>> model = PatchTSMixerModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
PatchTSMixerModel
class transformers.PatchTSMixerModel
< source >( config: PatchTSMixerConfig mask_input: bool = False )
Parameters
- config (PatchTSMixerConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
- mask_input (
bool
, optional, defaults toFalse
) — If True, Masking will be enabled. False otherwise.
The PatchTSMixer Model for time-series forecasting.
This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.)
This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( past_values: Tensor observed_mask: Optional = None output_hidden_states: Optional = False return_dict: Optional = None ) → transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerModelOutput
or tuple(torch.FloatTensor)
Parameters
- past_values (
torch.FloatTensor
of shape(batch_size, seq_length, num_input_channels)
) — Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series.For univariate time series,
num_input_channels
dimension should be 1. For multivariate time series, it is greater than 1. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - observed_mask (
torch.FloatTensor
of shape(batch_size, sequence_length, num_input_channels)
, optional) — Boolean mask to indicate whichpast_values
were observed and which were missing. Mask values selected in[0, 1]
:- 1 for values that are observed,
- 0 for values that are missing (i.e. NaNs that were replaced by zeros).
Returns
transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerModelOutput
or tuple(torch.FloatTensor)
A transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerModelOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (PatchTSMixerConfig) and inputs.
- last_hidden_state (
torch.FloatTensor
of shape(batch_size, num_channels, num_patches, d_model)
) — Hidden-state at the output of the last layer of the model. - hidden_states (
tuple(torch.FloatTensor)
, optional) — Hidden-states of the model at the output of each layer. - patch_input (
torch.FloatTensor
of shape(batch_size, num_channels, num_patches, patch_length)
) — Patched input data to the model. - mask: (
torch.FloatTensor
of shape(batch_size, num_channels, num_patches)
,optional) — Bool Tensor indicating True in masked patches and False otherwise. - loc: (
torch.FloatTensor
of shape(batch_size, 1, num_channels)
,optional) — Gives the mean of the context window per channel. Used for revin denorm outside the model, if revin enabled. - scale: (
torch.FloatTensor
of shape(batch_size, 1, num_channels)
,optional) — Gives the std dev of the context window per channel. Used for revin denorm outside the model, if revin enabled.
The PatchTSMixerModel forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
PatchTSMixerForPrediction
class transformers.PatchTSMixerForPrediction
< source >( config: PatchTSMixerConfig )
PatchTSMixer
for forecasting application.
forward
< source >( past_values: Tensor observed_mask: Optional = None future_values: Optional = None output_hidden_states: Optional = False return_loss: bool = True return_dict: Optional = None ) → transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForPredictionOutput
or tuple(torch.FloatTensor)
Parameters
- past_values (
torch.FloatTensor
of shape(batch_size, seq_length, num_input_channels)
) — Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series.For univariate time series,
num_input_channels
dimension should be 1. For multivariate time series, it is greater than 1. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - observed_mask (
torch.FloatTensor
of shape(batch_size, sequence_length, num_input_channels)
, optional) — Boolean mask to indicate whichpast_values
were observed and which were missing. Mask values selected in[0, 1]
:- 1 for values that are observed,
- 0 for values that are missing (i.e. NaNs that were replaced by zeros).
- future_values (
torch.FloatTensor
of shape(batch_size, target_len, num_input_channels)
for forecasting, —(batch_size, num_targets)
for regression, or(batch_size,)
for classification, optional): Target values of the time series, that serve as labels for the model. Thefuture_values
is what the Transformer needs during training to learn to output, given thepast_values
. Note that, this is NOT required for a pretraining task.For a forecasting task, the shape is be
(batch_size, target_len, num_input_channels)
. Even if we want to forecast only specific channels by setting the indices inprediction_channel_indices
parameter, pass the target data with all channels, as channel Filtering for both prediction and target will be manually applied before the loss computation. - return_loss (
bool
, optional) — Whether to return the loss in theforward
call.
Returns
transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForPredictionOutput
or tuple(torch.FloatTensor)
A transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForPredictionOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (PatchTSMixerConfig) and inputs.
- prediction_outputs (
torch.FloatTensor
of shape(batch_size, prediction_length, num_input_channels)
) — Prediction output from the forecast head. - last_hidden_state (
torch.FloatTensor
of shape(batch_size, num_input_channels, num_patches, d_model)
) — Backbone embeddings before passing through the head. - hidden_states (
tuple(torch.FloatTensor)
, optional) — Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - loss (optional, returned when
y
is provided,torch.FloatTensor
of shape()
) — Total loss. - loc (
torch.FloatTensor
, optional of shape(batch_size, 1, num_input_channels)
) — Input mean - scale (
torch.FloatTensor
, optional of shape(batch_size, 1, num_input_channels)
) — Input std dev
The PatchTSMixerForPrediction forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
PatchTSMixerForTimeSeriesClassification
class transformers.PatchTSMixerForTimeSeriesClassification
< source >( config: PatchTSMixerConfig )
PatchTSMixer
for classification application.
forward
< source >( past_values: Tensor target_values: Tensor = None output_hidden_states: Optional = False return_loss: bool = True return_dict: Optional = None ) → transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForTimeSeriesClassificationOutput
or tuple(torch.FloatTensor)
Parameters
- past_values (
torch.FloatTensor
of shape(batch_size, seq_length, num_input_channels)
) — Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series.For univariate time series,
num_input_channels
dimension should be 1. For multivariate time series, it is greater than 1. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - target_values (
torch.FloatTensor
of shape(batch_size, target_len, num_input_channels)
for forecasting, —(batch_size, num_targets)
for regression, or(batch_size,)
for classification, optional): Target values of the time series, that serve as labels for the model. Thetarget_values
is what the Transformer needs during training to learn to output, given thepast_values
. Note that, this is NOT required for a pretraining task.For a forecasting task, the shape is be
(batch_size, target_len, num_input_channels)
. Even if we want to forecast only specific channels by setting the indices inprediction_channel_indices
parameter, pass the target data with all channels, as channel Filtering for both prediction and target will be manually applied before the loss computation.For a classification task, it has a shape of
(batch_size,)
.For a regression task, it has a shape of
(batch_size, num_targets)
. - return_loss (
bool
, optional) — Whether to return the loss in theforward
call.
Returns
transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForTimeSeriesClassificationOutput
or tuple(torch.FloatTensor)
A transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForTimeSeriesClassificationOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (PatchTSMixerConfig) and inputs.
- prediction_outputs (
torch.FloatTensor
of shape(batch_size, num_labels)
) — Prediction output from the classfication head. - last_hidden_state (
torch.FloatTensor
of shape(batch_size, num_input_channels, num_patches, d_model)
) — Backbone embeddings before passing through the head. - hidden_states (
tuple(torch.FloatTensor)
, optional) — Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - loss (optional, returned when
y
is provided,torch.FloatTensor
of shape()
) — Total loss.
The PatchTSMixerForTimeSeriesClassification forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
PatchTSMixerForPretraining
class transformers.PatchTSMixerForPretraining
< source >( config: PatchTSMixerConfig )
PatchTSMixer
for mask pretraining.
forward
< source >( past_values: Tensor observed_mask: Optional = None output_hidden_states: Optional = False return_loss: bool = True return_dict: Optional = None ) → transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForPreTrainingOutput
or tuple(torch.FloatTensor)
Parameters
- past_values (
torch.FloatTensor
of shape(batch_size, seq_length, num_input_channels)
) — Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series.For univariate time series,
num_input_channels
dimension should be 1. For multivariate time series, it is greater than 1. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - observed_mask (
torch.FloatTensor
of shape(batch_size, sequence_length, num_input_channels)
, optional) — Boolean mask to indicate whichpast_values
were observed and which were missing. Mask values selected in[0, 1]
:- 1 for values that are observed,
- 0 for values that are missing (i.e. NaNs that were replaced by zeros).
- return_loss (
bool
, optional) — Whether to return the loss in theforward
call.
Returns
transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForPreTrainingOutput
or tuple(torch.FloatTensor)
A transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForPreTrainingOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (PatchTSMixerConfig) and inputs.
- prediction_outputs (
torch.FloatTensor
of shape(batch_size, num_input_channels, num_patches, patch_length)
) — Prediction output from the pretrain head. - hidden_states (
tuple(torch.FloatTensor)
, optional) — Hidden-states of the model at the output of each layer. - last_hidden_state (
torch.FloatTensor
of shape(batch_size, num_input_channels, num_patches, d_model)
) — Backbone embeddings before passing through the head. - loss (optional, returned when
y
is provided,torch.FloatTensor
of shape()
) — Total loss
The PatchTSMixerForPretraining forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
PatchTSMixerForRegression
class transformers.PatchTSMixerForRegression
< source >( config: PatchTSMixerConfig )
PatchTSMixer
for regression application.
forward
< source >( past_values: Tensor target_values: Tensor = None output_hidden_states: Optional = False return_loss: bool = True return_dict: Optional = None ) → transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForRegressionOutput
or tuple(torch.FloatTensor)
Parameters
- past_values (
torch.FloatTensor
of shape(batch_size, seq_length, num_input_channels)
) — Context values of the time series. For a pretraining task, this denotes the input time series to predict the masked portion. For a forecasting task, this denotes the history/past time series values. Similarly, for classification or regression tasks, it denotes the appropriate context values of the time series.For univariate time series,
num_input_channels
dimension should be 1. For multivariate time series, it is greater than 1. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - target_values (
torch.FloatTensor
of shape(batch_size, target_len, num_input_channels)
for forecasting, —(batch_size, num_targets)
for regression, or(batch_size,)
for classification, optional): Target values of the time series, that serve as labels for the model. Thetarget_values
is what the Transformer needs during training to learn to output, given thepast_values
. Note that, this is NOT required for a pretraining task.For a forecasting task, the shape is be
(batch_size, target_len, num_input_channels)
. Even if we want to forecast only specific channels by setting the indices inprediction_channel_indices
parameter, pass the target data with all channels, as channel Filtering for both prediction and target will be manually applied before the loss computation.For a classification task, it has a shape of
(batch_size,)
.For a regression task, it has a shape of
(batch_size, num_targets)
. - return_loss (
bool
, optional) — Whether to return the loss in theforward
call.
Returns
transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForRegressionOutput
or tuple(torch.FloatTensor)
A transformers.models.patchtsmixer.modeling_patchtsmixer.PatchTSMixerForRegressionOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (PatchTSMixerConfig) and inputs.
- regression_outputs (
torch.FloatTensor
of shape(batch_size, num_targets)
) — Prediction output from the regression head. - last_hidden_state (
torch.FloatTensor
of shape(batch_size, num_input_channels, num_patches, d_model)
) — Backbone embeddings before passing through the head. - hidden_states (
tuple(torch.FloatTensor)
, optional) — Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - loss (optional, returned when
y
is provided,torch.FloatTensor
of shape()
) — Total loss.
The PatchTSMixerForRegression forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.