Update.
[culture.git] / quizz_machine.py
index d8ebad8..43fd868 100755 (executable)
@@ -64,28 +64,12 @@ def masked_inplace_autoregression(
         model.train(t)
 
 
-######################################################################
-
-
-class Task:
-    def batches(self, split="train", nb_to_use=-1, desc=None):
-        pass
-
-    def vocabulary_size(self):
-        pass
-
-    def produce_results(
-        self, n_epoch, model, result_dir, logger, deterministic_synthesis
-    ):
-        pass
-
-
 ######################################################################
 
 import sky
 
 
-class QuizzMachine(Task):
+class QuizzMachine:
     def save_image(self, input, result_dir, filename, logger):
         img = sky.seq2img(input.to("cpu"), self.height, self.width)
         image_name = os.path.join(result_dir, filename)
@@ -281,7 +265,7 @@ class QuizzMachine(Task):
         seq_logproba = torch.empty(ar_mask.size(0), device=self.device)
 
         temperature = 1
-        d_temperature = 1
+        d_temperature = 1 / 3
 
         while True:
             seq_logproba[...] = 0