File size: 3,830 Bytes
1953b4d
 
52f607b
1953b4d
d3d52a6
a5853e1
1953b4d
52f607b
a5d1c4e
 
 
 
 
1953b4d
 
 
52f607b
 
 
 
 
1953b4d
52f607b
 
1953b4d
 
52f607b
 
 
 
 
 
 
 
d3d52a6
 
 
 
 
 
 
 
 
 
 
 
 
 
1953b4d
 
 
d3d52a6
1953b4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52f607b
1953b4d
 
 
 
 
52f607b
1953b4d
 
 
29dfbd3
1953b4d
 
52f607b
 
1953b4d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from flask import Flask, request, jsonify, Response
import requests
import json
import time
import random

app = Flask(__name__)

@app.route('/')
def index():
    return "text-to-image with siliconflow", 200

@app.route('/ai/v1/chat/completions', methods=['POST'])
def handle_request():
    try:
        body = request.json
        model = body.get('model')
        messages = body.get('messages')
        stream = body.get('stream', False)

        if not model or not messages or len(messages) == 0:
            return jsonify({"error": "Bad Request: Missing required fields"}), 400

        prompt = messages[-1]['content']
        new_url = f'https://api.siliconflow.cn/v1/{model}/text-to-image'

        new_request_body = {
            "prompt": prompt,
            "image_size": "1024x1024",
            "batch_size": 1,
            "num_inference_steps": 4,
            "guidance_scale": 1
        }

        # 从传入的 Authorization 头中随机选择一个 token
        authorization_header = request.headers.get('Authorization')
        if authorization_header:
            # 去掉 "Bearer " 前缀并分割 token
            tokens = authorization_header.replace("Bearer ", "").split(',')
            if len(tokens) > 1:
                selected_token = random.choice(tokens).strip()
            else:
                selected_token = tokens[0].strip()
            # 重新格式化为 "Bearer 随机选择的token"
            selected_token = f"Bearer {selected_token}"
        else:
            return jsonify({"error": "Unauthorized: Missing Authorization header"}), 401

        headers = {
            'accept': 'application/json',
            'content-type': 'application/json',
            'Authorization': selected_token
        }

        response = requests.post(new_url, headers=headers, json=new_request_body)
        response_body = response.json()

        image_url = response_body['images'][0]['url']
        unique_id = int(time.time() * 1000)
        current_timestamp = unique_id // 1000

        if stream:
            response_payload = {
                "id": unique_id,
                "object": "chat.completion.chunk",
                "created": current_timestamp,
                "model": model,
                "choices": [
                    {
                        "index": 0,
                        "delta": {
                            "content": f"![]({image_url})"
                        },
                        "finish_reason": "stop"
                    }
                ]
            }
            data_string = json.dumps(response_payload)
            return Response(f"data: {data_string}\n\n", content_type='text/event-stream')
        else:
            response_payload = {
                "id": unique_id,
                "object": "chat.completion",
                "created": current_timestamp,
                "model": model,
                "choices": [
                    {
                        "index": 0,
                        "message": {
                            "role": "assistant",
                            "content": f"![]({image_url})"
                        },
                        "logprobs": None,
                        "finish_reason": "length"
                    }
                ],
                "usage": {
                    "prompt_tokens": len(prompt),
                    "completion_tokens": len(image_url),
                    "total_tokens": len(prompt) + len(image_url)
                }
            }
            data_string = json.dumps(response_payload)
            return Response(f"{data_string}\n\n", content_type='text/event-stream')

    except Exception as e:
        return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=8000)