projects
/
mygptrnn.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[mygptrnn.git]
/
main.py
diff --git
a/main.py
b/main.py
index
79841f3
..
3aa696b
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-99,7
+99,11
@@
parser.add_argument("--nb_lines", type=int, default=None)
parser.add_argument("--caterpillar_height", type=int, default=None)
parser.add_argument("--caterpillar_height", type=int, default=None)
-parser.add_argument("--rho", type=float, default=0.0)
+parser.add_argument("--gate_dropout_proba", type=float, default=0.0)
+
+parser.add_argument("--gate_dropout_sync", type=bool, default=False)
+
+parser.add_argument("--rho_inner_loss", type=float, default=0.0)
parser.add_argument("--nb_blocks", type=int, default=None)
parser.add_argument("--nb_blocks", type=int, default=None)
@@
-747,7
+751,7
@@
model = mygpt.MyGPT(
dropout=args.dropout,
attention_layer=args.attention,
logger=log_string,
dropout=args.dropout,
attention_layer=args.attention,
logger=log_string,
-
**sup_
args,
+
args=
args,
)
model.to(device)
)
model.to(device)
@@
-905,7
+909,9
@@
for n_epoch in range(nb_epochs_finished, nb_epochs):
nb_train_samples += input.size(0)
nb_samples_seen += input.size(0)
nb_train_samples += input.size(0)
nb_samples_seen += input.size(0)
- total_loss = loss + (args.rho * inner_loss if args.rho > 0 else 0.0)
+ total_loss = loss + (
+ args.rho_inner_loss * inner_loss if args.rho_inner_loss > 0 else 0.0
+ )
it += 1
lr = get_lr(n_epoch, it)
it += 1
lr = get_lr(n_epoch, it)