Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import sys
|
3 |
import uvicorn
|
4 |
-
from fastapi import FastAPI, Query
|
5 |
from fastapi.responses import HTMLResponse
|
6 |
from starlette.middleware.cors import CORSMiddleware
|
7 |
from datasets import load_dataset, list_datasets
|
@@ -12,19 +12,23 @@ import psutil
|
|
12 |
import asyncio
|
13 |
import torch
|
14 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
15 |
-
from huggingface_hub import HfApi
|
16 |
from dotenv import load_dotenv
|
17 |
|
|
|
18 |
load_dotenv()
|
19 |
|
|
|
20 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
21 |
if not HUGGINGFACE_TOKEN:
|
22 |
logger.error("Hugging Face token not found. Please set the HUGGINGFACE_TOKEN environment variable.")
|
23 |
sys.exit(1)
|
24 |
|
|
|
25 |
datasets_dict = {}
|
26 |
example_usage_list = []
|
27 |
|
|
|
28 |
CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
|
29 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
30 |
os.environ["HF_HOME"] = CACHE_DIR
|
@@ -32,6 +36,9 @@ os.environ["HF_TOKEN"] = HUGGINGFACE_TOKEN
|
|
32 |
|
33 |
pipeline_instance = None # Solo un pipeline
|
34 |
|
|
|
|
|
|
|
35 |
def initialize_model():
|
36 |
global pipeline_instance
|
37 |
try:
|
@@ -40,7 +47,7 @@ def initialize_model():
|
|
40 |
model = AutoModelForCausalLM.from_pretrained(
|
41 |
base_model_repo,
|
42 |
cache_dir=CACHE_DIR,
|
43 |
-
ignore_mismatched_sizes=True #
|
44 |
)
|
45 |
tokenizer = AutoTokenizer.from_pretrained(base_model_repo, cache_dir=CACHE_DIR)
|
46 |
if tokenizer.pad_token is None:
|
@@ -60,7 +67,8 @@ def initialize_model():
|
|
60 |
def download_dataset(dataset_name):
|
61 |
try:
|
62 |
logger.info(f"Starting download for dataset: {dataset_name}")
|
63 |
-
|
|
|
64 |
create_example_usage(dataset_name)
|
65 |
except Exception as e:
|
66 |
logger.error(f"Error loading dataset {dataset_name}: {e}", exc_info=True)
|
@@ -116,6 +124,7 @@ def unify_datasets():
|
|
116 |
except Exception as e:
|
117 |
logger.error(f"Error unifying datasets: {e}", exc_info=True)
|
118 |
|
|
|
119 |
cpu_count = psutil.cpu_count(logical=False) or 1
|
120 |
memory_available_mb = psutil.virtual_memory().available / (1024 * 1024)
|
121 |
memory_per_download_mb = 100
|
@@ -130,24 +139,30 @@ logger.info(f"Using up to {max_concurrent_downloads} concurrent workers for down
|
|
130 |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_downloads)
|
131 |
|
132 |
async def download_and_process_datasets():
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
145 |
|
|
|
146 |
app = FastAPI()
|
147 |
|
|
|
148 |
app.add_middleware(
|
149 |
CORSMiddleware,
|
150 |
-
allow_origins=["*"],
|
151 |
allow_credentials=True,
|
152 |
allow_methods=["*"],
|
153 |
allow_headers=["*"]
|
@@ -159,9 +174,19 @@ message_history = []
|
|
159 |
async def startup_event():
|
160 |
logger.info("Application startup initiated.")
|
161 |
loop = asyncio.get_event_loop()
|
162 |
-
|
163 |
-
|
164 |
-
logger.info("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
@app.get('/')
|
167 |
async def index():
|
@@ -321,7 +346,12 @@ async def index():
|
|
321 |
userInput.value = '';
|
322 |
|
323 |
fetch(`/autocomplete?q=${encodeURIComponent(userMessage)}`)
|
324 |
-
.then(response =>
|
|
|
|
|
|
|
|
|
|
|
325 |
.then(data => {
|
326 |
const botMessages = data.result;
|
327 |
botMessages.forEach(message => {
|
@@ -331,6 +361,7 @@ async def index():
|
|
331 |
})
|
332 |
.catch(error => {
|
333 |
console.error('Error:', error);
|
|
|
334 |
});
|
335 |
}
|
336 |
|
@@ -349,15 +380,20 @@ async def index():
|
|
349 |
|
350 |
@app.get('/autocomplete')
|
351 |
async def autocomplete(q: str = Query(..., title='query')):
|
352 |
-
global message_history
|
353 |
message_history.append(('user', q))
|
|
|
|
|
|
|
|
|
|
|
354 |
try:
|
355 |
response = pipeline_instance(q, max_length=50, num_return_sequences=1)[0]['generated_text']
|
356 |
logger.debug(f"Successfully autocomplete, q:{q}, res:{response}")
|
357 |
return {"result": [response]}
|
358 |
except Exception as e:
|
359 |
logger.error(f"Ignored error in autocomplete: {e}", exc_info=True)
|
360 |
-
|
361 |
|
362 |
if __name__ == '__main__':
|
363 |
port = int(os.getenv("PORT", 443))
|
|
|
1 |
import os
|
2 |
import sys
|
3 |
import uvicorn
|
4 |
+
from fastapi import FastAPI, Query, HTTPException, BackgroundTasks
|
5 |
from fastapi.responses import HTMLResponse
|
6 |
from starlette.middleware.cors import CORSMiddleware
|
7 |
from datasets import load_dataset, list_datasets
|
|
|
12 |
import asyncio
|
13 |
import torch
|
14 |
from tenacity import retry, stop_after_attempt, wait_fixed
|
15 |
+
from huggingface_hub import HfApi, RepositoryNotFoundError
|
16 |
from dotenv import load_dotenv
|
17 |
|
18 |
+
# Cargar variables de entorno
|
19 |
load_dotenv()
|
20 |
|
21 |
+
# Obtener el token de Hugging Face
|
22 |
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
|
23 |
if not HUGGINGFACE_TOKEN:
|
24 |
logger.error("Hugging Face token not found. Please set the HUGGINGFACE_TOKEN environment variable.")
|
25 |
sys.exit(1)
|
26 |
|
27 |
+
# Inicializar diccionarios para datasets y ejemplos
|
28 |
datasets_dict = {}
|
29 |
example_usage_list = []
|
30 |
|
31 |
+
# Configuración de caché
|
32 |
CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
|
33 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
34 |
os.environ["HF_HOME"] = CACHE_DIR
|
|
|
36 |
|
37 |
pipeline_instance = None # Solo un pipeline
|
38 |
|
39 |
+
# Flag para indicar si la inicialización está completa
|
40 |
+
initialization_complete = False
|
41 |
+
|
42 |
def initialize_model():
|
43 |
global pipeline_instance
|
44 |
try:
|
|
|
47 |
model = AutoModelForCausalLM.from_pretrained(
|
48 |
base_model_repo,
|
49 |
cache_dir=CACHE_DIR,
|
50 |
+
ignore_mismatched_sizes=True # Ignorar discrepancias de tamaño
|
51 |
)
|
52 |
tokenizer = AutoTokenizer.from_pretrained(base_model_repo, cache_dir=CACHE_DIR)
|
53 |
if tokenizer.pad_token is None:
|
|
|
67 |
def download_dataset(dataset_name):
|
68 |
try:
|
69 |
logger.info(f"Starting download for dataset: {dataset_name}")
|
70 |
+
# Eliminado 'trust_remote_code=True' para evitar el error con ParquetConfig
|
71 |
+
datasets_dict[dataset_name] = load_dataset(dataset_name, cache_dir=CACHE_DIR)
|
72 |
create_example_usage(dataset_name)
|
73 |
except Exception as e:
|
74 |
logger.error(f"Error loading dataset {dataset_name}: {e}", exc_info=True)
|
|
|
124 |
except Exception as e:
|
125 |
logger.error(f"Error unifying datasets: {e}", exc_info=True)
|
126 |
|
127 |
+
# Configuración de concurrencia
|
128 |
cpu_count = psutil.cpu_count(logical=False) or 1
|
129 |
memory_available_mb = psutil.virtual_memory().available / (1024 * 1024)
|
130 |
memory_per_download_mb = 100
|
|
|
139 |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_downloads)
|
140 |
|
141 |
async def download_and_process_datasets():
|
142 |
+
global initialization_complete
|
143 |
+
try:
|
144 |
+
dataset_names = list_datasets()
|
145 |
+
logger.info(f"Found {len(dataset_names)} datasets to download.")
|
146 |
+
loop = asyncio.get_event_loop()
|
147 |
+
tasks = []
|
148 |
+
for dataset_name in dataset_names:
|
149 |
+
task = loop.run_in_executor(executor, download_dataset, dataset_name)
|
150 |
+
tasks.append(task)
|
151 |
+
await asyncio.gather(*tasks)
|
152 |
+
unify_datasets()
|
153 |
+
upload_model_to_hub()
|
154 |
+
initialization_complete = True
|
155 |
+
logger.info("All initialization tasks completed successfully.")
|
156 |
+
except Exception as e:
|
157 |
+
logger.error(f"Error during dataset processing: {e}", exc_info=True)
|
158 |
|
159 |
+
# Inicializar FastAPI
|
160 |
app = FastAPI()
|
161 |
|
162 |
+
# Configuración de CORS
|
163 |
app.add_middleware(
|
164 |
CORSMiddleware,
|
165 |
+
allow_origins=["*"], # Cambia esto según tus necesidades
|
166 |
allow_credentials=True,
|
167 |
allow_methods=["*"],
|
168 |
allow_headers=["*"]
|
|
|
174 |
async def startup_event():
|
175 |
logger.info("Application startup initiated.")
|
176 |
loop = asyncio.get_event_loop()
|
177 |
+
# Crear una tarea en segundo plano para inicializar el modelo y descargar datasets
|
178 |
+
loop.create_task(run_initialization(loop))
|
179 |
+
logger.info("Background initialization tasks started.")
|
180 |
+
|
181 |
+
async def run_initialization(loop):
|
182 |
+
global initialization_complete
|
183 |
+
try:
|
184 |
+
# Inicializar el modelo en un hilo separado
|
185 |
+
await loop.run_in_executor(None, initialize_model)
|
186 |
+
# Descargar y procesar datasets
|
187 |
+
await download_and_process_datasets()
|
188 |
+
except Exception as e:
|
189 |
+
logger.error(f"Error during startup tasks: {e}", exc_info=True)
|
190 |
|
191 |
@app.get('/')
|
192 |
async def index():
|
|
|
346 |
userInput.value = '';
|
347 |
|
348 |
fetch(`/autocomplete?q=${encodeURIComponent(userMessage)}`)
|
349 |
+
.then(response => {
|
350 |
+
if (response.status === 503) {
|
351 |
+
return response.json().then(data => { throw new Error(data.detail); });
|
352 |
+
}
|
353 |
+
return response.json();
|
354 |
+
})
|
355 |
.then(data => {
|
356 |
const botMessages = data.result;
|
357 |
botMessages.forEach(message => {
|
|
|
361 |
})
|
362 |
.catch(error => {
|
363 |
console.error('Error:', error);
|
364 |
+
appendMessage('bot', error.message || 'An error occurred. Please try again later.');
|
365 |
});
|
366 |
}
|
367 |
|
|
|
380 |
|
381 |
@app.get('/autocomplete')
|
382 |
async def autocomplete(q: str = Query(..., title='query')):
|
383 |
+
global message_history, pipeline_instance, initialization_complete
|
384 |
message_history.append(('user', q))
|
385 |
+
|
386 |
+
if not initialization_complete:
|
387 |
+
logger.warning("Model is not initialized yet.")
|
388 |
+
raise HTTPException(status_code=503, detail="Model is not initialized yet. Please try again later.")
|
389 |
+
|
390 |
try:
|
391 |
response = pipeline_instance(q, max_length=50, num_return_sequences=1)[0]['generated_text']
|
392 |
logger.debug(f"Successfully autocomplete, q:{q}, res:{response}")
|
393 |
return {"result": [response]}
|
394 |
except Exception as e:
|
395 |
logger.error(f"Ignored error in autocomplete: {e}", exc_info=True)
|
396 |
+
raise HTTPException(status_code=500, detail="An error occurred while processing your request.")
|
397 |
|
398 |
if __name__ == '__main__':
|
399 |
port = int(os.getenv("PORT", 443))
|