- A = torch.randn(1, 5, dtype=torch.float64).requires_grad_()
- X = torch.randn(1, 5, 3, dtype=torch.float64).requires_grad_()
- Y0 = torch.randn(1, 3, dtype=torch.float64).requires_grad_()
+ A = torch.randn(N, T, dtype=torch.float64).requires_grad_()
+ X = torch.randn(N, T, D, dtype=torch.float64).requires_grad_()
+ Y0 = torch.randn(N, D, dtype=torch.float64).requires_grad_()