gheinrich commited on
Commit
31f7840
1 Parent(s): c1fddb0

Upload model

Browse files
cls_token.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import torch
2
  from torch import nn
3
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
  import torch
10
  from torch import nn
11
 
config.json CHANGED
@@ -229,6 +229,8 @@
229
  "AutoConfig": "hf_model.RADIOConfig",
230
  "AutoModel": "hf_model.RADIOModel"
231
  },
 
 
232
  "torch_dtype": "float32",
233
  "transformers_version": "4.29.0",
234
  "version": "v1"
 
229
  "AutoConfig": "hf_model.RADIOConfig",
230
  "AutoModel": "hf_model.RADIOModel"
231
  },
232
+ "return_spatial_features": true,
233
+ "return_summary": true,
234
  "torch_dtype": "float32",
235
  "transformers_version": "4.29.0",
236
  "version": "v1"
enable_cpe_support.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  from typing import Union, Tuple
2
  from types import MethodType
3
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
  from typing import Union, Tuple
10
  from types import MethodType
11
 
hf_model.py CHANGED
@@ -35,10 +35,14 @@ class RADIOConfig(PretrainedConfig):
35
  self,
36
  args: Optional[dict] = None,
37
  version: Optional[str]="v1",
 
 
38
  **kwargs,
39
  ):
40
  self.args = args
41
  self.version = version
 
 
42
  super().__init__(**kwargs)
43
 
44
 
@@ -52,12 +56,10 @@ class RADIOModel(PreTrainedModel):
52
 
53
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
54
  args = RADIOArgs(**config.args)
 
55
  self.model = create_model_from_args(args)
56
-
57
  self.input_conditioner: InputConditioner = get_default_conditioner()
58
 
59
- #return RADIOModel(mod, conditioner, return_summary=return_summary, return_spatial_features=return_spatial_features)
60
-
61
  def forward(self, x: torch.Tensor):
62
  x = self.input_conditioner(x)
63
 
@@ -79,8 +81,8 @@ class RADIOModel(PreTrainedModel):
79
  else:
80
  raise ValueError("Unsupported model type")
81
 
82
- if self.return_summary and self.return_spatial_features:
83
  return summary, all_feat
84
- elif self.return_summary:
85
  return summary
86
  return all_feat
 
35
  self,
36
  args: Optional[dict] = None,
37
  version: Optional[str]="v1",
38
+ return_summary: Optional[bool] = True,
39
+ return_spatial_features: Optional[bool] = True,
40
  **kwargs,
41
  ):
42
  self.args = args
43
  self.version = version
44
+ self.return_summary = return_summary
45
+ self.return_spatial_features = return_spatial_features
46
  super().__init__(**kwargs)
47
 
48
 
 
56
 
57
  RADIOArgs = namedtuple("RADIOArgs", config.args.keys())
58
  args = RADIOArgs(**config.args)
59
+ self.config = config
60
  self.model = create_model_from_args(args)
 
61
  self.input_conditioner: InputConditioner = get_default_conditioner()
62
 
 
 
63
  def forward(self, x: torch.Tensor):
64
  x = self.input_conditioner(x)
65
 
 
81
  else:
82
  raise ValueError("Unsupported model type")
83
 
84
+ if self.config.return_summary and self.config.return_spatial_features:
85
  return summary, all_feat
86
+ elif self.config.return_summary:
87
  return summary
88
  return all_feat
input_conditioner.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  from typing import Union, Tuple
2
 
3
  import torch
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
  from typing import Union, Tuple
10
 
11
  import torch
model.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  from torch import nn
2
 
3
  from timm.models import create_model
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
  from torch import nn
10
 
11
  from timm.models import create_model
vit_patch_generator.py CHANGED
@@ -1,3 +1,11 @@
 
 
 
 
 
 
 
 
1
  import math
2
  from typing import Union, Tuple, Optional
3
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
  import math
10
  from typing import Union, Tuple, Optional
11