+
+ if args.oneshot_input == "head":
+ dim_in = args.dim_model
+ elif args.oneshot_input == "deep":
+ dim_in = args.dim_model * args.nb_blocks * 2
+ else:
+ raise ValueError(f"{args.oneshot_input=}")
+
+ if args.oneshot_output == "policy":
+ dim_out = 4
+ compute_loss = oneshot_policy_loss
+ elif args.oneshot_output == "trace":
+ dim_out = 1
+ else:
+ raise ValueError(f"{args.oneshot_output=}")
+