Linoy Tsaban commited on
Commit
6908973
1 Parent(s): b9a325a

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +114 -0
utils.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image, ImageDraw ,ImageFont
3
+ from matplotlib import pyplot as plt
4
+ import torchvision.transforms as T
5
+ import os
6
+ import torch
7
+ import yaml
8
+
9
+ def show_torch_img(img):
10
+ img = to_np_image(img)
11
+ plt.imshow(img)
12
+ plt.axis("off")
13
+
14
+ def to_np_image(all_images):
15
+ all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0]
16
+ return all_images
17
+
18
+ def tensor_to_pil(tensor_imgs):
19
+ if type(tensor_imgs) == list:
20
+ tensor_imgs = torch.cat(tensor_imgs)
21
+ tensor_imgs = (tensor_imgs / 2 + 0.5).clamp(0, 1)
22
+ to_pil = T.ToPILImage()
23
+ pil_imgs = [to_pil(img) for img in tensor_imgs]
24
+ return pil_imgs
25
+
26
+ def pil_to_tensor(pil_imgs):
27
+ to_torch = T.ToTensor()
28
+ if type(pil_imgs) == PIL.Image.Image:
29
+ tensor_imgs = to_torch(pil_imgs).unsqueeze(0)*2-1
30
+ elif type(pil_imgs) == list:
31
+ tensor_imgs = torch.cat([to_torch(pil_imgs).unsqueeze(0)*2-1 for img in pil_imgs]).to(device)
32
+ else:
33
+ raise Exception("Input need to be PIL.Image or list of PIL.Image")
34
+ return tensor_imgs
35
+
36
+
37
+ ## TODO implement this
38
+ # n = 10
39
+ # num_rows = 4
40
+ # num_col = n // num_rows
41
+ # num_col = num_col + 1 if n % num_rows else num_col
42
+ # num_col
43
+ def add_margin(pil_img, top = 0, right = 0, bottom = 0,
44
+ left = 0, color = (255,255,255)):
45
+ width, height = pil_img.size
46
+ new_width = width + right + left
47
+ new_height = height + top + bottom
48
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
49
+
50
+ result.paste(pil_img, (left, top))
51
+ return result
52
+
53
+ def image_grid(imgs, rows = 1, cols = None,
54
+ size = None,
55
+ titles = None, text_pos = (0, 0)):
56
+ if type(imgs) == list and type(imgs[0]) == torch.Tensor:
57
+ imgs = torch.cat(imgs)
58
+ if type(imgs) == torch.Tensor:
59
+ imgs = tensor_to_pil(imgs)
60
+
61
+ if not size is None:
62
+ imgs = [img.resize((size,size)) for img in imgs]
63
+ if cols is None:
64
+ cols = len(imgs)
65
+ assert len(imgs) >= rows*cols
66
+
67
+ top=20
68
+ w, h = imgs[0].size
69
+ delta = 0
70
+ if len(imgs)> 1 and not imgs[1].size[1] == h:
71
+ delta = top
72
+ h = imgs[1].size[1]
73
+ if not titles is None:
74
+ font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeMono.ttf",
75
+ size = 20, encoding="unic")
76
+ h = top + h
77
+ grid = Image.new('RGB', size=(cols*w, rows*h+delta))
78
+ for i, img in enumerate(imgs):
79
+
80
+ if not titles is None:
81
+ img = add_margin(img, top = top, bottom = 0,left=0)
82
+ draw = ImageDraw.Draw(img)
83
+ draw.text(text_pos, titles[i],(0,0,0),
84
+ font = font)
85
+ if not delta == 0 and i > 0:
86
+ grid.paste(img, box=(i%cols*w, i//cols*h+delta))
87
+ else:
88
+ grid.paste(img, box=(i%cols*w, i//cols*h))
89
+
90
+ return grid
91
+
92
+
93
+ """
94
+ input_folder - dataset folder
95
+ """
96
+ def load_dataset(input_folder):
97
+ # full_file_names = glob.glob(input_folder)
98
+ # class_names = [x[0] for x in os.walk(input_folder)]
99
+ class_names = next(os.walk(input_folder))[1]
100
+ class_names[:] = [d for d in class_names if not d[0] == '.']
101
+ file_names=[]
102
+ for class_name in class_names:
103
+ cur_path = os.path.join(input_folder, class_name)
104
+ filenames = next(os.walk(cur_path), (None, None, []))[2]
105
+ filenames = [f for f in filenames if not f[0] == '.']
106
+ file_names.append(filenames)
107
+ return class_names, file_names
108
+
109
+
110
+ def dataset_from_yaml(yaml_location):
111
+ with open(yaml_location, 'r') as stream:
112
+ data_loaded = yaml.safe_load(stream)
113
+
114
+ return data_loaded