Petr Tsvetkov commited on
Commit
7dd87f6
1 Parent(s): 4317849

Try new workaround

Browse files
Files changed (1) hide show
  1. hf_dataset_saver_builder.py +55 -56
hf_dataset_saver_builder.py CHANGED
@@ -5,67 +5,66 @@ from typing import Any
5
  import gradio as gr
6
 
7
 
8
- def _deserialize_components_fix(
9
- self,
10
- data_dir: Path,
11
- flag_data: list[Any],
12
- flag_option: str = "",
13
- username: str = "",
14
- ) -> tuple[dict[Any, Any], list[Any]]:
15
- """Deserialize components and return the corresponding row for the flagged sample.
 
16
 
17
- Images/audio are saved to disk as individual files.
18
- """
19
- # Components that can have a preview on dataset repos
20
- file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
21
 
22
- # Generate the row corresponding to the flagged sample
23
- features = OrderedDict()
24
- row = []
25
- for component, sample in zip(self.components, flag_data):
26
- # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
27
- label = component.label or ""
28
- save_dir = data_dir / gr.flagging.client_utils.strip_invalid_filename_characters(label)
29
- save_dir.mkdir(exist_ok=True, parents=True)
30
- deserialized = component.flag(sample, save_dir)
31
 
32
- # Add deserialized object to row
33
- features[label] = {"dtype": "string", "_type": "Value"}
34
- try:
35
- assert Path(deserialized).exists()
36
- row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
37
- except (AssertionError, TypeError, ValueError, OSError):
38
- deserialized = "" if deserialized is None else str(deserialized)
39
- row.append(deserialized)
40
 
41
- # If component is eligible for a preview, add the URL of the file
42
- # Be mindful that images and audio can be None
43
- if isinstance(component, tuple(file_preview_types)): # type: ignore
44
- for _component, _type in file_preview_types.items():
45
- if isinstance(component, _component):
46
- features[label + " file"] = {"_type": _type}
47
- break
48
- if deserialized:
49
- path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
50
- Path(deserialized).relative_to(self.dataset_dir)
51
- ).replace("\\", "/")
52
- row.append(
53
- gr.flagging.huggingface_hub.hf_hub_url(
54
- repo_id=self.dataset_id,
55
- filename=path_in_repo,
56
- repo_type="dataset",
 
57
  )
58
- )
59
- else:
60
- row.append("")
61
- features["flag"] = {"dtype": "string", "_type": "Value"}
62
- features["username"] = {"dtype": "string", "_type": "Value"}
63
- row.append(flag_option)
64
- row.append(username)
65
- return features, row
66
 
67
 
68
  def get_dataset_saver(*args, **kwargs):
69
- saver = gr.HuggingFaceDatasetSaver(*args, **kwargs)
70
- saver._deserialize_components = _deserialize_components_fix
71
- return saver
 
5
  import gradio as gr
6
 
7
 
8
+ class HFDatasetSaverFixed(gr.HuggingFaceDatasetSaver):
9
+ def _deserialize_components(
10
+ self,
11
+ data_dir: Path,
12
+ flag_data: list[Any],
13
+ flag_option: str = "",
14
+ username: str = "",
15
+ ) -> tuple[dict[Any, Any], list[Any]]:
16
+ """Deserialize components and return the corresponding row for the flagged sample.
17
 
18
+ Images/audio are saved to disk as individual files.
19
+ """
20
+ # Components that can have a preview on dataset repos
21
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
22
 
23
+ # Generate the row corresponding to the flagged sample
24
+ features = OrderedDict()
25
+ row = []
26
+ for component, sample in zip(self.components, flag_data):
27
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
28
+ label = component.label or ""
29
+ save_dir = data_dir / gr.flagging.client_utils.strip_invalid_filename_characters(label)
30
+ save_dir.mkdir(exist_ok=True, parents=True)
31
+ deserialized = component.flag(sample, save_dir)
32
 
33
+ # Add deserialized object to row
34
+ features[label] = {"dtype": "string", "_type": "Value"}
35
+ try:
36
+ assert Path(deserialized).exists()
37
+ row.append(str(Path(deserialized).relative_to(self.dataset_dir)))
38
+ except (AssertionError, TypeError, ValueError, OSError):
39
+ deserialized = "" if deserialized is None else str(deserialized)
40
+ row.append(deserialized)
41
 
42
+ # If component is eligible for a preview, add the URL of the file
43
+ # Be mindful that images and audio can be None
44
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
45
+ for _component, _type in file_preview_types.items():
46
+ if isinstance(component, _component):
47
+ features[label + " file"] = {"_type": _type}
48
+ break
49
+ if deserialized:
50
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
51
+ Path(deserialized).relative_to(self.dataset_dir)
52
+ ).replace("\\", "/")
53
+ row.append(
54
+ gr.flagging.huggingface_hub.hf_hub_url(
55
+ repo_id=self.dataset_id,
56
+ filename=path_in_repo,
57
+ repo_type="dataset",
58
+ )
59
  )
60
+ else:
61
+ row.append("")
62
+ features["flag"] = {"dtype": "string", "_type": "Value"}
63
+ features["username"] = {"dtype": "string", "_type": "Value"}
64
+ row.append(flag_option)
65
+ row.append(username)
66
+ return features, row
 
67
 
68
 
69
  def get_dataset_saver(*args, **kwargs):
70
+ return HFDatasetSaverFixed(*args, **kwargs)