import argparse
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="StepVideo inference script")
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_inference_args(parser)
parser = add_parallel_args(parser)
args = parser.parse_args(namespace=namespace)
return args
def add_extra_models_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="Extra models args, including vae, text encoders and tokenizers)"
)
group.add_argument(
"--vae_url",
type=str,
default='127.0.0.1',
help="vae url.",
)
group.add_argument(
"--caption_url",
type=str,
default='127.0.0.1',
help="caption url.",
)
return parser
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Denoise schedule args")
group.add_argument(
"--time_shift",
type=float,
default=7.0,
help="Shift factor for flow matching schedulers.",
)
group.add_argument(
"--flow_reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
group.add_argument(
"--flow_solver",
type=str,
default="euler",
help="Solver for flow matching.",
)
return parser
def add_inference_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Inference args")
group.add_argument(
"--model_dir",
type=str,
default="./ckpts",
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--model_resolution",
type=str,
default="540p",
choices=["540p"],
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load.",
)
group.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size for inference and evaluation.",
)
group.add_argument(
"--infer_steps",
type=int,
default=50,
help="Number of denoising steps for inference.",
)
group.add_argument(
"--save_path",
type=str,
default="./results",
help="Path to save the generated samples.",
)
group.add_argument(
"--output_file_name",
type=str,
default="",
help="Name to save the generated samples.",
)
group.add_argument(
"--name_suffix",
type=str,
default="",
help="Suffix for the names of saved samples.",
)
group.add_argument(
"--num_videos",
type=int,
default=1,
help="Number of videos to generate for each prompt.",
)
group.add_argument(
"--num_frames",
type=int,
default=102,
help="How many frames to sample from a video. ",
)
group.add_argument(
"--height",
type=int,
default=544,
help="The height of video sample",
)
group.add_argument(
"--width",
type=int,
default=992,
help="The width of video sample",
)
group.add_argument(
"--prompt",
type=str,
default=None,
help="Prompt for sampling during evaluation.",
)
group.add_argument(
"--first_image_path",
type=str,
default='./assets/demo.png',
help="The reference image path for image-to-video task.",
)
group.add_argument("--seed", type=int, default=1234, help="Seed for evaluation.")
group.add_argument(
"--pos_magic", type=str, default="画面中的主体动作表现生动自然、画面流畅、生动细节、光线统一柔和、超真实动态捕捉、大师级运镜、整体不变形、超高清、画面稳定、逼真的细节、专业级构图、超细节、清晰。", help="Positive magic prompt for sampling."
)
group.add_argument(
"--neg_magic", type=str, default="动画、模糊、变形、毁容、低质量、拼贴、粒状、标志、抽象、插图、计算机生成、扭曲、动作不流畅、面部有褶皱、表情僵硬、畸形手指", help="Negative magic prompt for sampling."
)
group.add_argument(
"--cfg_scale", type=float, default=9.0, help="Classifier free guidance scale."
)
group.add_argument(
"--motion_score", type=float, default=5, help="Score to control the motion level of the video."
)
return parser
def add_parallel_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Parallel args")
group.add_argument(
"--ulysses_degree",
type=int,
default=8,
help="Ulysses degree.",
)
group.add_argument(
"--ring_degree",
type=int,
default=1,
help="Ulysses degree.",
)
group.add_argument(
"--tensor_parallel_degree",
type=int,
default=1,
help="Tensor parallel degree.",
)
group.add_argument(
"--use_dit_cache",
action='store_true',
help="Use dit cache.",
)
return parser