Spaces:
Runtime error
Runtime error
Joshua Lochner
commited on
Commit
•
3af0cd0
1
Parent(s):
b3b69aa
Use `itertools.islice` instead of custom slicing
Browse files- src/preprocess.py +21 -36
src/preprocess.py
CHANGED
@@ -302,9 +302,9 @@ class PreprocessArguments:
|
|
302 |
num_jobs: int = field(
|
303 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
304 |
|
305 |
-
|
306 |
-
|
307 |
-
)
|
308 |
|
309 |
do_generate: bool = field(
|
310 |
default=False, metadata={'help': 'Generate labelled data.'}
|
@@ -325,8 +325,8 @@ class PreprocessArguments:
|
|
325 |
valid_split: float = field(
|
326 |
default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})
|
327 |
|
328 |
-
|
329 |
-
'help': '
|
330 |
|
331 |
max_videos: int = field(default=None, metadata={
|
332 |
'help': 'Maximum number of videos to preprocess.'})
|
@@ -588,7 +588,8 @@ def main():
|
|
588 |
progress.set_description(f'Processing {task.args[0]}')
|
589 |
progress.update()
|
590 |
|
591 |
-
InterruptibleTaskPool(
|
|
|
592 |
|
593 |
final_path = os.path.join(
|
594 |
processed_args.processed_dir, processed_args.processed_file)
|
@@ -675,36 +676,20 @@ def main():
|
|
675 |
|
676 |
# TODO
|
677 |
# count_videos = 0
|
678 |
-
# count_segments = 0
|
679 |
-
|
680 |
-
write_mode = 'w' if preprocess_args.overwrite else 'a'
|
681 |
-
|
682 |
-
get_all = preprocess_args.max_videos is None
|
683 |
-
|
684 |
-
total = len(final_data) if get_all else preprocess_args.max_videos
|
685 |
|
686 |
-
index = 0
|
687 |
data = final_data.items()
|
688 |
-
if preprocess_args.skip_videos is not None:
|
689 |
-
print('Skipping first', preprocess_args.skip_videos, 'videos')
|
690 |
-
data = itertools.islice(data, preprocess_args.skip_videos, None)
|
691 |
-
index = preprocess_args.skip_videos
|
692 |
|
693 |
-
|
694 |
-
|
695 |
-
else:
|
696 |
-
total = min(len(final_data) -
|
697 |
-
preprocess_args.skip_videos, total)
|
698 |
|
699 |
-
|
700 |
-
open(negative_file, write_mode, encoding='utf-8') as negative, \
|
701 |
-
tqdm(total=total) as progress:
|
702 |
|
703 |
-
|
704 |
-
|
|
|
705 |
|
706 |
-
|
707 |
-
break
|
708 |
|
709 |
progress.set_description(f'Processing {video_id}')
|
710 |
progress.update()
|
@@ -735,22 +720,22 @@ def main():
|
|
735 |
if wps < preprocess_args.min_wps:
|
736 |
continue
|
737 |
|
738 |
-
segment_text = ' '.join((x['text'] for x in seg))
|
739 |
-
extracted_segments = extract_sponsors(seg)
|
740 |
d = {
|
741 |
-
'video_index':
|
742 |
'video_id': video_id,
|
743 |
-
'text': clean_text(
|
744 |
'words_per_second': round(wps, 3),
|
745 |
}
|
746 |
|
|
|
747 |
if extracted_segments:
|
748 |
extracted_texts = []
|
749 |
for s in extracted_segments:
|
750 |
-
w = ' '.join(
|
751 |
category = s['category'].upper()
|
752 |
extracted_texts.append(
|
753 |
-
f
|
|
|
754 |
|
755 |
extracted_text = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
|
756 |
extracted_texts)
|
|
|
302 |
num_jobs: int = field(
|
303 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
304 |
|
305 |
+
# append: bool = field(
|
306 |
+
# default=False, metadata={'help': 'Append to training, testing and validation data, if present.'}
|
307 |
+
# )
|
308 |
|
309 |
do_generate: bool = field(
|
310 |
default=False, metadata={'help': 'Generate labelled data.'}
|
|
|
325 |
valid_split: float = field(
|
326 |
default=0.05, metadata={'help': 'Ratio of validation data. Value between 0 and 1.'})
|
327 |
|
328 |
+
start_index: int = field(default=None, metadata={
|
329 |
+
'help': 'Video to start at.'})
|
330 |
|
331 |
max_videos: int = field(default=None, metadata={
|
332 |
'help': 'Maximum number of videos to preprocess.'})
|
|
|
588 |
progress.set_description(f'Processing {task.args[0]}')
|
589 |
progress.update()
|
590 |
|
591 |
+
InterruptibleTaskPool(
|
592 |
+
tasks, preprocess_args.num_jobs, callback).start()
|
593 |
|
594 |
final_path = os.path.join(
|
595 |
processed_args.processed_dir, processed_args.processed_file)
|
|
|
676 |
|
677 |
# TODO
|
678 |
# count_videos = 0
|
679 |
+
# count_segments = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
680 |
|
|
|
681 |
data = final_data.items()
|
|
|
|
|
|
|
|
|
682 |
|
683 |
+
start_index = preprocess_args.start_index or 0
|
684 |
+
end_index = (preprocess_args.max_videos or len(data)) + start_index
|
|
|
|
|
|
|
685 |
|
686 |
+
data = list(itertools.islice(data, start_index, end_index))
|
|
|
|
|
687 |
|
688 |
+
with open(positive_file, 'a', encoding='utf-8') as positive, \
|
689 |
+
open(negative_file, 'a', encoding='utf-8') as negative, \
|
690 |
+
tqdm(data) as progress:
|
691 |
|
692 |
+
for offset, (video_id, sponsor_segments) in enumerate(data):
|
|
|
693 |
|
694 |
progress.set_description(f'Processing {video_id}')
|
695 |
progress.update()
|
|
|
720 |
if wps < preprocess_args.min_wps:
|
721 |
continue
|
722 |
|
|
|
|
|
723 |
d = {
|
724 |
+
'video_index': offset + start_index,
|
725 |
'video_id': video_id,
|
726 |
+
'text': clean_text(' '.join(x['text'] for x in seg)),
|
727 |
'words_per_second': round(wps, 3),
|
728 |
}
|
729 |
|
730 |
+
extracted_segments = extract_sponsors(seg)
|
731 |
if extracted_segments:
|
732 |
extracted_texts = []
|
733 |
for s in extracted_segments:
|
734 |
+
w = ' '.join(q['text'] for q in s['words'])
|
735 |
category = s['category'].upper()
|
736 |
extracted_texts.append(
|
737 |
+
f'{START_SEGMENT_TEMPLATE.format(category)} {w} {END_SEGMENT_TEMPLATE.format(category)}'
|
738 |
+
)
|
739 |
|
740 |
extracted_text = f' {CustomTokens.BETWEEN_SEGMENTS.value} '.join(
|
741 |
extracted_texts)
|