projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
b3b56ad
..
43f7d53
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-22,6
+22,7
@@
def masked_inplace_autoregression(
batch_size,
input,
ar_mask,
batch_size,
input,
ar_mask,
+ summed_logits,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
temperature,
deterministic_synthesis,
forbidden_tokens=None,
@@
-41,16
+42,15
@@
def masked_inplace_autoregression(
total=(input.size(0) + batch_size - 1) // batch_size,
)
total=(input.size(0) + batch_size - 1) // batch_size,
)
- sum_logits = 0
-
with torch.autograd.no_grad():
t = model.training
model.eval()
for input, ar_mask in batches:
with torch.autograd.no_grad():
t = model.training
model.eval()
for input, ar_mask in batches:
-
sum_logits +=
model.masked_inplace_autoregression(
+ model.masked_inplace_autoregression(
input=input,
ar_mask=ar_mask,
input=input,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
temperature=temperature,
deterministic_synthesis=deterministic_synthesis,
forbidden_tokens=forbidden_tokens,
@@
-59,8
+59,6
@@
def masked_inplace_autoregression(
model.train(t)
model.train(t)
- return sum_logits
-
######################################################################
######################################################################
@@
-90,6
+88,9
@@
class World(Task):
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
+ def save_quizzes(self, input, result_dir, filename_prefix, logger):
+ self.save_image(input, result_dir, filename_prefix + ".png", logger)
+
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
@@
-110,49
+111,52
@@
class World(Task):
self.height = 6
self.width = 8
self.height = 6
self.width = 8
- self.train_
input
= world.generate_seq(
+ self.train_
w_quizzes
= world.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_
input
= world.generate_seq(
+ self.test_
w_quizzes
= world.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)
nb_test_samples, height=self.height, width=self.width
).to(device)
- self.nb_codes = max(self.train_
input.max(), self.test_input
.max()) + 1
+ self.nb_codes = max(self.train_
w_quizzes.max(), self.test_w_quizzes
.max()) + 1
- self.train_quizzes = []
- self.test_quizzes = []
+ self.train_
c_
quizzes = []
+ self.test_
c_
quizzes = []
if result_dir is not None:
if result_dir is not None:
- self.save_
image
(
- self.train_
input[:72], result_dir, f"world_train.png
", logger
+ self.save_
quizzes
(
+ self.train_
w_quizzes[:72], result_dir, f"culture_w_quizzes
", logger
)
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
)
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
- input = self.train_input
-
quizzes = self.train
_quizzes
+ w_quizzes = self.train_w_quizzes
+
c_quizzes = self.train_c
_quizzes
else:
else:
- input = self.test_input
-
quizzes = self.test
_quizzes
+ w_quizzes = self.test_w_quizzes
+
c_quizzes = self.test_c
_quizzes
- if len(quizzes) > 0:
-
quizzes = torch.cat(
quizzes, dim=0)
- if
quizzes.size(0) > input
.size(0) // 2:
- i = torch.randperm(
input.size(0))[: input
.size(0) // 2]
-
quizzes =
quizzes[i]
+ if len(
c_
quizzes) > 0:
+
c_quizzes = torch.cat(c_
quizzes, dim=0)
+ if
c_quizzes.size(0) > w_quizzes
.size(0) // 2:
+ i = torch.randperm(
w_quizzes.size(0))[: w_quizzes
.size(0) // 2]
+
c_quizzes = c_
quizzes[i]
- i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
- input = input[i]
+ i = torch.randperm(w_quizzes.size(0))[
+ : w_quizzes.size(0) - c_quizzes.size(0)
+ ]
+ w_quizzes = w_quizzes[i]
- self.nb_batch_
samples_world = input
.size(0)
- self.nb_batch_
samples_quizzes =
quizzes.size(0)
+ self.nb_batch_
w_quizzes = w_quizzes
.size(0)
+ self.nb_batch_
c_quizzes = c_
quizzes.size(0)
- input = torch.cat([
input,
quizzes], dim=0)
+ input = torch.cat([
w_quizzes, c_
quizzes], dim=0)
else:
else:
- self.nb_batch_samples_world = input.size(0)
- self.nb_batch_samples_quizzes = 0
+ input = w_quizzes
+ self.nb_batch_w_quizzes = w_quizzes.size(0)
+ self.nb_batch_c_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
# Shuffle
input = input[torch.randperm(input.size(0))]
@@
-180,6
+184,7
@@
class World(Task):
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
@@
-193,13
+198,13
@@
class World(Task):
return nb_total, nb_correct
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(self.train_
input
)
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_
w_quizzes
)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(self.test_
input
, logger)
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_
w_quizzes
, logger)
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
@@
-210,7
+215,7
@@
class World(Task):
##############################
##############################
- input = self.test_
input
[:96]
+ input = self.test_
w_quizzes
[:96]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
@@
-219,36
+224,37
@@
class World(Task):
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
device=self.device,
)
temperature=1.0,
deterministic_synthesis=deterministic_synthesis,
progress_bar_desc=None,
device=self.device,
)
- self.save_
image
(
+ self.save_
quizzes
(
result[:72],
result_dir,
result[:72],
result_dir,
- f"
world_prediction_{n_epoch:04d}_{model.id:02d}.png
",
+ f"
culture_prediction_{n_epoch:04d}_{model.id:02d}
",
logger,
)
return main_test_accuracy
logger,
)
return main_test_accuracy
- def renew_
sampl
es(self, nb, for_train=True):
- input = self.train_
input if for_train else self.test_input
+ def renew_
w_quizz
es(self, nb, for_train=True):
+ input = self.train_
w_quizzes if for_train else self.test_w_quizzes
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
self.device
)
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
self.device
)
- def store_
new_quizzes(self, new
_quizzes, for_train=True):
+ def store_
c_quizzes(self, new_c
_quizzes, for_train=True):
if for_train:
if for_train:
- self.train_
quizzes.append(new
_quizzes)
+ self.train_
c_quizzes.append(new_c
_quizzes)
else:
else:
- self.test_
quizzes.append(new
_quizzes)
+ self.test_
c_quizzes.append(new_c
_quizzes)
- def create_
new
_quizzes(
+ def create_
c
_quizzes(
self,
n_epoch,
result_dir,
self,
n_epoch,
result_dir,
@@
-261,28
+267,32
@@
class World(Task):
###############################################################
# Generate quizzes with model
###############################################################
# Generate quizzes with model
- quizzes = torch.empty(
+
c_
quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(quizzes.size(), 1, device=self.device)
+ ar_mask = torch.full(c_quizzes.size(), 1, device=self.device)
+ summed_logits = torch.empty(nb, device=self.device)
temperature = 1
d_temperature = 1
while True:
temperature = 1
d_temperature = 1
while True:
- sum_logits = masked_inplace_autoregression(
+ summed_logits[...] = 0
+
+ masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
model=model,
batch_size=self.batch_size,
- input=quizzes,
+ input=
c_
quizzes,
ar_mask=ar_mask,
ar_mask=ar_mask,
+ summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
temperature=temperature,
deterministic_synthesis=False,
- progress_bar_desc="
creating
quizzes",
+ progress_bar_desc="
sampling c_
quizzes",
device=self.device,
)
device=self.device,
)
- average_logits = sum
_logits / quizzes.size(0
)
+ average_logits = sum
med_logits.mean(
)
logger(f"{average_logits=} {desired_average_logits=}")
logger(f"{average_logits=} {desired_average_logits=}")
@@
-290,29
+300,32
@@
class World(Task):
break
# Oh man that's ugly
break
# Oh man that's ugly
- if average_logits
> desired_average_logits
:
- if d_temperature
<
0:
+ if average_logits
< desired_average_logits * 1.1
:
+ if d_temperature
>
0:
d_temperature *= -0.5
temperature += d_temperature
d_temperature *= -0.5
temperature += d_temperature
- el
se
:
- if d_temperature
>
0:
+ el
if average_logits > desired_average_logits
:
+ if d_temperature
<
0:
d_temperature *= -0.5
temperature += d_temperature
d_temperature *= -0.5
temperature += d_temperature
+ else:
+ break
+
logger(f"chaging temperature to {temperature}")
###############################################################
# Create the reverse quizzes
l = self.height * self.width
logger(f"chaging temperature to {temperature}")
###############################################################
# Create the reverse quizzes
l = self.height * self.width
- direction = quizzes[:, l : l + 1]
+ direction =
c_
quizzes[:, l : l + 1]
direction = world.token_forward * (
direction == world.token_backward
) + world.token_backward * (direction == world.token_forward)
direction = world.token_forward * (
direction == world.token_backward
) + world.token_backward * (direction == world.token_forward)
- reverse_quizzes = torch.cat(
- [
quizzes[:, l + 1 :], direction,
quizzes[:, :l]], dim=1
+ reverse_
c_
quizzes = torch.cat(
+ [
c_quizzes[:, l + 1 :], direction, c_
quizzes[:, :l]], dim=1
)
)
- ar_mask = self.make_ar_mask(quizzes)
+ ar_mask = self.make_ar_mask(
c_
quizzes)
###############################################################
# Check how many of the other models can solve them in both
###############################################################
# Check how many of the other models can solve them in both
@@
-321,45
+334,47
@@
class World(Task):
nb_correct = []
for m in other_models:
nb_correct = []
for m in other_models:
- result = quizzes.clone()
+ result =
c_
quizzes.clone()
masked_inplace_autoregression(
model=m,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
masked_inplace_autoregression(
model=m,
batch_size=self.batch_size,
input=result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving quizzes",
+ progress_bar_desc="solving
c_
quizzes",
device=self.device,
)
device=self.device,
)
- correct = (quizzes == result).long().min(dim=-1).values
+ correct = (
c_
quizzes == result).long().min(dim=-1).values
- reverse_result = reverse_quizzes.clone()
+ reverse_result = reverse_
c_
quizzes.clone()
masked_inplace_autoregression(
model=m,
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
masked_inplace_autoregression(
model=m,
batch_size=self.batch_size,
input=reverse_result,
ar_mask=ar_mask,
+ summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving reversed quizzes",
+ progress_bar_desc="solving reversed
c_
quizzes",
device=self.device,
)
reverse_correct = (
device=self.device,
)
reverse_correct = (
- (reverse_quizzes == reverse_result).long().min(dim=-1).values
+ (reverse_
c_
quizzes == reverse_result).long().min(dim=-1).values
)
nb_correct.append((correct * reverse_correct)[None, :])
)
nb_correct.append((correct * reverse_correct)[None, :])
- nb_correct = torch.cat(nb_correct, dim=0)
+ nb_correct = torch.cat(nb_correct, dim=0)
.sum(dim=0)
# filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
# with open(filename, "w") as f:
# for k in nb_correct:
# f.write(f"{k}\n")
# filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
# with open(filename, "w") as f:
# for k in nb_correct:
# f.write(f"{k}\n")
- return
quizzes, nb_correct.sum(dim=0), sum_logits
+ return
c_quizzes, nb_correct, summed_logits.mean()