@@ -78,7 +78,7 @@ network_G:
#### path
path:
- pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth
+# pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth
strict_load: true
resume_state: auto
@@ -47,10 +47,10 @@ class SRFlowModel(BaseModel):
# define network and load pretrained models
self.netG = networks.define_Flow(opt, step).to(self.device)
- if opt['dist']:
- self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
- else:
- self.netG = DataParallel(self.netG)
+ # if opt['dist']:
+ # self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
+ # else:
+ # self.netG = DataParallel(self.netG)
# print network
self.print_network()
@@ -34,13 +34,13 @@ class InvertibleConv1x1(nn.Module):
def get_weight(self, input, reverse):
w_shape = self.w_shape
pixels = thops.pixels(input)
- dlogdet = torch.slogdet(self.weight)[1] * pixels
+ dlogdet = torch.log(abs(torch.det(self.weight))) * pixels
if not reverse:
weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
else:
- weight = torch.inverse(self.weight.double()).float() \
- .view(w_shape[0], w_shape[1], 1, 1)
+ weight = self.weight.float().view(w_shape[0], w_shape[1], 1, 1)
return weight, dlogdet
+
def forward(self, input, logdet=None, reverse=False):
"""
log-det = log|abs(|W|)| * pixels
@@ -57,7 +57,7 @@ class SRFlowNet(nn.Module):
return True
return False
- def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
+ def forward(self, lr=None, gt=None, z=None, eps_std=0.9, reverse=True, epses=None, reverse_with_grad=False,
lr_enc=None,
add_gt_noise=False, step=None, y_label=None):
if not reverse:
@@ -143,6 +143,8 @@ class SRFlowNet(nn.Module):
return -score_real
def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
+ if z == None:
+ z = thops.normal(mean=0, std=eps_std, size=(lr.size(0), 192, lr.size(2)//2, lr.size(3)//2)).to(lr.device)
logdet = torch.zeros_like(lr[:, 0, 0, 0])
pixels = thops.pixels(lr) * self.opt['scale'] ** 2
@@ -51,7 +51,7 @@ class Conv2d(nn.Conv2d):
super().__init__(in_channels, out_channels, kernel_size, stride,
padding, bias=(not do_actnorm))
# init weight with std
- self.weight.data.normal_(mean=0.0, std=weight_std)
+ # self.weight.data.normal_(mean=0.0, std=weight_std)
if not do_actnorm:
self.bias.data.zero_()
else:
@@ -114,8 +114,8 @@ class GaussianDiag:
def sample_eps(shape, eps_std, seed=None):
if seed is not None:
torch.manual_seed(seed)
- eps = torch.normal(mean=torch.zeros(shape),
- std=torch.ones(shape) * eps_std)
+ eps = thops.normal(mean=0.,
+ std=1. * eps_std,size =shape)
return eps
@@ -15,6 +15,10 @@
# This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
import torch
+import numpy as np
+def normal(size=None,mean=0.0, std=1):
+ a = np.random.normal(loc=mean, scale=std, size=size).astype(np.float32)
+ return torch.from_numpy(a)
def sum(tensor, dim=None, keepdim=False):
--
2.29.2.windows.2