######################################################################
-args = parser.parse_args()
+# args = parser.parse_args()
-assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+args, sup_args = parser.parse_known_args()
+
+sup_args = dict([x.removeprefix("--").split("=") for x in sup_args])
if args.result_dir is None:
args.result_dir = f"results_{args.task}_{args.model}"
print(f"result directory {args.result_dir} already exists")
exit(1)
+loss_file = open(os.path.join(args.result_dir, "loss.dat"), "a")
+
log_file = open(os.path.join(args.result_dir, args.log_filename), "a")
if args.seed >= 0:
log_string(f"sha256sum {l.strip()}")
now = time.strftime("%Y%m%d-%H%M%S", time.localtime())
-os.system(f"tar zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
+os.system(f"tar --ignore-failed-read zcvf {args.result_dir}/src-{now}.tgz *.py *.sh")
log_string(f"argv {' '.join(sys.argv)}")
for n in vars(args):
log_string(f"args.{n} {getattr(args, n)}")
+for n in vars(sup_args):
+ log_string(f"sup_args.{n} {getattr(sup_args, n)}")
+
######################################################################
######################################################################
+assert args.picocvlr_prune_properties in {"none", "train+eval", "eval"}
+
+
def picoclvr_pruner_horizontal_green(p):
return not ("green" in p and ("left" in p or "right" in p))
causal=True,
dropout=args.dropout,
attention_layer=args.attention,
+ logger=log_string,
+ **sup_args,
)
model.to(device)
it = 0
+n_batch = 0
+
for n_epoch in range(nb_epochs_finished, nb_epochs):
if args.optim == "sgd":
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
total_loss.backward()
optimizer.step()
+ grad_norm = sum([p.grad.pow(2).sum() for p in model.parameters()]).sqrt()
+
+ loss_file.write(f"{n_epoch} {n_batch} {loss.item()} {grad_norm.item()}\n")
+
+ n_batch += 1
+
with torch.autograd.no_grad():
model.eval()
nb_lines,
attention_dropout=0.0,
len_max=1e5,
+ logger=print,
+ **kwargs,
):
super().__init__()
nb_lines,
attention_dropout=0.0,
len_max=1e5,
+ logger=print,
+ **kwargs,
):
super().__init__()
caterpillar_height,
attention_dropout=0.0,
len_max=1e5,
+ logger=print,
+ **kwargs,
):
super().__init__()
self.proba_gate_dropout = 0.0
+ default_b_G = kwargs.get("default_b_G")
+ if default_b_G is None:
+ default_b_G = -math.log(caterpillar_height - 1)
+
+ logger(f"default_b_G {default_b_G}")
+
self.w_G = randw(nb_heads, caterpillar_height, dim_model)
- self.b_G = nn.Parameter(
- torch.full(
- (nb_heads, caterpillar_height), -math.log(caterpillar_height - 1)
- )
- )
+ self.b_G = nn.Parameter(torch.full((nb_heads, caterpillar_height), default_b_G))
self.w_K = randw(nb_heads, dim_qk, dim_model)
self.w_V = randw(nb_heads, dim_v, dim_model)
torch.einsum("ntc,hrc->nhrt", X, self.w_G) + self.b_G[None, :, :, None]
).sigmoid()
+ # Clip the gating to avoid values greater than 1 when several
+ # heads hit the same row
+
+ G = G / G.sum(1, keepdim=True).clamp(min=1)
+
######################################################################
# Roll the gating indexes
- warnings.warn("rotating barrel", RuntimeWarning)
+ # warnings.warn("rotating barrel", RuntimeWarning)
- r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
- t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
- r_barrel = (r_barrel + (t_barrel + t0) // L) % R
- G = G.gather(dim=2, index=r_barrel.expand_as(G))
+ # r_barrel = torch.arange(R, device=G.device)[None, None, :, None]
+ # t_barrel = torch.arange(t1 - t0, device=G.device)[None, None, None, :]
+ # r_barrel = (r_barrel + (t_barrel + t0) // L) % R
+ # G = G.gather(dim=2, index=r_barrel.expand_as(G))
######################################################################
# The "flashbacks"
# We prepare the arguments for the parallel scan
- # Clip the gating to avoid values greater than 1 when several
- # heads hit the same row
-
- G = G / G.sum(1, keepdim=True).clamp(min=1)
-
A = 1 - G.sum(1)
# warnings.warn("harmonic recurrence", RuntimeWarning)
nb_heads=1,
causal=False,
attention_dropout=0.0,
+ logger=print,
+ **kwargs,
):
super().__init__()
dropout=0.0,
len_max=1e5,
attention_layer="kvrec",
+ logger=print,
+ **kwargs,
):
super().__init__()
nb_heads=nb_heads,
causal=causal,
attention_dropout=dropout,
+ logger=logger,
+ **kwargs,
)
elif attention_layer == "dumbrec":
return DumbRec(
nb_heads=nb_heads,
nb_lines=nb_lines,
attention_dropout=dropout,
+ logger=logger,
+ **kwargs,
)
elif attention_layer == "kvrec":
return KVRec(
nb_heads=nb_heads,
nb_lines=nb_lines,
attention_dropout=dropout,
+ logger=logger,
+ **kwargs,
)
elif attention_layer == "caterpillar":
return Caterpillar(
caterpillar_length=self.caterpillar_length,
caterpillar_height=self.caterpillar_height,
attention_dropout=dropout,
+ logger=logger,
+ **kwargs,
)
else:
raise ValueError(f"Unknown attention type {attention_layer}.")