X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=22b651795b0501331c69fd24c3ea499779576e48;hb=0d86d8ca945722438d3c85cd01b3740269ed3546;hp=7aa59bea856c205d232ec1bf49c80f18ef20bed1;hpb=16e7952b7cc32ca21498fa3a12fb79f679ea8c21;p=picoclvr.git diff --git a/problems.py b/problems.py index 7aa59be..22b6517 100755 --- a/problems.py +++ b/problems.py @@ -22,37 +22,51 @@ class Problem: nb_correct = ((result == input).long() * ar_mask).sum().item() return nb_total, nb_correct + #################### class ProblemDegradation(Problem): - def __init__(self, nb_state_tokens=7, nb_time_steps=10, value_max=100, hard=False): + def __init__(self, nb_state_tokens=5, nb_time_steps=12, value_max=25, hard=False): + assert value_max // nb_state_tokens >= 2 self.nb_state_tokens = nb_state_tokens self.nb_time_steps = nb_time_steps self.value_max = value_max self.hard = hard - def generate_sequences(self,nb): - - x = (torch.rand(nb,self.nb_state_tokens).sort(dim=-1).indices == 0).long() * self.value_max + def generate_sequences(self, nb): + x = ( + torch.rand(nb, self.nb_state_tokens).sort(dim=-1).indices == 0 + ).long() * self.value_max seq = [x] - for t in range(self.nb_time_steps-1): - v = torch.rand(x.size()) * (x > 0).float() - u = (v.max(dim=-1,keepdim=True).values == v).long() - n = (u*x*torch.rand(x.size())).long().sum(dim=-1,keepdim=True) // 2 - x = x + n * (u.roll(shifts=-1,dims=-1) - 2 * u + u.roll(shifts=1,dims=-1)) + for t in range(self.nb_time_steps - 1): + v = (torch.rand(x.size()).sort(dim=-1).indices + 1) * (x >= 2).long() + u = (v.max(dim=-1, keepdim=True).values == v).long() + n = ( + (u * x) + .minimum(2 + torch.randint(self.value_max // 4 - 2, x.size())) + .sum(dim=-1, keepdim=True) + ) + m = 1 + ((n - 1) * torch.rand(n.size())).long() + x = ( + x + + m * u.roll(shifts=-1, dims=-1) + - n * u + + (n - m) * u.roll(shifts=1, dims=-1) + ) seq.append(x) - if self.hard: seq.reverse() + if self.hard: + seq.reverse() - seq = torch.cat(seq,dim=1) - return seq,seq.new_full(seq.size(), 1, dtype=torch.int64) + seq = torch.cat(seq, dim=1) + return seq, seq.new_full(seq.size(), 1, dtype=torch.int64) def compute_nb_correct(self, input, ar_mask, result): nb_total = result.size(0) nb_correct = 0 - e=result.new_zeros(self.nb_state_tokens) + e = result.new_zeros(self.nb_state_tokens) for seq in result: states = list(seq.split(self.nb_state_tokens)) @@ -60,27 +74,38 @@ class ProblemDegradation(Problem): states.reverse() d = states[0] - j=d.sort(descending=True).indices[0] + j = d.sort(descending=True).indices[0] e.zero_() - e[j]=self.value_max - if (d-e).abs().sum() == 0: + e[j] = self.value_max + if (d - e).abs().sum() == 0: nb_errors = 0 - for k in range(len(states)-1): - d=states[k]-states[k+1] - j=d.sort(descending=True).indices[0] - e.zero_() - e[j]=d[j] - e[(j+1)%e.size(0)]=-d[j]//2 - e[(j-1)%e.size(0)]=-d[j]//2 - if (d-e).abs().sum() > 0: + for k in range(len(states) - 1): + d = states[k + 1] - states[k] + j = d.sort(descending=False).indices[0] + if ( + d[j] == 0 + or d[j] > self.value_max // 4 + or d[(j + 1) % e.size(0)] <= 0 + or d[(j + 1) % e.size(0)] >= -d[j] + ): nb_errors += 1 + else: + e.zero_() + e[j] = d[j] + e[(j + 1) % e.size(0)] = d[(j + 1) % e.size(0)] + e[(j - 1) % e.size(0)] = -d[(j + 1) % e.size(0)] - d[j] + if (d - e).abs().sum() > 0: + nb_errors += 1 if nb_errors == 0: nb_correct += 1 return nb_total, nb_correct def seq2str(self, seq): - return " | ".join( [ " ".join([f"{x:02d}" for x in s ]) for s in seq.split(self.nb_state_tokens) ] ) + return " | ".join( + [" ".join([f"{x:02d}" for x in s]) for s in seq.split(self.nb_state_tokens)] + ) + ####################