Woziii commited on
Commit
03dd785
1 Parent(s): db99484

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -58
app.py CHANGED
@@ -1,63 +1,177 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  ],
 
 
 
59
  )
60
 
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
+ import spaces
2
  import gradio as gr
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
+ from threading import Thread
6
+ from typing import Iterator
7
+ from datasets import load_dataset
8
+ import soundfile as sf
9
+ import numpy as np
10
+
11
+ # Constantes
12
+ DEFAULT_MAX_NEW_TOKENS = 100
13
+ MAX_INPUT_TOKEN_LENGTH = 4096
14
+ MODEL_NAME = "openai/whisper-small"
15
+ FILE_LIMIT_MB = 1000
16
+ YT_LENGTH_LIMIT_S = 3600
17
+
18
+ # Chargement des modèles
19
+ device = 0 if torch.cuda.is_available() else "cpu"
20
+ stt_pipeline = pipeline(
21
+ task="automatic-speech-recognition",
22
+ model=MODEL_NAME,
23
+ device=device,
24
+ model_kwargs={"low_cpu_mem_usage": True},
25
+ )
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
28
+ lm_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True, device_map="cpu")
29
+
30
+ tts_pipeline = pipeline("text-to-speech", "microsoft/speecht5_tts")
31
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
32
+ speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
33
+
34
+ # System prompt
35
+ SYSTEM_PROMPT = """Tu es un assistant IA français nommé Lucas, conçu pour aider les utilisateurs de manière amicale et professionnelle. Tu dois toujours répondre en français. Voici quelques règles à suivre :
36
+
37
+ 1. Sois poli, respectueux et bienveillant dans toutes tes interactions.
38
+ 2. Fournis des informations précises et à jour, en citant des sources fiables si nécessaire.
39
+ 3. Si tu ne connais pas la réponse à une question, admets-le honnêtement.
40
+ 4. Adapte ton langage en fonction du contexte de la conversation.
41
+ 5. Respecte la vie privée des utilisateurs et ne demande pas d'informations personnelles.
42
+ 6. Encourage la réflexion critique et l'apprentissage.
43
+ 7. Évite tout contenu inapproprié, offensant ou discriminatoire.
44
+
45
+ Ton objectif est d'assister l'utilisateur de la meilleure façon possible tout en respectant ces principes."""
46
+
47
+ # Variables globales
48
+ is_first_interaction = True
49
+
50
+ # Fonctions utilitaires
51
+ def transcribe_audio(audio):
52
+ return stt_pipeline(audio, generate_kwargs={"language": "french"})["text"]
53
+
54
+ def text_to_speech(text):
55
+ speech = tts_pipeline(text, forward_params={"speaker_embeddings": speaker_embedding})
56
+ return (speech["audio"], speech["sampling_rate"])
57
+
58
+ def determine_response_type(message):
59
+ if len(message.split()) < 10:
60
+ return "short"
61
+ elif len(message.split()) > 30:
62
+ return "long"
63
+ else:
64
+ return "medium"
65
+
66
+ def post_process_response(response, is_short):
67
+ # Implémentez ici la logique de post-traitement si nécessaire
68
+ return response
69
+
70
+ def check_coherence(response):
71
+ # Implémentez ici la vérification de cohérence si nécessaire
72
+ return True
73
+
74
+ def early_stopping(text):
75
+ # Implémentez ici la logique d'arrêt anticipé si nécessaire
76
+ return text
77
+
78
+ # Fonction principale
79
+ @spaces.GPU(duration=120)
80
+ def speech_to_speech_pipeline(
81
+ audio_input,
82
+ chat_history: list[tuple[str, str]],
83
+ max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
84
+ temperature: float = 0.7,
85
+ top_p: float = 0.95,
86
+ ) -> Iterator[str]:
87
+ global is_first_interaction
88
+
89
+ # Transcription de l'audio en texte
90
+ message = transcribe_audio(audio_input)
91
+
92
+ if is_first_interaction:
93
+ warning_message = """⚠️ Attention : Je suis un modèle en version alpha (V.0.0.5) et je peux générer des réponses incohérentes ou inexactes. Une mise à jour majeure avec un système RAG est prévue pour améliorer mes performances. Merci de votre compréhension ! 😊"""
94
+ yield warning_message
95
+ is_first_interaction = False
96
+
97
+ response_type = determine_response_type(message)
98
+ if response_type == "short":
99
+ max_new_tokens = max(70, max_new_tokens)
100
+ elif response_type == "long":
101
+ max_new_tokens = min(max(120, max_new_tokens), 200)
102
+ else: # medium
103
+ max_new_tokens = min(max(70, max_new_tokens), 120)
104
+
105
+ conversation = []
106
+ conversation.append({"role": "system", "content": SYSTEM_PROMPT})
107
+
108
+ for user, assistant in chat_history[-3:]:
109
+ conversation.append({"role": "user", "content": user})
110
+ if assistant:
111
+ conversation.append({"role": "assistant", "content": assistant})
112
+
113
+ conversation.append({"role": "user", "content": message})
114
+
115
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
116
+ attention_mask = torch.ones_like(input_ids)
117
+
118
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
119
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
120
+ attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
121
+ gr.Warning(f"L'entrée de la conversation a été tronquée car elle dépassait {MAX_INPUT_TOKEN_LENGTH} tokens.")
122
+
123
+ input_ids = input_ids.to(lm_model.device)
124
+ attention_mask = attention_mask.to(lm_model.device)
125
+
126
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
127
+ generate_kwargs = dict(
128
+ input_ids=input_ids,
129
+ attention_mask=attention_mask,
130
+ streamer=streamer,
131
+ max_new_tokens=max_new_tokens,
132
+ do_sample=True,
133
  top_p=top_p,
134
+ temperature=temperature,
135
+ num_beams=1,
136
+ eos_token_id=tokenizer.eos_token_id,
137
+ pad_token_id=tokenizer.pad_token_id,
138
+ )
139
+
140
+ t = Thread(target=lm_model.generate, kwargs=generate_kwargs)
141
+ t.start()
142
+
143
+ outputs = []
144
+ for text in streamer:
145
+ outputs.append(text)
146
+ partial_output = early_stopping("".join(outputs))
147
+ processed_output = post_process_response(partial_output, response_type == "short")
148
+ if not check_coherence(processed_output):
149
+ yield "Je m'excuse, ma réponse manquait de cohérence. Pouvez-vous reformuler votre question ?"
150
+ return
151
+ yield processed_output
152
+
153
+ final_output = early_stopping("".join(outputs))
154
+ final_processed_output = post_process_response(final_output, response_type == "short")
155
+
156
+ if check_coherence(final_processed_output):
157
+ audio_output, sample_rate = text_to_speech(final_processed_output)
158
+ yield (sample_rate, audio_output)
159
+ else:
160
+ yield "Je m'excuse, ma réponse finale manquait de cohérence. Pouvez-vous reformuler votre question ?"
161
+
162
+ # Interface Gradio
163
+ iface = gr.Interface(
164
+ fn=speech_to_speech_pipeline,
165
+ inputs=[
166
+ gr.Audio(source="microphone", type="numpy"),
167
+ gr.State([]), # pour chat_history
168
+ gr.Slider(minimum=1, maximum=500, value=DEFAULT_MAX_NEW_TOKENS, label="Max New Tokens"),
169
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature"),
170
+ gr.Slider(minimum=0.05, maximum=1.0, value=0.95, label="Top P"),
171
  ],
172
+ outputs="audio",
173
+ title="Assistant IA Lucas - Speech-to-Speech en Français",
174
+ description="Parlez dans le microphone et obtenez une réponse audio générée par l'IA en français."
175
  )
176
 
177
+ iface.launch()