+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+ for label, input in [
+ ("train", self.train_input[:32]),
+ ("test", self.test_input[:32]),
+ ]:
+ output = model(BracketedSequence(input)).x
+ output = output.log_softmax(dim=-1)
+ filename = os.path.join(
+ result_dir, f"stack_with_crossentropy_{n_epoch:04d}_{label}.txt"
+ )
+ with open(filename, "w") as f:
+ for n in range(input.size(0)):
+ s = stack.seq_to_str(
+ input[n], nb_stacks=self.nb_stacks, nb_digits=self.nb_digits
+ )
+ for t, k, w in zip(range(input[n].size(0)), input[n], s.split(" ")):
+ u = (
+ " " * (10 - len(w))
+ + w
+ + " "
+ + str(output[n][t][k].exp().item())
+ + "\n"
+ )
+ f.write(u)
+ f.write("\n")
+ logger(f"wrote {filename}")
+ #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+