Spaces:
Build error
Build error
import matplotlib as mpl | |
mpl.use('Agg') | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import matplotlib.patches as patches | |
import numpy as np | |
from PIL import Image | |
from zipfile import ZipFile | |
import gradio as gr | |
class SampleClass: | |
def __init__(self): | |
self.test_df = pd.read_json("data/full_pred_test_w_plurals_w_iou.json") | |
self.val_df = pd.read_json("data/full_pred_val_w_plurals_w_iou.json") | |
self.zip_file = ZipFile("data/saiapr_tc-12.zip", 'r') | |
self.filtered_df = None | |
def __get(self, img_path): | |
img_obj = self.zip_file.open(img_path) | |
img = Image.open(img_obj) | |
# img = np.array(img) | |
return img | |
def __loadPredictions(self, split, model): | |
assert(split in ['test','val']) | |
assert(model in ['baseline','extended']) | |
if split == "test": | |
df = self.test_df | |
elif split == "val": | |
df = self.val_df | |
else: | |
raise ValueError("File not available yet") | |
if model == 'baseline': | |
df = df.rename(columns={'baseline_hit':'hit', 'baseline_pred':'predictions', | |
'extended_hit':'hit_other', 'extended_pred':'predictions_other', | |
'baseline_iou':'iou', | |
'extended_iou':'iou_other'} | |
) | |
elif model == 'extended': | |
df = df.rename(columns={'extended_hit':'hit', 'extended_pred':'predictions', | |
'baseline_hit':'hit_other', 'baseline_pred':'predictions_other', | |
'extended_iou':'iou', | |
'baseline_iou':'iou_other'} | |
) | |
return df | |
def __getSample(self, id): | |
sample = self.filtered_df[self.filtered_df.sample_idx == id] | |
sent = sample['sent'].values[0] | |
pos_tags = sample['pos_tags'].values[0] | |
plural_tks = sample['plural_tks'].values[0] | |
cat_intrinsic = sample['intrinsic'].values[0] | |
cat_spatial = sample['spatial'].values[0] | |
cat_ordinal = sample['ordinal'].values[0] | |
cat_relational = sample['relational'].values[0] | |
cat_plural = sample['plural'].values[0] | |
categories = [('instrinsic',cat_intrinsic), | |
('spatial',cat_spatial), | |
('ordinal',cat_ordinal), | |
('relational',cat_relational), | |
('plural',cat_plural)] | |
hit = sample['hit'].values[0] | |
hit_o = sample['hit_other'].values[0] | |
iou = sample['iou'].values[0] | |
iou_o = sample['iou_other'].values[0] | |
prediction = {0:' FAIL ',1:' CORRECT '} | |
bbox_gt = sample['bbox'].values[0] | |
x1_gt,y1_gt,x2_gt,y2_gt = bbox_gt | |
# x1_gt,y1_gt,x2_gt,y2_gt = tuple(map(float,bbox_gt[1:-1].split(","))) | |
bp_bbox = sample['predictions'].values[0] | |
x1_pred,y1_pred,x2_pred,y2_pred = bp_bbox | |
# x1_pred,y1_pred,x2_pred,y2_pred = tuple(map(float,bp_bbox[1:-1].split(","))) | |
bp_o_bbox = sample['predictions_other'].values[0] | |
x1_pred_o,y1_pred_o,x2_pred_o,y2_pred_o = bp_o_bbox | |
# x1_pred_o,y1_pred_o,x2_pred_o,y2_pred_o = tuple(map(float,bp_o_bbox[1:-1].split(","))) | |
# Create Fig with predictions | |
img_path = "saiapr_tc-12"+sample['file_path'].values[0].split("saiapr_tc-12")[1] | |
img_seg_path = img_path.replace("images","segmented_images") | |
fig, ax = plt.subplots(1) | |
ax.imshow(self.__get(img_path), interpolation='bilinear') | |
# Create bbox's | |
rect_gt = patches.Rectangle((x1_gt,y1_gt), (x2_gt-x1_gt),(y2_gt-y1_gt), | |
linewidth=2, edgecolor='blue', facecolor='None') #fill=True, alpha=.3 | |
rect_pred = patches.Rectangle((x1_pred,y1_pred), (x2_pred-x1_pred),(y2_pred-y1_pred), | |
linewidth=2, edgecolor='lightgreen', facecolor='none') | |
rect_pred_o = patches.Rectangle((x1_pred_o,y1_pred_o), (x2_pred_o-x1_pred_o),(y2_pred_o-y1_pred_o), | |
linewidth=2, edgecolor='red', facecolor='none') | |
ax.add_patch(rect_gt) | |
ax.add_patch(rect_pred) | |
ax.add_patch(rect_pred_o) | |
ax.axis('off') | |
info = {'Expresion':sent, | |
'Idx Sample':str(id), | |
'IoU': str(round(iou,2)) + "("+prediction[hit]+")", | |
'IoU other': str(round(iou_o,2)) + "("+prediction[hit_o]+")", | |
'Pos Tags':str(pos_tags), | |
'PluralTks ':plural_tks, | |
'Categories':",".join([c for c,b in categories if b]) | |
} | |
plt.title(info['Expresion'], fontsize=12) | |
plt.tight_layout() | |
plt.close(fig) | |
fig.canvas.draw() | |
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
w, h = fig.canvas.get_width_height() | |
img = data.reshape((int(h), int(w), -1)) | |
return info, img, self.__get(img_seg_path) | |
def explorateSamples(self, | |
username, | |
predictions, | |
category, | |
model, | |
split, | |
next_idx_sample): | |
next_idx_sample = int(next_idx_sample) | |
hit = {'fail':0,'correct':1} | |
df = self.__loadPredictions(split, model) | |
self.filtered_df = df[(df[category] == 1) & (df.hit == hit[predictions])] | |
all_idx_samples = self.filtered_df.sample_idx.to_list() | |
parts = np.array_split(list(all_idx_samples), 4) | |
user_ids = { | |
'luciana':list(parts[0]), | |
'mauri':list(parts[1]), | |
'jorge':list(parts[2]), | |
'nano':list(parts[3]) | |
} | |
try: | |
id_ = user_ids[username].index(next_idx_sample) | |
except: | |
id_ = 0 | |
next_idx_sample = user_ids[username][ min(id_+1, len(user_ids[username])-1) ] | |
progress = {f"{id_}/{len(user_ids[username])-1}":id_/(len(user_ids[username])-1)} | |
info, img, img_seg = self.__getSample(user_ids[username][id_]) | |
info = "".join([str(k)+":\t"+str(v)+"\n" for k,v in list(info.items())[1:]]).strip() | |
return (gr.Number.update(value=next_idx_sample),progress,img,info,img_seg) | |