import os
import torch
import torch.nn as nn
from ..utils.config_utils import ConfigMixin
from .model_load_utils import load_state_dict
DIFFUSER_SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
class DiffusionModel(nn.Module):
config_class = ConfigMixin
weigths_name = DIFFUSER_SAFETENSORS_WEIGHTS_NAME
def __init__(self, config):
super().__init__()
self.config = config
@classmethod
def from_pretrained(cls, model_path, **kwargs):
dtype = kwargs.pop('dtype', None)
if not (dtype in {torch.bfloat16, torch.float16}):
raise ValueError("dtype should be a torch.bfloat16 or torch.float16")
real_path = os.path.abspath(model_path)
if not (os.path.exists(real_path) and os.path.isdir(real_path)):
raise ValueError(f"{real_path} is invalid!")
if not issubclass(cls.config_class, ConfigMixin):
raise ValueError("config_class is not subclass of ConfigMixin.")
if cls.weigths_name is None:
raise ValueError("weigths_name is not defined.")
weights_path = os.path.join(real_path, cls.weigths_name)
if not (os.path.exists(weights_path) and os.path.isfile(weights_path)):
raise ValueError(f"'{cls.weigths_name}' is not found in '{model_path}'!")
init_dict, _ = cls.config_class.load_config(real_path, **kwargs)
config = cls.config_class(**init_dict)
model = cls(config)
state_dict = load_state_dict(weights_path)
model._load_weights(state_dict)
if dtype is not None:
model.to(dtype)
return model
def _load_weights(self, state_dict):
with torch.no_grad():
self.load_state_dict(state_dict)