patrickligardes commited on
Commit
1086f36
β€’
1 Parent(s): 92118b2

Update utils_mask.py

Browse files
Files changed (1) hide show
  1. utils_mask.py +4 -3
utils_mask.py CHANGED
@@ -88,7 +88,7 @@ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint:
88
  (parse_array == 12).astype(np.float32) + \
89
  (parse_array == 13).astype(np.float32) + \
90
  (parse_array == 5).astype(np.float32)
91
- parse_mask_legs = cv2.dilate(parse_mask_legs.astype(np.uint8), np.ones((6, 6), np.uint8), iterations=6) # Dilate legs
92
 
93
  # Fixed and changeable masks
94
  parser_mask_fixed = (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
@@ -97,8 +97,9 @@ def get_mask_location(model_type, category, model_parse: Image.Image, keypoint:
97
  parse_mask_fixed_lower_cloth # Add lower cloth to fixed mask
98
  parser_mask_changeable = np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
99
 
100
- # Combine masks (upper body + legs)
101
- parse_mask = np.maximum(parse_mask_upper, parse_mask_legs)
 
102
 
103
  elif category == 'upper_body':
104
  parse_mask = (parse_array == 4).astype(np.float32) + (parse_array == 7).astype(np.float32)
 
88
  (parse_array == 12).astype(np.float32) + \
89
  (parse_array == 13).astype(np.float32) + \
90
  (parse_array == 5).astype(np.float32)
91
+ parse_mask_legs_dilated = cv2.dilate(parse_mask_legs.astype(np.uint8), np.ones((6, 6), np.uint8), iterations=6) # Dilate legs
92
 
93
  # Fixed and changeable masks
94
  parser_mask_fixed = (parse_array == label_map["upper_clothes"]).astype(np.float32) + \
 
97
  parse_mask_fixed_lower_cloth # Add lower cloth to fixed mask
98
  parser_mask_changeable = np.logical_and(parse_array, np.logical_not(parser_mask_fixed))
99
 
100
+ # Combine masks (upper body + dilated legs + changeable areas)
101
+ parse_mask = np.maximum.reduce([parse_mask_upper, parse_mask_legs_dilated, parser_mask_changeable.astype(np.float32)])
102
+
103
 
104
  elif category == 'upper_body':
105
  parse_mask = (parse_array == 4).astype(np.float32) + (parse_array == 7).astype(np.float32)