swahili / main.py
Jayem-11
main_code
07f39f7
raw
history blame
2.38 kB
from fastapi import FastAPI, UploadFile
from moviepy.editor import *
from transformers import AutoTokenizer , AutoModelForSeq2SeqLM , pipeline
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers import WhisperFeatureExtractor, WhisperTokenizer
import librosa
import numpy as np
import torch
app = FastAPI()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Swahili", task="transcribe")
processor = WhisperProcessor.from_pretrained("Jayem-11/whisper-small-swahili-3")
asr_model = WhisperForConditionalGeneration.from_pretrained('Jayem-11/whisper-small-swahili-3')
forced_decoder_ids = processor.get_decoder_prompt_ids(language="sw", task="transcribe")
t5_tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
summary_model = (AutoModelForSeq2SeqLM.from_pretrained("Jayem-11/mt5-summarize-sw"))
@app.get("/")
async def read_root():
return {"Successful"}
def extract_and_resample_audio(file):
with open('vid.mp4', 'wb') as f:
f.write(file)
video = VideoFileClip("vid.mp4")
# Extract audio from the video
audio = video.audio
# Save the audio to a temporary file
audio.write_audiofile("temp_audio.wav")
# Load the temporary audio file
audio_data, sr = librosa.load("temp_audio.wav")
# Resample the audio to 16000Hz
audio_resampled = librosa.resample(audio_data, orig_sr = sr, target_sr=16000)
print("Done resampling")
return audio_resampled
@app.post("/predict")
async def predict(file: UploadFile):
audio_resampled = extract_and_resample_audio(await file.read())
input_feats = feature_extractor(audio_resampled, sampling_rate = 16000).input_features[0]
input_feats = np.expand_dims(input_feats, axis=0)
input_feats = torch.from_numpy(input_feats)
output = asr_model.generate(input_features=input_feats.to(device),max_new_tokens=255,).cpu().numpy()
sample_text = tokenizer.batch_decode(output, skip_special_tokens=True)
summarizer = pipeline("summarization", model=summary_model, tokenizer=t5_tokenizer)
summary = summarizer(
sample_text,
max_length=215,
)
return {'summary': summary}