ksvmuralidhar's picture
Update api.py
dcf4c1e verified
import re
import os
from transformers import (BartTokenizerFast,
TFAutoModelForSeq2SeqLM)
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response, Request
from typing import List
from pydantic import BaseModel, Field
from fastapi.exceptions import RequestValidationError
import uvicorn
import json
import logging
import multiprocessing
os.environ['TF_USE_LEGACY_KERAS'] = "1"
SUMM_CHECKPOINT = "facebook/bart-base"
SUMM_INPUT_N_TOKENS = 400
SUMM_TARGET_N_TOKENS = 300
def load_summarizer_models():
summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
logging.warning('Loaded summarizer models')
return summ_tokenizer, summ_model
def summ_preprocess(txt):
txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
txt = txt.replace('PUBLISHED:', ' ')
txt = txt.replace('UPDATED', ' ')
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = txt.replace(' : ', ' ')
txt = txt.replace('(CNN)', ' ')
txt = txt.replace('--', ' ')
txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = re.sub(r'\n+',' ', txt)
txt = " ".join(txt.split())
return txt
async def summ_inference_tokenize(input_: list, n_tokens: int):
tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
return summ_tokenizer, tokenized_data
async def summ_inference(txts: str):
logging.warning("Entering summ_inference()")
txts = [*map(summ_preprocess, txts)]
inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
return result
async def scrape_urls(urls):
logging.warning('Entering scrape_urls()')
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
results = []
for url in urls:
f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
results.append(f) # appending result to results
scraped_texts = []
scrape_errors = []
for f in results:
t, e = f.get(timeout=120)
scraped_texts.append(t)
scrape_errors.append(e)
pool.close()
pool.join()
logging.warning('Exiting scrape_urls()')
return scraped_texts, scrape_errors
description = "API to generate summaries of news articles from their URLs."
app = FastAPI(title='News Summarizer API',
description=description,
version="0.0.1",
contact={
"name": "Author: KSV Muralidhar",
"url": "https://ksvmuralidhar.in"
},
license_info={
"name": "License: MIT",
"identifier": "MIT"
},
swagger_ui_parameters={"defaultModelsExpandDepth": -1})
summ_tokenizer, summ_model = load_summarizer_models()
class URLList(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles to generate summaries")
key: str = Field(..., description="Authentication Key")
class SuccessfulResponse(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
summaries: List[str] = Field(..., description="List of generated summaries of news articles")
summarizer_error: str = Field("", description="Empty string as the response code is 200")
class AuthenticationError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: str = Field("", description="Empty string as authentication failed")
scrape_errors: str = Field("", description="Empty string as authentication failed")
summaries: str = Field("", description="Empty string as authentication failed")
summarizer_error: str = Field("Error: Authentication error: Invalid API key.")
class SummaryError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: List[str] = Field(..., description="List of scraped text from input URLs")
scrape_errors: List[str] = Field(..., description="List of errors raised during scraping. One item for corresponding URL")
summaries: str = Field("", description="Empty string as summarizer encountered an error")
summarizer_error: str = Field("Error: Summarizer Error with a message describing the error")
class InputValidationError(BaseModel):
urls: List[str] = Field(..., description="List of URLs of news articles inputted by the user")
scraped_texts: str = Field("", description="Empty string as validation failed")
scrape_errors: str = Field("", description="Empty string as validation failed")
summaries: str = Field("", description="Empty string as validation failed")
summarizer_error: str = Field("Validation Error with a message describing the error")
class NewsSummarizerAPIAuthenticationError(Exception):
pass
class NewsSummarizerAPIScrapingError(Exception):
pass
def authenticate_key(api_key: str):
if api_key != os.getenv('API_KEY'):
raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
urls = request.query_params.getlist("urls")
error_details = exc.errors()
error_messages = []
for error in error_details:
loc = [*map(str, error['loc'])][-1]
msg = error['msg']
error_messages.append(f"{loc}: {msg}")
error_message = "; ".join(error_messages) if error_messages else ""
response_json = {'urls': urls, 'scraped_texts': '', 'scrape_errors': '', 'summaries': "", 'summarizer_error': f'Validation Error: {error_message}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=422)
@app.post("/generate_summary/", tags=["Generate Summary"], response_model=List[SuccessfulResponse],
responses={
401: {"model": AuthenticationError, "description": "Authentication Error: Returned when the entered API key is incorrect"},
500: {"model": SummaryError, "description": "Summarizer Error: Returned when the API couldn't generate the summary of even a single article"},
422: {"model": InputValidationError, "description": "Validation Error: Returned when the payload data doesn't match the data type requirements"}
})
async def generate_summary(q: URLList):
"""
Get summaries of news articles by passing the list of URLs as input.
- **urls**: List of URLs (required)
- **key**: Authentication key (required)
"""
try:
logging.warning("Entering generate_summary()")
urls = ""
scraped_texts = ""
scrape_errors = ""
summaries = ""
request_json = q.json()
request_json = json.loads(request_json)
urls = request_json['urls']
api_key = request_json['key']
_ = authenticate_key(api_key)
scraped_texts, scrape_errors = await scrape_urls(urls)
unique_scraped_texts = [*set(scraped_texts)]
if (unique_scraped_texts[0] == "") and (len(unique_scraped_texts) == 1):
raise NewsSummarizerAPIScrapingError("Scrape Error: Couldn't scrape text from any of the URLs")
summaries = await summ_inference(scraped_texts)
status_code = 200
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
except Exception as e:
status_code = 500
if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
status_code = 401
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'Error: {e}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=status_code)
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=7860, workers=3)