Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,41 +1,34 @@
|
|
1 |
import os
|
2 |
import tensorflow as tf
|
3 |
from tensorflow import keras
|
4 |
-
|
5 |
-
from pydantic import BaseModel
|
6 |
-
import uvicorn
|
7 |
-
|
8 |
-
# Suppress TensorFlow logging
|
9 |
-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
10 |
|
11 |
# Load the model
|
12 |
-
model_path = 'doctor_ai_model.h5' #
|
13 |
model = keras.models.load_model(model_path)
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
# Define the request model
|
19 |
-
class InputData(BaseModel):
|
20 |
-
input_data: list
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
input_array = tf.convert_to_tensor(data.input_data)
|
26 |
-
|
27 |
-
# Check input shape
|
28 |
-
if input_array.shape[1] != 27:
|
29 |
-
return {'error': 'Input data must have shape: (None, 27)'}
|
30 |
|
31 |
-
# Expand dimensions
|
32 |
-
|
33 |
|
34 |
-
|
|
|
35 |
predicted_class = tf.argmax(prediction, axis=1).numpy().tolist()
|
36 |
|
37 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
#
|
40 |
-
|
41 |
-
uvicorn.run(app, host='127.0.0.1', port=8000)
|
|
|
1 |
import os
|
2 |
import tensorflow as tf
|
3 |
from tensorflow import keras
|
4 |
+
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Load the model
|
7 |
+
model_path = 'doctor_ai_model.h5' # or .h5
|
8 |
model = keras.models.load_model(model_path)
|
9 |
|
10 |
+
def predict(input_data):
|
11 |
+
# Convert input to tensor
|
12 |
+
input_tensor = tf.convert_to_tensor(input_data)
|
|
|
|
|
|
|
13 |
|
14 |
+
# Ensure the input shape is correct
|
15 |
+
if input_tensor.shape[1] != 27: # Adjust based on your input shape
|
16 |
+
return "Input data must have shape: (None, 27)"
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
# Expand dimensions for the model
|
19 |
+
input_tensor = tf.expand_dims(input_tensor, axis=0)
|
20 |
|
21 |
+
# Make prediction
|
22 |
+
prediction = model.predict(input_tensor)
|
23 |
predicted_class = tf.argmax(prediction, axis=1).numpy().tolist()
|
24 |
|
25 |
+
return predicted_class
|
26 |
+
|
27 |
+
# Create Gradio interface
|
28 |
+
iface = gr.Interface(fn=predict,
|
29 |
+
inputs=gr.inputs.Dataframe(type='numpy',
|
30 |
+
label='Input Data (should have shape (None, 27))'),
|
31 |
+
outputs='json')
|
32 |
|
33 |
+
# Launch the Gradio app
|
34 |
+
iface.launch()
|
|