######################################################################
+def create_trunk(dim_model, dim_keys, dim_hidden, nb_heads, nb_blocks, dropout=0.0):
+ trunk_blocks = []
+
+ for b in range(nb_blocks):
+ trunk_blocks += [
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ MHAttention(
+ dim_model=dim_model,
+ dim_qk=dim_keys,
+ dim_v=dim_model // nb_heads,
+ nb_heads=nb_heads,
+ attention=vanilla_attention,
+ attention_dropout=dropout,
+ ),
+ ),
+ WithResidual(
+ nn.LayerNorm((dim_model,)),
+ nn.Linear(in_features=dim_model, out_features=dim_hidden),
+ nn.ReLU(),
+ nn.Linear(in_features=dim_hidden, out_features=dim_model),
+ nn.Dropout(dropout),
+ ),
+ ]
+
+ return nn.Sequential(*trunk_blocks)
+
+
+######################################################################
+
+
class AttentionAE(nn.Module):
def __init__(
self,
- vocabulary_size,
+ vocabulary_size_in,
+ vocabulary_size_out,
dim_model,
dim_keys,
dim_hidden,
assert dim_model % nb_heads == 0
self.embedding = nn.Sequential(
- nn.Embedding(2 * vocabulary_size, dim_model),
+ nn.Embedding(vocabulary_size_in, dim_model),
nn.Dropout(dropout),
)
self.positional_encoding = VaswaniPositionalEncoding(len_max)
- trunk_blocks = []
-
- for b in range(nb_blocks):
- trunk_blocks += [
- WithResidual(
- nn.LayerNorm((dim_model,)),
- MHAttention(
- dim_model=dim_model,
- dim_qk=dim_keys,
- dim_v=dim_model // nb_heads,
- nb_heads=nb_heads,
- attention=vanilla_attention,
- attention_dropout=dropout,
- ),
- ),
- WithResidual(
- nn.LayerNorm((dim_model,)),
- nn.Linear(in_features=dim_model, out_features=dim_hidden),
- nn.ReLU(),
- nn.Linear(in_features=dim_hidden, out_features=dim_model),
- nn.Dropout(dropout),
- ),
- ]
-
- self.trunk = nn.Sequential(*trunk_blocks)
+ self.trunk = create_trunk(
+ dim_model=dim_model,
+ dim_keys=dim_keys,
+ dim_hidden=dim_hidden,
+ nb_heads=nb_heads,
+ nb_blocks=nb_blocks,
+ dropout=dropout,
+ )
- self.readout = nn.Linear(in_features=dim_model, out_features=vocabulary_size)
+ self.readout = nn.Linear(
+ in_features=dim_model, out_features=vocabulary_size_out
+ )
with torch.no_grad():
for m in self.modules():
if __name__ == "__main__":
model = FunctionalAttentionAE(
- vocabulary_size=100,
+ vocabulary_size_in=100,
+ vocabulary_size_out=100,
dim_model=16,
dim_keys=64,
dim_hidden=32,
torch.set_float32_matmul_precision("high")
-# torch.set_default_dtype(torch.bfloat16)
+torch.set_default_dtype(torch.bfloat16)
######################################################################
######################################################################
-if args.gpus == "all":
- gpus_idx = range(torch.cuda.device_count())
+if args.gpus == "none" or not torch.cuda.is_available():
+ gpus = [torch.device("cpu")]
else:
- gpus_idx = [int(k) for k in args.gpus.split(",")]
+ if args.gpus == "all":
+ gpus_idx = range(torch.cuda.device_count())
+ else:
+ gpus_idx = [int(k) for k in args.gpus.split(",")]
-gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
+ gpus = [torch.device(f"cuda:{n}") for n in gpus_idx]
-if torch.cuda.is_available():
- main_device = gpus[0]
-else:
- assert len(gpus) == 0
- main_device = torch.device("cpu")
+main_device = gpus[0]
if args.train_batch_size is None:
args.train_batch_size = args.batch_size
######################################################################
+# IMT stands for image/mask/target
+
def add_hints_imt(imt_set, proba_hints):
"""Set every component of the mask to zero with probability proba,
# Prediction
-def samples_for_prediction_imt(input):
+def make_imt_samples_for_prediction(input):
nb = input.size(0)
masks = input.new_zeros(input.size())
u = F.one_hot(torch.randint(4, (nb,), device=masks.device), num_classes=4)
imt[:, 0] = imt[:, 0] * (1 - imt[:, 1])
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(imt[:, 0] * 2 + imt[:, 1])
+ logits = model(imt[:, 0] + imt[:, 1] * vocabulary_size)
dist = torch.distributions.categorical.Categorical(logits=logits)
result = (1 - imt[:, 1]) * imt[:, 0] + imt[:, 1] * dist.sample()
record.append(result)
######################################################################
-def samples_for_generation_imt(input):
+def make_imt_samples_for_generation(input):
nb = input.size(0)
probs_iterations = 0.1 ** torch.linspace(
0, 1, args.diffusion_nb_iterations, device=input.device
for input, masks, changed in src:
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(input * 2 + masks)
+ logits = model(input + masks * vocabulary_size)
dist = torch.distributions.categorical.Categorical(logits=logits)
output = dist.sample()
r = prioritized_rand(input != output)
q_p, q_g = quizzes.to(local_device).chunk(2)
# Half of the samples are used to train the prediction.
- b_p = samples_for_prediction_imt(q_p)
+ b_p = make_imt_samples_for_prediction(q_p)
# We inject noise in all to avoid drift of the culture toward
# "finding waldo" type of complexity
b_p = add_input_noise_imt(b_p, args.proba_input_noise)
# The other half are denoising examples to train the generative
# process.
- b_g = samples_for_generation_imt(q_g)
+ b_g = make_imt_samples_for_generation(q_g)
imt_set = torch.cat([b_p, b_g])
imt_set = imt_set[torch.randperm(imt_set.size(0), device=imt_set.device)]
model.optimizer.zero_grad()
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
- logits = model(input * 2 + masks)
+ logits = model(input + masks * vocabulary_size)
loss_per_token = F.cross_entropy(
logits.transpose(1, 2), targets, reduction="none"
# Save some images of the prediction results
quizzes = generate_quiz_set(150, c_quizzes, args.c_quiz_multiplier)
- imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ imt_set = make_imt_samples_for_prediction(quizzes.to(local_device))
result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
masks = imt_set[:, 1].to("cpu")
# c_quizzes=test_c_quizzes,
c_quiz_multiplier=args.c_quiz_multiplier,
)
- imt_set = samples_for_prediction_imt(quizzes.to(local_device))
+ imt_set = make_imt_samples_for_prediction(quizzes.to(local_device))
result = ae_predict(model, imt_set, local_device=local_device).to("cpu")
correct = (quizzes == result).min(dim=1).values.long()
raise ValueError(f"Unknown model type {args.model_type}")
model = model_constructor(
- vocabulary_size=vocabulary_size * 2,
+ vocabulary_size_in=vocabulary_size * 2,
+ vocabulary_size_out=vocabulary_size,
dim_model=args.dim_model,
dim_keys=args.dim_keys,
dim_hidden=args.dim_hidden,
######################################################################
+if args.test == "aebn":
+ model = new_model()
+
+ # model.trunk = (
+ # model.trunk[: len(model.trunk) // 2] + model.trunk[len(model.trunk) // 2 :]
+ # )
+
+ model.id = 0
+ model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
+ model.test_accuracy = 0.0
+ model.nb_epochs = 0
+
+ for n_epoch in range(args.nb_epochs):
+ one_complete_epoch(
+ model,
+ n_epoch,
+ train_c_quizzes=None,
+ test_c_quizzes=None,
+ local_device=main_device,
+ )
+
+######################################################################
+
train_c_quizzes, test_c_quizzes = None, None
models = []
result_dir, prefix + t.__name__ + ".png", quizzes, nrow=nrow, delta=True
)
- def detect_rectangles(self, q1, q2):
+ def oracle(self, q1, q2):
c = torch.arange(self.nb_colors)
- I = torch.arange(self.height)[None, :, None]
- J = torch.arange(self.width)[None, :, None]
+ all_i = torch.arange(self.height)[None, :, None]
+ all_j = torch.arange(self.width)[None, :, None]
def corners(q):
q = q.reshape(-1, self.height, self.width)
a = (q[:, :, :, None] == c[None, None, None, :]).long()
mi = a.max(dim=2).values
- i = mi * I
+ i = mi * all_i
i1 = (i + (1 - mi) * q.size(1)).min(dim=1).values
i2 = (i + (1 - mi) * (-1)).max(dim=1).values + 1
mj = a.max(dim=1).values
- j = mj * J
+ j = mj * all_j
j1 = (j + (1 - mj) * q.size(2)).min(dim=1).values
j2 = (j + (1 - mj) * (-1)).max(dim=1).values + 1
m = (
- ((I > i1[:, None, :]) & (I < i2[:, None, :] - 1))[:, :, None, :]
- & ((J > j1[:, None, :]) & (J < j2[:, None, :] - 1))[:, None, :, :]
+ ((all_i > i1[:, None, :]) & (all_i < i2[:, None, :] - 1))[:, :, None, :]
+ & ((all_j > j1[:, None, :]) & (all_j < j2[:, None, :] - 1))[
+ :, None, :, :
+ ]
).long()
f = ((a * m).long().sum(dim=(1, 2)) > 0).long()
+ i1[:, 0], i2[:, 0], j1[:, 0], j2[:, 0] = self.height, 0, self.width, 0
return i1, i2, j1, j2, f
+ # Coordinates and frame-shape per grid per color
+ #
+ # NxC
+ #
q1_i1, q1_i2, q1_j1, q1_j2, q1_f = corners(q1)
q2_i1, q2_i2, q2_j1, q2_j2, q2_f = corners(q2)
+
u1, u2 = 0, 0
for _ in range(10):
r2 = q.new_zeros(q1.size(0), self.height, self.width)
m1 = (
- ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :]))[:, :, None, :]
- & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[:, None, :, :]
+ ((all_i >= q1_i1[:, None, :]) & (all_i < q1_i2[:, None, :]))[
+ :, :, None, :
+ ]
+ & ((all_j >= q1_j1[:, None, :]) & (all_j < q1_j2[:, None, :]))[
+ :, None, :, :
+ ]
).long()
f1 = (
- (
- ((I == q1_i1[:, None, :]) | (I == q1_i2[:, None, :] - 1))[
- :, :, None, :
- ]
- & ((J >= q1_j1[:, None, :]) & (J < q1_j2[:, None, :]))[
- :, None, :, :
- ]
- )
- | (
- ((I >= q1_i1[:, None, :]) & (I < q1_i2[:, None, :] - 1))[
+ m1
+ * (
+ ((all_i == q1_i1[:, None, :]) | (all_i == q1_i2[:, None, :] - 1))[
:, :, None, :
]
- & ((J == q1_j1[:, None, :]) | (J == q1_j2[:, None, :] - 1))[
+ | ((all_j == q1_j1[:, None, :]) | (all_j == q1_j2[:, None, :] - 1))[
:, None, :, :
]
- )
- ).long()
+ ).long()
+ )
r2 = q.new_zeros(q2.size(0), self.height, self.width)
m2 = (
- ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :]))[:, :, None, :]
- & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[:, None, :, :]
+ ((all_i >= q2_i1[:, None, :]) & (all_i < q2_i2[:, None, :]))[
+ :, :, None, :
+ ]
+ & ((all_j >= q2_j1[:, None, :]) & (all_j < q2_j2[:, None, :]))[
+ :, None, :, :
+ ]
).long()
f2 = (
- (
- ((I == q2_i1[:, None, :]) | (I == q2_i2[:, None, :] - 1))[
- :, :, None, :
- ]
- & ((J >= q2_j1[:, None, :]) & (J < q2_j2[:, None, :]))[
- :, None, :, :
- ]
- )
- | (
- ((I >= q2_i1[:, None, :]) & (I < q2_i2[:, None, :] - 1))[
+ m2
+ * (
+ ((all_i == q2_i1[:, None, :]) | (all_i == q2_i2[:, None, :] - 1))[
:, :, None, :
]
- & ((J == q2_j1[:, None, :]) | (J == q2_j2[:, None, :] - 1))[
+ | ((all_j == q2_j1[:, None, :]) | (all_j == q2_j2[:, None, :] - 1))[
:, None, :, :
]
- )
- ).long()
+ ).long()
+ )
for c in torch.randperm(self.nb_colors - 1) + 1:
r1[...] = q1_f[:, None, None, c] * (
u1 = (1 - match) * u1 + match * r1
u2 = (1 - match) * u2 + match * r2
- return u1.flatten(1), u2.flatten(1)
+ ok = (u1.flatten(1) == q1).min(dim=1).values & (u2.flatten(1) == q2).min(
+ dim=1
+ ).values
- # o = F.one_hot(q * (1 - m)).sum(dim=1)
- # print(o)
- # print(o.sort(dim=1, descending=True))
- # c = N x nb_col x 4
+ # q1_i1, q1_i2, q1_j1, q1_j2, q1_f = corners(q1)
+ # q2_i1, q2_i2, q2_j1, q2_j2, q2_f = corners(q2)
+ # NxC
+
+ q1i1, q1i2, q1j1, q1j2, q1f = (
+ q1_i1[:, :, None],
+ q1_i2[:, :, None],
+ q1_j1[:, :, None],
+ q1_j2[:, :, None],
+ q1_f[:, :, None],
+ )
+ q2i1, q2i2, q2j1, q2j2, q2f = (
+ q2_i1[:, None, :],
+ q2_i2[:, None, :],
+ q2_j1[:, None, :],
+ q2_j2[:, None, :],
+ q2_f[:, None, :],
+ )
+
+ match = (
+ (q1i1 < q1i2)
+ & (q2i1 < q2i2)
+ & (q1i1 == q2i1)
+ & (q1i2 == q2i2)
+ & (q1j1 == q2j1)
+ & (q1j2 == q2j2)
+ ).long()
+ translate = (
+ ((q1i1 - q2i1).abs() <= 1)
+ & (q1i1 - q2i1 == q1i2 - q2i2)
+ & ((q1j1 - q2j1).abs() <= 1)
+ & (q1j1 - q2j1 == q1j2 - q2j2)
+ & ((q1i1 - q2i1).abs() + (q1j1 - q2j1).abs() > 0)
+ ).long()
+ grow = (
+ (
+ (q2i1 == q1i1 - 1)
+ & (q2i2 == q1i2 + 1)
+ & (q2j1 == q1j1 - 1)
+ & (q2j2 == q1j2 + 1)
+ )
+ | (
+ (q2i1 == q1i1 + 1)
+ & (q2i2 == q1i2 - 1)
+ & (q2j1 == q1j1 + 1)
+ & (q2j2 == q1j2 - 1)
+ )
+ ).long()
+
+ nb_same_color_not_frame = torch.einsum("ncc->n", match * q2f)
+ nb_change_color = torch.einsum("ncd->n", match * q2f) - nb_same_color_not_frame
+ nb_frame = torch.einsum("ncc->n", match * (1 - q2f))
+ nb_translate = torch.einsum("ncc->n", translate)
+ nb_translate_change_color = torch.einsum("ncd->n", translate) - nb_translate
+ nb_grow = torch.einsum("ncc->n", grow)
+ nb_grow_change_color = torch.einsum("ncd->n", grow) - nb_grow
+
+ print("-------------------------")
+ print("nb_same_color_not_frame", nb_same_color_not_frame)
+ print("nb_change_color", nb_change_color)
+ print("nb_frame", nb_frame)
+ print("nb_translate", nb_translate)
+ print("nb_translate_change_color", nb_translate_change_color)
+ print("nb_grow", nb_grow)
+ print("nb_grow_change_color", nb_grow_change_color)
+
+ # ok = ok & ( <= 1) & (translate.sum(dim=(1,2) == 3)
+
+ # print("match", match, "\n\n")
+
+ # print("translate", translate, "\n\n")
+
+ # print("grow", grow, "\n\n")
+
+ return u1.flatten(1), u2.flatten(1)
######################################################################
+def recenv(a):
+ s_row = a.sum(dim=2, keepdim=True)
+ c_row = s_row.cumsum(dim=1)
+ s_col = a.sum(dim=1, keepdim=True)
+ c_col = s_col.cumsum(dim=2)
+ env_row = ((c_row > 0) & ((c_row < c_row[:, -1:, :]) | (s_row > 0))).long()
+ env_col = ((c_col > 0) & ((c_col < c_col[:, :, -1:]) | (s_col > 0))).long()
+ return env_row * env_col
+
+def valid(q1, q2, m1=None, m2=None):
+ if m1 is None:
+ m1=q1.new_zeros(m1.size())
+ if m2 is None:
+ m2=q2.new_zeros(m2.size())
+
+def valid_exact_match(q1, q2, m1, m2, c1, c2):
+ # q1, q2, m1, m2 are NxHxW
+ q1,m1=F.hone_hot(q1,numclasses=self.nb_colors),m1[:,:,:,None]
+ q2,m2=F.hone_hot(q2,numclasses=self.nb_colors),m2[:,:,:,None]
+ a1 = (1 - m1) * q1
+ b1 = (1 - m1) * (1-q1)
+ a2 = (1 - m2) * q2
+ b2 = (1 - m2) * (1-q2)
+ rec_a1 = recenv(a1)
+ rec_a2 = recenv(a2)
+ rec = recenv(1-(1-rec_a1)*(1-rec_a2))
+ ok = rec_a1 * rec * (1-m1)
+
if __name__ == "__main__":
+
import time
grids = Grids()
nb, nrow = 64, 4
nb_rows = 12
+ q = grids.generate_w_quizzes_(
+ 1,
+ tasks=[
+ grids.task_replace_color,
+ # grids.task_translate,
+ # grids.task_grow,
+ # grids.task_frame,
+ ],
+ )
+
+ q = q.reshape(q.size(0), 4, q.size(1)//4)
+
+ print(q)
+ print(valid_exact_match(q[:,0], q[:,1]
+
+ exit(0)
+
# c_quizzes = torch.load("/home/fleuret/state.pth")["train_c_quizzes"]
# c_quizzes = c_quizzes[torch.randperm(c_quizzes.size(0))[: nrow * nb_rows]]
# )
w_quizzes = grids.generate_w_quizzes_(
- 16,
+ 1,
tasks=[
grids.task_replace_color,
grids.task_translate,
q = w_quizzes.reshape(-1, 4, w_quizzes.size(1) // 4)
r = q.new_zeros(q.size())
- r[:, 0], r[:, 1] = grids.detect_rectangles(q[:, 0], q[:, 1])
- r[:, 2], r[:, 3] = grids.detect_rectangles(q[:, 2], q[:, 3])
+ r[:, 0], r[:, 1] = grids.oracle(q[:, 0], q[:, 1])
+ r[:, 2], r[:, 3] = grids.oracle(q[:, 2], q[:, 3])
grids.save_quizzes_as_image(
"/tmp",