projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Added figures
[culture.git]
/
tasks.py
diff --git
a/tasks.py
b/tasks.py
index
43f7d53
..
80ffdbb
100755
(executable)
--- a/
tasks.py
+++ b/
tasks.py
@@
-88,9
+88,6
@@
class World(Task):
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
torchvision.utils.save_image(img.float() / 255.0, image_name, nrow=6, padding=4)
logger(f"wrote {image_name}")
- def save_quizzes(self, input, result_dir, filename_prefix, logger):
- self.save_image(input, result_dir, filename_prefix + ".png", logger)
-
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
def make_ar_mask(self, input):
b = torch.arange(input.size(1), device=input.device) > input.size(1) // 2
return b.long()[None, :].expand_as(input)
@@
-111,52
+108,49
@@
class World(Task):
self.height = 6
self.width = 8
self.height = 6
self.width = 8
- self.train_
w_quizzes
= world.generate_seq(
+ self.train_
input
= world.generate_seq(
nb_train_samples, height=self.height, width=self.width
).to(device)
nb_train_samples, height=self.height, width=self.width
).to(device)
- self.test_
w_quizzes
= world.generate_seq(
+ self.test_
input
= world.generate_seq(
nb_test_samples, height=self.height, width=self.width
).to(device)
nb_test_samples, height=self.height, width=self.width
).to(device)
- self.nb_codes = max(self.train_
w_quizzes.max(), self.test_w_quizzes
.max()) + 1
+ self.nb_codes = max(self.train_
input.max(), self.test_input
.max()) + 1
- self.train_
c_
quizzes = []
- self.test_
c_
quizzes = []
+ self.train_quizzes = []
+ self.test_quizzes = []
if result_dir is not None:
if result_dir is not None:
- self.save_
quizzes
(
- self.train_
w_quizzes[:72], result_dir, f"culture_w_quizzes
", logger
+ self.save_
image
(
+ self.train_
input[:72], result_dir, f"world_train.png
", logger
)
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
)
def batches(self, split="train", desc=None):
assert split in {"train", "test"}
if split == "train":
- w_quizzes = self.train_w_quizzes
-
c_quizzes = self.train_c
_quizzes
+ input = self.train_input
+
quizzes = self.train
_quizzes
else:
else:
- w_quizzes = self.test_w_quizzes
-
c_quizzes = self.test_c
_quizzes
+ input = self.test_input
+
quizzes = self.test
_quizzes
- if len(
c_
quizzes) > 0:
-
c_quizzes = torch.cat(c_
quizzes, dim=0)
- if
c_quizzes.size(0) > w_quizzes
.size(0) // 2:
- i = torch.randperm(
w_quizzes.size(0))[: w_quizzes
.size(0) // 2]
-
c_quizzes = c_
quizzes[i]
+ if len(quizzes) > 0:
+
quizzes = torch.cat(
quizzes, dim=0)
+ if
quizzes.size(0) > input
.size(0) // 2:
+ i = torch.randperm(
input.size(0))[: input
.size(0) // 2]
+
quizzes =
quizzes[i]
- i = torch.randperm(w_quizzes.size(0))[
- : w_quizzes.size(0) - c_quizzes.size(0)
- ]
- w_quizzes = w_quizzes[i]
+ i = torch.randperm(input.size(0))[: input.size(0) - quizzes.size(0)]
+ input = input[i]
- self.nb_batch_
w_quizzes = w_quizzes
.size(0)
- self.nb_batch_
c_quizzes = c_
quizzes.size(0)
+ self.nb_batch_
samples_world = input
.size(0)
+ self.nb_batch_
samples_quizzes =
quizzes.size(0)
- input = torch.cat([
w_quizzes, c_
quizzes], dim=0)
+ input = torch.cat([
input,
quizzes], dim=0)
else:
else:
- input = w_quizzes
- self.nb_batch_w_quizzes = w_quizzes.size(0)
- self.nb_batch_c_quizzes = 0
+ self.nb_batch_samples_world = input.size(0)
+ self.nb_batch_samples_quizzes = 0
# Shuffle
input = input[torch.randperm(input.size(0))]
# Shuffle
input = input[torch.randperm(input.size(0))]
@@
-198,13
+192,13
@@
class World(Task):
return nb_total, nb_correct
return nb_total, nb_correct
- train_nb_total, train_nb_correct = compute_accuracy(self.train_
w_quizzes
)
+ train_nb_total, train_nb_correct = compute_accuracy(self.train_
input
)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
logger(
f"accuracy_train {n_epoch} nb_total {train_nb_total} nb_correct {train_nb_correct} accuracy {(100.0*train_nb_correct)/train_nb_total:.02f}%"
)
- test_nb_total, test_nb_correct = compute_accuracy(self.test_
w_quizzes
, logger)
+ test_nb_total, test_nb_correct = compute_accuracy(self.test_
input
, logger)
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
logger(
f"accuracy_test {n_epoch} nb_total {test_nb_total} nb_correct {test_nb_correct} accuracy {(100.0*test_nb_correct)/test_nb_total:.02f}%"
@@
-215,7
+209,7
@@
class World(Task):
##############################
##############################
- input = self.test_
w_quizzes
[:96]
+ input = self.test_
input
[:96]
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
ar_mask = self.make_ar_mask(input)
result = input.clone() * (1 - ar_mask)
@@
-231,30
+225,30
@@
class World(Task):
device=self.device,
)
device=self.device,
)
- self.save_
quizzes
(
+ self.save_
image
(
result[:72],
result_dir,
result[:72],
result_dir,
- f"
culture_prediction_{n_epoch:04d}_{model.id:02d}
",
+ f"
world_prediction_{n_epoch:04d}_{model.id:02d}.png
",
logger,
)
return main_test_accuracy
logger,
)
return main_test_accuracy
- def renew_
w_quizz
es(self, nb, for_train=True):
- input = self.train_
w_quizzes if for_train else self.test_w_quizzes
+ def renew_
sampl
es(self, nb, for_train=True):
+ input = self.train_
input if for_train else self.test_input
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
self.device
)
nb = min(nb, input.size(0))
input[:-nb] = input[nb:].clone()
input[-nb:] = world.generate_seq(nb, height=self.height, width=self.width).to(
self.device
)
- def store_
c_quizzes(self, new_c
_quizzes, for_train=True):
+ def store_
new_quizzes(self, new
_quizzes, for_train=True):
if for_train:
if for_train:
- self.train_
c_quizzes.append(new_c
_quizzes)
+ self.train_
quizzes.append(new
_quizzes)
else:
else:
- self.test_
c_quizzes.append(new_c
_quizzes)
+ self.test_
quizzes.append(new
_quizzes)
- def create_
c
_quizzes(
+ def create_
new
_quizzes(
self,
n_epoch,
result_dir,
self,
n_epoch,
result_dir,
@@
-267,11
+261,11
@@
class World(Task):
###############################################################
# Generate quizzes with model
###############################################################
# Generate quizzes with model
-
c_
quizzes = torch.empty(
+ quizzes = torch.empty(
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
nb, self.height * self.width * 2 + 1, device=self.device, dtype=torch.int64
)
- ar_mask = torch.full(
c_
quizzes.size(), 1, device=self.device)
+ ar_mask = torch.full(quizzes.size(), 1, device=self.device)
summed_logits = torch.empty(nb, device=self.device)
temperature = 1
summed_logits = torch.empty(nb, device=self.device)
temperature = 1
@@
-283,12
+277,12
@@
class World(Task):
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
masked_inplace_autoregression(
model=model,
batch_size=self.batch_size,
- input=
c_
quizzes,
+ input=quizzes,
ar_mask=ar_mask,
summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
ar_mask=ar_mask,
summed_logits=summed_logits,
temperature=temperature,
deterministic_synthesis=False,
- progress_bar_desc="
sampling c_
quizzes",
+ progress_bar_desc="
creating
quizzes",
device=self.device,
)
device=self.device,
)
@@
-311,21
+305,21
@@
class World(Task):
else:
break
else:
break
- logger(f"chaging temperature to {temperature}")
+ logger(f"cha
n
ging temperature to {temperature}")
###############################################################
# Create the reverse quizzes
l = self.height * self.width
###############################################################
# Create the reverse quizzes
l = self.height * self.width
- direction =
c_
quizzes[:, l : l + 1]
+ direction = quizzes[:, l : l + 1]
direction = world.token_forward * (
direction == world.token_backward
) + world.token_backward * (direction == world.token_forward)
direction = world.token_forward * (
direction == world.token_backward
) + world.token_backward * (direction == world.token_forward)
- reverse_
c_
quizzes = torch.cat(
- [
c_quizzes[:, l + 1 :], direction, c_
quizzes[:, :l]], dim=1
+ reverse_quizzes = torch.cat(
+ [
quizzes[:, l + 1 :], direction,
quizzes[:, :l]], dim=1
)
)
- ar_mask = self.make_ar_mask(
c_
quizzes)
+ ar_mask = self.make_ar_mask(quizzes)
###############################################################
# Check how many of the other models can solve them in both
###############################################################
# Check how many of the other models can solve them in both
@@
-334,7
+328,7
@@
class World(Task):
nb_correct = []
for m in other_models:
nb_correct = []
for m in other_models:
- result =
c_
quizzes.clone()
+ result = quizzes.clone()
masked_inplace_autoregression(
model=m,
masked_inplace_autoregression(
model=m,
@@
-344,13
+338,13
@@
class World(Task):
summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving
c_
quizzes",
+ progress_bar_desc="solving quizzes",
device=self.device,
)
device=self.device,
)
- correct = (
c_
quizzes == result).long().min(dim=-1).values
+ correct = (quizzes == result).long().min(dim=-1).values
- reverse_result = reverse_
c_
quizzes.clone()
+ reverse_result = reverse_quizzes.clone()
masked_inplace_autoregression(
model=m,
masked_inplace_autoregression(
model=m,
@@
-360,21
+354,21
@@
class World(Task):
summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
summed_logits=None,
temperature=1.0,
deterministic_synthesis=True,
- progress_bar_desc="solving reversed
c_
quizzes",
+ progress_bar_desc="solving reversed quizzes",
device=self.device,
)
reverse_correct = (
device=self.device,
)
reverse_correct = (
- (reverse_
c_
quizzes == reverse_result).long().min(dim=-1).values
+ (reverse_quizzes == reverse_result).long().min(dim=-1).values
)
nb_correct.append((correct * reverse_correct)[None, :])
)
nb_correct.append((correct * reverse_correct)[None, :])
- nb_correct = torch.cat(nb_correct, dim=0)
.sum(dim=0)
+ nb_correct = torch.cat(nb_correct, dim=0)
# filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
# with open(filename, "w") as f:
# for k in nb_correct:
# f.write(f"{k}\n")
# filename = os.path.join(result_dir, "correct_{n_epoch:04d}.dat")
# with open(filename, "w") as f:
# for k in nb_correct:
# f.write(f"{k}\n")
- return
c_quizzes, nb_correct
, summed_logits.mean()
+ return
quizzes, nb_correct.sum(dim=0)
, summed_logits.mean()