5 import torch, torchvision
8 from torch.nn import functional as F
11 class Blanket(torch.autograd.Function):
15 y /= y.pow(2).sum(dim=1, keepdim=True).sqrt() + 1e-6
16 y *= math.sqrt(y.numel() / y.size(0))
21 # Normalize the forward
26 def backward(ctx, grad_output):
27 grad_output = grad_output.clone()
28 # Normalize the gradient
29 Blanket.normalize(grad_output)
33 blanket = Blanket.apply
35 ######################################################################
37 if __name__ == "__main__":
38 x = torch.rand(2, 3).requires_grad_()
42 g = torch.autograd.grad(z, x)[0]