import os
import torch
import numpy as np
import os
import argparse
def parse_args():
desc = "Pytorch implementation of CGAN collections"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--input_dim', type=int, default=62, help="The input_dim")
parser.add_argument('--output_dim', type=int, default=3, help="The output_dim")
parser.add_argument('--input_size', type=int, default=28, help="The image size of MNIST")
parser.add_argument('--class_num', type=int, default=10, help="The num of classes of MNIST")
parser.add_argument('--pth_path', type=str, default='CGAN_G.pth', help='pth model path')
parser.add_argument('--onnx_path', type=str, default="CGAN.onnx", help='onnx model path')
parser.add_argument('--save_path', type=str, default="data", help='processed data path')
return parser.parse_args()
def prep_preocess(args):
sample_num = args.class_num**2
z_dim = args.input_dim
sample_z_ = torch.zeros((sample_num, z_dim))
for i in range(args.class_num):
sample_z_[i * args.class_num] = torch.rand(1,z_dim)
for j in range(1, args.class_num):
sample_z_[i * args.class_num + j] = sample_z_[i * args.class_num]
if not os.path.exists(os.path.join(args.save_path)):
os.makedirs(os.path.join(args.save_path))
temp = torch.zeros((args.class_num, 1))
for i in range(args.class_num):
temp[i, 0] = i
temp_y = torch.zeros((sample_num, 1))
for i in range(args.class_num):
temp_y[i * args.class_num: (i + 1) * args.class_num] = temp
sample_y_ = torch.zeros((sample_num, args.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
input = torch.cat([sample_z_, sample_y_], 1)
input = np.array(input).astype(np.float32)
input.tofile(os.path.join(args.save_path, 'input' + '.bin'))
if __name__ == "__main__":
args = parse_args()
prep_preocess(args)
print("data preprocessed stored in",args.save_path)