X-Git-Url: https://fleuret.org/cgi-bin/gitweb/gitweb.cgi?a=blobdiff_plain;ds=sidebyside;f=sky.py;h=040ec67e0ed194687d28cbb7f4d1d8c2808736b4;hb=64abc9f3a07a8211f308271fde7d8f876a968ab5;hp=2183cf1f7e61f0f28dfb71afcb8f9b136a85e891;hpb=d283cd3d46a6323fec4c6a0970ac71e553e4a486;p=culture.git diff --git a/sky.py b/sky.py index 2183cf1..040ec67 100755 --- a/sky.py +++ b/sky.py @@ -42,9 +42,6 @@ class Sky(problem.Problem): "_" + "".join([chr(ord("A") + n) for n in range(len(colors) - 1)]) + "><" ) - def nb_token_values(self): - return len(self.colors) - def __init__( self, height=6, @@ -155,15 +152,6 @@ class Sky(problem.Problem): ###################################################################### - def generate_prompts_and_answers(self, nb): - frame_sequences = self.generate_frame_sequences(nb) - frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) - prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) - answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) - return prompts, answers - - ###################################################################### - def frame2img(self, x, scale=15): x = x.reshape(x.size(0), self.height, -1) m = torch.logical_and( @@ -250,6 +238,18 @@ class Sky(problem.Problem): img.float() / 255.0, image_name, nrow=6, padding=margin * 2, pad_value=1.0 ) + ###################################################################### + + def nb_token_values(self): + return len(self.colors) + + def generate_prompts_and_answers(self, nb): + frame_sequences = self.generate_frame_sequences(nb) + frame_sequences = torch.cat([x[None] for x in frame_sequences], dim=0) + prompts = frame_sequences[:, : frame_sequences.size(1) // 2].flatten(1) + answers = frame_sequences[:, frame_sequences.size(1) // 2 :].flatten(1) + return prompts, answers + def save_quizzes( self, result_dir,