GodfreyOwino commited on
Commit
6e498bf
1 Parent(s): b6fa28f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -0
app.py CHANGED
@@ -3,6 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import joblib
5
  import numpy as np
 
6
 
7
  app = FastAPI()
8
  app.add_middleware(
@@ -31,6 +32,9 @@ class InputData(BaseModel):
31
  @app.post("/predict")
32
  async def predict(data: InputData):
33
  try:
 
 
 
34
  input_data = pd.DataFrame({
35
  'crop_name': [data.crop_name],
36
  'target_yield': [data.target_yield],
@@ -45,6 +49,11 @@ async def predict(data: InputData):
45
 
46
  # Use the encoder to transform the crop_name
47
  input_data['crop_name'] = le.transform(input_data['crop_name'])
 
 
 
 
 
48
 
49
  prediction = model.predict(input_data)
50
  return {
@@ -55,6 +64,7 @@ async def predict(data: InputData):
55
  "lime_need": float(prediction[0][4])
56
  }
57
  except Exception as e:
 
58
  raise HTTPException(status_code=500, detail=str(e))
59
 
60
  @app.get("/")
 
3
  from pydantic import BaseModel
4
  import joblib
5
  import numpy as np
6
+ import pandas as pd
7
 
8
  app = FastAPI()
9
  app.add_middleware(
 
32
  @app.post("/predict")
33
  async def predict(data: InputData):
34
  try:
35
+ # Validating crop_name
36
+ if data.crop_name not in le.classes_:
37
+ raise ValueError(f"Invalid crop_name: {data.crop_name}")
38
  input_data = pd.DataFrame({
39
  'crop_name': [data.crop_name],
40
  'target_yield': [data.target_yield],
 
49
 
50
  # Use the encoder to transform the crop_name
51
  input_data['crop_name'] = le.transform(input_data['crop_name'])
52
+ # Validating the input shape
53
+ expected_shape = model.n_features_in_
54
+ if input_data.shape[1] != expected_shape:
55
+ raise ValueError(f"Input shape mismatch. Expected {expected_shape} features, got {input_data.shape[1]}")
56
+
57
 
58
  prediction = model.predict(input_data)
59
  return {
 
64
  "lime_need": float(prediction[0][4])
65
  }
66
  except Exception as e:
67
+ logging.error(f"Error in predict function: {str(e)}")
68
  raise HTTPException(status_code=500, detail=str(e))
69
 
70
  @app.get("/")