|
import requests |
|
import base64 |
|
import os |
|
import json |
|
import streamlit as st |
|
import whisperx |
|
import torch |
|
from utils import convert_segments_object_to_text |
|
|
|
def get_completion_from_gigachat(prompt, max_tokens, access_token): |
|
url_completion = os.getenv('GIGA_COMPLETION_URL') |
|
|
|
data_copm = json.dumps({ |
|
"model": os.getenv('GIGA_MODEL'), |
|
"messages": [ |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
], |
|
"stream": False, |
|
"max_tokens": max_tokens, |
|
}) |
|
|
|
headers_comp = { |
|
'Content-Type': 'application/json', |
|
'Accept': 'application/json', |
|
'Authorization': 'Bearer ' + access_token |
|
} |
|
|
|
response = requests.post(url_completion, headers=headers_comp, data=data_copm, verify=False) |
|
response_data = response.json() |
|
answer_from_llm = response_data['choices'][0]['message']['content'] |
|
|
|
return answer_from_llm |
|
|
|
|
|
st.title('Audio Transcription App') |
|
st.sidebar.title("Settings") |
|
|
|
device = st.sidebar.selectbox("Device", ["cpu", "cuda"], index=1) |
|
batch_size = st.sidebar.number_input("Batch Size", min_value=1, value=16) |
|
compute_type = st.sidebar.selectbox("Compute Type", ["float16", "int8"], index=0) |
|
|
|
initial_giga_base_prompt = "Напиши резюме транскрибации звонка, текст которого приложен в ниже. Выдели самостоятельно цель встречи, потом описать ключевые моменты всей встречи. Потом выделить отдельные темы звонка и выделить ключевые моменты в них. Напиши итоги того, о чем договорились говорящие, если такое возможно выделить из текста.\nТранскрибация: " |
|
initial_giga_processing_prompt = "Обработай транкрибацию звонка. Убедись, что каждое слово назначено правильному спикеру. Если заметишь, что слово или фраза ошибочно приписаны другому спикеру, исправь это. Постарайся понять имена говорящих из контекста разговора и замени «SPEAKER_00», «SPEAKER_01» и т.д. на их реальные имена. Если чье-то имя понять невозможно, то не меняй его. Приложи в ответе обработанную транскрибацию\nТранскрибация: " |
|
|
|
giga_base_prompt = st.sidebar.text_area("Промпт ГигаЧата для резюмирования", value=initial_giga_base_prompt) |
|
giga_max_tokens = st.sidebar.number_input("Максимальное количество токенов при резюмировании", min_value=1, value=1024) |
|
|
|
enable_summarization = st.sidebar.checkbox("Добавить обработку транскрибации", value=False) |
|
giga_processing_prompt = st.sidebar.text_area("Промпт ГигаЧата для обработки транскрибации", value=initial_giga_processing_prompt) |
|
|
|
ACCESS_TOKEN = st.secrets["HF_TOKEN"] |
|
|
|
uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"]) |
|
|
|
if uploaded_file is not None: |
|
st.audio(uploaded_file) |
|
file_extension = uploaded_file.name.split(".")[-1] |
|
temp_file_path = f"temp_file.{file_extension}" |
|
|
|
with open(temp_file_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
|
|
with st.spinner('Транскрибируем...'): |
|
|
|
model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type) |
|
|
|
audio = whisperx.load_audio(temp_file_path) |
|
result = model.transcribe(audio, batch_size=batch_size, language="ru") |
|
print('Transcribed, now aligning') |
|
|
|
model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) |
|
result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False) |
|
print('Aligned, now diarizing') |
|
|
|
diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device) |
|
diarize_segments = diarize_model(audio) |
|
result_diar = whisperx.assign_word_speakers(diarize_segments, result) |
|
|
|
st.write("Результат транскрибации:") |
|
transcript = convert_segments_object_to_text(result_diar) |
|
st.text(transcript) |
|
|
|
with st.spinner('Обрабатываем транскрибацию...'): |
|
username = st.secrets["GIGA_USERNAME"] |
|
password = st.secrets["GIGA_SECRET"] |
|
|
|
|
|
auth_str = f'{username}:{password}' |
|
auth_bytes = auth_str.encode('utf-8') |
|
auth_base64 = base64.b64encode(auth_bytes).decode('utf-8') |
|
url = os.getenv('GIGA_AUTH_URL') |
|
|
|
headers = { |
|
'Authorization': f'Basic {auth_base64}', |
|
'RqUID': os.getenv('GIGA_rquid'), |
|
'Content-Type': 'application/x-www-form-urlencoded', |
|
'Accept': 'application/json' |
|
} |
|
|
|
data = { |
|
'scope': os.getenv('GIGA_SCOPE') |
|
} |
|
|
|
response = requests.post(url, headers=headers, data=data, verify=False) |
|
access_token = response.json()['access_token'] |
|
print('Got access token') |
|
|
|
transcribe_answer = get_completion_from_gigachat(giga_processing_prompt + transcript, 32768, access_token) |
|
|
|
st.write("Результат обработки:") |
|
st.text(transcribe_answer) |
|
|
|
|
|
|
|
with st.spinner('Резюмируем...'): |
|
summary_answer = get_completion_from_gigachat(giga_base_prompt + transcribe_answer, giga_max_tokens, access_token) |
|
|
|
st.write("Результат резюмирования:") |
|
st.text(summary_answer) |