import argparse
import datetime
import logging
import os
import sys
import time
from typing import Optional
from PIL import Image
import ray
import torch
import torch.distributed as dist
import torch_npu
from pydantic import BaseModel
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, WAN_CONFIGS
from wan.distributed.parallel_mgr import ParallelConfig, init_parallel_env
from wan.distributed.tp_applicator import TensorParallelApplicator
from wan.utils.utils import save_video
from mindiesd import CacheConfig, CacheAgent
from request import GeneratorRequest
torch_npu.npu.set_compile_mode(jit_compile=False)
torch.npu.config.allow_internal_format = False
ATTENTION_CACHE_METHOD = 'attention_cache'
EXAMPLE_PROMPT = {
"t2v-A14B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"i2v-A14B": {
"prompt": (
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. "
"The fluffy-furred feline gazes directly at the camera with a relaxed expression. "
"Blurred beach scenery forms the background featuring crystal-clear waters, "
"distant green hills, and a blue sky dotted with white clouds. "
"The cat assumes a naturally relaxed posture, as if savoring the sea breeze "
"and warm sunlight. A close-up shot highlights the feline's intricate details "
"and the refreshing atmosphere of the seaside."
),
"image": "examples/i2v_input.JPG",
},
"ti2v-5B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
}
@ray.remote(resources={"NPU": 1})
class GeneratorWorker:
def __init__(self, args, rank: int, world_size: int):
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
os.environ["ALGO"] = "1"
os.environ["PYTORCH_NPU_ALLOC_CONF"] = 'expandable_segments:True'
os.environ["TASK_QUEUE_ENABLE"] = "2"
os.environ["CPU_AFFINITY_CONF"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.rank = rank
self.world_size = world_size
self._init_logging(rank)
self.initialize_model(args)
@classmethod
def _init_logging(cls, rank):
if rank == 0:
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def initialize_model(self, args):
self.args = args
cfg = self._init_parallel_env(args)
rainfusion_config = {
"sparsity": args.sparsity,
"skip_timesteps": args.sparse_start_step,
"grid_size": None,
"atten_mask_all": None,
"type": args.rainfusion_type
}
if dist.is_initialized():
base_seed = [args.base_seed] if self.rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
logging.info("Model initialization completed")
if "t2v" in args.task:
self._init_t2v_pipeline(args, cfg, rainfusion_config)
else:
self._init_i2v_pipeline(args, cfg, rainfusion_config)
def generate(self, request: GeneratorRequest):
stream = torch.npu.Stream()
stream.synchronize()
start_time = time.time()
request_info = {
"task": request.task,
"prompt": request.prompt,
"size": request.size,
"steps": request.sample_steps,
"frame_num": request.frame_num,
"shift": request.sample_shift,
"sample_solver": request.sample_solver,
"sampling_steps": request.sample_steps,
"guide_scale": request.sample_guide_scale,
"seed": request.base_seed,
"offload_model": request.offload_model
}
logging.info(f"request: {request_info}")
img = None
if request.image is not None:
img = Image.open(request.image).convert("RGB")
logging.info(f"Input image: {request.image}")
rainfusion_config = {
"sparsity": self.sparsity,
"skip_timesteps": self.sparse_start_step,
"grid_size": None,
"atten_mask_all": None,
"type": self.rainfusion_type
}
if self.use_rainfusion:
if self.dit_fsdp:
self.pipe.low_noise_model._fsdp_wrapped_module.rainfusion_config = rainfusion_config
self.pipe.high_noise_model._fsdp_wrapped_module.rainfusion_config = rainfusion_config
else:
self.pipe.low_noise_model.rainfusion_config = rainfusion_config
self.pipe.high_noise_model.rainfusion_config = rainfusion_config
self.pipe.low_noise_model.freqs_list = None
self.pipe.high_noise_model.freqs_list = None
logging.info(f"freqs_list: {self.pipe.low_noise_model.freqs_list}")
if "t2v" in request.task:
video = self.pipe.generate(
request.prompt,
size=SIZE_CONFIGS[request.size],
frame_num=request.frame_num,
shift=request.sample_shift,
sample_solver=request.sample_solver,
sampling_steps=request.sample_steps,
guide_scale=request.sample_guide_scale,
seed=request.base_seed,
offload_model=request.offload_model
)
else:
if img is None:
raise ValueError("Image is required for i2v generation.")
video = self.pipe.generate(
request.prompt,
img,
max_area=MAX_AREA_CONFIGS[request.size],
frame_num=request.frame_num,
shift=request.sample_shift,
sample_solver=request.sample_solver,
sampling_steps=request.sample_steps,
guide_scale=request.sample_guide_scale,
seed=request.base_seed,
offload_model=request.offload_model)
stream.synchronize()
elapsed_time = time.time() - start_time
if self.rank == 0:
if request.save_disk_path is None:
formatted_time = datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y%m%d_%H%M%S")
formatted_prompt = request.prompt.replace(" ", "_").replace("/", "_")[:50]
suffix = '.mp4'
size_format = request.size.replace('*', 'x') if sys.platform == 'win32' else request.size
request.save_disk_path = f"{size_format}_{formatted_prompt}_{formatted_time}{suffix}"
logging.info(f"Saving generated video to {request.save_disk_path}")
save_video(
tensor=video[None],
save_file=request.save_disk_path,
fps=request.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
del video
return {
"message": "Video generated successfully",
"elapsed_time": f"{elapsed_time:.2f} sec",
"output": request.save_disk_path
}
def _init_common_pipeline(self, args, pipe, rainfusion_config):
transformer_low = pipe.low_noise_model
transformer_high = pipe.high_noise_model
if args.use_rainfusion:
if args.dit_fsdp:
transformer_low._fsdp_wrapped_module.rainfusion_config = rainfusion_config
transformer_high._fsdp_wrapped_module.rainfusion_config = rainfusion_config
else:
transformer_low.rainfusion_config = rainfusion_config
transformer_high.rainfusion_config = rainfusion_config
self.use_rainfusion = args.use_rainfusion
self.sparsity = args.sparsity
self.sparse_start_step = args.sparse_start_step
self.rainfusion_type = rainfusion_config["type"]
if args.tp_size > 1:
logging.info("Initializing Tensor Parallel ...")
applicator = TensorParallelApplicator(args.tp_size, device_map="cpu")
applicator.apply_to_model(transformer_low)
applicator.apply_to_model(transformer_high)
if args.use_attentioncache:
config_high = CacheConfig(
method=ATTENTION_CACHE_METHOD,
blocks_count=len(transformer_high.blocks),
steps_count=args.sample_steps,
step_start=args.start_step,
step_interval=args.attentioncache_interval,
step_end=args.end_step
)
config_low = CacheConfig(
method=ATTENTION_CACHE_METHOD,
blocks_count=len(transformer_low.blocks),
steps_count=args.sample_steps,
step_start=args.start_step,
step_interval=args.attentioncache_interval,
step_end=args.end_step
)
else:
config_high = CacheConfig(
method=ATTENTION_CACHE_METHOD,
blocks_count=len(transformer_high.blocks),
steps_count=args.sample_steps
)
config_low = CacheConfig(
method=ATTENTION_CACHE_METHOD,
blocks_count=len(transformer_low.blocks),
steps_count=args.sample_steps
)
cache_high = CacheAgent(config_high)
cache_low = CacheAgent(config_low)
if args.dit_fsdp:
for block in transformer_high._fsdp_wrapped_module.blocks:
block._fsdp_wrapped_module.cache = cache_high
block._fsdp_wrapped_module.args = args
for block in transformer_low._fsdp_wrapped_module.blocks:
block._fsdp_wrapped_module.cache = cache_low
block._fsdp_wrapped_module.args = args
else:
for block in transformer_high.blocks:
block.cache = cache_high
block.args = args
for block in transformer_low.blocks:
block.cache = cache_low
block.args = args
def _init_t2v_pipeline(self, args, cfg, rainfusion_config):
logging.info("Creating WanT2V pipeline.")
self.pipe = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=self.rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_sp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
use_vae_parallel=args.vae_parallel,
)
self._init_common_pipeline(args, self.pipe, rainfusion_config)
logging.info("Warm up 2 steps ...")
self.pipe.generate(
EXAMPLE_PROMPT["t2v-A14B"]["prompt"],
size=SIZE_CONFIGS["1280*720"],
frame_num=81,
shift=None,
sample_solver='unipc',
sampling_steps=2,
guide_scale=args.sample_guide_scale,
seed=0,
offload_model=args.offload_model)
logging.info("T2V warmup finished.")
def _init_i2v_pipeline(self, args, cfg, rainfusion_config):
logging.info("Creating WanI2V pipeline.")
self.pipe = wan.WanI2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=0,
rank=self.rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_sp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
convert_model_dtype=args.convert_model_dtype,
use_vae_parallel=args.vae_parallel,
)
self._init_common_pipeline(args, self.pipe, rainfusion_config)
logging.info("Warm up 2 steps ...")
img = Image.open(EXAMPLE_PROMPT["i2v-A14B"]["image"]).convert("RGB")
self.pipe.generate(
EXAMPLE_PROMPT["i2v-A14B"]["prompt"],
img,
max_area=MAX_AREA_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=2,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
logging.info("I2V warmup finished.")
def _init_parallel_env(self, args):
if self.world_size > 1:
torch.npu.set_device(0)
dist.init_process_group(
backend="hccl",
init_method="env://",
rank=self.rank,
world_size=self.world_size)
need_parallel_env = (
args.cfg_size > 1 or
args.ulysses_size > 1 or
args.ring_size > 1 or
args.tp_size > 1
)
if need_parallel_env:
if args.cfg_size * args.ulysses_size * args.ring_size * args.tp_size != self.world_size:
product = args.cfg_size * args.ulysses_size * args.ring_size * args.tp_size
raise ValueError(
f"The number of cfg_size, ulysses_size, ring_size and tp_size should be equal to the world size. "
f"Got {args.cfg_size} * {args.ulysses_size} * {args.ring_size} * {args.tp_size} = {product}, \
expected {self.world_size}"
)
sp_degree = args.ulysses_size * args.ring_size
parallel_config = ParallelConfig(
sp_degree=sp_degree,
ulysses_degree=args.ulysses_size,
ring_degree=args.ring_size,
tp_degree=args.tp_size,
use_cfg_parallel=(args.cfg_size == 2),
world_size=self.world_size,
)
init_parallel_env(parallel_config)
if args.tp_size > 1 and args.dit_fsdp:
logging.info("DiT using Tensor Parallel, disabled dit_fsdp")
args.dit_fsdp = False
self.dit_fsdp = args.dit_fsdp
cfg = WAN_CONFIGS[args.task]
if args.ulysses_size > 1:
if cfg.num_heads % args.ulysses_size != 0:
raise ValueError(f"`{cfg.num_heads}` cannot be divided evenly by `{args.ulysses_size}`")
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
return cfg
def _parse_args():
from generate import _parse_args
return _parse_args()
def _validate_args(request: GeneratorRequest):
from generate import _validate_args
_validate_args(request)