Yhhxhfh commited on
Commit
3487d09
1 Parent(s): f60c02f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -23
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import sys
3
  import uvicorn
4
- from fastapi import FastAPI, Query
5
  from fastapi.responses import HTMLResponse
6
  from starlette.middleware.cors import CORSMiddleware
7
  from datasets import load_dataset, list_datasets
@@ -12,19 +12,23 @@ import psutil
12
  import asyncio
13
  import torch
14
  from tenacity import retry, stop_after_attempt, wait_fixed
15
- from huggingface_hub import HfApi
16
  from dotenv import load_dotenv
17
 
 
18
  load_dotenv()
19
 
 
20
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
21
  if not HUGGINGFACE_TOKEN:
22
  logger.error("Hugging Face token not found. Please set the HUGGINGFACE_TOKEN environment variable.")
23
  sys.exit(1)
24
 
 
25
  datasets_dict = {}
26
  example_usage_list = []
27
 
 
28
  CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
29
  os.makedirs(CACHE_DIR, exist_ok=True)
30
  os.environ["HF_HOME"] = CACHE_DIR
@@ -32,6 +36,9 @@ os.environ["HF_TOKEN"] = HUGGINGFACE_TOKEN
32
 
33
  pipeline_instance = None # Solo un pipeline
34
 
 
 
 
35
  def initialize_model():
36
  global pipeline_instance
37
  try:
@@ -40,7 +47,7 @@ def initialize_model():
40
  model = AutoModelForCausalLM.from_pretrained(
41
  base_model_repo,
42
  cache_dir=CACHE_DIR,
43
- ignore_mismatched_sizes=True # Añadir este parámetro
44
  )
45
  tokenizer = AutoTokenizer.from_pretrained(base_model_repo, cache_dir=CACHE_DIR)
46
  if tokenizer.pad_token is None:
@@ -60,7 +67,8 @@ def initialize_model():
60
  def download_dataset(dataset_name):
61
  try:
62
  logger.info(f"Starting download for dataset: {dataset_name}")
63
- datasets_dict[dataset_name] = load_dataset(dataset_name, trust_remote_code=True, cache_dir=CACHE_DIR)
 
64
  create_example_usage(dataset_name)
65
  except Exception as e:
66
  logger.error(f"Error loading dataset {dataset_name}: {e}", exc_info=True)
@@ -116,6 +124,7 @@ def unify_datasets():
116
  except Exception as e:
117
  logger.error(f"Error unifying datasets: {e}", exc_info=True)
118
 
 
119
  cpu_count = psutil.cpu_count(logical=False) or 1
120
  memory_available_mb = psutil.virtual_memory().available / (1024 * 1024)
121
  memory_per_download_mb = 100
@@ -130,24 +139,30 @@ logger.info(f"Using up to {max_concurrent_downloads} concurrent workers for down
130
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_downloads)
131
 
132
  async def download_and_process_datasets():
133
- dataset_names = list_datasets()
134
- logger.info(f"Found {len(dataset_names)} datasets to download.")
135
- loop = asyncio.get_event_loop()
136
- tasks = []
137
- for dataset_name in dataset_names:
138
- task = loop.run_in_executor(executor, download_dataset, dataset_name)
139
- tasks.append(task)
140
- await asyncio.gather(*tasks)
141
- unify_datasets()
142
- upload_model_to_hub()
143
-
144
- # Elimina la llamada a asyncio.run(main()) y mueve la inicialización al evento de inicio de FastAPI
 
 
 
 
145
 
 
146
  app = FastAPI()
147
 
 
148
  app.add_middleware(
149
  CORSMiddleware,
150
- allow_origins=["*"],
151
  allow_credentials=True,
152
  allow_methods=["*"],
153
  allow_headers=["*"]
@@ -159,9 +174,19 @@ message_history = []
159
  async def startup_event():
160
  logger.info("Application startup initiated.")
161
  loop = asyncio.get_event_loop()
162
- await loop.run_in_executor(None, initialize_model)
163
- await download_and_process_datasets()
164
- logger.info("Application startup completed.")
 
 
 
 
 
 
 
 
 
 
165
 
166
  @app.get('/')
167
  async def index():
@@ -321,7 +346,12 @@ async def index():
321
  userInput.value = '';
322
 
323
  fetch(`/autocomplete?q=${encodeURIComponent(userMessage)}`)
324
- .then(response => response.json())
 
 
 
 
 
325
  .then(data => {
326
  const botMessages = data.result;
327
  botMessages.forEach(message => {
@@ -331,6 +361,7 @@ async def index():
331
  })
332
  .catch(error => {
333
  console.error('Error:', error);
 
334
  });
335
  }
336
 
@@ -349,15 +380,20 @@ async def index():
349
 
350
  @app.get('/autocomplete')
351
  async def autocomplete(q: str = Query(..., title='query')):
352
- global message_history
353
  message_history.append(('user', q))
 
 
 
 
 
354
  try:
355
  response = pipeline_instance(q, max_length=50, num_return_sequences=1)[0]['generated_text']
356
  logger.debug(f"Successfully autocomplete, q:{q}, res:{response}")
357
  return {"result": [response]}
358
  except Exception as e:
359
  logger.error(f"Ignored error in autocomplete: {e}", exc_info=True)
360
- return {"result": []}
361
 
362
  if __name__ == '__main__':
363
  port = int(os.getenv("PORT", 443))
 
1
  import os
2
  import sys
3
  import uvicorn
4
+ from fastapi import FastAPI, Query, HTTPException, BackgroundTasks
5
  from fastapi.responses import HTMLResponse
6
  from starlette.middleware.cors import CORSMiddleware
7
  from datasets import load_dataset, list_datasets
 
12
  import asyncio
13
  import torch
14
  from tenacity import retry, stop_after_attempt, wait_fixed
15
+ from huggingface_hub import HfApi, RepositoryNotFoundError
16
  from dotenv import load_dotenv
17
 
18
+ # Cargar variables de entorno
19
  load_dotenv()
20
 
21
+ # Obtener el token de Hugging Face
22
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
23
  if not HUGGINGFACE_TOKEN:
24
  logger.error("Hugging Face token not found. Please set the HUGGINGFACE_TOKEN environment variable.")
25
  sys.exit(1)
26
 
27
+ # Inicializar diccionarios para datasets y ejemplos
28
  datasets_dict = {}
29
  example_usage_list = []
30
 
31
+ # Configuración de caché
32
  CACHE_DIR = os.path.expanduser("~/.cache/huggingface")
33
  os.makedirs(CACHE_DIR, exist_ok=True)
34
  os.environ["HF_HOME"] = CACHE_DIR
 
36
 
37
  pipeline_instance = None # Solo un pipeline
38
 
39
+ # Flag para indicar si la inicialización está completa
40
+ initialization_complete = False
41
+
42
  def initialize_model():
43
  global pipeline_instance
44
  try:
 
47
  model = AutoModelForCausalLM.from_pretrained(
48
  base_model_repo,
49
  cache_dir=CACHE_DIR,
50
+ ignore_mismatched_sizes=True # Ignorar discrepancias de tamaño
51
  )
52
  tokenizer = AutoTokenizer.from_pretrained(base_model_repo, cache_dir=CACHE_DIR)
53
  if tokenizer.pad_token is None:
 
67
  def download_dataset(dataset_name):
68
  try:
69
  logger.info(f"Starting download for dataset: {dataset_name}")
70
+ # Eliminado 'trust_remote_code=True' para evitar el error con ParquetConfig
71
+ datasets_dict[dataset_name] = load_dataset(dataset_name, cache_dir=CACHE_DIR)
72
  create_example_usage(dataset_name)
73
  except Exception as e:
74
  logger.error(f"Error loading dataset {dataset_name}: {e}", exc_info=True)
 
124
  except Exception as e:
125
  logger.error(f"Error unifying datasets: {e}", exc_info=True)
126
 
127
+ # Configuración de concurrencia
128
  cpu_count = psutil.cpu_count(logical=False) or 1
129
  memory_available_mb = psutil.virtual_memory().available / (1024 * 1024)
130
  memory_per_download_mb = 100
 
139
  executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_downloads)
140
 
141
  async def download_and_process_datasets():
142
+ global initialization_complete
143
+ try:
144
+ dataset_names = list_datasets()
145
+ logger.info(f"Found {len(dataset_names)} datasets to download.")
146
+ loop = asyncio.get_event_loop()
147
+ tasks = []
148
+ for dataset_name in dataset_names:
149
+ task = loop.run_in_executor(executor, download_dataset, dataset_name)
150
+ tasks.append(task)
151
+ await asyncio.gather(*tasks)
152
+ unify_datasets()
153
+ upload_model_to_hub()
154
+ initialization_complete = True
155
+ logger.info("All initialization tasks completed successfully.")
156
+ except Exception as e:
157
+ logger.error(f"Error during dataset processing: {e}", exc_info=True)
158
 
159
+ # Inicializar FastAPI
160
  app = FastAPI()
161
 
162
+ # Configuración de CORS
163
  app.add_middleware(
164
  CORSMiddleware,
165
+ allow_origins=["*"], # Cambia esto según tus necesidades
166
  allow_credentials=True,
167
  allow_methods=["*"],
168
  allow_headers=["*"]
 
174
  async def startup_event():
175
  logger.info("Application startup initiated.")
176
  loop = asyncio.get_event_loop()
177
+ # Crear una tarea en segundo plano para inicializar el modelo y descargar datasets
178
+ loop.create_task(run_initialization(loop))
179
+ logger.info("Background initialization tasks started.")
180
+
181
+ async def run_initialization(loop):
182
+ global initialization_complete
183
+ try:
184
+ # Inicializar el modelo en un hilo separado
185
+ await loop.run_in_executor(None, initialize_model)
186
+ # Descargar y procesar datasets
187
+ await download_and_process_datasets()
188
+ except Exception as e:
189
+ logger.error(f"Error during startup tasks: {e}", exc_info=True)
190
 
191
  @app.get('/')
192
  async def index():
 
346
  userInput.value = '';
347
 
348
  fetch(`/autocomplete?q=${encodeURIComponent(userMessage)}`)
349
+ .then(response => {
350
+ if (response.status === 503) {
351
+ return response.json().then(data => { throw new Error(data.detail); });
352
+ }
353
+ return response.json();
354
+ })
355
  .then(data => {
356
  const botMessages = data.result;
357
  botMessages.forEach(message => {
 
361
  })
362
  .catch(error => {
363
  console.error('Error:', error);
364
+ appendMessage('bot', error.message || 'An error occurred. Please try again later.');
365
  });
366
  }
367
 
 
380
 
381
  @app.get('/autocomplete')
382
  async def autocomplete(q: str = Query(..., title='query')):
383
+ global message_history, pipeline_instance, initialization_complete
384
  message_history.append(('user', q))
385
+
386
+ if not initialization_complete:
387
+ logger.warning("Model is not initialized yet.")
388
+ raise HTTPException(status_code=503, detail="Model is not initialized yet. Please try again later.")
389
+
390
  try:
391
  response = pipeline_instance(q, max_length=50, num_return_sequences=1)[0]['generated_text']
392
  logger.debug(f"Successfully autocomplete, q:{q}, res:{response}")
393
  return {"result": [response]}
394
  except Exception as e:
395
  logger.error(f"Ignored error in autocomplete: {e}", exc_info=True)
396
+ raise HTTPException(status_code=500, detail="An error occurred while processing your request.")
397
 
398
  if __name__ == '__main__':
399
  port = int(os.getenv("PORT", 443))