Spaces:
Runtime error
Runtime error
import pickle | |
import tensorflow as tf | |
from tensorflow import keras | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import uvicorn | |
# Load the model | |
model_path = 'doctor_ai_model.h5' | |
model = keras.models.load_model(model_path) | |
# Create FastAPI app | |
app = FastAPI() | |
# Define the request model | |
class InputData(BaseModel): | |
input_data: list | |
# Define the prediction endpoint | |
async def predict(data: InputData): | |
# Prepare input data | |
input_array = tf.convert_to_tensor(data.input_data) | |
# Check if input shape matches the model's input shape | |
expected_shape = (None, 27) | |
if input_array.shape[1] != expected_shape[1]: | |
return {'error': f'Input data must have shape: {expected_shape}'} | |
# Make a prediction | |
prediction = model.predict(tf.expand_dims(input_array, axis=0)) # Expand dims to match batch size | |
predicted_class = tf.argmax(prediction, axis=1).numpy().tolist() | |
return {'predicted_class': predicted_class} | |
# Start the FastAPI server (this will run offline) | |
if __name__ == '__main__': | |
uvicorn.run(app, host='127.0.0.1', port=8000) # Use localhost for offline mode | |