("gray", [128, 128, 128]),
]
- def check_structure(self, quizzes, struct):
+ def check_order(self, quizzes, quad_order):
S = self.height * self.width
return (
- (quizzes[:, 0 * (S + 1)] == self.l2tok[struct[0]])
- & (quizzes[:, 1 * (S + 1)] == self.l2tok[struct[1]])
- & (quizzes[:, 2 * (S + 1)] == self.l2tok[struct[2]])
- & (quizzes[:, 3 * (S + 1)] == self.l2tok[struct[3]])
+ (quizzes[:, 0 * (S + 1)] == self.l2tok[quad_order[0]])
+ & (quizzes[:, 1 * (S + 1)] == self.l2tok[quad_order[1]])
+ & (quizzes[:, 2 * (S + 1)] == self.l2tok[quad_order[2]])
+ & (quizzes[:, 3 * (S + 1)] == self.l2tok[quad_order[3]])
).all()
- def get_structure(self, quizzes):
+ def get_order(self, quizzes):
S = self.height * self.width
- struct = tuple(
+ quad_order = tuple(
self.tok2l[n.item()]
for n in quizzes.reshape(quizzes.size(0), 4, S + 1)[0, :, 0]
)
- self.check_structure(quizzes, struct)
- return struct
+ self.check_order(quizzes, quad_order)
+ return quad_order
- def inject_noise(self, quizzes, noise, struct, quad):
- assert self.check_structure(quizzes, struct=struct)
+ def inject_noise(self, quizzes, noise, quad_order, quad_noise):
+ assert self.check_order(quizzes, quad_order=quad_order)
S = self.height * self.width
- mask = torch.tensor(quad, device=quizzes.device)
+ mask = torch.tensor(quad_noise, device=quizzes.device)
mask = mask[None, :, None].expand(1, 4, S + 1).clone()
mask[:, :, 0] = 0
mask = mask.reshape(1, -1).expand_as(quizzes)
return quizzes
# What a mess
- def reconfigure(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+ def reconfigure(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
if torch.is_tensor(quizzes):
- return self.reconfigure([quizzes], struct=struct)[0]
+ return self.reconfigure([quizzes], quad_order=quad_order)[0]
S = self.height * self.width
result = [x.new(x.size()) for x in quizzes]
- struct_from = self.get_structure(quizzes[0][:1])
- i = self.indices_select(quizzes[0], struct_from)
+ quad_order_from = self.get_order(quizzes[0][:1])
+ i = self.indices_select(quizzes[0], quad_order_from)
- sf = dict((l, n) for n, l in enumerate(struct_from))
+ sf = dict((l, n) for n, l in enumerate(quad_order_from))
for q in range(4):
- k = sf[struct[q]]
+ k = sf[quad_order[q]]
for x, y in zip(quizzes, result):
l = x.size(1) // 4
y[i, q * l : (q + 1) * l] = x[i, k * l : (k + 1) * l]
if j.any():
for z, y in zip(
- self.reconfigure([x[j] for x in quizzes], struct=struct), result
+ self.reconfigure([x[j] for x in quizzes], quad_order=quad_order), result
):
y[j] = z
def trivial(self, quizzes):
S = self.height * self.width
- assert self.check_structure(quizzes, struct=("A", "f_A", "B", "f_B"))
+ assert self.check_order(quizzes, quad_order=("A", "f_A", "B", "f_B"))
a = quizzes.reshape(quizzes.size(0), 4, S + 1)[:, :, 1:]
return (a[:, 0] == a[:, 1]).min(dim=1).values | (a[:, 2] == a[:, 3]).min(
dim=1
).values
def make_quiz_mask(
- self, quizzes, struct=("A", "f_A", "B", "f_B"), quad=(0, 0, 0, 1)
+ self, quizzes, quad_order=("A", "f_A", "B", "f_B"), quad_mask=(0, 0, 0, 1)
):
- assert self.check_structure(quizzes, struct)
+ assert self.check_order(quizzes, quad_order)
ar_mask = quizzes.new_zeros(quizzes.size())
S = self.height * self.width
a = ar_mask.reshape(ar_mask.size(0), 4, S + 1)[:, :, 1:]
- a[:, 0, :] = quad[0]
- a[:, 1, :] = quad[1]
- a[:, 2, :] = quad[2]
- a[:, 3, :] = quad[3]
+ a[:, 0, :] = quad_mask[0]
+ a[:, 1, :] = quad_mask[1]
+ a[:, 2, :] = quad_mask[2]
+ a[:, 3, :] = quad_mask[3]
return ar_mask
- def indices_select(self, quizzes, struct=("A", "f_A", "B", "f_B")):
+ def indices_select(self, quizzes, quad_order=("A", "f_A", "B", "f_B")):
S = self.height * self.width
q = quizzes.reshape(quizzes.size(0), 4, S + 1)
return (
- (q[:, 0, 0] == self.l2tok[struct[0]])
- & (q[:, 1, 0] == self.l2tok[struct[1]])
- & (q[:, 2, 0] == self.l2tok[struct[2]])
- & (q[:, 3, 0] == self.l2tok[struct[3]])
+ (q[:, 0, 0] == self.l2tok[quad_order[0]])
+ & (q[:, 1, 0] == self.l2tok[quad_order[1]])
+ & (q[:, 2, 0] == self.l2tok[quad_order[2]])
+ & (q[:, 3, 0] == self.l2tok[quad_order[3]])
)
def __init__(
######################################################################
- def create_empty_quizzes(self, nb, struct=("A", "f_A", "B", "f_B")):
+ def create_empty_quizzes(self, nb, quad_order=("A", "f_A", "B", "f_B")):
S = self.height * self.width
quizzes = torch.zeros(nb, 4 * (S + 1), dtype=torch.int64)
- quizzes[:, 0 * (S + 1)] = self.l2tok[struct[0]]
- quizzes[:, 1 * (S + 1)] = self.l2tok[struct[1]]
- quizzes[:, 2 * (S + 1)] = self.l2tok[struct[2]]
- quizzes[:, 3 * (S + 1)] = self.l2tok[struct[3]]
+ quizzes[:, 0 * (S + 1)] = self.l2tok[quad_order[0]]
+ quizzes[:, 1 * (S + 1)] = self.l2tok[quad_order[1]]
+ quizzes[:, 2 * (S + 1)] = self.l2tok[quad_order[2]]
+ quizzes[:, 3 * (S + 1)] = self.l2tok[quad_order[3]]
return quizzes
# nb = 5
# quizzes = grids.generate_w_quizzes_(nb, tasks=[grids.task_fill])
# print(quizzes)
- # print(grids.get_structure(quizzes))
+ # print(grids.get_order(quizzes))
# quizzes = grids.reconfigure(quizzes, struct=("A", "B", "f_A", "f_B"))
# print("DEBUG2", quizzes)
- # print(grids.get_structure(quizzes))
+ # print(grids.get_order(quizzes))
# print(quizzes)
# i = torch.rand(quizzes.size(0)) < 0.5
# print(
# i.equal(j),
- # grids.get_structure(quizzes[j]),
- # grids.get_structure(quizzes[j == False]),
+ # grids.get_order(quizzes[j]),
+ # grids.get_order(quizzes[j == False]),
# )
# exit(0)
(solved_c_quizzes[:, model.id], _, _) = quiz_machine.predict(
model,
solved_c_quizzes[:, model.id],
- struct=("A", "f_A", "B", "f_B"),
+ quad_orders=("A", "f_A", "B", "f_B"),
quad=(0, 0, 0, 1),
)
return bs
+def ae_batches(quiz_machine, nb, data_structures, local_device, desc=None):
+ full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
+ args.nb_train_samples, data_structures=data_structures
+ )
+
+ src = zip(
+ full_input.split(args.batch_size),
+ full_mask_generate.split(args.batch_size),
+ full_mask_loss.split(args.batch_size),
+ )
+
+ if desc is not None:
+ src = tqdm.tqdm(
+ src,
+ dynamic_ncols=True,
+ desc=desc,
+ total=full_input.size(0) // args.batch_size,
+ )
+
+ for input, mask_generate, mask_loss in src:
+ yield (
+ input.to(local_device),
+ mask_generate.to(local_device),
+ mask_loss.to(local_device),
+ )
+
+
def test_ae(local_device=main_device):
model = MyAttentionVAE(
vocabulary_size=vocabulary_size,
dropout=args.dropout,
).to(main_device)
+ data_structures = [
+ (("A", "f_A", "B", "f_B"), (0, 0, 0, 1), (0, 0, 1, 0), (1, 1, 1, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 0, 1, 0), (0, 0, 0, 1), (1, 1, 1, 1)),
+ (("A", "f_A", "B", "f_B"), (0, 1, 0, 0), (1, 0, 0, 0), (1, 1, 1, 1)),
+ (("A", "f_A", "B", "f_B"), (1, 0, 0, 0), (0, 1, 0, 0), (1, 1, 1, 1)),
+ (("A", "f_A", "B", "f_B"), (1, 1, 1, 0), (0, 0, 0, 0), (1, 1, 1, 1)),
+ ]
+
model.optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
model.to(local_device).train()
model.train()
nb_train_samples, acc_train_loss = 0, 0.0
- full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
- args.nb_train_samples
- )
-
- src = zip(
- full_input.split(args.batch_size),
- full_mask_generate.split(args.batch_size),
- full_mask_loss.split(args.batch_size),
- )
-
- for input, mask_generate, mask_loss in tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc="training",
- total=full_input.size(0) // args.batch_size,
+ for input, mask_generate, mask_loss in ae_batches(
+ quiz_machine,
+ args.nb_train_samples,
+ data_structures,
+ local_device,
+ "training",
):
- input = input.to(local_device)
- mask_generate = mask_generate.to(local_device)
- mask_loss = mask_loss.to(local_device)
-
if nb_train_samples % args.batch_size == 0:
model.optimizer.zero_grad()
nb_test_samples, acc_test_loss = 0, 0.0
- full_input, full_mask_generate, full_mask_loss = quiz_machine.data_input(
- args.nb_test_samples
- )
-
- src = zip(
- full_input.split(args.batch_size),
- full_mask_generate.split(args.batch_size),
- full_mask_loss.split(args.batch_size),
- )
-
- for input, mask_generate, mask_loss in tqdm.tqdm(
- src,
- dynamic_ncols=True,
- desc="testing",
- total=full_input.size(0) // args.batch_size,
+ for input, mask_generate, mask_loss in ae_batches(
+ quiz_machine,
+ args.nb_test_samples,
+ data_structures,
+ local_device,
+ "test",
):
- input = input.to(local_device)
- mask_generate = mask_generate.to(local_device)
- mask_loss = mask_loss.to(local_device)
-
targets = input
mask_noise = (mask_generate != 0) & (
log_string(f"test_loss {n_epoch} model AE {acc_test_loss/nb_test_samples}")
- input, mask_generate, mask_loss = quiz_machine.data_input(128)
- input = input.to(local_device)
- mask_generate = mask_generate.to(local_device)
- mask_loss = mask_loss.to(local_device)
+ input, mask_generate, mask_loss = next(
+ ae_batches(quiz_machine, 128, data_structures, local_device)
+ )
+
targets = input
pred_result = None
nb = 0
# We consider all the configurations that we train for
- for struct, quad_generate, _, _ in quiz_machine.test_structures:
- i = quiz_machine.problem.indices_select(quizzes=input, struct=struct)
+ for quad_order, quad_generate, _, _ in quiz_machine.test_structures:
+ i = quiz_machine.problem.indices_select(
+ quizzes=input, quad_order=quad_order
+ )
nb += i.long().sum()
predicted_parts[i] = torch.tensor(quad_generate, device=result.device)[
quizzes = quizzes[i]
self.randomize_configuations_inplace(
- quizzes, structs=[s for s, _, _, _ in data_structures]
+ quizzes, quad_orders=[s for s, _, _, _ in data_structures]
)
quiz_mask_generate = quizzes.new_full(quizzes.size(), 1)
quiz_mask_loss = quizzes.new_full(quizzes.size(), 1)
- for struct, quad_generate, quad_noise, quad_loss in data_structures:
- i = self.problem.indices_select(quizzes=quizzes, struct=struct)
+ for quad_order, quad_generate, quad_noise, quad_loss in data_structures:
+ i = self.problem.indices_select(quizzes=quizzes, quad_order=quad_order)
if i.any():
if self.prompt_noise > 0.0:
quizzes[i] = self.problem.inject_noise(
- quizzes[i], self.prompt_noise, struct=struct, quad=quad_noise
+ quizzes[i],
+ self.prompt_noise,
+ quad_order=quad_order,
+ quad_noise=quad_noise,
)
quiz_mask_generate[i] = self.make_quiz_mask(
- quizzes=quizzes[i], struct=struct, quad=quad_generate
+ quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_generate
)
quiz_mask_loss[i] = self.make_quiz_mask(
- quizzes=quizzes[i], struct=struct, quad=quad_loss
+ quizzes=quizzes[i], quad_order=quad_order, quad_mask=quad_loss
)
return quizzes, quiz_mask_generate, quiz_mask_loss
######################################################################
- def make_quiz_mask(self, quizzes, struct, quad):
- assert struct in [s for s, _, _, _ in self.train_structures]
- return self.problem.make_quiz_mask(quizzes, struct=struct, quad=quad)
+ def make_quiz_mask(self, quizzes, quad_order, quad_mask):
+ assert quad_order in [s for s, _, _, _ in self.train_structures]
+ return self.problem.make_quiz_mask(
+ quizzes, quad_order=quad_order, quad_mask=quad_mask
+ )
######################################################################
- def predict(self, model, quizzes, struct, quad):
+ def predict(self, model, quizzes, quad_order, quad_mask):
quizzes = quizzes.to(self.device)
- ar_mask = self.make_quiz_mask(quizzes=quizzes, struct=struct, quad=quad)
+ ar_mask = self.make_quiz_mask(
+ quizzes=quizzes, quad_order=quad_order, quad_mask=quad_mask
+ )
result = quizzes * (1 - ar_mask)
seq_logprobas = torch.zeros(quizzes.size(0), device=self.device)
nb = 0
# We consider all the configurations that we train for
- for struct, quad_generate, _, _ in self.test_structures:
- i = self.problem.indices_select(quizzes=input, struct=struct)
+ for quad_order, quad_generate, _, _ in self.test_structures:
+ i = self.problem.indices_select(quizzes=input, quad_order=quad_order)
nb += i.long().sum()
result[i], correct[i], _ = self.predict(
- model=model, quizzes=input[i], struct=struct, quad=quad_generate
+ model=model, quizzes=input[i], quad_order=quad_order, quad=quad_generate
)
predicted_parts[i] = torch.tensor(quad_generate, device=self.device)[
######################################################################
- def randomize_configuations_inplace(self, quizzes, structs):
- r = torch.randint(len(structs), (quizzes.size(0),), device=quizzes.device)
- for c in range(len(structs)):
+ def randomize_configuations_inplace(self, quizzes, quad_orders):
+ r = torch.randint(len(quad_orders), (quizzes.size(0),), device=quizzes.device)
+ for c in range(len(quad_orders)):
quizzes[r == c] = self.problem.reconfigure(
- quizzes[r == c], struct=structs[c]
+ quizzes[r == c], quad_order=quad_orders[c]
)
######################################################################
self,
model,
c_quizzes,
- struct,
+ quad_order,
quad_loss,
quad_noise=None,
temperature=1.0,
if device is None:
device = self.device
- c_quizzes = self.problem.reconfigure(c_quizzes, struct)
+ c_quizzes = self.problem.reconfigure(c_quizzes, quad_order)
seq_logprobas = torch.zeros(
c_quizzes.size(0),
# if self.prompt_noise > 0.0 and quad_noise is not None:
# c_quizzes = self.problem.inject_noise(
- # c_quizzes, self.prompt_noise, struct=struct, quad=quad_noise
+ # c_quizzes, self.prompt_noise, quad_order=quad_order, quad_noise=quad_noise
# )
with torch.autograd.no_grad():
):
input = input.to(device)
quiz_mask_loss = self.make_quiz_mask(
- input, struct=struct, quad=quad_loss
+ input, quad_order=quad_order, quad_mask=quad_loss
)
output = model(mygpt.BracketedSequence(input)).x / temperature
l[...] = (
c_quizzes = None
for n_step, setup in enumerate(procedure):
- struct, quad_generate, model_modifier = setup
+ quad_order, quad_generate, model_modifier = setup
if c_quizzes is None:
- c_quizzes = self.problem.create_empty_quizzes(nb, struct)
+ c_quizzes = self.problem.create_empty_quizzes(nb, quad_order)
c_quizzes = c_quizzes.to(self.device)
- elif struct != pred_struct:
- c_quizzes = self.problem.reconfigure(c_quizzes, struct)
- pred_struct = struct
+ elif quad_order != pred_quad_order:
+ c_quizzes = self.problem.reconfigure(c_quizzes, quad_order)
+ pred_quad_order = quad_order
if model_modifier is not None:
model_modifier(model_for_generation)
self.autoregression(
model=model_for_generation,
input=c_quizzes,
- ar_mask=self.make_quiz_mask(c_quizzes, struct, quad_generate),
+ ar_mask=self.make_quiz_mask(
+ quizzes=c_quizzes, quad_order=quad_order, quad_mask=quad_generate
+ ),
seq_logprobas=seq_logprobas,
progress_bar_desc=f"autoregression {n_step+1}/{len(procedure)}",
)