ctheodoris madhavanvenkatesh commited on
Commit
22bf20f
1 Parent(s): fe1640b

comment out "def save_model_without_heads(original_model_save_directory)"; redundant for ISP/Emb extractor (#382)

Browse files

- comment out "def save_model_without_heads(original_model_save_directory)"; redundant for ISP/Emb extractor (0a4e8a4ba076513934876cf83b1d3e727c26d46e)


Co-authored-by: Madhavan Venkatesh <[email protected]>

Files changed (1) hide show
  1. geneformer/mtl/utils.py +38 -38
geneformer/mtl/utils.py CHANGED
@@ -73,44 +73,44 @@ def calculate_combined_f1(combined_labels, combined_preds):
73
  return f1, accuracy
74
 
75
 
76
- def save_model_without_heads(original_model_save_directory):
77
- # Create a new directory for the model without heads
78
- new_model_save_directory = original_model_save_directory + "_No_Heads"
79
- if not os.path.exists(new_model_save_directory):
80
- os.makedirs(new_model_save_directory)
81
-
82
- # Load the model state dictionary
83
- model_state_dict = torch.load(
84
- os.path.join(original_model_save_directory, "pytorch_model.bin")
85
- )
86
-
87
- # Initialize a new BERT model without the classification heads
88
- config = BertConfig.from_pretrained(
89
- os.path.join(original_model_save_directory, "config.json")
90
- )
91
- model_without_heads = BertModel(config)
92
-
93
- # Filter the state dict to exclude classification heads
94
- model_without_heads_state_dict = {
95
- k: v
96
- for k, v in model_state_dict.items()
97
- if not k.startswith("classification_heads")
98
- }
99
-
100
- # Load the filtered state dict into the model
101
- model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
102
-
103
- # Save the model without heads
104
- model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
105
- torch.save(model_without_heads.state_dict(), model_save_path)
106
-
107
- # Copy the configuration file
108
- shutil.copy(
109
- os.path.join(original_model_save_directory, "config.json"),
110
- new_model_save_directory,
111
- )
112
-
113
- print(f"Model without classification heads saved to {new_model_save_directory}")
114
 
115
 
116
  def get_layer_freeze_range(pretrained_path):
 
73
  return f1, accuracy
74
 
75
 
76
+ # def save_model_without_heads(original_model_save_directory):
77
+ # # Create a new directory for the model without heads
78
+ # new_model_save_directory = original_model_save_directory + "_No_Heads"
79
+ # if not os.path.exists(new_model_save_directory):
80
+ # os.makedirs(new_model_save_directory)
81
+
82
+ # # Load the model state dictionary
83
+ # model_state_dict = torch.load(
84
+ # os.path.join(original_model_save_directory, "pytorch_model.bin")
85
+ # )
86
+
87
+ # # Initialize a new BERT model without the classification heads
88
+ # config = BertConfig.from_pretrained(
89
+ # os.path.join(original_model_save_directory, "config.json")
90
+ # )
91
+ # model_without_heads = BertModel(config)
92
+
93
+ # # Filter the state dict to exclude classification heads
94
+ # model_without_heads_state_dict = {
95
+ # k: v
96
+ # for k, v in model_state_dict.items()
97
+ # if not k.startswith("classification_heads")
98
+ # }
99
+
100
+ # # Load the filtered state dict into the model
101
+ # model_without_heads.load_state_dict(model_without_heads_state_dict, strict=False)
102
+
103
+ # # Save the model without heads
104
+ # model_save_path = os.path.join(new_model_save_directory, "pytorch_model.bin")
105
+ # torch.save(model_without_heads.state_dict(), model_save_path)
106
+
107
+ # # Copy the configuration file
108
+ # shutil.copy(
109
+ # os.path.join(original_model_save_directory, "config.json"),
110
+ # new_model_save_directory,
111
+ # )
112
+
113
+ # print(f"Model without classification heads saved to {new_model_save_directory}")
114
 
115
 
116
  def get_layer_freeze_range(pretrained_path):