File size: 3,995 Bytes
1953b4d
 
52f607b
1953b4d
bc7ed47
a78ea5d
a5853e1
1953b4d
52f607b
a78ea5d
 
 
 
a5d1c4e
 
fab1a10
a5d1c4e
b735536
1953b4d
 
 
52f607b
 
 
f8d271e
52f607b
1953b4d
f8d271e
bc7ed47
 
 
 
 
 
 
 
 
 
 
52f607b
1953b4d
f8d271e
52f607b
 
 
 
 
 
 
f8d271e
1953b4d
 
 
bc7ed47
1953b4d
f8d271e
1953b4d
 
f8d271e
1953b4d
70ac8d0
 
f8d271e
1953b4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52f607b
1953b4d
 
 
 
 
52f607b
1953b4d
 
 
f8d271e
1953b4d
 
52f607b
a78ea5d
 
 
 
52f607b
f8d271e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from flask import Flask, request, jsonify, Response
import requests
import json
import time
import random
import logging

app = Flask(__name__)

# 配置日志格式
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)

@app.route('/')
def index():
    return "text-to-image with siliconflow", 200

@app.route('/ai/v1/chat/completions', methods=['POST'])
def handle_request():
    try:
        body = request.json
        model = body.get('model')
        messages = body.get('messages')
        stream = body.get('stream', False)

        if not model or not messages or len(messages) == 0:
            return jsonify({"error": "Bad Request: Missing required fields"}), 400

        authorization_header = request.headers.get('Authorization')
        if not authorization_header:
            return jsonify({"error": "Unauthorized: Missing Authorization header"}), 401

        # Extract tokens from Authorization header
        tokens = authorization_header.split(' ')[1].split(',')
        if len(tokens) == 1:
            selected_token = tokens[0]
        else:
            selected_token = random.choice(tokens)

        prompt = messages[-1]['content']
        new_url = f'https://api.siliconflow.cn/v1/{model}/text-to-image'

        new_request_body = {
            "prompt": prompt,
            "image_size": "1024x1024",
            "batch_size": 1,
            "num_inference_steps": 4,
            "guidance_scale": 1
        }

        headers = {
            'accept': 'application/json',
            'content-type': 'application/json',
            'Authorization': f'Bearer {selected_token}'
        }

        response = requests.post(new_url, headers=headers, json=new_request_body)
        response_body = response.json()

        image_url = response_body['images'][0]['url']
        unique_id = str(int(time.time() * 1000))  # Convert id to string
        current_timestamp = int(unique_id) // 1000

        if stream:
            response_payload = {
                "id": unique_id,
                "object": "chat.completion.chunk",
                "created": current_timestamp,
                "model": model,
                "choices": [
                    {
                        "index": 0,
                        "delta": {
                            "content": f"![]({image_url})"
                        },
                        "finish_reason": "stop"
                    }
                ]
            }
            data_string = json.dumps(response_payload)
            return Response(f"data: {data_string}\n\n", content_type='text/event-stream')
        else:
            response_payload = {
                "id": unique_id,
                "object": "chat.completion",
                "created": current_timestamp,
                "model": model,
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": f"![]({image_url})"
                        },
                        "logprobs": None,
                        "finish_reason": "length"
                    }
                ],
                "usage": {
                    "prompt_tokens": len(prompt),
                    "completion_tokens": len(image_url),
                    "total_tokens": len(prompt) + len(image_url)
                }
            }
            data_string = json.dumps(response_payload)
            return Response(f"{data_string}\n\n", content_type='text/event-stream')

    except Exception as e:
        return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500

    finally:
        # 记录请求的 model 和 被命中的 token
        logger.info(f'"POST /ai/v1/chat/completions HTTP/1.1" "model: {model}" "token: {selected_token}" "status: {response.status_code}" -')

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000)