projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
main.py
diff --git
a/main.py
b/main.py
index
305bd3c
..
69ee58f
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-34,8
+34,8
@@
parser = argparse.ArgumentParser(
parser.add_argument(
"--task",
type=str,
parser.add_argument(
"--task",
type=str,
- default="
picoclvr
",
- help="picoclvr, mnist, maze, snake, stack, expr, world",
+ default="
sandbox
",
+ help="
sandbox,
picoclvr, mnist, maze, snake, stack, expr, world",
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
)
parser.add_argument("--log_filename", type=str, default="train.log", help=" ")
@@
-150,6
+150,12
@@
if args.result_dir is None:
######################################################################
default_args = {
######################################################################
default_args = {
+ "sandbox": {
+ "nb_epochs": 10,
+ "batch_size": 25,
+ "nb_train_samples": 25000,
+ "nb_test_samples": 10000,
+ },
"picoclvr": {
"nb_epochs": 25,
"batch_size": 25,
"picoclvr": {
"nb_epochs": 25,
"batch_size": 25,
@@
-189,7
+195,7
@@
default_args = {
"world": {
"nb_epochs": 10,
"batch_size": 25,
"world": {
"nb_epochs": 10,
"batch_size": 25,
- "nb_train_samples":
1
25000,
+ "nb_train_samples": 25000,
"nb_test_samples": 1000,
},
}
"nb_test_samples": 1000,
},
}
@@
-257,7
+263,16
@@
picoclvr_pruner_eval = (
######################################################################
######################################################################
-if args.task == "picoclvr":
+if args.task == "sandbox":
+ task = tasks.SandBox(
+ nb_train_samples=args.nb_train_samples,
+ nb_test_samples=args.nb_test_samples,
+ batch_size=args.batch_size,
+ logger=log_string,
+ device=device,
+ )
+
+elif args.task == "picoclvr":
task = tasks.PicoCLVR(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
task = tasks.PicoCLVR(
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,