File size: 1,457 Bytes
150d962
 
08aed96
150d962
 
 
 
08aed96
150d962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import argparse
from ormbg import ORMBG


def export_to_onnx(model_path, onnx_path):

    net = ORMBG()

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_path))
        net = net.cuda()
    else:
        net.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))

    net.eval()

    # Create a dummy input tensor. The size should match the model's input size.
    # Adjust the dimensions as necessary; here it is assumed the input is a 3-channel image.
    dummy_input = torch.randn(
        1,
        3,
        1024,
        1024,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )

    torch.onnx.export(
        net,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=10,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Export a trained model to ONNX format."
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default="./models/ormbg.pth",
        help="The path to the trained model file.",
    )
    parser.add_argument(
        "--onnx_path",
        type=str,
        default="./models/example.onnx",
        help="The path where the ONNX model will be saved.",
    )

    args = parser.parse_args()

    export_to_onnx(args.model_path, args.onnx_path)