-elif args.task == "snake":
- task = tasks.Snake(
- nb_train_samples=args.nb_train_samples,
- nb_test_samples=args.nb_test_samples,
- batch_size=args.batch_size,
- height=args.snake_height,
- width=args.snake_width,
- nb_colors=args.snake_nb_colors,
- length=args.snake_length,
- prompt_length=args.snake_length // 2,
- device=device,
+if args.max_percents_of_test_in_train >= 0:
+
+ def subsets_as_tuples(batches, cs):
+ s = set()
+ for batch in batches:
+ for x in batch:
+ s.add(tuple([v.item() for v in x]))
+ if len(s) == cs:
+ yield s
+ s = set()
+ yield s
+
+ nb_test, nb_in_train = 0, 0
+ for test_subset in subsets_as_tuples(
+ quizz_machine.batches(split="test", desc="test-check"), 25000
+ ):
+ in_train = set()
+ for train_subset in subsets_as_tuples(
+ quizz_machine.batches(split="train", desc="train-check"), 25000
+ ):
+ in_train.update(test_subset.intersection(train_subset))
+ nb_in_train += len(in_train)
+ nb_test += len(test_subset)
+
+ log_string(
+ f"data_check {nb_in_train*100/nb_test:.02f}% ({nb_in_train}/{nb_test}) of test samples are in the train set"