|
import os |
|
import asyncio |
|
import uuid |
|
import json |
|
|
|
from fastapi import FastAPI, HTTPException, Request |
|
from fastapi.responses import StreamingResponse |
|
from socketio import AsyncClient |
|
|
|
|
|
API_KEY = os.environ.get("PPLX_KEY") |
|
PPLX_COOKIE = os.environ.get("PPLX_COOKIE") |
|
USER_AGENT = os.environ.get("USER_AGENT") |
|
|
|
PROXY_URL = os.environ.get("PROXY_URL") |
|
|
|
app = FastAPI() |
|
|
|
|
|
async def validate_api_key(request: Request, call_next): |
|
api_key = request.headers.get("x-api-key") |
|
if api_key != API_KEY: |
|
log_request(request, 401) |
|
raise HTTPException(status_code=401, detail="Invalid API key") |
|
response = await call_next(request) |
|
return response |
|
|
|
app.middleware("http")(validate_api_key) |
|
|
|
|
|
|
|
def log_request(request: Request, status_code: int): |
|
timestamp = datetime.datetime.now().isoformat() |
|
ip = request.client.host |
|
route = request.url.path |
|
print(f"{timestamp} - {ip} - {route} - {status_code}") |
|
|
|
|
|
|
|
@app.get("/") |
|
async def root(request: Request): |
|
log_request(request, 200) |
|
return { |
|
"message": "Welcome to the Perplexity AI Proxy API", |
|
"endpoints": { |
|
"/ai/v1/messages": { |
|
"method": "POST", |
|
"description": "Send a message to the AI", |
|
"headers": { |
|
"x-api-key": "Your API key (required)", |
|
"Content-Type": "application/json", |
|
}, |
|
"body": { |
|
"messages": "Array of message objects", |
|
"stream": "Boolean (true for streaming response)", |
|
|
|
}, |
|
} |
|
}, |
|
} |
|
|
|
|
|
|
|
@app.post("/ai/v1/messages") |
|
async def handle_ai_message(request: Request): |
|
try: |
|
json_body = await request.json() |
|
if not json_body.get("stream"): |
|
log_request(request, 200) |
|
return { |
|
"id": str(uuid.uuid4()), |
|
"content": [ |
|
{"text": "Please turn on streaming."}, |
|
{"id": "string", "name": "string", "input": {}}, |
|
], |
|
"model": "string", |
|
"stop_reason": "end_turn", |
|
"stop_sequence": "string", |
|
"usage": {"input_tokens": 0, "output_tokens": 0}, |
|
} |
|
elif json_body.get("stream"): |
|
async def event_stream(json_body): |
|
|
|
user_message = [{"question": "", "answer": ""}] |
|
last_update = True |
|
if json_body.get("system"): |
|
|
|
json_body["messages"].insert(0, {"role": "system", "content": json_body.get("system")}) |
|
for msg in json_body.get("messages", []): |
|
if msg["role"] in ("system", "user"): |
|
if last_update: |
|
user_message[-1]["question"] += msg["content"] + "\n" |
|
elif not user_message[-1]["question"]: |
|
user_message[-1]["question"] += msg["content"] + "\n" |
|
else: |
|
user_message.append({"question": msg["content"] + "\n", "answer": ""}) |
|
last_update = True |
|
elif msg["role"] == "assistant": |
|
if not last_update: |
|
user_message[-1]["answer"] += msg["content"] + "\n" |
|
elif not user_message[-1]["answer"]: |
|
user_message[-1]["answer"] += msg["content"] + "\n" |
|
else: |
|
user_message.append({"question": "", "answer": msg["content"] + "\n"}) |
|
last_update = False |
|
|
|
|
|
previous_messages = "\n\n".join([msg["content"] for msg in json_body.get("messages", [])]) |
|
msgid = str(uuid.uuid4()) |
|
|
|
yield create_event( |
|
"message_start", |
|
{ |
|
"type": "message_start", |
|
"message": { |
|
"id": msgid, |
|
"type": "message", |
|
"role": "assistant", |
|
"content": [], |
|
"model": "claude-3-opus-20240229", |
|
"stop_reason": None, |
|
"stop_sequence": None, |
|
"usage": {"input_tokens": 8, "output_tokens": 1}, |
|
}, |
|
}, |
|
) |
|
yield create_event( |
|
"content_block_start", |
|
{"type": "content_block_start", "index": 0, "content_block": {"type": "text", "text": ""}}, |
|
) |
|
yield create_event("ping", {"type": "ping"}) |
|
|
|
|
|
engineio_kwargs = {} |
|
if PROXY_URL: |
|
engineio_kwargs = {"http_proxy": PROXY_URL, "https_proxy": PROXY_URL} |
|
|
|
|
|
async with AsyncClient(logger=True, engineio_logger=True, **engineio_kwargs) as socket: |
|
try: |
|
await socket.connect( |
|
"https://www.perplexity.ai/", |
|
headers={ |
|
"Cookie": PPLX_COOKIE, |
|
"User-Agent": USER_AGENT, |
|
"Accept": "*/*", |
|
"priority": "u=1, i", |
|
"Referer": "https://www.perplexity.ai/", |
|
}, |
|
transports=["websocket"], |
|
) |
|
print(" > [Connected]") |
|
await socket.emit( |
|
"perplexity_ask", |
|
previous_messages, |
|
{ |
|
"version": "2.9", |
|
"source": "default", |
|
"attachments": [], |
|
"language": "en-GB", |
|
"timezone": "Europe/London", |
|
"search_focus": "writing", |
|
"frontend_uuid": str(uuid.uuid4()), |
|
"mode": "concise", |
|
"is_related_query": False, |
|
"is_default_related_query": False, |
|
"visitor_id": str(uuid.uuid4()), |
|
"frontend_context_uuid": str(uuid.uuid4()), |
|
"prompt_source": "user", |
|
"query_source": "home", |
|
}, |
|
) |
|
response = await socket.wait() |
|
print(response) |
|
yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0}) |
|
yield create_event( |
|
"message_delta", |
|
{ |
|
"type": "message_delta", |
|
"delta": {"stop_reason": "end_turn", "stop_sequence": None}, |
|
"usage": {"output_tokens": 12}, |
|
}, |
|
) |
|
yield create_event("message_stop", {"type": "message_stop"}) |
|
log_request(request, 200) |
|
except Exception as e: |
|
print(e) |
|
log_request(request, 500) |
|
finally: |
|
await socket.disconnect() |
|
|
|
@socket.on("query_progress") |
|
async def on_query_progress(data): |
|
if data.get("text"): |
|
text = json.loads(data["text"]) |
|
chunk = text["chunks"][-1] if text.get("chunks") else None |
|
if chunk: |
|
yield create_event( |
|
"content_block_delta", |
|
{ |
|
"type": "content_block_delta", |
|
"index": 0, |
|
"delta": {"type": "text_delta", "text": chunk}, |
|
}, |
|
) |
|
|
|
return StreamingResponse(event_stream(json_body), media_type="text/event-stream;charset=utf-8") |
|
else: |
|
raise HTTPException(status_code=400, detail="Invalid request") |
|
except Exception as e: |
|
print(e) |
|
log_request(request, 400) |
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
|
|
|
|
|
def create_event(event: str, data: dict): |
|
|
|
if isinstance(data, dict): |
|
data = json.dumps(data) |
|
return f"event: {event}\ndata: {data}\n\n" |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
if not API_KEY: |
|
print("Warning: PPLX_KEY environment variable is not set. API key validation will fail.") |
|
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 8081))) |
|
|
|
|