import torch
import argparse
from magvit2_pytorch import (
VideoTokenizer,
VideoTokenizerTrainer
)
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False
import torch.nn.functional as F
from npu_patch import adaptive_avg_pool2d
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--dataset_folder", type=str, default=None, help="Path of training dataset."
)
parser.add_argument(
"--batch_size", type=int, default=16, help="Batch size for the training dataloader."
)
parser.add_argument(
"--grad_accum_every",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=2e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--num_train_steps",
type=int,
default=5000,
help="Total number of training steps to perform.",
)
args = parser.parse_args()
return args
def main():
args = parse_args()
F.adaptive_avg_pool2d = adaptive_avg_pool2d
tokenizer = VideoTokenizer(
image_size = 128,
init_dim = 64,
max_dim = 512,
codebook_size = 1024,
layers = (
'residual',
'compress_space',
('consecutive_residual', 2),
'compress_space',
('consecutive_residual', 2),
'linear_attend_space',
'compress_space',
('consecutive_residual', 2),
'attend_space',
'compress_time',
('consecutive_residual', 2),
'compress_time',
('consecutive_residual', 2),
'attend_time',
),
use_gan = False
)
trainer = VideoTokenizerTrainer(
tokenizer,
dataset_folder = args.dataset_folder,
dataset_type = 'videos',
batch_size = args.batch_size,
grad_accum_every = args.grad_accum_every,
learning_rate = args.learning_rate,
num_train_steps = args.num_train_steps
)
trainer.train()
if __name__ == "__main__":
main()