from typing import Optional, List
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from tensorfn.config import (
MainConfig,
Config,
Optimizer,
Scheduler,
DataLoader,
Instance,
)
import diffusion
import model
from models.unet_autoenc import BeatGANsAutoencConfig
class Diffusion(Config):
beta_schedule: Instance
class Dataset(Config):
name: StrictStr
path: StrictStr
resolution: StrictInt
class Training(Config):
ckpt_path: StrictStr
optimizer: Optimizer
scheduler: Optional[Scheduler]
dataloader: DataLoader
class Eval(Config):
wandb: StrictBool
save_every: StrictInt
valid_every: StrictInt
log_every: StrictInt
class DiffusionConfig(MainConfig):
diffusion: Diffusion
training: Training
def get_model_conf():
return BeatGANsAutoencConfig(image_size=256,
in_channels=3+20,
model_channels=128,
out_channels=3*2,
num_res_blocks=2,
num_input_res_blocks=None,
embed_channels=512,
attention_resolutions=(32, 16, 8,),
time_embed_channels=None,
dropout=0.1,
channel_mult=(1, 1, 2, 2, 4, 4),
input_channel_mult=None,
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
resblock_updown=True,
use_new_attention_order=False,
resnet_two_cond=True,
resnet_cond_channels=None,
resnet_use_zero_module=True,
attn_checkpoint=False,
enc_out_channels=512,
enc_attn_resolutions=None,
enc_pool='adaptivenonzero',
enc_num_res_block=2,
enc_channel_mult=(1, 1, 2, 2, 4, 4, 4),
enc_grad_checkpoint=False,
latent_net_conf=None)