projects
/
picoclvr.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[picoclvr.git]
/
main.py
diff --git
a/main.py
b/main.py
index
55f2c2f
..
9198edc
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-89,7
+89,9
@@
parser.add_argument("--checkpoint_name", type=str, default="checkpoint.pth")
##############################
# filetask
##############################
# filetask
-parser.add_argument("--filetask_file", type=str, default=None)
+parser.add_argument("--filetask_train_file", type=str, default=None)
+
+parser.add_argument("--filetask_test_file", type=str, default=None)
##############################
# rpl options
##############################
# rpl options
@@
-403,10
+405,11
@@
picoclvr_pruner_eval = (
if args.task == "file":
assert (
if args.task == "file":
assert (
- args.filetask_file is not None
- ), "You have to specify the task
file with --filetask_file <filename>
"
+ args.filetask_
train_file is not None and args.filetask_test_
file is not None
+ ), "You have to specify the task
train and test files
"
task = tasks.TaskFromFile(
task = tasks.TaskFromFile(
- args.filetask_file,
+ args.filetask_train_file,
+ args.filetask_test_file,
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,
nb_train_samples=args.nb_train_samples,
nb_test_samples=args.nb_test_samples,
batch_size=args.batch_size,