projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Added figures
[culture.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
b3b56ad
..
80ffdbb
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-22,6
+22,7
@@
def masked_inplace_autoregression(
batch_size,
input,
ar_mask,
batch_size,
input,
ar_mask,
+ summed_logits,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
@@
-41,16
+42,15
@@
def masked_inplace_autoregression(
total=(input.size(0) + batch_size - 1) // batch_size,
)
total=(input.size(0) + batch_size - 1) // batch_size,
)
- sum_logits = 0
-
with torch.autograd.no_grad():
t = model.training
model.eval()
for input, ar_mask in batches:
with torch.autograd.no_grad():
t = model.training
model.eval()
for input, ar_mask in batches:
-
sum_logits +=
model.masked_inplace_autoregression(
+ model.masked_inplace_autoregression(
input=input,
ar_mask=ar_mask,
input=input,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
@@
-59,8
+59,6
@@
def masked_inplace_autoregression(
model.train(t)
model.train(t)
- return sum_logits
-
######################################################################
######################################################################
@@
-180,6
+178,7
@@
class World(Task):
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
@@
-219,6
+218,7
@@
class World(Task):
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
@@
-266,23
+266,27
@@
class World(Task):
)
ar_mask = torch.full(quizzes.size(), 1, device=self.device)
)
ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ summed_logits = torch.empty(nb, device=self.device)
temperature = 1
d_temperature = 1
while True:
temperature = 1
d_temperature = 1
while True:
- sum_logits = masked_inplace_autoregression(
+ summed_logits[...] = 0
+
+ masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
input=quizzes,
ar_mask=ar_mask,
model=model,
batch_size=self.batch_size,
input=quizzes,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
progress_bar_desc="creating quizzes",
device=self.device,
)
temperature=temperature,
deterministic_synthesis=False,
progress_bar_desc="creating quizzes",
device=self.device,
)
- average_logits = sum
_logits / quizzes.size(0
)
+ average_logits = sum
med_logits.mean(
)
logger(f"{average_logits=} {desired_average_logits=}")
logger(f"{average_logits=} {desired_average_logits=}")
@@
-290,15
+294,18
@@
class World(Task):
break
# Oh man that's ugly
break
# Oh man that's ugly
- if average_logits
> desired_average_logits
:
- if d_temperature
<
0:
+ if average_logits
< desired_average_logits * 1.1
:
+ if d_temperature
>
0:
d_temperature *= -0.5
temperature += d_temperature
d_temperature *= -0.5
temperature += d_temperature
- el
se
:
- if d_temperature
>
0:
+ el
if average_logits > desired_average_logits
:
+ if d_temperature
<
0:
d_temperature *= -0.5
temperature += d_temperature
d_temperature *= -0.5
temperature += d_temperature
- logger(f"chaging temperature to {temperature}")
+ else:
+ break
+
+ logger(f"changing temperature to {temperature}")
###############################################################
# Create the reverse quizzes
###############################################################
# Create the reverse quizzes
@@
-328,6
+335,7
@@
class World(Task):
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving quizzes",
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving quizzes",
@@
-343,6
+351,7
@@
class World(Task):
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving reversed quizzes",
temperature=1.0,
deterministic_synthesis=True,
progress_bar_desc="solving reversed quizzes",
@@
-362,4
+371,4
@@
class World(Task):
# for k in nb_correct:
# f.write(f"{k}\n")
# for k in nb_correct:
# f.write(f"{k}\n")
- return quizzes, nb_correct.sum(dim=0), sum
_logits
+ return quizzes, nb_correct.sum(dim=0), sum
med_logits.mean()