from functools import partial
from typing import Dict
import gradio as gr
import numpy as np
import plotly.graph_objects as go
from huggingface_hub import from_pretrained_keras
ROOT_DATA_URL = "https://raw.githubusercontent.com/hfawaz/cd-diagram/master/FordA"
TRAIN_DATA_URL = f"{ROOT_DATA_URL}/FordA_TRAIN.tsv"
TEST_DATA_URL = f"{ROOT_DATA_URL}/FordA_TEST.tsv"
TIMESERIES_LEN = 500
CLASSES = {"Symptom does NOT exist", "Symptom exists"}
model = from_pretrained_keras("keras-io/timeseries-classification-from-scratch")
# Read data
def read_data(file_url: str):
data = np.loadtxt(file_url, delimiter="\t")
y = data[:, 0]
x = data[:, 1:]
return x, y.astype(int)
x_train, y_train = read_data(file_url=TRAIN_DATA_URL)
x_test, y_test = read_data(file_url=TEST_DATA_URL)
# Helper functions
def get_prediction(row_index: int, data: np.ndarray) -> Dict[str, float]:
x = data[row_index].reshape((1, TIMESERIES_LEN, 1))
predictions = model.predict(x).flatten()
return {k: float(v) for k, v in zip(CLASSES, predictions)}
def create_plot(row_index: int, dataset_name: str) -> go.Figure:
x = x_train
row = x[row_index]
scatter = go.Scatter(
x=list(range(TIMESERIES_LEN)),
y=row.flatten(),
mode="lines+markers",
)
fig = go.Figure(data=scatter)
fig.update_layout(title=f"Timeseries in row {row_index} of {dataset_name} set ")
return fig
def show_tab_section(data: np.ndarray, dataset_name: str):
num_indexes = data.shape[0]
index = gr.Slider(
maximum=num_indexes - 1,
label="Select the index of the row you want to classify:",
)
button = gr.Button("Predict")
plot = gr.Plot()
create_plot_data = partial(create_plot, dataset_name=dataset_name)
button.click(create_plot_data, inputs=[index], outputs=[plot])
get_prediction_data = partial(get_prediction, data=data)
label = gr.Label()
button.click(get_prediction_data, inputs=[index], outputs=[label])
# Gradio Demo
title = "# Timeseries classification from scratch"
description = """
Select a time series in the Training or Test dataset and ask the model to classify it!
The model was trained on the FordA dataset. Each row is a diagnostic session run on an automotive subsystem. In each session 500 samples were collected. Given a time series, the model was trained to identify if a specific symptom exists or it does not exist.
Model: https://huggingface.co/keras-io/timeseries-classification-from-scratch
Keras Example: https://keras.io/examples/timeseries/timeseries_classification_from_scratch/