X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=5583fc89827d82be551db14ac9cb601f670b4233;hb=3dea181a5903a0e577e4830c66405b40f2a2df1d;hp=8b57cb2056462d8265bb38c1dffb49c4d59dc41a;hpb=e3a8032a070175ece08fc79c77312d5f2f59150e;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 8b57cb2..5583fc8 100755 --- a/tasks.py +++ b/tasks.py @@ -73,8 +73,12 @@ class Problem: class ProblemByheart(Problem): def __init__(self): - pass + nb_seq, len_prompt, len_result = 100, 5, 5 + self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) + self.seq[:,len_prompt]=-1 + def generate_sequences(self, nb): + return self.seq[torch.randint(self.seq.size(0), (nb,))] class SandBox(Task): def __init__( @@ -89,13 +93,23 @@ class SandBox(Task): self.batch_size = batch_size + problems = [ ProblemByheart() ] + nb_common_codes = 100 + def generate_sequences(nb_samples): problem_indexes = torch.randint(len(problems), (nb_samples,)) nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) print(f"{nb_samples_per_problem}") + all_seq = [] + for nb, p in zip(nb_samples_per_problem,problems): + all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) + return all_seq + + train_seq = generate_sequences(nb_train_samples) + test_seq = generate_sequences(nb_test_samples) - self.train_input = generate_sequences(nb_train_samples) - self.test_input = generate_sequences(nb_test_samples) + for strain, stest in zip(train_seq, test_seq): + s = torch.cat((strain,stest),0) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1