Hervé BREDIN commited on
Commit
57604a5
1 Parent(s): 98df54b

feat: update to latest pyannote and wavesurfer (#3)

Browse files
Files changed (3) hide show
  1. app.py +47 -46
  2. assets/template.html +39 -44
  3. requirements.txt +1 -7
app.py CHANGED
@@ -23,10 +23,10 @@
23
 
24
  import io
25
  import base64
 
26
  import numpy as np
27
  import scipy.io.wavfile
28
  from typing import Text
29
- from huggingface_hub import HfApi
30
  import streamlit as st
31
  from pyannote.audio import Pipeline
32
  from pyannote.audio import Audio
@@ -49,32 +49,47 @@ def to_base64(waveform: np.ndarray, sample_rate: int = 16000) -> Text:
49
  PYANNOTE_LOGO = "https://avatars.githubusercontent.com/u/7559051?s=400&v=4"
50
  EXCERPT = 30.0
51
 
52
- st.set_page_config(
53
- page_title="pyannote.audio pretrained pipelines", page_icon=PYANNOTE_LOGO
54
- )
55
 
 
56
 
57
- st.sidebar.image(PYANNOTE_LOGO)
 
 
 
 
 
 
 
 
 
58
 
59
- st.markdown("""# 🎹 Pretrained pipelines
60
- """)
61
 
62
  PIPELINES = [
63
- p.modelId
64
- for p in HfApi().list_models(filter="pyannote-audio-pipeline")
65
- if p.modelId.startswith("pyannote/")
66
  ]
67
 
68
  audio = Audio(sample_rate=16000, mono=True)
69
 
70
- selected_pipeline = st.selectbox("Select a pipeline", PIPELINES, index=0)
 
71
 
72
  with st.spinner("Loading pipeline..."):
73
- pipeline = Pipeline.from_pretrained(selected_pipeline, use_auth_token=st.secrets["PYANNOTE_TOKEN"])
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- uploaded_file = st.file_uploader("Choose an audio file")
76
  if uploaded_file is not None:
77
-
78
  try:
79
  duration = audio.get_duration(uploaded_file)
80
  except RuntimeError as e:
@@ -86,12 +101,12 @@ if uploaded_file is not None:
86
  uri = "".join(uploaded_file.name.split())
87
  file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri}
88
 
89
- with st.spinner(f"Processing first {EXCERPT:g} seconds..."):
90
  output = pipeline(file)
91
 
92
- with open('assets/template.html') as html, open('assets/style.css') as css:
93
  html_template = html.read()
94
- st.markdown('<style>{}</style>'.format(css.read()), unsafe_allow_html=True)
95
 
96
  colors = [
97
  "#ffd70033",
@@ -105,50 +120,36 @@ if uploaded_file is not None:
105
  ]
106
  num_colors = len(colors)
107
 
108
- label2color = {label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels()))}
 
 
109
 
110
  BASE64 = to_base64(waveform.numpy().T)
111
 
112
  REGIONS = ""
113
- LEGENDS = ""
114
- labels=[]
115
  for segment, _, label in output.itertracks(yield_label=True):
116
- REGIONS += f"var re = wavesurfer.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});"
117
- if not label in labels:
118
- LEGENDS += f"<li><span style='background-color:{label2color[label]}'></span>{label}</li>"
119
- labels.append(label)
120
 
121
  html = html_template.replace("BASE64", BASE64).replace("REGIONS", REGIONS)
122
  components.html(html, height=250, scrolling=True)
123
- st.markdown("<div style='overflow : auto'><ul class='legend'>"+LEGENDS+"</ul></div>", unsafe_allow_html=True)
124
-
125
- st.markdown("---")
126
 
127
  with io.StringIO() as fp:
128
  output.write_rttm(fp)
129
  content = fp.getvalue()
130
-
131
  b64 = base64.b64encode(content.encode()).decode()
132
- href = f'Download as <a download="{output.uri}.rttm" href="data:file/text;base64,{b64}">RTTM</a> or run it on the whole {int(duration):d}s file:'
133
  st.markdown(href, unsafe_allow_html=True)
134
 
135
  code = f"""
136
- from pyannote.audio import Pipeline
137
- pipeline = Pipeline.from_pretrained("{selected_pipeline}")
138
- output = pipeline("{uploaded_file.name}")
139
- """
140
- st.code(code, language='python')
141
-
142
-
143
-
144
- st.sidebar.markdown(
145
- """
146
- -------------------
147
-
148
- To use these pipelines on more and longer files on your own (GPU, hence much faster) servers, check the [documentation](https://github.com/pyannote/pyannote-audio).
149
 
150
- For [technical questions](https://github.com/pyannote/pyannote-audio/discussions) and [bug reports](https://github.com/pyannote/pyannote-audio/issues), please check [pyannote.audio](https://github.com/pyannote/pyannote-audio) Github repository.
 
 
151
 
152
- For commercial enquiries and scientific consulting, please contact [me](mailto:[email protected]).
153
- """
154
- )
 
23
 
24
  import io
25
  import base64
26
+ import torch
27
  import numpy as np
28
  import scipy.io.wavfile
29
  from typing import Text
 
30
  import streamlit as st
31
  from pyannote.audio import Pipeline
32
  from pyannote.audio import Audio
 
49
  PYANNOTE_LOGO = "https://avatars.githubusercontent.com/u/7559051?s=400&v=4"
50
  EXCERPT = 30.0
51
 
52
+ st.set_page_config(page_title="pyannote pretrained pipelines", page_icon=PYANNOTE_LOGO)
 
 
53
 
54
+ col1, col2 = st.columns([0.2, 0.8], gap="small")
55
 
56
+ with col1:
57
+ st.image(PYANNOTE_LOGO)
58
+
59
+ with col2:
60
+ st.markdown(
61
+ """
62
+ # pretrained pipelines
63
+ Make the most of [pyannote](https://github.com/pyannote) thanks to our [consulting services](https://herve.niderb.fr/consulting.html)
64
+ """
65
+ )
66
 
 
 
67
 
68
  PIPELINES = [
69
+ "pyannote/speaker-diarization-3.0",
 
 
70
  ]
71
 
72
  audio = Audio(sample_rate=16000, mono=True)
73
 
74
+ selected_pipeline = st.selectbox("Select a pretrained pipeline", PIPELINES, index=0)
75
+
76
 
77
  with st.spinner("Loading pipeline..."):
78
+ try:
79
+ use_auth_token = st.secrets["PYANNOTE_TOKEN"]
80
+ except FileNotFoundError:
81
+ use_auth_token = None
82
+ except KeyError:
83
+ use_auth_token = None
84
+
85
+ pipeline = Pipeline.from_pretrained(
86
+ selected_pipeline, use_auth_token=use_auth_token
87
+ )
88
+ if torch.cuda.is_available():
89
+ pipeline.to(torch.device("cuda"))
90
 
91
+ uploaded_file = st.file_uploader("Upload an audio file")
92
  if uploaded_file is not None:
 
93
  try:
94
  duration = audio.get_duration(uploaded_file)
95
  except RuntimeError as e:
 
101
  uri = "".join(uploaded_file.name.split())
102
  file = {"waveform": waveform, "sample_rate": sample_rate, "uri": uri}
103
 
104
+ with st.spinner(f"Processing {EXCERPT:g} seconds..."):
105
  output = pipeline(file)
106
 
107
+ with open("assets/template.html") as html, open("assets/style.css") as css:
108
  html_template = html.read()
109
+ st.markdown("<style>{}</style>".format(css.read()), unsafe_allow_html=True)
110
 
111
  colors = [
112
  "#ffd70033",
 
120
  ]
121
  num_colors = len(colors)
122
 
123
+ label2color = {
124
+ label: colors[k % num_colors] for k, label in enumerate(sorted(output.labels()))
125
+ }
126
 
127
  BASE64 = to_base64(waveform.numpy().T)
128
 
129
  REGIONS = ""
 
 
130
  for segment, _, label in output.itertracks(yield_label=True):
131
+ REGIONS += f"regions.addRegion({{start: {segment.start:g}, end: {segment.end:g}, color: '{label2color[label]}', resize : false, drag : false}});"
 
 
 
132
 
133
  html = html_template.replace("BASE64", BASE64).replace("REGIONS", REGIONS)
134
  components.html(html, height=250, scrolling=True)
 
 
 
135
 
136
  with io.StringIO() as fp:
137
  output.write_rttm(fp)
138
  content = fp.getvalue()
 
139
  b64 = base64.b64encode(content.encode()).decode()
140
+ href = f'<a download="{output.uri}.rttm" href="data:file/text;base64,{b64}">Download</a> result in RTTM file format or run it locally:'
141
  st.markdown(href, unsafe_allow_html=True)
142
 
143
  code = f"""
144
+ # load pretrained pipeline
145
+ from pyannote.audio import Pipeline
146
+ pipeline = Pipeline.from_pretrained("{selected_pipeline}",
147
+ use_auth_token=HUGGINGFACE_TOKEN)
 
 
 
 
 
 
 
 
 
148
 
149
+ # (optional) send pipeline to GPU
150
+ import torch
151
+ pipeline.to(torch.device("cuda"))
152
 
153
+ # process audio file
154
+ output = pipeline("audio.wav")"""
155
+ st.code(code, language="python")
assets/template.html CHANGED
@@ -1,46 +1,41 @@
1
- <script src="https://unpkg.com/wavesurfer.js"></script>
2
- <script src="https://unpkg.com/wavesurfer.js/dist/plugin/wavesurfer.regions.min.js"></script>
3
- <script src="https://unpkg.com/wavesurfer.js/dist/plugin/wavesurfer.timeline.min.js"></script>
4
- <br>
5
- <div id="waveform"></div>
6
- <div id="timeline"></div>
7
- <br>
8
- <div><button onclick="play()" id="ppb">Play</button><div>
9
- <script type="text/javascript">
10
- var labels=[];
11
- var wavesurfer = WaveSurfer.create({
12
- container: '#waveform',
13
- barGap: 2,
14
- barHeight: 3,
15
- barWidth: 3,
16
- barRadius: 2,
17
- plugins: [
18
- WaveSurfer.regions.create({}),
19
- WaveSurfer.timeline.create({
20
- container: "#timeline",
21
- notchPercentHeight: 40,
22
- primaryColor: "#444",
23
- primaryFontColor: "#444"
24
- })
25
- ]
26
- });
27
- wavesurfer.load('BASE64');
28
- wavesurfer.on('ready', function () {
29
- wavesurfer.play();
30
- });
31
- wavesurfer.on('play',function() {
32
- document.getElementById('ppb').innerHTML = "Pause";
33
- });
34
- wavesurfer.on('pause',function() {
35
- document.getElementById('ppb').innerHTML = "Play";
36
- });
37
  REGIONS
38
- document.addEventListener('keyup', event => {
39
- if (event.code === 'Space') {
40
- play();
41
- }
42
- })
43
- function play(){
44
- wavesurfer.isPlaying() ? wavesurfer.pause() : wavesurfer.play();
45
- }
 
 
 
 
 
 
46
  </script>
 
 
 
1
+ <script type="module">
2
+ import WaveSurfer from 'https://unpkg.com/wavesurfer.js@7/dist/wavesurfer.esm.js'
3
+ import RegionsPlugin from 'https://unpkg.com/wavesurfer.js@7/dist/plugins/regions.esm.js'
4
+
5
+
6
+ var labels=[];
7
+ const wavesurfer = WaveSurfer.create({
8
+ container: '#waveform',
9
+ barGap: 2,
10
+ barHeight: 3,
11
+ barWidth: 3,
12
+ barRadius: 2,
13
+ });
14
+
15
+ const regions = wavesurfer.registerPlugin(RegionsPlugin.create())
16
+
17
+ wavesurfer.load('BASE64');
18
+ wavesurfer.on('ready', function () {
19
+ wavesurfer.play();
20
+ });
21
+
22
+ wavesurfer.on('decode', function () {
23
+
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  REGIONS
25
+
26
+ wavesurfer.play();
27
+
28
+ });
29
+
30
+ wavesurfer.on('click', () => {
31
+ play();
32
+ });
33
+
34
+
35
+ function play(){
36
+ wavesurfer.isPlaying() ? wavesurfer.pause() : wavesurfer.play();
37
+ }
38
+
39
  </script>
40
+ <div id="waveform"></div>
41
+
requirements.txt CHANGED
@@ -1,7 +1 @@
1
- torch==1.11.0
2
- torchvision==0.12.0
3
- torchaudio==0.11.0
4
- torchtext==0.12.0
5
- speechbrain==0.5.12
6
- pyannote-audio>=2.1
7
-
 
1
+ pyannote-audio==3.0.1