X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=problems.py;h=7b1d69859f6598fd6aee1cf832838678e2e5d377;hb=b718ef527d4bfb014a9ad564bb5199c7d0780aa9;hp=516158795d047002ef70163bea7e03b68b9b9181;hpb=16cb07f99cf770fb4e97824f874a68cbddd4c1cf;p=picoclvr.git diff --git a/problems.py b/problems.py index 5161587..7b1d698 100755 --- a/problems.py +++ b/problems.py @@ -21,6 +21,68 @@ class Problem: #################### +class ProblemTwoTargets(Problem): + def __init__(self, len_total=10, len_target=2): + assert len_total >= 3 * (2 + len_target) - 1 + self.len_total = len_total + self.len_target = len_target + + def generate_sequences(self, nb): + k = torch.arange(self.len_total)[None, :] + l = torch.randint(self.len_total, (2, nb))[:, :, None] + 1 + i = torch.randint(10, (2, nb))[:, :, None] + a = l[0] + b = l[0] + 1 + l[1] + c = l[0] + 1 + l[1] + 1 + l[0] + sequences = ( + (k < a) * i[0] + + (k == a) * 10 + + (k > a) * (k < b) * i[1] + + (k == b) * 11 + + (k > b) * (k < c) * i[1] + + (k >= c) * 12 + ) + ar_mask = (sequences == 11).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789|>_"[x.item()] for x in seq) + + +#################### + + +class ProblemLenId(Problem): + def __init__(self, len_max=10): + self.len_max = len_max + + def generate_sequences(self, nb): + k = torch.arange(self.len_max * 3 + 3)[None, :] + l = torch.randint(self.len_max, (2, nb))[:, :, None] + 1 + i = torch.randint(10, (2, nb))[:, :, None] + a = l[0] + b = l[0] + 1 + l[1] + c = l[0] + 1 + l[1] + 1 + l[0] + sequences = ( + (k < a) * i[0] + + (k == a) * 10 + + (k > a) * (k < b) * i[1] + + (k == b) * 11 + + (k > b) * (k < c) * i[1] + + (k >= c) * 12 + ) + ar_mask = (sequences == 11).long() + ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) + return sequences, ar_mask + + def seq2str(self, seq): + return "".join("0123456789|>_"[x.item()] for x in seq) + + +#################### + + class ProblemLevel0(Problem): def __init__(self, nb_sentences=100, len_prompt=5, len_result=5): self.seq = torch.randint(10, (nb_sentences, len_prompt + 1 + len_result)) @@ -32,6 +94,12 @@ class ProblemLevel0(Problem): ar_mask = (ar_mask.cumsum(1) - ar_mask).clamp(max=1) return sequences, ar_mask + def seq2str(self, seq): + return "".join("0123456789|"[x.item()] for x in seq) + + +#################### + class ProblemLevel1(Problem): def __init__(self, nb_operators=100, len_source=5, len_result=8): @@ -64,6 +132,9 @@ class ProblemLevel1(Problem): return "".join("0123456789|>"[x.item()] for x in seq) +#################### + + class ProblemLevel2(Problem): def __init__(self, len_source=5, len_result=8): self.len_source = len_source