Spaces:
Build error
Build error
import os | |
from typing import Literal, Optional, Set | |
import gradio as gr | |
from fastapi import Request | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
class GradioWebhookApp: | |
""" | |
```py | |
from gradio_webhooks import GradioWebhookApp | |
app = GradioWebhookApp() | |
@app.add_webhook("/test_webhook") | |
async def hello(): | |
return {"in_gradio": True} | |
app.ready() | |
``` | |
""" | |
def __init__( | |
self, | |
ui: gr.Blocks, | |
webhook_secret: Optional[str] = None, | |
) -> None: | |
# Launch gradio app: | |
# - as non-blocking so that webhooks can be added afterwards | |
# - as shared if launch locally (to receive webhooks) | |
app, _, _ = ui.launch(prevent_thread_lock=True, share=not ui.is_space) | |
self.gradio_app = ui | |
self.fastapi_app = app | |
self.webhook_paths: Set[str] = set() | |
# Add auth middleware to check the "X-Webhook-Secret" header | |
self._webhook_secret = webhook_secret or os.getenv("WEBHOOK_SECRET") | |
if self._webhook_secret is None: | |
print( | |
"\nWebhook secret is not defined. This means your webhook endpoints will be open to everyone." | |
) | |
print( | |
"To add a secret, set `WEBHOOK_SECRET` as environment variable or pass it at initialization: " | |
"\n\t`app = GradioWebhookApp(webhook_secret='my_secret', ...)`" | |
) | |
print( | |
"For more details about Webhook secrets, please refer to https://huggingface.co/docs/hub/webhooks#webhook-secret." | |
) | |
else: | |
print("\nWebhook secret is correctly defined.") | |
app.middleware("http")(self._webhook_secret_middleware) | |
def add_webhook(self, path: str): | |
"""Decorator to add a webhook to the server app.""" | |
self.webhook_paths.add(path) | |
return self.fastapi_app.post(path) | |
def ready(self) -> None: | |
"""Set the app as "ready" and block main thread to keep it running.""" | |
url = ( | |
self.gradio_app.share_url | |
if self.gradio_app.share_url is not None | |
else self.gradio_app.local_url | |
).strip("/") | |
print("\nWebhooks are correctly setup and ready to use:") | |
print("\n".join(f" - POST {url}{webhook}" for webhook in self.webhook_paths)) | |
print("Go to https://huggingface.co/settings/webhooks to setup your webhooks.") | |
self.gradio_app.block_thread() | |
async def _webhook_secret_middleware(self, request: Request, call_next) -> None: | |
"""Middleware to check "X-Webhook-Secret" header on every webhook request.""" | |
if request.url.path in self.webhook_paths: | |
if self._webhook_secret is not None: | |
request_secret = request.headers.get("x-webhook-secret") | |
if request_secret is None: | |
return JSONResponse( | |
{"error": "x-webhook-secret header not set."}, status_code=401 | |
) | |
if request_secret != self._webhook_secret: | |
return JSONResponse( | |
{"error": "Invalid webhook secret."}, status_code=403 | |
) | |
return await call_next(request) | |
class WebhookPayloadEvent(BaseModel): | |
action: Literal["create", "update", "delete"] | |
scope: str | |
class WebhookPayloadRepo(BaseModel): | |
type: Literal["dataset", "model", "space"] | |
name: str | |
private: bool | |
class WebhookPayloadDiscussion(BaseModel): | |
num: int | |
isPullRequest: bool | |
status: Literal["open", "closed", "merged"] | |
class WebhookPayload(BaseModel): | |
event: WebhookPayloadEvent | |
repo: WebhookPayloadRepo | |
discussion: Optional[WebhookPayloadDiscussion] | |