- for input, targets in task.policy_batches(split="train"):
- output_gpt = gpt(mygpt.BracketedSequence(input), with_readout=False).x
+ for input, policies in task.policy_batches(split="train"):
+ ####
+ # print(f'{input.size()=} {policies.size()=}')
+ # s = maze.stationary_densities(
+ # exit(0)
+ ####
+ mask = input.unsqueeze(-1) == maze.v_empty
+ output_gpt = gpt(mygpt.BracketedSequence(input), mode=args.oneshot_mode).x