"--task",
type=str,
default="sandbox",
- help="sandbox, picoclvr, mnist, maze, snake, stack, expr, rpl, world",
+ help="byheart, learnop, guessop, twotargets, addition, picoclvr, mnist, maze, snake, stack, expr, rpl",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
parser.add_argument("--rpl_no_prog", action="store_true", default=False)
-##############################
-# sandbox options
-
-parser.add_argument("--sandbox_level", type=int, default=0)
-
-parser.add_argument("--sandbox_levels_nb_items", type=int, default=25)
-
-parser.add_argument("--sandbox_levels_len_source", type=int, default=6)
-
-parser.add_argument("--sandbox_levels_len_result", type=int, default=8)
-
##############################
# picoclvr options
######################################################################
-if args.task == "sandbox":
- if args.sandbox_level == 0:
- problem = problems.ProblemLevel0(
- nb_sentences=args.sandbox_levels_nb_items,
- len_prompt=args.sandbox_levels_len_source,
- len_result=args.sandbox_levels_len_result,
- )
- elif args.sandbox_level == 1:
- problem = problems.ProblemLevel1(
- nb_operators=args.sandbox_levels_nb_items,
- len_source=args.sandbox_levels_len_source,
- len_result=args.sandbox_levels_len_result,
- )
- elif args.sandbox_level == 2:
- problem = problems.ProblemLevel2(
- len_source=args.sandbox_levels_len_source,
- len_result=args.sandbox_levels_len_result,
- )
- else:
- raise ValueError(f"Unknown sandbox level {args.sandbox_level}")
+if args.task == "byheart":
+ task = tasks.SandBox(
+ problem=problems.ProblemByHeart(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+
+elif args.task == "learnop":
+ task = tasks.SandBox(
+ problem=problems.ProblemLearnOperator(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+
+elif args.task == "guessop":
+ task = tasks.SandBox(
+ problem=problems.ProblemGuessOperator(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+
+elif args.task == "twotargets":
+ task = tasks.SandBox(
+ problem=problems.ProblemTwoTargets(),
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+elif args.task == "addition":
task = tasks.SandBox(
- # problem,
- # problems.ProblemAddition(zero_padded=False, inverted_result=False),
- # problems.ProblemLenId(len_max=args.sandbox_levels_len_source),
- problems.ProblemTwoTargets(len_total=16, len_targets=4),
+ problem=problems.ProblemAddition(),
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
####################
-class ProblemLenId(Problem):
- def __init__(self, len_max=10):
- self.len_max = len_max
-
- def generate_sequences(self, nb):
- k = torch.arange(self.len_max * 3 + 3)[None, :]
- l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1
- i = torch.randint(10, (2, nb))[:, :, None]
- a = l[0]
- b = l[0] + 1 + l[1]
- c = l[0] + 1 + l[1] + 1 + l[0]
- sequences = (
- (k < a) * i[0]
- + (k == a) * 10
- + (k > a) * (k < b) * i[1]
- + (k == b) * 11
- + (k > b) * (k < c) * i[1]
- + (k >= c) * 12
- )
- ar_mask = (sequences == 11).long()
- ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1)
- return sequences, ar_mask
-
- def seq2str(self, seq):
- return "".join("0123456789|>_"[x.item()] for x in seq)
-
-
-####################
-
-
-class ProblemLevel0(Problem):
- def __init__(self, nb_sentences=100, len_prompt=5, len_result=5):
+class ProblemByHeart(Problem):
+ def __init__(self, nb_sentences=100, len_prompt=8, len_result=8):
self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result))
self.seq[:, len_prompt] = 10
####################
-class ProblemLevel1(Problem):
+class ProblemLearnOperator(Problem):
def __init__(self, nb_operators=100, len_source=5, len_result=8):
self.len_source = len_source
self.len_result = len_result
// 10 ** torch.arange(self.len_nb_operator - 1, -1, -1)
) % 10
marker1 = torch.full((nb, 1), 10)
- # source = torch.randint(10, (nb, self.len_source))
source = torch.rand(nb, 10).sort(dim=1).indices[:, : self.len_source]
marker2 = torch.full((nb, 1), 11)
result = operators.bmm(source[:, :, None]).squeeze(-1)
####################
-class ProblemLevel2(Problem):
+class ProblemGuessOperator(Problem):
def __init__(self, len_source=5, len_result=8):
self.len_source = len_source
self.len_result = len_result
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
if save_attention_image is None:
logger("no save_attention_image (is pycairo installed?)")
else:
f"property_{prefix}miss {n_epoch} {100*nb_missing_properties/nb_requested_properties:.02f}%"
)
+ logger(
+ f"main_test_accuracy {n_epoch} {1-nb_missing_properties/nb_requested_properties}"
+ )
+
######################################################################
def produce_results(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
if count is not None:
proportion_optimal = count.diagonal().sum().float() / count.sum()
logger(f"proportion_optimal_test {proportion_optimal*100:.02f}%")
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
######################################################################
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
##############################################################
# Log a few generated sequences
input = self.test_input[:10, : 12 * (1 + self.nb_digits)]
f"accuracy_prog_test {n_epoch} nb_total {test_nb_total} nb_errors {test_nb_errors} accuracy {100.0*(1-test_nb_errors/test_nb_total):.02f}%"
)
+ logger(f"main_test_accuracy {n_epoch} {1-test_nb_errors/test_nb_total}")
+
test_nb_total, test_nb_errors = compute_nb_errors_output(
self.test_input[:1000].to(self.device), nb_to_log=10
)
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
)
+ logger(f"main_test_accuracy {n_epoch} {test_nb_correct/test_nb_total}")
+
nb_total = test_nb_delta.sum() + test_nb_missed
for d in range(test_nb_delta.size(0)):
logger(