+
+def the_dot_products(value1, value2, params):
+ g1g1, g1g2, g2g2 = 0, 0, 0
+ for p in params:
+ g1 = torch.autograd.grad(value1, p, retain_graph=True)[0]
+ g2 = torch.autograd.grad(value2, p, retain_graph=True)[0]
+ g1g1 += g1.pow(2).sum()[None]
+ g2g2 += g2.pow(2).sum()[None]
+ g1g2 += (g1 * g2).sum()[None]
+ return torch.cat([g1g1, g1g2, g2g2])
+
+
+def update_ave_grad(value, params, name, eps=1e-3):
+ for p in params:
+ g = torch.autograd.grad(value, p, retain_graph=True)[0]
+ ag = getattr(p, name) if hasattr(p, name) else 0
+ setattr(p, name, (1 - eps) * ag + eps * g)
+
+
+def norm(params, name):
+ s = 0
+ for p in params:
+ s += getattr(p, name).pow(2).sum()
+ return s
+
+