Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
from graph_decoder.diffusion_model import GraphDiT | |
# model_state = load_model() | |
# generate_graph(2.5, 15.4, 21.0, 1.5, 2.8, 2, 0, 1, model_state, 50) | |
def count_parameters(model): | |
r""" | |
Returns the number of trainable parameters and number of all parameters in the model. | |
""" | |
trainable_params, all_param = 0, 0 | |
for param in model.parameters(): | |
num_params = param.numel() | |
all_param += num_params | |
if param.requires_grad: | |
trainable_params += num_params | |
return trainable_params, all_param | |
def load_graph_decoder(path='model_labeled'): | |
model_config_path = f"{path}/config.yaml" | |
data_info_path = f"{path}/data.meta.json" | |
model = GraphDiT( | |
model_config_path=model_config_path, | |
data_info_path=data_info_path, | |
# model_dtype=torch.float16, | |
model_dtype=torch.float32, | |
) | |
model.init_model(path) | |
model.disable_grads() | |
trainable_params, all_param = count_parameters(model) | |
param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( | |
path, trainable_params, all_param, 100 * trainable_params / all_param | |
) | |
print(param_stats) | |
return model | |