|
import requests |
|
import base64 |
|
import os |
|
import json |
|
import streamlit as st |
|
import whisperx |
|
import torch |
|
from utils import convert_segments_object_to_text |
|
|
|
st.title('Audio Transcription App') |
|
st.sidebar.title("Settings") |
|
|
|
device = st.sidebar.selectbox("Device", ["cpu", "cuda"], index=1) |
|
batch_size = st.sidebar.number_input("Batch Size", min_value=1, value=16) |
|
compute_type = st.sidebar.selectbox("Compute Type", ["float16", "int8"], index=0) |
|
|
|
ACCESS_TOKEN = st.secrets["HF_TOKEN"] |
|
|
|
uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"]) |
|
|
|
if uploaded_file is not None: |
|
st.audio(uploaded_file) |
|
file_extension = uploaded_file.name.split(".")[-1] |
|
temp_file_path = f"temp_file.{file_extension}" |
|
|
|
with open(temp_file_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
|
|
with st.spinner('Транскрибируем...'): |
|
|
|
model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type) |
|
|
|
audio = whisperx.load_audio(temp_file_path) |
|
result = model.transcribe(audio, batch_size=batch_size, language="ru") |
|
print('Transcribed, now aligning') |
|
|
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) |
|
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) |
|
print('Aligned, now diarizing') |
|
|
|
diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device) |
|
diarize_segments = diarize_model(audio) |
|
result_diar = whisperx.assign_word_speakers(diarize_segments, result) |
|
|
|
st.write("Результат транскрибации:") |
|
transcript = convert_segments_object_to_text(result_diar) |
|
st.text(transcript) |
|
|
|
with st.spinner('Резюмируем...'): |
|
username = st.secrets["GIGA_USERNAME"] |
|
password = st.secrets["GIGA_SECRET"] |
|
|
|
|
|
auth_str = f'{username}:{password}' |
|
auth_bytes = auth_str.encode('utf-8') |
|
auth_base64 = base64.b64encode(auth_bytes).decode('utf-8') |
|
url = os.getenv('GIGA_AUTH_URL') |
|
|
|
headers = { |
|
'Authorization': f'Basic {auth_base64}', |
|
'RqUID': os.getenv('GIGA_rquid'), |
|
'Content-Type': 'application/x-www-form-urlencoded', |
|
'Accept': 'application/json' |
|
} |
|
|
|
data = { |
|
'scope': os.getenv('GIGA_SCOPE') |
|
} |
|
|
|
response = requests.post(url, headers=headers, data=data, verify=False) |
|
access_token = response.json()['access_token'] |
|
print('Got access token') |
|
|
|
url_completion = os.getenv('GIGA_COMPLETION_URL') |
|
|
|
data_copm = json.dumps({ |
|
"model": os.getenv('GIGA_MODEL'), |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": os.getenv('GIGA_BASE_PROMPT') + transcript |
|
} |
|
], |
|
"stream": False, |
|
"max_tokens": int(os.getenv('GIGA_MAX_TOKENS')), |
|
}) |
|
|
|
headers_comp = { |
|
'Content-Type': 'application/json', |
|
'Accept': 'application/json', |
|
'Authorization': 'Bearer ' + access_token |
|
} |
|
|
|
response = requests.post(url_completion, headers=headers_comp, data=data_copm, verify=False) |
|
response_data = response.json() |
|
answer_from_llm = response_data['choices'][0]['message']['content'] |
|
|
|
st.write("Результат резюмирования:") |
|
st.text(answer_from_llm) |