projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
mygpt.py
diff --git
a/mygpt.py
b/mygpt.py
index
0400b48
..
77c29ce
100755
(executable)
--- a/
mygpt.py
+++ b/
mygpt.py
@@
-46,7
+46,7
@@
class BracketedSequence:
return self.x[:, self.first : self.first + self.nb]
def complete(self):
return self.x[:, self.first : self.first + self.nb]
def complete(self):
- return self.first == 0 and self.nb == x.size(1)
+ return self.first == 0 and self.nb ==
self.
x.size(1)
######################################################################
######################################################################
@@
-275,7
+275,12
@@
class MyGPT(nn.Module):
# unchanged.
def masked_inplace_autoregression(
# unchanged.
def masked_inplace_autoregression(
- self, input, ar_mask, forbidden_tokens=None, deterministic_synthesis=False
+ self,
+ input,
+ ar_mask,
+ deterministic_synthesis=False,
+ forbidden_tokens=None,
+ forced_biases=None,
):
to_generate = (ar_mask.sum(0) > 0).nonzero()
if to_generate.min() > 0:
):
to_generate = (ar_mask.sum(0) > 0).nonzero()
if to_generate.min() > 0:
@@
-287,6
+292,8
@@
class MyGPT(nn.Module):
logits = output[:, s]
if forbidden_tokens is not None:
logits = logits.masked_fill(forbidden_tokens, float("-inf"))
logits = output[:, s]
if forbidden_tokens is not None:
logits = logits.masked_fill(forbidden_tokens, float("-inf"))
+ if forced_biases is not None:
+ logits = logits + forced_biases[None, :]
if deterministic_synthesis:
t_next = logits.argmax(1)
else:
if deterministic_synthesis:
t_next = logits.argmax(1)
else: