X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;f=tasks.py;h=9cd06ae054ae7e1adee634a9361adb8680d1356c;hb=0f580d4facb4b4b485d0a38d62d06c0639715b77;hp=5583fc89827d82be551db14ac9cb601f670b4233;hpb=3dea181a5903a0e577e4830c66405b40f2a2df1d;p=picoclvr.git diff --git a/tasks.py b/tasks.py index 5583fc8..9cd06ae 100755 --- a/tasks.py +++ b/tasks.py @@ -75,11 +75,12 @@ class ProblemByheart(Problem): def __init__(self): nb_seq, len_prompt, len_result = 100, 5, 5 self.seq = torch.randint(10, (nb_seq, len_prompt + 1 + len_result)) - self.seq[:,len_prompt]=-1 + self.seq[:, len_prompt] = -1 def generate_sequences(self, nb): return self.seq[torch.randint(self.seq.size(0), (nb,))] + class SandBox(Task): def __init__( self, @@ -93,7 +94,7 @@ class SandBox(Task): self.batch_size = batch_size - problems = [ ProblemByheart() ] + problems = [ProblemByheart()] nb_common_codes = 100 def generate_sequences(nb_samples): @@ -101,7 +102,7 @@ class SandBox(Task): nb_samples_per_problem = torch.one_hot(problem_indexes).sum(0) print(f"{nb_samples_per_problem}") all_seq = [] - for nb, p in zip(nb_samples_per_problem,problems): + for nb, p in zip(nb_samples_per_problem, problems): all_seq.append(p.generate_sequences(nb_samples_per_problem[nb])) return all_seq @@ -109,7 +110,7 @@ class SandBox(Task): test_seq = generate_sequences(nb_test_samples) for strain, stest in zip(train_seq, test_seq): - s = torch.cat((strain,stest),0) + s = torch.cat((strain, stest), 0) self.nb_codes = max(self.train_input.max(), self.test_input.max()) + 1