Spaces:
Runtime error
Runtime error
import json | |
import asyncio | |
import logging | |
import time | |
import requests | |
from tqdm.asyncio import tqdm_asyncio | |
from huggingface_hub import get_inference_endpoint | |
from models import env_config, embed_config | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
endpoint = get_inference_endpoint(env_config.tei_name, token=env_config.hf_token) | |
async def embed_chunk(sentence, semaphore, tmp_file): | |
async with semaphore: | |
payload = { | |
"inputs": sentence, | |
"truncate": True | |
} | |
try: | |
resp = await endpoint.async_client.post(json=payload) | |
except Exception as e: | |
raise RuntimeError(str(e)) | |
result = json.loads(resp) | |
tmp_file.write( | |
json.dumps({"vector": result[0], env_config.input_text_col: sentence}) + "\n" | |
) | |
async def embed_wrapper(input_ds, temp_file): | |
semaphore = asyncio.BoundedSemaphore(embed_config.semaphore_bound) | |
jobs = [ | |
asyncio.create_task(embed_chunk(row[env_config.input_text_col], semaphore, temp_file)) | |
for row in input_ds if row[env_config.input_text_col].strip() | |
] | |
logger.info(f"num chunks to embed: {len(jobs)}") | |
tic = time.time() | |
await tqdm_asyncio.gather(*jobs) | |
logger.info(f"embed time: {time.time() - tic}") | |
def wake_up_endpoint(): | |
endpoint.fetch() | |
if endpoint.status != 'running': | |
logger.info("Starting up TEI endpoint") | |
endpoint.resume().wait().fetch() | |
# n_loop = 0 | |
# while requests.get( | |
# url=endpoint.url, | |
# headers={"Authorization": f"Bearer {env_config.hf_token}"} | |
# ).status_code != 200: | |
# time.sleep(2) | |
# n_loop += 1 | |
# if n_loop > 20: | |
# raise TimeoutError("TEI endpoint is unavailable") | |
logger.info("TEI endpoint is up") | |
return | |