projects
/
picoclvr.git
/ commitdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
| commitdiff |
tree
raw
|
patch
|
inline
| side by side (from parent 1:
f29d0fa
)
Update.
author
François Fleuret
<francois@fleuret.org>
Tue, 4 Jul 2023 16:08:55 +0000
(18:08 +0200)
committer
François Fleuret
<francois@fleuret.org>
Tue, 4 Jul 2023 16:08:55 +0000
(18:08 +0200)
main.py
patch
|
blob
|
history
diff --git
a/main.py
b/main.py
index
beafc19
..
b907e60
100755
(executable)
--- a/
main.py
+++ b/
main.py
@@
-1091,7
+1091,7
@@
class TaskExpr(Task):
result = input.clone()
filler, space = self.char2id["#"], self.char2id[" "]
ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
result = input.clone()
filler, space = self.char2id["#"], self.char2id[" "]
ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
- result = (1 - ar_mask) * result +
filler * ar_mask
+ result = (1 - ar_mask) * result +
ar_mask * filler
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
@@
-1113,16
+1113,19
@@
class TaskExpr(Task):
result = input.clone()
filler, space = self.char2id["#"], self.char2id[" "]
ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
result = input.clone()
filler, space = self.char2id["#"], self.char2id[" "]
ar_mask = (result == space).long().cumsum(dim=1).clamp(max=1)
- result = (1 - ar_mask) * result +
filler * ar_mask
+ result = (1 - ar_mask) * result +
ar_mask * filler
for n in range(result.size(0)):
s = "".join([self.id2char[k.item()] for k in result[n]])
log_string(f"test_before {s}")
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
for n in range(result.size(0)):
s = "".join([self.id2char[k.item()] for k in result[n]])
log_string(f"test_before {s}")
masked_inplace_autoregression(
model, self.batch_size, result, ar_mask, device=self.device
)
+ correct = (1 - ar_mask) * space + ar_mask * input
for n in range(result.size(0)):
s = "".join([self.id2char[k.item()] for k in result[n]])
log_string(f"test_after {s}")
for n in range(result.size(0)):
s = "".join([self.id2char[k.item()] for k in result[n]])
log_string(f"test_after {s}")
+ s = "".join([self.id2char[k.item()] for k in correct[n]])
+ log_string(f"correct {s}")
##############################################################
model.train(t)
##############################################################
model.train(t)