X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=inline;f=sky.py;h=040ec67e0ed194687d28cbb7f4d1d8c2808736b4;hb=eaed6307836d88abe7c0f4be733a38364ba20e2f;hp=d2a4568919c31f097966979988914a529652880a;hpb=3b41e2797fc340fd11cb35015b57c3cae1e8447b;p=culture.git diff --git a/sky.py b/sky.py index d2a4568..040ec67 100755 --- a/sky.py +++ b/sky.py @@ -42,9 +42,6 @@ class Sky(problem.Problem): "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" ) - def nb_token_values(self): - return len(self.colors) - def __init__( self, height=6, @@ -155,17 +152,8 @@ class Sky(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb): - frame_sequences = self.generate_frame_sequences(nb) - frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) - prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) - answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) - return prompts, answers - - ###################################################################### - def frame2img(self, x, scale=15): - x = x.reshape(-1, self.height, self.width) + x = x.reshape(x.size(0), self.height, -1) m = torch.logical_and( x >= 0, x < self.first_bird_token + self.nb_bird_tokens ).long() @@ -250,6 +238,18 @@ class Sky(problem.Problem): img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0 ) + ###################################################################### + + def nb_token_values(self): + return len(self.colors) + + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) + prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) + answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) + return prompts, answers + def save_quizzes( self, result_dir, @@ -274,7 +274,7 @@ class Sky(problem.Problem): if __name__ == "__main__": import time - sky = Sky(height=6, width=8, speed=4, nb_iterations=2) + sky = Sky(height=6, width=8, speed=1, nb_iterations=4) prompts, answers = sky.generate_prompts_and_answers(4)