######################################################################
-
-
def generate_sets_and_params(
batch_nb_mlps,
nb_samples,
nb_epochs,
device=torch.device("cpu"),
print_log=False,
+ save_as_examples=False,
):
data_input = torch.zeros(batch_nb_mlps, 2 * nb_samples, 2, device=device)
data_targets = torch.zeros(
i = (data_targets.float().mean(-1) - 0.5).abs() > 0.1
nb = i.sum()
- nb_rec = 2
- support = torch.rand(nb, nb_rec, 2, 3, device=device) * 2 - 1
+ nb_rec = 8
+ nb_values = 2 # more increases the min-max gap
+ support = torch.rand(nb, nb_rec, 2, nb_values, device=device) * 2 - 1
support = support.sort(-1).values
- support = support[:, :, :, torch.tensor([0, 2])].view(nb, nb_rec, 4)
+ support = support[:, :, :, torch.tensor([0, nb_values - 1])].view(nb, nb_rec, 4)
x = torch.rand(nb, 2 * nb_samples, 2, device=device) * 2 - 1
y = (
test_input = dequantize(q_test_input, -1, 1)
test_targets = test_targets
+ if save_as_examples:
+ for k in range(q_train_input.size(0)):
+ with open(f"example_{k:04d}.dat", "w") as f:
+ for u, c in zip(train_input[k], train_targets[k]):
+ f.write(f"{c} {u[0].item()} {u[1].item()}\n")
+
hidden_dim = 32
w1 = torch.randn(batch_nb_mlps, hidden_dim, 2, device=device) / math.sqrt(2)
b1 = torch.zeros(batch_nb_mlps, hidden_dim, device=device)
- w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(hidden_dim)
+ w2 = torch.randn(batch_nb_mlps, 2, hidden_dim, device=device) / math.sqrt(
+ hidden_dim
+ )
b2 = torch.zeros(batch_nb_mlps, 2, device=device)
w1.requires_grad_()
# print(f"{k=} {acc_train_loss=} {train_error=}")
+ acc_test_loss = 0
+ nb_test_errors = 0
+
+ for input, targets in zip(
+ test_input.split(batch_size, dim=1), test_targets.split(batch_size, dim=1)
+ ):
+ h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
+ h = F.relu(h)
+ output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
+ loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
+ acc_test_loss += loss.item() * input.size(0)
+
+ wta = output.argmax(-1)
+ nb_test_errors += (wta != targets).long().sum(-1)
+
+ test_error = nb_test_errors / test_input.size(1)
q_params = torch.cat(
[quantize(p.view(batch_nb_mlps, -1), -2, 2) for p in [w1, b1, w2, b2]], dim=1
)
batch_nb_mlps, -1
)
- return q_train_set, q_test_set, q_params
+ return q_train_set, q_test_set, q_params, test_error
######################################################################
-def evaluate_q_params(q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024):
-
+def evaluate_q_params(
+ q_params, q_set, batch_size=25, device=torch.device("cpu"), nb_mlps_per_batch=1024,
+ save_as_examples=False,
+):
errors = []
nb_mlps = q_params.size(0)
- for n in range(0,nb_mlps,nb_mlps_per_batch):
- batch_nb_mlps = min(nb_mlps_per_batch,nb_mlps-n)
- batch_q_params = q_params[n:n+batch_nb_mlps]
- batch_q_set = q_set[n:n+batch_nb_mlps]
+ for n in range(0, nb_mlps, nb_mlps_per_batch):
+ batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n)
+ batch_q_params = q_params[n : n + batch_nb_mlps]
+ batch_q_set = q_set[n : n + batch_nb_mlps]
hidden_dim = 32
w1 = torch.empty(batch_nb_mlps, hidden_dim, 2, device=device)
b1 = torch.empty(batch_nb_mlps, hidden_dim, device=device)
k = 0
for p in [w1, b1, w2, b2]:
print(f"{p.size()=}")
- x = dequantize(batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2).view(
- p.size()
- )
+ x = dequantize(
+ batch_q_params[:, k : k + p.numel() // batch_nb_mlps], -2, 2
+ ).view(p.size())
p.copy_(x)
k += p.numel() // batch_nb_mlps
h = torch.einsum("mij,mnj->mni", w1, input) + b1[:, None, :]
h = F.relu(h)
output = torch.einsum("mij,mnj->mni", w2, h) + b2[:, None, :]
- loss = F.cross_entropy(output.reshape(-1, output.size(-1)), targets.reshape(-1))
+ loss = F.cross_entropy(
+ output.reshape(-1, output.size(-1)), targets.reshape(-1)
+ )
acc_loss += loss.item() * input.size(0)
wta = output.argmax(-1)
nb_errors += (wta != targets).long().sum(-1)
errors.append(nb_errors / data_input.size(1))
acc_loss = acc_loss / data_input.size(1)
-
return torch.cat(errors)
device,
nb_mlps_per_batch=1024,
):
+ seqs, q_test_sets, test_errors = [], [], []
- seqs, q_test_sets = [],[]
-
- for n in range(0,nb_mlps,nb_mlps_per_batch):
- q_train_set, q_test_set, q_params = generate_sets_and_params(
- batch_nb_mlps = min(nb_mlps_per_batch, nb_mlps - n),
+ for n in range(0, nb_mlps, nb_mlps_per_batch):
+ q_train_set, q_test_set, q_params, test_error = generate_sets_and_params(
+ batch_nb_mlps=min(nb_mlps_per_batch, nb_mlps - n),
nb_samples=nb_samples,
batch_size=batch_size,
nb_epochs=nb_epochs,
device=device,
)
- seqs.append(torch.cat(
- [
- q_train_set,
- q_train_set.new_full(
- (
- q_train_set.size(0),
- 1,
+ seqs.append(
+ torch.cat(
+ [
+ q_train_set,
+ q_train_set.new_full(
+ (
+ q_train_set.size(0),
+ 1,
+ ),
+ nb_quantization_levels,
),
- nb_quantization_levels,
- ),
- q_params,
- ],
- dim=-1,
- ))
+ q_params,
+ ],
+ dim=-1,
+ )
+ )
q_test_sets.append(q_test_set)
+ test_errors.append(test_error)
seq = torch.cat(seqs)
q_test_set = torch.cat(q_test_sets)
+ test_error = torch.cat(test_errors)
- return seq, q_test_set
+ return seq, q_test_set, test_error
######################################################################
if __name__ == "__main__":
import time
- batch_nb_mlps, nb_samples = 128, 500
+ batch_nb_mlps, nb_samples = 128, 2500
+
+ generate_sets_and_params(
+ batch_nb_mlps=10,
+ nb_samples=nb_samples,
+ batch_size=25,
+ nb_epochs=100,
+ device=torch.device("cpu"),
+ print_log=False,
+ save_as_examples=True,
+ )
+
+ exit(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = []
- seq, q_test_set = generate_sequence_and_test_set(
+ seq, q_test_set, test_error = generate_sequence_and_test_set(
nb_mlps=batch_nb_mlps,
nb_samples=nb_samples,
device=device,
batch_size=25,
nb_epochs=250,
- nb_mlps_per_batch=17
+ nb_mlps_per_batch=17,
)
end_time = time.perf_counter()
class QMLP(Task):
-
######################
def __init__(
nb_train_samples,
nb_test_samples,
batch_size,
+ result_dir,
logger=None,
device=torch.device("cpu"),
):
f"generating {nb_train_samples+nb_test_samples} samples (can take some time)"
)
- seq, q_test_set = generate_sequence_and_test_set(
- nb_mlps=nb_train_samples+nb_test_samples,
+ seq, q_test_set, test_error = qmlp.generate_sequence_and_test_set(
+ nb_mlps=nb_train_samples + nb_test_samples,
nb_samples=self.nb_samples_per_mlp,
device=self.device,
batch_size=64,
nb_epochs=250,
- nb_mlps_per_batch=1024
+ nb_mlps_per_batch=1024,
)
self.train_input = seq[:nb_train_samples]
self.train_q_test_set = q_test_set[:nb_train_samples]
self.test_input = seq[nb_train_samples:]
self.test_q_test_set = q_test_set[nb_train_samples:]
+ self.ref_test_errors = test_error
+
+ filename = os.path.join(result_dir, f"test_errors_ref.dat")
+ with open(filename, "w") as f:
+ for e in self.ref_test_errors:
+ f.write(f"{e}\n")
self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1
for batch in tqdm.tqdm(
input.split(self.batch_size), dynamic_ncols=True, desc=f"epoch-{split}"
):
- yield self.trim(batch)
+ yield batch
def vocabulary_size(self):
return self.nb_codes
):
correct = self.test_input[:1000]
result = correct.clone()
- ar_mask = torch.arange(result.size(1)) > self.nb_samples_per_mlp * 3 + 1
+ ar_mask = (
+ torch.arange(result.size(1), device=result.device)
+ > self.nb_samples_per_mlp * 3 + 1
+ ).long()[None, :]
+ ar_mask = ar_mask.expand_as(result)
result *= 1 - ar_mask # paraaaaanoiaaaaaaa
- logger(f"----------------------------------------------------------")
-
- for e in self.tensor2str(result[:10]):
- logger(f"test_before {e}")
-
masked_inplace_autoregression(
model,
self.batch_size,
device=self.device,
)
- logger(f"----------------------------------------------------------")
-
- for e in self.tensor2str(result[:10]):
- logger(f"test_after {e}")
-
- logger(f"----------------------------------------------------------")
-
- q_train_set = result[:, : nb_samples * 3]
- q_params = result[:, nb_samples * 3 + 1 :]
- error_test = evaluate_q_params(q_params, q_test_set, nb_mlps_per_batch=17)
+ q_train_set = result[:, : self.nb_samples_per_mlp * 3]
+ q_params = result[:, self.nb_samples_per_mlp * 3 + 1 :]
+ error_test = qmlp.evaluate_q_params(q_params, self.test_q_test_set)
- logger(f"{error_test=}")
+ filename = os.path.join(result_dir, f"test_errors_{n_epoch:04d}.dat")
+ with open(filename, "w") as f:
+ for e in error_test:
+ f.write(f"{e}\n")
######################################################################