projects
/
mygpt.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygpt.git]
/
main.py
diff --git
a/main.py
b/main.py
index
85cf4cf
..
11cf0a3
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-25,7
+25,7
@@
parser.add_argument('--log_filename',
type = str, default = 'train.log')
parser.add_argument('--download',
type = str, default = 'train.log')
parser.add_argument('--download',
-
type = bool
, default = False)
+
action='store_true'
, default = False)
parser.add_argument('--seed',
type = int, default = 0)
parser.add_argument('--seed',
type = int, default = 0)
@@
-67,11
+67,14
@@
parser.add_argument('--dropout',
type = float, default = 0.1)
parser.add_argument('--synthesis_sampling',
type = float, default = 0.1)
parser.add_argument('--synthesis_sampling',
-
type = bool
, default = True)
+
action='store_true'
, default = True)
parser.add_argument('--checkpoint_name',
type = str, default = 'checkpoint.pth')
parser.add_argument('--checkpoint_name',
type = str, default = 'checkpoint.pth')
+parser.add_argument('--picoclvr_many_colors',
+ action='store_true', default = False)
+
######################################################################
args = parser.parse_args()
######################################################################
args = parser.parse_args()
@@
-353,7
+356,7
@@
if args.data == 'wiki103':
elif args.data == 'mnist':
task = TaskMNIST(batch_size = args.batch_size, device = device)
elif args.data == 'picoclvr':
elif args.data == 'mnist':
task = TaskMNIST(batch_size = args.batch_size, device = device)
elif args.data == 'picoclvr':
- task = TaskPicoCLVR(batch_size = args.batch_size, device = device)
+ task = TaskPicoCLVR(batch_size = args.batch_size,
many_colors = args.picoclvr_many_colors,
device = device)
else:
raise ValueError(f'Unknown dataset {args.data}.')
else:
raise ValueError(f'Unknown dataset {args.data}.')