Spaces:
Runtime error
Runtime error
File size: 12,397 Bytes
5e01e13 ea06e5a 5e01e13 313342d 7adeb33 ea06e5a 190395c 4bceec3 fb9010f 376239d 190395c dba8b67 376239d 6c4dec8 376239d 190395c 5e01e13 ea06e5a 5e01e13 7df4a81 cbc358c 7df4a81 cbc358c 7df4a81 cbc358c 7df4a81 0d44baa 7df4a81 cbc358c 7df4a81 2f582d1 0d44baa 7adeb33 7df4a81 0d44baa cbc358c 0d44baa cbc358c 0d44baa cbc358c 0d44baa cbc358c 7adeb33 7df4a81 5e01e13 be4f1b2 2f582d1 be4f1b2 5e01e13 997ca15 0d44baa 93330ce 2f582d1 313342d 0d44baa 3911108 7adeb33 313342d 7df4a81 0d44baa 7adeb33 313342d 7df4a81 0d44baa 2f582d1 0d44baa 2f582d1 0d44baa 7df4a81 5e01e13 602f686 190395c fb9010f 190395c 7df4a81 0d44baa 602f686 be4f1b2 0d44baa 190395c 0d44baa 2f582d1 997ca15 2f582d1 313342d 93330ce 3911108 7df4a81 cacde98 7df4a81 0d44baa c7632f4 7df4a81 2f582d1 0d44baa eda017b 0d44baa eda017b 0d44baa eda017b 0d44baa 602f686 0d44baa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 |
import json
import gradio as gr
import os
from PIL import Image
import plotly.graph_objects as go
import plotly.express as px
import operator
TITLE = "Identity Representation in Diffusion Models"
_INTRO = """
# Identity Representation in Diffusion Models
Explore the data generated from [DiffusionBiasExplorer](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)!
This demo showcases patterns in images generated by Stable Diffusion and Dalle-2 systems.
Specifically, images obtained from prompt inputs that span various gender- and ethnicity-related terms are clustered to show how those shape visual representations (more details below).
We encourage users to take advantage of this app to explore those trends, for example through the lens of the following questions:
- Find the cluster that has the most prompts denoting a gender or ethnicity that you identify with. Do you think the generated images look like you?
- Find two clusters that have a similar distribution of gender terms but different distributions of ethnicity terms. Do you see any meaningful differences in how gender is visually represented?
- Do you find that some ethnicity terms lead to more stereotypical visual representations than others?
- Do you find that some gender terms lead to more stereotypical visual representations than others?
These questions only scratch the surface of what we can learn from demos like this one,
let us know what you find [in the discussions tab](https://huggingface.co/spaces/society-ethics/DiffusionFaceClustering/discussions),
or if you think of other relevant questions!
"""
_CONTEXT = """
##### How do diffusion-based models represent gender and ethnicity?
In order to evaluate the *social biases* that Text-to-Image (TTI) systems may reproduce or exacerbate,
we need to first understand how the visual representations they generate relate to notions of gender and ethnicity.
These two aspects of a person's identity, however, ar known as **socialy constructed characteristics**:
that is to say, gender and ethnicity only exist in interactions between people, they do not have an independent existence based solely on physical (or visual) attributes.
This means that while we can characterize trends in how the models associate visual features with specific *identity terms in the generation prompts*,
we should not assign a specific gender or ethnicity to a synthetic figure generated by an ML model.
In this app, we instead take a 2-step clustering-based approach. First, we generate 680 images for each model by varying mentions of terms that denote gender or ethnicity in the prompts.
Then, we use a [VQA-based model](https://huggingface.co/Salesforce/blip-vqa-base) to cluster these images at different granularities (12, 24, or 48 clusters).
Exploring these clusters allows us to examine trends in the models' associations between visual features and textual representation of social attributes.
**Note:** this demo was developed with a limited set of gender- and ethnicity-related terms that are more relevant to the US context as a first approach,
so users may not always find themselves represented.
If you have suggestions for additional categories you would particularly like to see in the next version,
please tell us about them [in the discussions tab](https://huggingface.co/spaces/society-ethics/DiffusionFaceClustering/discussions)!
"""
clusters_12 = json.load(open("clusters/id_all_blip_clusters_12.json"))
clusters_24 = json.load(open("clusters/id_all_blip_clusters_24.json"))
clusters_48 = json.load(open("clusters/id_all_blip_clusters_48.json"))
clusters_by_size = {
12: clusters_12,
24: clusters_24,
48: clusters_48,
}
def to_string(label):
if label == "SD_2":
label = "Stable Diffusion 2.0"
elif label == "SD_14":
label = "Stable Diffusion 1.4"
elif label == "DallE":
label = "Dall-E 2"
elif label == "non-binary":
label = "non-binary person"
elif label == "person":
label = "<i>unmarked</i> (person)"
elif label == "":
label = "<i>unmarked</i> ()"
elif label == "gender":
label = "gender term"
return label
def summarize_clusters(clusters_list, max_terms=3):
for cl_id, cl_dict in enumerate(clusters_list):
total = len(cl_dict["img_path_list"])
gdr_list = cl_dict["labels_gender"]
eth_list = cl_dict["labels_ethnicity"]
cl_dict["sentence_desc"] = (
f"Cluster {cl_id} | \t"
+ f"gender terms incl.: {gdr_list[0][0].replace('person', 'unmarked(gender)')}"
+ (
f" - {gdr_list[1][0].replace('person', 'unmarked(gender)')} | "
if len(gdr_list) > 1
else " | "
)
+ f"ethnicity terms incl.: {'unmarked(ethnicity)' if eth_list[0][0] == '' else eth_list[0][0]}"
+ (
f" - {'unmarked(ethnicity)' if eth_list[1][0] == '' else eth_list[1][0]}"
if len(eth_list) > 1
else ""
)
)
cl_dict["summary_desc"] = (
f"Cluster {cl_id} has {total} images.\n"
+ f"- The most represented gender terms are {gdr_list[0][0].replace('person', 'unmarked')} ({gdr_list[0][1]})"
+ (
f" and {gdr_list[1][0].replace('person', 'unmarked')} ({gdr_list[1][1]}).\n"
if len(gdr_list) > 1
else ".\n"
)
+ f"- The most represented ethnicity terms are {'unmarked' if eth_list[0][0] == '' else eth_list[0][0]} ({eth_list[0][1]})"
+ (
f" and {'unmarked' if eth_list[1][0] == '' else eth_list[1][0]} ({eth_list[1][1]}).\n"
if len(eth_list) > 1
else ".\n"
)
+ "See below for a more detailed description."
)
for _, clusters_list in clusters_by_size.items():
summarize_clusters(clusters_list)
dropdown_descs = dict(
(num_clusters, [cl_dct["sentence_desc"] for cl_dct in clusters_list])
for num_clusters, clusters_list in clusters_by_size.items()
)
def describe_cluster(cl_dict, block="label", max_items=4):
labels_values = sorted(cl_dict.items(), key=operator.itemgetter(1))
labels_values.reverse()
total = float(sum(cl_dict.values()))
lv_prcnt = list(
(item[0], round(item[1] * 100 / total, 0)) for item in labels_values
)
top_label = lv_prcnt[0][0]
description_string = (
"<span>The most represented %s is <b>%s</b>, making up about <b>%d%%</b> of the cluster.</span>"
% (to_string(block), to_string(top_label), lv_prcnt[0][1])
)
description_string += "<p>This is followed by: "
for lv in lv_prcnt[1 : min(len(lv_prcnt), 1 + max_items)]:
description_string += "<BR/><b>%s:</b> %d%%" % (to_string(lv[0]), lv[1])
if len(lv_prcnt) > max_items + 1:
description_string += "<BR/><b> - Other terms:</b> %d%%" % (
sum(lv[1] for lv in lv_prcnt[max_items + 1 :]),
)
description_string += "</p>"
return description_string
def show_cluster(cl_id, num_clusters):
if not cl_id:
cl_id = 0
else:
cl_id = (
dropdown_descs[num_clusters].index(cl_id)
if cl_id in dropdown_descs[num_clusters]
else 0
)
if not num_clusters:
num_clusters = 12
cl_dct = clusters_by_size[num_clusters][cl_id]
images = []
for i in range(8):
img_path = "/".join(
[st.replace("/", "") for st in cl_dct["img_path_list"][i].split("//")][3:]
)
im = Image.open(os.path.join("identities-images", img_path))
# .resize((256, 256))
caption = (
"_".join([img_path.split("/")[0], img_path.split("/")[-1]])
.replace("Photo_portrait_of_an_", "")
.replace("Photo_portrait_of_a_", "")
.replace("SD_v2_random_seeds_identity_", "(SD v.2) ")
.replace("dataset-identities-dalle2_", "(Dall-E 2) ")
.replace("SD_v1.4_random_seeds_identity_", "(SD v.1.4) ")
.replace("_", " ")
)
images.append((im, caption))
model_fig = go.Figure()
model_fig.add_trace(
go.Pie(
labels=list(dict(cl_dct["labels_model"]).keys()),
values=list(dict(cl_dct["labels_model"]).values()),
)
)
model_description = describe_cluster(dict(cl_dct["labels_model"]), "system")
gender_fig = go.Figure()
gender_fig.add_trace(
go.Pie(
labels=list(dict(cl_dct["labels_gender"]).keys()),
values=list(dict(cl_dct["labels_gender"]).values()),
)
)
gender_description = describe_cluster(dict(cl_dct["labels_gender"]), "gender")
ethnicity_fig = go.Figure()
ethnicity_fig.add_trace(
go.Bar(
x=list(dict(cl_dct["labels_ethnicity"]).keys()),
y=list(dict(cl_dct["labels_ethnicity"]).values()),
marker_color=px.colors.qualitative.G10,
)
)
ethnicity_description = describe_cluster(
dict(cl_dct["labels_ethnicity"]), "ethnicity"
)
return (
clusters_by_size[num_clusters][cl_id]["summary_desc"],
gender_fig,
gender_description,
model_fig,
model_description,
ethnicity_fig,
ethnicity_description,
images,
gr.update(choices=dropdown_descs[num_clusters]),
# gr.update(choices=[i for i in range(num_clusters)]),
)
with gr.Blocks(title=TITLE) as demo:
gr.Markdown(_INTRO)
with gr.Accordion(
"How do diffusion-based models represent gender and ethnicity?", open =False
):
gr.Markdown(_CONTEXT)
gr.HTML(
"""<span style="color:red" font-size:smaller>⚠️ DISCLAIMER: the images displayed by this tool were generated by text-to-image systems and may depict offensive stereotypes or contain explicit content.</span>"""
)
num_clusters = gr.Radio(
[12, 24, 48],
value=12,
label="How many clusters do you want to make from the data?",
)
with gr.Row():
with gr.Column():
cluster_id = gr.Dropdown(
choices=dropdown_descs[num_clusters.value],
value=0,
label="Select cluster to visualize:",
)
a = gr.Text(label="Cluster summary")
with gr.Column():
gallery = gr.Gallery(label="Most representative images in cluster").style(
grid=[2, 4], height="auto"
)
with gr.Row():
with gr.Column():
c = gr.Plot(label="How many images from each system?")
c_desc = gr.HTML(label="")
with gr.Column(scale=1):
b = gr.Plot(label="Which gender terms are represented?")
b_desc = gr.HTML(label="")
with gr.Column(scale=2):
d = gr.Plot(label="Which ethnicity terms are present?")
d_desc = gr.HTML(label="")
gr.Markdown(
"### Plot Descriptions \n\n"
+ " The **System makeup** plot (*left*) corresponds to the number of images from the cluster that come from each of the TTI systems that we are comparing: Dall-E 2, Stable Diffusion v.1.4. and Stable Diffusion v.2.\n\n"
+ " The **Gender term makeup** plot (*middle*) shows the number of images based on the input prompts that used the phrases man, woman, non-binary person, and person (unmarked) to describe the figure's gender.\n\n"
+ " The **Ethnicity label makeup** plot (*right*) corresponds to the number of images from each of the 18 ethnicity descriptions used in the prompts. A blank value denotes unmarked ethnicity.\n\n"
)
demo.load(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[a, b, b_desc, c, c_desc, d, d_desc, gallery, cluster_id],
)
num_clusters.change(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[
a,
b,
b_desc,
c,
c_desc,
d,
d_desc,
gallery,
cluster_id,
],
)
cluster_id.change(
fn=show_cluster,
inputs=[cluster_id, num_clusters],
outputs=[a, b, b_desc, c, c_desc, d, d_desc, gallery, cluster_id],
)
if __name__ == "__main__":
demo.queue().launch(debug=True)
|