img = img.view((1,) + img.size())
ref_input = 0.5 + 0.5 * (img - img.mean()) / img.std()
mse_loss = torch.nn.MSELoss()
edge_energy = MultiScaleEdgeEnergy()
img = img.view((1,) + img.size())
ref_input = 0.5 + 0.5 * (img - img.mean()) / img.std()
mse_loss = torch.nn.MSELoss()
edge_energy = MultiScaleEdgeEnergy()