|
import os |
|
import sys |
|
import uvicorn |
|
from fastapi import FastAPI, Query, HTTPException, BackgroundTasks |
|
from fastapi.responses import HTMLResponse |
|
from starlette.middleware.cors import CORSMiddleware |
|
from datasets import load_dataset, list_datasets |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
from loguru import logger |
|
import concurrent.futures |
|
import psutil |
|
import asyncio |
|
import torch |
|
from tenacity import retry, stop_after_attempt, wait_fixed |
|
from huggingface_hub import HfApi, RepositoryNotFoundError |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
if not HUGGINGFACE_TOKEN: |
|
logger.error("Hugging Face token not found. Please set the HUGGINGFACE_TOKEN environment variable.") |
|
sys.exit(1) |
|
|
|
|
|
datasets_dict = {} |
|
example_usage_list = [] |
|
|
|
|
|
CACHE_DIR = os.path.expanduser("~/.cache/huggingface") |
|
os.makedirs(CACHE_DIR, exist_ok=True) |
|
os.environ["HF_HOME"] = CACHE_DIR |
|
os.environ["HF_TOKEN"] = HUGGINGFACE_TOKEN |
|
|
|
pipeline_instance = None |
|
|
|
|
|
initialization_complete = False |
|
|
|
def initialize_model(): |
|
global pipeline_instance |
|
try: |
|
logger.info("Initializing the base model and tokenizer.") |
|
base_model_repo = "meta-llama/Llama-3.2-1B" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
base_model_repo, |
|
cache_dir=CACHE_DIR, |
|
ignore_mismatched_sizes=True |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(base_model_repo, cache_dir=CACHE_DIR) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
pipeline_instance = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
device=0 if torch.cuda.is_available() else -1 |
|
) |
|
logger.info("Model and tokenizer initialized successfully.") |
|
except Exception as e: |
|
logger.error(f"Error initializing model and tokenizer: {e}", exc_info=True) |
|
sys.exit(1) |
|
|
|
@retry(stop=stop_after_attempt(3), wait=wait_fixed(5)) |
|
def download_dataset(dataset_name): |
|
try: |
|
logger.info(f"Starting download for dataset: {dataset_name}") |
|
|
|
datasets_dict[dataset_name] = load_dataset(dataset_name, cache_dir=CACHE_DIR) |
|
create_example_usage(dataset_name) |
|
except Exception as e: |
|
logger.error(f"Error loading dataset {dataset_name}: {e}", exc_info=True) |
|
raise |
|
|
|
def upload_model_to_hub(): |
|
try: |
|
api = HfApi() |
|
model_repo = "Yhhxhfh/test" |
|
try: |
|
api.repo_info(repo_id=model_repo) |
|
logger.info(f"Model repository {model_repo} already exists.") |
|
except RepositoryNotFoundError: |
|
api.create_repo(repo_id=model_repo, private=False, token=HUGGINGFACE_TOKEN) |
|
logger.info(f"Created model repository {model_repo}.") |
|
logger.info(f"Pushing the model and tokenizer to {model_repo}.") |
|
pipeline_instance.model.push_to_hub(model_repo, use_auth_token=HUGGINGFACE_TOKEN) |
|
pipeline_instance.tokenizer.push_to_hub(model_repo, use_auth_token=HUGGINGFACE_TOKEN) |
|
logger.info(f"Successfully pushed the model and tokenizer to {model_repo}.") |
|
except Exception as e: |
|
logger.error(f"Error uploading model to Hugging Face Hub: {e}", exc_info=True) |
|
|
|
def create_example_usage(dataset_name): |
|
try: |
|
logger.info(f"Creating example usage for dataset {dataset_name}") |
|
example_prompts = [ |
|
"Translate the following catering menu from English to French:", |
|
"Generate a catering menu for a wedding with vegetarian options:", |
|
"Convert the following catering menu to a gluten-free version:", |
|
"Provide a detailed catering menu for a corporate event including desserts:", |
|
"Generate a children's birthday party catering menu with allergen-free items:" |
|
] |
|
examples = [] |
|
for prompt in example_prompts: |
|
generated_text = pipeline_instance(prompt, max_length=50, num_return_sequences=1)[0]['generated_text'] |
|
examples.append({"prompt": prompt, "response": generated_text}) |
|
example_usage_list.append({"dataset_name": dataset_name, "examples": examples}) |
|
logger.info(f"Example usage created for dataset {dataset_name}") |
|
except Exception as e: |
|
logger.error(f"Error creating example usage for dataset {dataset_name}: {e}", exc_info=True) |
|
|
|
def unify_datasets(): |
|
try: |
|
logger.info("Starting to unify datasets") |
|
unified_dataset = None |
|
for dataset in datasets_dict.values(): |
|
if unified_dataset is None: |
|
unified_dataset = dataset |
|
else: |
|
unified_dataset = unified_dataset.concatenate(dataset) |
|
datasets_dict['unified'] = unified_dataset |
|
logger.info("Datasets successfully unified.") |
|
except Exception as e: |
|
logger.error(f"Error unifying datasets: {e}", exc_info=True) |
|
|
|
|
|
cpu_count = psutil.cpu_count(logical=False) or 1 |
|
memory_available_mb = psutil.virtual_memory().available / (1024 * 1024) |
|
memory_per_download_mb = 100 |
|
memory_available = int(memory_available_mb / memory_per_download_mb) |
|
gpu_count = torch.cuda.device_count() |
|
max_concurrent_downloads = min(cpu_count, memory_available, gpu_count * 2 if gpu_count else cpu_count) |
|
max_concurrent_downloads = max(1, max_concurrent_downloads) |
|
max_concurrent_downloads = min(10, max_concurrent_downloads) |
|
|
|
logger.info(f"Using up to {max_concurrent_downloads} concurrent workers for downloading datasets.") |
|
|
|
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_downloads) |
|
|
|
async def download_and_process_datasets(): |
|
global initialization_complete |
|
try: |
|
dataset_names = list_datasets() |
|
logger.info(f"Found {len(dataset_names)} datasets to download.") |
|
loop = asyncio.get_event_loop() |
|
tasks = [] |
|
for dataset_name in dataset_names: |
|
task = loop.run_in_executor(executor, download_dataset, dataset_name) |
|
tasks.append(task) |
|
await asyncio.gather(*tasks) |
|
unify_datasets() |
|
upload_model_to_hub() |
|
initialization_complete = True |
|
logger.info("All initialization tasks completed successfully.") |
|
except Exception as e: |
|
logger.error(f"Error during dataset processing: {e}", exc_info=True) |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
|
|
message_history = [] |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
logger.info("Application startup initiated.") |
|
loop = asyncio.get_event_loop() |
|
|
|
loop.create_task(run_initialization(loop)) |
|
logger.info("Background initialization tasks started.") |
|
|
|
async def run_initialization(loop): |
|
global initialization_complete |
|
try: |
|
|
|
await loop.run_in_executor(None, initialize_model) |
|
|
|
await download_and_process_datasets() |
|
except Exception as e: |
|
logger.error(f"Error during startup tasks: {e}", exc_info=True) |
|
|
|
@app.get('/') |
|
async def index(): |
|
html_code = """ |
|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<!-- Existing head content --> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>ChatGPT Chatbot</title> |
|
<style> |
|
/* Existing styles */ |
|
/* Add styles for the model selector */ |
|
.model-selector { |
|
margin-bottom: 10px; |
|
} |
|
body { |
|
font-family: Arial, sans-serif; |
|
margin: 0; |
|
padding: 0; |
|
background-color: #f4f4f4; |
|
} |
|
.container { |
|
max-width: 800px; |
|
margin: auto; |
|
padding: 20px; |
|
} |
|
.chat-container { |
|
background-color: #fff; |
|
border-radius: 8px; |
|
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); |
|
overflow: hidden; |
|
margin-bottom: 20px; |
|
animation: fadeInUp 0.5s ease forwards; |
|
display: flex; |
|
flex-direction: column; |
|
} |
|
.chat-box { |
|
flex: 1; |
|
overflow-y: auto; |
|
padding: 10px; |
|
} |
|
.chat-input { |
|
width: calc(100% - 20px); |
|
border: none; |
|
border-top: 1px solid #ddd; |
|
padding: 10px; |
|
font-size: 16px; |
|
outline: none; |
|
} |
|
.chat-input:focus { |
|
border-top: 1px solid #007bff; |
|
} |
|
.user-message { |
|
margin-bottom: 10px; |
|
padding: 8px 12px; |
|
border-radius: 8px; |
|
background-color: #007bff; |
|
color: #fff; |
|
max-width: 70%; |
|
word-wrap: break-word; |
|
align-self: flex-end; |
|
} |
|
.bot-message { |
|
margin-bottom: 10px; |
|
padding: 8px 12px; |
|
border-radius: 8px; |
|
background-color: #4CAF50; |
|
color: #fff; |
|
max-width: 70%; |
|
word-wrap: break-word; |
|
} |
|
.toggle-history { |
|
text-align: center; |
|
cursor: pointer; |
|
color: #007bff; |
|
margin-bottom: 10px; |
|
} |
|
.history-container { |
|
display: none; |
|
} |
|
.history-container.show { |
|
display: block; |
|
} |
|
.history-container .history-content { |
|
max-height: 200px; |
|
overflow-y: auto; |
|
} |
|
@keyframes fadeInUp { |
|
from { |
|
opacity: 0; |
|
transform: translateY(20px); |
|
} |
|
to { |
|
opacity: 1; |
|
transform: translateY(0); |
|
} |
|
} |
|
</style> |
|
</head> |
|
<body> |
|
<div class="container"> |
|
<h1 style="text-align: center;">ChatGPT Chatbot</h1> |
|
<div class="chat-container" id="chat-container"> |
|
<div class="chat-box" id="chat-box"> |
|
</div> |
|
<input type="text" class="chat-input" id="user-input" placeholder="Type your message..."> |
|
<button onclick="retryLastMessage()">Retry Last Message</button> |
|
</div> |
|
<div class="toggle-history" onclick="toggleHistory()">Toggle History</div> |
|
<div class="history-container" id="history-container"> |
|
<h2>Chat History</h2> |
|
<div class="history-content" id="history-content"></div> |
|
</div> |
|
</div> |
|
<script> |
|
function toggleHistory() { |
|
const historyContainer = document.getElementById('history-container'); |
|
historyContainer.classList.toggle('show'); |
|
} |
|
|
|
function saveMessage(sender, message) { |
|
const historyContent = document.getElementById('history-content'); |
|
const messageElement = document.createElement('div'); |
|
messageElement.className = `${sender}-message`; |
|
messageElement.innerText = message; |
|
historyContent.appendChild(messageElement); |
|
} |
|
|
|
function appendMessage(sender, message) { |
|
const chatBox = document.getElementById('chat-box'); |
|
const messageElement = document.createElement('div'); |
|
messageElement.className = `${sender}-message`; |
|
messageElement.innerText = message; |
|
chatBox.appendChild(messageElement); |
|
chatBox.scrollTop = chatBox.scrollHeight; |
|
} |
|
|
|
const chatContainer = document.getElementById('chat-container'); |
|
const chatBox = document.getElementById('chat-box'); |
|
const userInput = document.getElementById('user-input'); |
|
|
|
userInput.addEventListener('keyup', function(event) { |
|
if (event.keyCode === 13) { |
|
event.preventDefault(); |
|
sendMessage(); |
|
} |
|
}); |
|
|
|
function sendMessage() { |
|
const userMessage = userInput.value.trim(); |
|
if (userMessage === '') return; |
|
|
|
saveMessage('user', userMessage); |
|
appendMessage('user', userMessage); |
|
userInput.value = ''; |
|
|
|
fetch(`/autocomplete?q=${encodeURIComponent(userMessage)}`) |
|
.then(response => { |
|
if (response.status === 503) { |
|
return response.json().then(data => { throw new Error(data.detail); }); |
|
} |
|
return response.json(); |
|
}) |
|
.then(data => { |
|
const botMessages = data.result; |
|
botMessages.forEach(message => { |
|
saveMessage('bot', message); |
|
appendMessage('bot', message); |
|
}); |
|
}) |
|
.catch(error => { |
|
console.error('Error:', error); |
|
appendMessage('bot', error.message || 'An error occurred. Please try again later.'); |
|
}); |
|
} |
|
|
|
function retryLastMessage() { |
|
const lastUserMessage = document.querySelector('.user-message:last-of-type'); |
|
if (lastUserMessage) { |
|
userInput.value = lastUserMessage.innerText; |
|
sendMessage(); |
|
} |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
return HTMLResponse(content=html_code, status_code=200) |
|
|
|
@app.get('/autocomplete') |
|
async def autocomplete(q: str = Query(..., title='query')): |
|
global message_history, pipeline_instance, initialization_complete |
|
message_history.append(('user', q)) |
|
|
|
if not initialization_complete: |
|
logger.warning("Model is not initialized yet.") |
|
raise HTTPException(status_code=503, detail="Model is not initialized yet. Please try again later.") |
|
|
|
try: |
|
response = pipeline_instance(q, max_length=50, num_return_sequences=1)[0]['generated_text'] |
|
logger.debug(f"Successfully autocomplete, q:{q}, res:{response}") |
|
return {"result": [response]} |
|
except Exception as e: |
|
logger.error(f"Ignored error in autocomplete: {e}", exc_info=True) |
|
raise HTTPException(status_code=500, detail="An error occurred while processing your request.") |
|
|
|
if __name__ == '__main__': |
|
port = int(os.getenv("PORT", 443)) |
|
uvicorn.run(app=app, host='0.0.0.0', port=port) |
|
|