|
from fastapi import FastAPI, Request, Response, HTTPException, Query, UploadFile, File |
|
from fastapi.responses import JSONResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import io |
|
|
|
from scripts.db_helper import insert_ban , search_ban |
|
import uuid |
|
from typing import Optional |
|
import numpy as np |
|
import cv2 |
|
import base64 |
|
|
|
from scripts.comfyui import ComfyUI |
|
from scripts.liveportrait import LP |
|
|
|
import os |
|
|
|
ui = None |
|
lp = None |
|
|
|
def generate_response(response_message, response_status ,uuid_code, metadata): |
|
|
|
response_dict = { |
|
"response_message" : response_message, |
|
"response_status" : response_status, |
|
"data" :{ |
|
"UUID" : uuid_code, |
|
"metadata" : metadata |
|
} |
|
} |
|
|
|
return response_dict |
|
|
|
def image_to_base64(image): |
|
_, buffer = cv2.imencode('.png', image) |
|
|
|
image_bytes = buffer.tobytes() |
|
base64_image = base64.b64encode(image_bytes).decode('utf-8') |
|
return base64_image |
|
|
|
def base64_to_image(base64_string): |
|
image_bytes = base64.b64decode(base64_string) |
|
np_array = np.frombuffer(image_bytes, np.uint8) |
|
image = cv2.imdecode(np_array, cv2.IMREAD_COLOR) |
|
return image |
|
|
|
app = FastAPI() |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
print("Application is starting up...") |
|
global ui |
|
ui = ComfyUI() |
|
for i in range(10): |
|
if ui.UI(): |
|
break |
|
global lp |
|
lp = LP(port = ui.port) |
|
|
|
@app.on_event("shutdown") |
|
async def shutdown_event(): |
|
print("Application is shutting down...") |
|
global ui |
|
ui = None |
|
global lp |
|
lp = None |
|
|
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
db_params = { |
|
'dbname': 'API_DB', |
|
'user': 'postgres', |
|
'password': '4b95dfe8-4644-46ce-a4fe-648d6d4860a4', |
|
'host': '44.208.52.100', |
|
'port': '5432' |
|
} |
|
|
|
async def custom_middleware(request: Request, call_next): |
|
|
|
request_api_key = request.headers.get("api_key") |
|
|
|
client_ip = request.client.host |
|
|
|
if search_ban(db_params, client_ip): |
|
return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) |
|
|
|
|
|
|
|
|
|
|
|
response = await call_next(request) |
|
|
|
if (response.status_code == 404) or (response.status_code == 405): |
|
insert_ban(db_params, client_ip, "Access using unapproved method/endpoint") |
|
return JSONResponse(status_code=401, content={"detail": "Unauthorized"}) |
|
|
|
return response |
|
|
|
app.middleware('http')(custom_middleware) |
|
|
|
from scripts.db_helper import get_request_data |
|
from scripts.s3 import upload_to_s3, generate_presigned_url |
|
@app.post("/genrate_video") |
|
async def genrate_video(request_id: uuid.UUID): |
|
|
|
request_id = str(request_id) |
|
|
|
request_data = get_request_data( |
|
db_params=db_params, |
|
request_id=request_id |
|
) |
|
|
|
if request_data is None: |
|
return generate_response('Incorrect Request', False ,'', []) |
|
|
|
print("update request status to processing") |
|
|
|
video = None |
|
|
|
try: |
|
print(1) |
|
lp = LP( port = ui.port) |
|
video = lp.LP(request_id, request_data[0], f'/home/ubuntu/AI/videos/{request_data[1]}.mp4') |
|
print(2) |
|
except: |
|
print(3) |
|
server_url = f"http://127.0.0.1:{ui.port}" |
|
if not ui.is_server_ready(server_url, 5): |
|
ui.restart_if_down() |
|
|
|
|
|
lp = LP( port = ui.port) |
|
print(4) |
|
video = lp.LP(request_id, request_data[0], f'/home/ubuntu/AI/videos/{request_data[1]}.mp4') |
|
|
|
if video is None: |
|
if ui.is_server_ready(server_url, 5): |
|
print("update request status to error") |
|
|
|
bucket = 'app-faceanimate-s3' |
|
local_file_path = f'/home/ubuntu/AI/ComfyUI/output/{video[0]}' |
|
s3_key = f'{request_id}/video.mp4' |
|
upload_to_s3(bucket, local_file_path, s3_key) |
|
|
|
for video_path in video[1]: |
|
os.remove(video_path) |
|
|
|
video_url = generate_presigned_url(bucket, s3_key) |
|
|
|
print("update request status to processed") |
|
|
|
return generate_response('Video Genrated', True ,'', [video_url]) |
|
|
|
if __name__ == '__main__': |
|
|
|
number_of_workers = 2 |
|
|
|
import uvicorn |
|
|
|
from scripts.comfyui import setup_database |
|
setup_database(number_of_workers) |
|
|
|
uvicorn.run("main:app", host='0.0.0.0', port=3000, workers=number_of_workers ) |
|
|