parser.add_argument("--attention", type=str, default=None)
+parser.add_argument("--memex_proba", type=float, default=0)
+
+parser.add_argument("--memex_nb_epochs", type=float, default=1)
+
parser.add_argument("--dim_model", type=int, default=None)
parser.add_argument("--dim_keys", type=int, default=None)
parser.add_argument("--caterpillar_height", type=int, default=None)
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=str2bool, default=False)
+
+parser.add_argument("--gate_dropout_replace", type=str2bool, default=False)
+
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
parser.add_argument("--nb_blocks", type=int, default=None)
parser.add_argument("--grid_size", type=int, default=6)
+parser.add_argument("--grid_nb_colors", type=int, default=6)
+
+parser.add_argument("--grid_nb_shapes", type=int, default=6)
+
##############################
# picoclvr options
log_string(f"sha256sum {l.strip()}")
now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
-os.system(f"tar --ignore-failed-read zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
log_string(f"argv {' '.join(sys.argv)}")
for n in vars(args):
log_string(f"args.{n} {getattr(args, n)}")
-for n in vars(sup_args):
- log_string(f"sup_args.{n} {getattr(sup_args, n)}")
+for k, v in sup_args.items():
+ log_string(f'sup_args["{k}"] "{v}"')
######################################################################
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
size=args.grid_size,
+ nb_shapes=args.grid_nb_shapes,
+ nb_colors=args.grid_nb_colors,
logger=log_string,
device=device_data,
)
vocabulary_size = task.vocabulary_size()
+if args.memex_proba > 0:
+ vocabulary_size += 1
+
log_string(f"vocabulary_size {vocabulary_size}")
##############################
dropout=args.dropout,
attention_layer=args.attention,
logger=log_string,
- **sup_args,
+ args=args,
)
model.to(device)
##############################
+if "calibrate" in sup_args:
+ for input in task.batches(split="train", desc="calibrate"):
+ input = input.to(device)
+ output = model(mygpt.BracketedSequence(input)).x
+
+ for n, m in model.named_modules():
+ for a in dir(m):
+ x = getattr(m, a)
+ if isinstance(x, mygpt.Calibrator):
+ print(f"####### ${n} | ${a} ########################")
+ mean, std = x.moments()
+ print("mean\n", mean, "\n")
+ print("std\n", std, "\n")
+ print(f"############################################\n\n")
+
+ exit(0)
+
+##############################
+
nb_samples_seen = 0
if nb_epochs_finished >= nb_epochs:
nb_train_samples, acc_train_loss, acc_train_inner_loss = 0, 0.0, 0.0
- for input in task.batches(split="train"):
+ def add_memex(batches, memex_proba):
+ for input in batches:
+ if torch.rand(1).item() < memex_proba:
+ sep = torch.full(
+ (input.size(0), 1), vocabulary_size - 1, device=input.device
+ )
+
+ yield torch.cat(
+ [
+ input,
+ sep,
+ input,
+ ],
+ dim=1,
+ )
+ yield input
+
+ train_batches = add_memex(
+ task.batches(split="train"),
+ args.memex_proba if n_epoch < args.memex_nb_epochs else 0.0,
+ )
+
+ for input in train_batches:
model.reset_inner_loss()
input = input.to(device)
nb_train_samples += input.size(0)
nb_samples_seen += input.size(0)
- total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+ total_loss = loss + (
+ args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ )
it += 1
lr = get_lr(n_epoch, it)