File size: 2,382 Bytes
07f39f7
 
 
 
 
 
 
 
 
 
068e8af
 
 
07f39f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI, UploadFile
from moviepy.editor import *
from transformers import AutoTokenizer , AutoModelForSeq2SeqLM , pipeline
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from transformers import WhisperFeatureExtractor, WhisperTokenizer
import librosa
import numpy as np
import torch 



app = FastAPI()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device  

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Swahili", task="transcribe")
processor = WhisperProcessor.from_pretrained("Jayem-11/whisper-small-swahili-3")
asr_model = WhisperForConditionalGeneration.from_pretrained('Jayem-11/whisper-small-swahili-3')
forced_decoder_ids = processor.get_decoder_prompt_ids(language="sw", task="transcribe")

t5_tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
summary_model = (AutoModelForSeq2SeqLM.from_pretrained("Jayem-11/mt5-summarize-sw")) 


@app.get("/")
async def read_root():
    return {"Successful"}



def extract_and_resample_audio(file):

    with open('vid.mp4', 'wb') as f:
    
        f.write(file)
    
    video = VideoFileClip("vid.mp4")

    # Extract audio from the video
    audio = video.audio

    # Save the audio to a temporary file
    audio.write_audiofile("temp_audio.wav")

    # Load the temporary audio file
    audio_data, sr = librosa.load("temp_audio.wav")

    # Resample the audio to 16000Hz
    audio_resampled = librosa.resample(audio_data, orig_sr = sr, target_sr=16000)
    print("Done resampling")

    return audio_resampled

@app.post("/predict")
async def predict(file: UploadFile):
    audio_resampled = extract_and_resample_audio(await file.read())


    input_feats = feature_extractor(audio_resampled, sampling_rate = 16000).input_features[0]


    input_feats = np.expand_dims(input_feats, axis=0)


    input_feats = torch.from_numpy(input_feats)


    output = asr_model.generate(input_features=input_feats.to(device),max_new_tokens=255,).cpu().numpy()


    sample_text = tokenizer.batch_decode(output, skip_special_tokens=True)


    summarizer = pipeline("summarization", model=summary_model, tokenizer=t5_tokenizer)

    summary = summarizer(
        sample_text,
        max_length=215, 
    )

    return {'summary': summary}