oscarwang2's picture
Update app.py
dbc69f1 verified
import gradio as gr
import yfinance as yf
import plotly.graph_objects as go
from statsmodels.tsa.arima.model import ARIMA
import pandas as pd
import logging
logging.basicConfig(level=logging.INFO)
def fetch_eth_price(period):
eth = yf.Ticker("ETH-USD")
if period == '1d':
data = eth.history(period="1d", interval="1m")
predict_steps = 60 # Next 60 minutes
freq = 'min' # Minute frequency
elif period == '5d':
data = eth.history(period="5d", interval="15m")
predict_steps = 96 # Next 24 hours
freq = '15min' # 15 minutes frequency
elif period == '1wk':
data = eth.history(period="1wk", interval="30m")
predict_steps = 336 # Next 7 days
freq = '30min' # 30 minutes frequency
elif period == '1mo':
data = eth.history(period="1mo", interval="1h")
predict_steps = 720 # Next 30 days
freq = 'H' # Hourly frequency
else:
return None, None, None
data.index = pd.DatetimeIndex(data.index)
data = data.asfreq(freq) # Ensure the data has a consistent frequency
# Limit the data to the last 200 points to reduce prediction time
data = data[-200:]
return data, predict_steps, freq
def make_predictions(data, predict_steps, freq):
if data is None or data.empty:
logging.error("No data available for prediction.")
return pd.DataFrame(index=pd.date_range(start=pd.Timestamp.now(), periods=predict_steps+1, freq=freq)[1:])
logging.info(f"Starting model training with {len(data)} data points...")
model = ARIMA(data['Close'], order=(5, 1, 0))
model_fit = model.fit()
logging.info("Model training completed.")
forecast = model_fit.forecast(steps=predict_steps)
future_dates = pd.date_range(start=data.index[-1], periods=predict_steps+1, freq=freq, inclusive='right')
forecast_df = pd.DataFrame(forecast, index=future_dates[1:], columns=['Prediction'])
logging.info("Predictions generated successfully.")
return forecast_df
def plot_eth(period):
data, predict_steps, freq = fetch_eth_price(period)
forecast_df = make_predictions(data, predict_steps, freq)
fig = go.Figure()
fig.add_trace(go.Scatter(x=data.index, y=data['Close'], mode='lines', name='ETH Price'))
fig.add_trace(go.Scatter(x=forecast_df.index, y=forecast_df['Prediction'], mode='lines', name='Prediction', line=dict(dash='dash', color='orange')))
fig.update_layout(title=f"ETH Price and Predictions ({period})", xaxis_title="Date", yaxis_title="Price (USD)")
logging.info("Plotting completed.")
return fig
def refresh_predictions(period):
return plot_eth(period)
with gr.Blocks() as iface:
period = gr.Radio(["1d", "5d", "1wk", "1mo"], label="Select Period")
plot = gr.Plot()
refresh_button = gr.Button("Refresh Predictions and Prices")
period.change(fn=plot_eth, inputs=period, outputs=plot)
refresh_button.click(fn=refresh_predictions, inputs=period, outputs=plot)
iface.launch()