Sabbah13 commited on
Commit
09b358f
1 Parent(s): fa7c624

added openai

Browse files
Files changed (1) hide show
  1. app.py +43 -24
app.py CHANGED
@@ -4,6 +4,7 @@ import whisperx
4
  import torch
5
  from utils import convert_segments_object_to_text, check_password
6
  from gigiachat_requests import get_access_token, get_completion_from_gigachat, get_number_of_tokens
 
7
 
8
  if check_password():
9
  st.title('Audio Transcription App')
@@ -13,54 +14,69 @@ if check_password():
13
  batch_size = int(os.getenv('BATCH_SIZE'))
14
  compute_type = os.getenv('COMPUTE_TYPE')
15
 
16
- initial_giga_base_prompt = os.getenv('GIGA_BASE_PROMPT')
17
- initial_giga_processing_prompt = os.getenv('GIGA_PROCCESS_PROMPT')
18
 
19
- giga_base_prompt = st.sidebar.text_area("Промпт для резюмирования", value=initial_giga_base_prompt)
20
- giga_max_tokens = st.sidebar.number_input("Максимальное количество токенов при резюмировании", min_value=1, value=1024)
 
21
 
22
  enable_summarization = st.sidebar.checkbox("Добавить обработку транскрибации", value=False)
23
- giga_processing_prompt = st.sidebar.text_area("Промпт для обработки транскрибации", value=initial_giga_processing_prompt)
24
 
25
  ACCESS_TOKEN = st.secrets["HF_TOKEN"]
26
 
27
  uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"])
28
 
29
  if uploaded_file is not None:
 
 
 
 
 
 
30
  st.audio(uploaded_file)
31
  file_extension = uploaded_file.name.split(".")[-1] # Получаем расширение файла
32
  temp_file_path = f"temp_file.{file_extension}" # Создаем временное имя файла с правильным расширением
33
 
34
  with open(temp_file_path, "wb") as f:
35
  f.write(uploaded_file.getbuffer())
 
 
36
 
37
- with st.spinner('Транскрибируем...'):
38
- # Load model
39
- model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type)
40
- # Load and transcribe audio
41
- audio = whisperx.load_audio(temp_file_path)
42
- result = model.transcribe(audio, batch_size=batch_size, language="ru")
43
- print('Transcribed, now aligning')
44
 
45
- model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
46
- result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
47
- print('Aligned, now diarizing')
48
 
49
- diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device)
50
- diarize_segments = diarize_model(audio)
51
- result_diar = whisperx.assign_word_speakers(diarize_segments, result)
52
 
53
- st.write("Результат транскрибации:")
54
- transcript = convert_segments_object_to_text(result_diar)
 
 
 
55
  st.text(transcript)
56
 
57
  access_token = get_access_token()
58
 
59
  if (enable_summarization):
60
  with st.spinner('Обрабатываем транскрибацию...'):
61
- number_of_tokens = get_number_of_tokens(transcript, access_token)
62
- print('Количество токенов в транскрибации: ' + str(number_of_tokens))
63
- transcript = get_completion_from_gigachat(giga_processing_prompt + transcript, number_of_tokens + 500, access_token)
 
 
 
64
 
65
  st.write("Результат обработки:")
66
  st.text(transcript)
@@ -68,7 +84,10 @@ if check_password():
68
 
69
 
70
  with st.spinner('Резюмируем...'):
71
- summary_answer = get_completion_from_gigachat(giga_base_prompt + transcript, giga_max_tokens, access_token)
 
 
 
72
 
73
  st.write("Результат резюмирования:")
74
  st.text(summary_answer)
 
4
  import torch
5
  from utils import convert_segments_object_to_text, check_password
6
  from gigiachat_requests import get_access_token, get_completion_from_gigachat, get_number_of_tokens
7
+ from openai_requests import get_completion_from_openai
8
 
9
  if check_password():
10
  st.title('Audio Transcription App')
 
14
  batch_size = int(os.getenv('BATCH_SIZE'))
15
  compute_type = os.getenv('COMPUTE_TYPE')
16
 
17
+ initial_base_prompt = os.getenv('BASE_PROMPT')
18
+ initial_processing_prompt = os.getenv('PROCCESS_PROMPT')
19
 
20
+ llm = st.sidebar.selectbox("LLM", ["GigaChat", "Chat GPT"], index=0)
21
+ base_prompt = st.sidebar.text_area("Промпт для резюмирования", value=initial_base_prompt)
22
+ max_tokens_summary = st.sidebar.number_input("Максимальное количество токенов при резюмировании", min_value=1, value=1024)
23
 
24
  enable_summarization = st.sidebar.checkbox("Добавить обработку транскрибации", value=False)
25
+ processing_prompt = st.sidebar.text_area("Промпт для обработки транскрибации", value=initial_processing_prompt)
26
 
27
  ACCESS_TOKEN = st.secrets["HF_TOKEN"]
28
 
29
  uploaded_file = st.file_uploader("Загрузите аудиофайл", type=["mp4", "wav", "m4a"])
30
 
31
  if uploaded_file is not None:
32
+ file_name = uploaded_file.name
33
+
34
+ if 'file_name' not in st.session_state or st.session_state.file_name != file_name:
35
+ st.session_state.transcript = ''
36
+ st.session_state.file_name = file_name
37
+
38
  st.audio(uploaded_file)
39
  file_extension = uploaded_file.name.split(".")[-1] # Получаем расширение файла
40
  temp_file_path = f"temp_file.{file_extension}" # Создаем временное имя файла с правильным расширением
41
 
42
  with open(temp_file_path, "wb") as f:
43
  f.write(uploaded_file.getbuffer())
44
+
45
+ if 'transcript' not in st.session_state or st.session_state.transcript == '':
46
 
47
+ with st.spinner('Транскрибируем...'):
48
+ # Load model
49
+ model = whisperx.load_model(os.getenv('WHISPER_MODEL_SIZE'), device, compute_type=compute_type)
50
+ # Load and transcribe audio
51
+ audio = whisperx.load_audio(temp_file_path)
52
+ result = model.transcribe(audio, batch_size=batch_size, language="ru")
53
+ print('Transcribed, now aligning')
54
 
55
+ model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
56
+ result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
57
+ print('Aligned, now diarizing')
58
 
59
+ diarize_model = whisperx.DiarizationPipeline(use_auth_token=st.secrets["HF_TOKEN"], device=device)
60
+ diarize_segments = diarize_model(audio)
61
+ result_diar = whisperx.assign_word_speakers(diarize_segments, result)
62
 
63
+ st.write("Результат транскрибации:")
64
+ transcript = convert_segments_object_to_text(result_diar)
65
+ else:
66
+ transcript = st.session_state.transcript
67
+
68
  st.text(transcript)
69
 
70
  access_token = get_access_token()
71
 
72
  if (enable_summarization):
73
  with st.spinner('Обрабатываем транскрибацию...'):
74
+ if (llm == 'GigaChat'):
75
+ number_of_tokens = get_number_of_tokens(transcript, access_token)
76
+ print('Количество токенов в транскрибации: ' + str(number_of_tokens))
77
+ transcript = get_completion_from_gigachat(processing_prompt + transcript, number_of_tokens + 500, access_token)
78
+ elif (llm == 'Chat GPT'):
79
+ transcript = get_completion_from_openai(processing_prompt + transcript)
80
 
81
  st.write("Результат обработки:")
82
  st.text(transcript)
 
84
 
85
 
86
  with st.spinner('Резюмируем...'):
87
+ if (llm == 'GigaChat'):
88
+ summary_answer = get_completion_from_gigachat(base_prompt + transcript, max_tokens_summary, access_token)
89
+ elif (llm == 'Chat GPT'):
90
+ summary_answer = get_completion_from_openai(base_prompt + transcript)
91
 
92
  st.write("Результат резюмирования:")
93
  st.text(summary_answer)