Spaces:
Runtime error
Runtime error
import os | |
import tensorflow as tf | |
from tensorflow import keras | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import uvicorn | |
# Suppress TensorFlow logging | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
# Load the model | |
model_path = 'doctor_ai_model' | |
model = keras.models.load_model(model_path) | |
# Create FastAPI app | |
app = FastAPI() | |
# Define the request model | |
class InputData(BaseModel): | |
input_data: list | |
# Define the prediction endpoint | |
async def predict(data: InputData): | |
input_array = tf.convert_to_tensor(data.input_data) | |
# Check input shape | |
if input_array.shape[1] != 27: | |
return {'error': 'Input data must have shape: (None, 27)'} | |
# Expand dimensions if necessary | |
input_array = tf.expand_dims(input_array, axis=0) | |
prediction = model.predict(input_array) | |
predicted_class = tf.argmax(prediction, axis=1).numpy().tolist() | |
return {'predicted_class': predicted_class} | |
# Start the FastAPI server | |
if __name__ == '__main__': | |
uvicorn.run(app, host='127.0.0.1', port=8000) | |