from flask import Flask, request, jsonify, Response, stream_with_context import requests import json import time import random import logging import sys import re app = Flask(__name__) # 配置日志 logging.basicConfig( level=logging.DEBUG, format='%(asctime)s [%(levelname)s] %(message)s', handlers=[ logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) SYSTEM_ASSISTANT = """作为 Stable Diffusion Prompt 提示词专家,您将从关键词中创建提示,通常来自 Danbooru 等数据库。 提示通常描述图像,使用常见词汇,按重要性排列,并用逗号分隔。避免使用"-"或".",但可以接受空格和自然语言。避免词汇重复。 为了强调关键词,请将其放在括号中以增加其权重。例如,"(flowers)"将'flowers'的权重增加1.1倍,而"(((flowers)))"将其增加1.331倍。使用"(flowers:1.5)"将'flowers'的权重增加1.5倍。只为重要的标签增加权重。 提示包括三个部分:**前缀** (质量标签+风格词+效果器)+ **主题** (图像的主要焦点)+ **场景** (背景、环境)。 * 前缀影响图像质量。像"masterpiece"、"best quality"、"4k"这样的标签可以提高图像的细节。像"illustration"、"lensflare"这样的风格词定义图像的风格。像"bestlighting"、"lensflare"、"depthoffield"这样的效果器会影响光照和深度。 * 主题是图像的主要焦点,如角色或场景。对主题进行详细描述可以确保图像丰富而详细。增加主题的权重以增强其清晰度。对于角色,描述面部、头发、身体、服装、姿势等特征。 * 场景描述环境。没有场景,图像的背景是平淡的,主题显得过大。某些主题本身包含场景(例如建筑物、风景)。像"花草草地"、"阳光"、"河流"这样的环境词可以丰富场景。你的任务是设计图像生成的提示。请按照以下步骤进行操作: 1. 我会发送给您一个图像场景。需要你生成详细的图像描述 2. 图像描述必须是英文,输出为Positive Prompt。 示例: 我发送:二战时期的护士。 您回复只回复: A WWII-era nurse in a German uniform, holding a wine bottle and stethoscope, sitting at a table in white attire, with a table in the background, masterpiece, best quality, 4k, illustration style, best lighting, depth of field, detailed character, detailed environment. """ RATIO_MAP = { "1:1": "1024x1024", "1:2": "1024x2048", "3:2": "1536x1024", "4:3": "1536x2048", "16:9": "2048x1152", "9:16": "1152x2048" } def get_random_token(auth_header): if not auth_header: return None if auth_header.startswith('Bearer '): auth_header = auth_header[7:] tokens = [token.strip() for token in auth_header.split(',') if token.strip()] if not tokens: return None return f"Bearer {random.choice(tokens)}" def translate_and_enhance_prompt(prompt, auth_token): translate_url = 'https://api.siliconflow.cn/v1/chat/completions' translate_body = { 'model': 'Qwen/Qwen2-72B-Instruct', 'messages': [ {'role': 'system', 'content': SYSTEM_ASSISTANT}, {'role': 'user', 'content': prompt} ] } headers = { 'Content-Type': 'application/json', 'Authorization': auth_token } logger.info(f"Sending request to {translate_url}") logger.info(f"Request headers: {headers}") logger.info(f"Request body: {json.dumps(translate_body, ensure_ascii=False)}") try: response = requests.post(translate_url, headers=headers, json=translate_body, timeout=30) logger.info(f"Response status code: {response.status_code}") logger.info(f"Response content: {response.text}") response.raise_for_status() result = response.json() return result['choices'][0]['message']['content'] except requests.exceptions.RequestException as e: logger.error(f"Error in translate_and_enhance_prompt: {str(e)}") raise def extract_size_from_prompt(prompt): size_match = re.search(r'-s\s+(\S+)', prompt) if size_match: size = size_match.group(1) clean_prompt = re.sub(r'-s\s+\S+', '', prompt).strip() else: size = "16:9" clean_prompt = prompt return RATIO_MAP.get(size, RATIO_MAP["16:9"]), clean_prompt @app.route('/') def index(): return "text-to-image with siliconflow", 200 @app.route('/ai/v1/chat/completions', methods=['POST']) def handle_request(): try: body = request.json model = body.get('model') messages = body.get('messages') stream = body.get('stream', False) if not model or not messages or len(messages) == 0: return jsonify({"error": "Bad Request: Missing required fields"}), 400 prompt = messages[-1]['content'] image_size, clean_prompt = extract_size_from_prompt(prompt) random_token = get_random_token(request.headers.get('Authorization')) if not random_token: return jsonify({"error": "Unauthorized: Invalid or missing Authorization header"}), 401 try: enhanced_prompt = translate_and_enhance_prompt(clean_prompt, random_token) except Exception as e: logger.error(f"Error in translate_and_enhance_prompt: {str(e)}") return jsonify({"error": "Failed to enhance prompt"}), 500 new_url = f'https://api.siliconflow.cn/v1/{model}/text-to-image' new_request_body = { "prompt": enhanced_prompt, "image_size": image_size, "batch_size": 1, "num_inference_steps": 4, "guidance_scale": 1 } headers = { 'accept': 'application/json', 'content-type': 'application/json', 'Authorization': random_token } logger.info(f"Sending request to {new_url}") logger.info(f"Request headers: {headers}") logger.info(f"Request body: {json.dumps(new_request_body, ensure_ascii=False)}") try: response = requests.post(new_url, headers=headers, json=new_request_body, timeout=60) logger.info(f"Response status code: {response.status_code}") logger.info(f"Response content: {response.text}") response.raise_for_status() response_body = response.json() if 'images' in response_body and response_body['images'] and 'url' in response_body['images'][0]: image_url = response_body['images'][0]['url'] logger.info(f"Successfully retrieved image URL: {image_url}") else: logger.error(f"Unexpected response structure: {response_body}") return jsonify({"error": "Unexpected response structure from image generation API"}), 500 except requests.exceptions.RequestException as e: logger.error(f"Error in image generation request: {str(e)}") return jsonify({"error": "Failed to generate image"}), 500 except (KeyError, IndexError, ValueError) as e: logger.error(f"Error parsing image generation response: {str(e)}") return jsonify({"error": "Failed to parse image generation response"}), 500 unique_id = str(int(time.time() * 1000)) # 生成字符串类型的 unique_id current_timestamp = int(time.time()) system_fingerprint = "fp_" + ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=9)) image_data = {'data': [{'url': image_url}]} if stream: return stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint) else: return non_stream_response(unique_id, image_data, clean_prompt, enhanced_prompt, image_size, current_timestamp, model, system_fingerprint) except Exception as e: logger.error(f"Unexpected error in handle_request: {str(e)}") return jsonify({"error": f"Internal Server Error: {str(e)}"}), 500 def stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint): logger.debug("Starting stream response") return Response(stream_with_context(generate_stream(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint)), content_type='text/event-stream') def generate_stream(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint): chunks = [ f"原始提示词:\n{original_prompt}\n", f"翻译后的提示词:\n{translated_prompt}\n", f"图像规格:{size}\n", "正在根据提示词生成图像...\n", "图像正在处理中...\n", "即将完成...\n", f"生成成功!\n图像生成完毕,以下是结果:\n\n![生成的图像]({image_data['data'][0]['url']})" ] for i, chunk in enumerate(chunks): json_chunk = json.dumps({ "id": unique_id, "object": "chat.completion.chunk", "created": created, "model": model, "system_fingerprint": system_fingerprint, "choices": [{ "index": 0, "delta": {"content": chunk}, "logprobs": None, "finish_reason": None }] }) yield f"data: {json_chunk}\n\n" time.sleep(0.5) # 模拟生成时间 final_chunk = json.dumps({ "id": unique_id, "object": "chat.completion.chunk", "created": created, "model": model, "system_fingerprint": system_fingerprint, "choices": [{ "index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop" }] }) yield f"data: {final_chunk}\n\n" def non_stream_response(unique_id, image_data, original_prompt, translated_prompt, size, created, model, system_fingerprint): content = ( f"原始提示词:{original_prompt}\n" f"翻译后的提示词:{translated_prompt}\n" f"图像规格:{size}\n" f"图像生成成功!\n" f"以下是结果:\n\n" f"![生成的图像]({image_data['data'][0]['url']})" ) response = { 'id': unique_id, 'object': "chat.completion", 'created': created, 'model': model, 'system_fingerprint': system_fingerprint, 'choices': [{ 'index': 0, 'message': { 'role': "assistant", 'content': content }, 'finish_reason': "stop" }], 'usage': { 'prompt_tokens': len(original_prompt), 'completion_tokens': len(content), 'total_tokens': len(original_prompt) + len(content) } } return jsonify(response) if __name__ == '__main__': app.run(host='0.0.0.0', port=8000)