######################################################################
-def create_model():
- hidden_dim = 32
-
- model = nn.Sequential(
- nn.Linear(2, hidden_dim),
- nn.ReLU(),
- nn.Linear(hidden_dim, hidden_dim),
- nn.ReLU(),
- nn.Linear(hidden_dim, 2),
- )
-
- return model
-
-
-######################################################################
def generate_sets_and_params(
- nb_mlps,
+ batch_nb_mlps,
nb_samples,
batch_size,
nb_epochs,
device=torch.device("cpu"),
print_log=False,
):
- data_input = torch.zeros(nb_mlps, 2 * nb_samples, 2, device=device)
+ data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
data_targets = torch.zeros(
- nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
+ batch_nb_mlps, 2 * nb_samples, dtype=torch.int64, device=device
)
while (data_targets.float().mean(-1) - 0.5).abs().max() > 0.1:
i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
nb = i.sum()
- print(f"{nb=}")
nb_rec = 2
support = torch.rand(nb, nb_rec, 2, 3, device=device) * 2 - 1
test_targets = test_targets
hidden_dim = 32
- w1 = torch.randn(nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
- b1 = torch.zeros(nb_mlps, hidden_dim, device=device)
- w2 = torch.randn(nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim)
- b2 = torch.zeros(nb_mlps, 2, device=device)
+ w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
+ b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
+ w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim)
+ b2 = torch.zeros(batch_nb_mlps, 2, device=device)
w1.requires_grad_()
b1.requires_grad_()
# print(f"{k=} {acc_train_loss=} {train_error=}")
q_params = torch.cat(
- [quantize(p.view(nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
+ [quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
)
q_train_set = torch.cat([q_train_input, train_targets[:, :, None]], -1).reshape(
- nb_mlps, -1
+ batch_nb_mlps, -1
)
q_test_set = torch.cat([q_test_input, test_targets[:, :, None]], -1).reshape(
- nb_mlps, -1
+ batch_nb_mlps, -1
)
return q_train_set, q_test_set, q_params
######################################################################
-def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu")):
+def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024):
+
+ errors = []
nb_mlps = q_params.size(0)
- hidden_dim = 32
- w1 = torch.empty(nb_mlps, hidden_dim, 2, device=device)
- b1 = torch.empty(nb_mlps, hidden_dim, device=device)
- w2 = torch.empty(nb_mlps, 2, hidden_dim, device=device)
- b2 = torch.empty(nb_mlps, 2, device=device)
-
- with torch.no_grad():
- k = 0
- for p in [w1, b1, w2, b2]:
- print(f"{p.size()=}")
- x = dequantize(q_params[:, k : k + p.numel() // nb_mlps], -2, 2).view(
- p.size()
- )
- p.copy_(x)
- k += p.numel() // nb_mlps
- q_set = q_set.view(nb_mlps, -1, 3)
- data_input = dequantize(q_set[:, :, :2], -1, 1).to(device)
- data_targets = q_set[:, :, 2].to(device)
+ for n in range(0,nb_mlps,nb_mlps_per_batch):
+ batch_nb_mlps = min(nb_mlps_per_batch,nb_mlps-n)
+ batch_q_params = q_params[n:n+batch_nb_mlps]
+ batch_q_set = q_set[n:n+batch_nb_mlps]
+ hidden_dim = 32
+ w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
+ b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
+ w2 = torch.empty(batch_nb_mlps, 2, hidden_dim, device=device)
+ b2 = torch.empty(batch_nb_mlps, 2, device=device)
- print(f"{data_input.size()=} {data_targets.size()=}")
+ with torch.no_grad():
+ k = 0
+ for p in [w1, b1, w2, b2]:
+ print(f"{p.size()=}")
+ x = dequantize(batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2).view(
+ p.size()
+ )
+ p.copy_(x)
+ k += p.numel() // batch_nb_mlps
- criterion = nn.CrossEntropyLoss()
- criterion.to(device)
+ batch_q_set = batch_q_set.view(batch_nb_mlps, -1, 3)
+ data_input = dequantize(batch_q_set[:, :, :2], -1, 1).to(device)
+ data_targets = batch_q_set[:, :, 2].to(device)
+
+ print(f"{data_input.size()=} {data_targets.size()=}")
+
+ criterion = nn.CrossEntropyLoss()
+ criterion.to(device)
+
+ acc_loss = 0.0
+ nb_errors = 0
- acc_loss = 0.0
- nb_errors = 0
+ for input, targets in zip(
+ data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
+ ):
+ h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+ h = F.relu(h)
+ output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+ loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
+ acc_loss += loss.item() * input.size(0)
+ wta = output.argmax(-1)
+ nb_errors += (wta != targets).long().sum(-1)
- for input, targets in zip(
- data_input.split(batch_size, dim=1), data_targets.split(batch_size, dim=1)
- ):
- h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
- h = F.relu(h)
- output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
- loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
- acc_loss += loss.item() * input.size(0)
- wta = output.argmax(-1)
- nb_errors += (wta != targets).long().sum(-1)
+ errors.append(nb_errors / data_input.size(1))
+ acc_loss = acc_loss / data_input.size(1)
- error = nb_errors / data_input.size(1)
- acc_loss = acc_loss / data_input.size(1)
- return error
+ return torch.cat(errors)
######################################################################
batch_size,
nb_epochs,
device,
+ nb_mlps_per_batch=1024,
):
- q_train_set, q_test_set, q_params = generate_sets_and_params(
- nb_mlps,
- nb_samples,
- batch_size,
- nb_epochs,
- device=device,
- )
- input = torch.cat(
- [
- q_train_set,
- q_train_set.new_full(
- (
- q_train_set.size(0),
- 1,
+ inputs, q_test_sets = [],[]
+
+ for n in range(0,nb_mlps,nb_mlps_per_batch):
+ q_train_set, q_test_set, q_params = generate_sets_and_params(
+ batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n),
+ nb_samples=nb_samples,
+ batch_size=batch_size,
+ nb_epochs=nb_epochs,
+ device=device,
+ )
+
+ inputs.append(torch.cat(
+ [
+ q_train_set,
+ q_train_set.new_full(
+ (
+ q_train_set.size(0),
+ 1,
+ ),
+ nb_quantization_levels,
),
- nb_quantization_levels,
- ),
- q_params,
- ],
- dim=-1,
- )
+ q_params,
+ ],
+ dim=-1,
+ ))
- print(f"SANITY #1 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
+ q_test_sets.append(q_test_set)
- ar_mask = (
- (torch.arange(input.size(0), device=input.device) > q_train_set.size(0) + 1)
- .long()
- .view(1, -1)
- .reshape(nb_mlps, -1)
- )
+ input = torch.cat(inputs)
+ q_test_set = torch.cat(q_test_sets)
- return input, ar_mask, q_test_set
+ return input, q_test_set
######################################################################
if __name__ == "__main__":
import time
- nb_mlps, nb_samples = 128, 200
+ batch_nb_mlps, nb_samples = 128, 500
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = []
- for n in range(2):
- data.append(
- generate_sequence_and_test_set(
- nb_mlps=nb_mlps,
- nb_samples=nb_samples,
- device=device,
- batch_size=25,
- nb_epochs=250,
- )
- )
+ input, q_test_set = generate_sequence_and_test_set(
+ nb_mlps=batch_nb_mlps,
+ nb_samples=nb_samples,
+ device=device,
+ batch_size=25,
+ nb_epochs=250,
+ nb_mlps_per_batch=17
+ )
end_time = time.perf_counter()
- nb = sum([i.size(0) for i, _, _ in data])
- print(f"{nb / (end_time - start_time):.02f} samples per second")
-
- for input, ar_mask, q_test_set in data:
- q_train_set = input[:, : nb_samples * 3]
- q_params = input[:, nb_samples * 3 + 1 :]
- print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
- error_train = evaluate_q_params(q_params, q_train_set)
- print(f"train {error_train*100}%")
- error_test = evaluate_q_params(q_params, q_test_set)
- print(f"test {error_test*100}%")
+ print(f"{input.size(0) / (end_time - start_time):.02f} samples per second")
+
+ q_train_set = input[:, : nb_samples * 3]
+ q_params = input[:, nb_samples * 3 + 1 :]
+ print(f"SANITY #2 {q_train_set.size()=} {q_params.size()=} {input.size()=}")
+ error_train = evaluate_q_params(q_params, q_train_set, nb_mlps_per_batch=17)
+ print(f"train {error_train*100}%")
+ error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
+ print(f"test {error_test*100}%")
######################################################################
+
+import qmlp
+
+
+class QMLP(Task):
+
+ ######################
+
+ def __init__(
+ self,
+ nb_train_samples,
+ nb_test_samples,
+ batch_size,
+ logger=None,
+ device=torch.device("cpu"),
+ ):
+ super().__init__()
+
+ self.device = device
+ self.batch_size = batch_size
+
+ if logger is not None:
+ logger(
+ f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
+ )
+
+ self.train_descr = self.grid_factory.generate_samples(
+ nb_train_samples, lambda r: tqdm.tqdm(r)
+ )
+ self.test_descr = self.grid_factory.generate_samples(
+ nb_test_samples, lambda r: tqdm.tqdm(r)
+ )
+
+ # Build the tokenizer
+ tokens = set()
+ for d in [self.train_descr, self.test_descr]:
+ for s in d:
+ for t in s.strip().split(" "):
+ tokens.add(t)
+ # make this set a sorted list to get the same tensors given
+ # the same descr
+ tokens = list(tokens)
+ tokens.sort()
+ tokens = ["#"] + tokens
+ self.token2id = dict([(t, n) for n, t in enumerate(tokens)])
+ self.id2token = dict([(n, t) for n, t in enumerate(tokens)])
+ self.t_nul = self.token2id["#"]
+ self.t_true = self.token2id["true"]
+ self.t_false = self.token2id["false"]
+
+ # Tokenize the train and test sets
+ self.train_input = self.str2tensor(self.train_descr)
+ self.test_input = self.str2tensor(self.test_descr)
+
+ def batches(self, split="train"):
+ assert split in {"train", "test"}
+ input = self.train_input if split == "train" else self.test_input
+ for batch in tqdm.tqdm(
+ input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
+ ):
+ yield self.trim(batch)
+
+ def vocabulary_size(self):
+ return len(self.token2id)
+
+ def produce_results(
+ self, n_epoch, model, result_dir, logger, deterministic_synthesis
+ ):
+ correct = self.test_input[:1000]
+ result = correct.clone()
+ ar_mask = torch.logical_or(result == self.t_true, result == self.t_false).long()
+ result *= 1 - ar_mask # paraaaaanoiaaaaaaa
+
+ logger(f"----------------------------------------------------------")
+
+ for e in self.tensor2str(result[:10]):
+ logger(f"test_before {e}")
+
+ masked_inplace_autoregression(
+ model,
+ self.batch_size,
+ result,
+ ar_mask,
+ deterministic_synthesis,
+ device=self.device,
+ )
+
+ logger(f"----------------------------------------------------------")
+
+ for e in self.tensor2str(result[:10]):
+ logger(f"test_after {e}")
+
+ logger(f"----------------------------------------------------------")
+
+ nb_total = ar_mask.sum().item()
+ nb_correct = ((correct == result).long() * ar_mask).sum().item()
+
+ logger(f"test_performance {n_epoch} {nb_total=} {nb_correct=}")
+ logger(f"main_test_accuracy {n_epoch} {nb_correct / nb_total}")
+
+
+######################################################################