projects
/
culture.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
Update.
[culture.git]
/
sky.py
diff --git
a/sky.py
b/sky.py
index
abcd394
..
1164185
100755
(executable)
--- a/
sky.py
+++ b/
sky.py
@@
-112,11
+112,20
@@
class Sky(problem.Problem):
break
result = torch.zeros(
break
result = torch.zeros(
- self.nb_iterations, self.height, self.width, dtype=torch.int64
+ self.nb_iterations * self.speed,
+ self.height,
+ self.width,
+ dtype=torch.int64,
)
)
- for l in range(self.nb_iterations):
- fine = collision_okay()
+ fine = torch.empty(self.nb_iterations * self.speed)
+
+ t_to_keep = (
+ torch.arange(self.nb_iterations, device=result.device) * self.speed
+ )
+
+ for l in range(self.nb_iterations * self.speed):
+ fine[l] = collision_okay()
for n in range(self.nb_birds):
c = col[n]
result[l, i[n], j[n]] = c
for n in range(self.nb_birds):
c = col[n]
result[l, i[n], j[n]] = c
@@
-136,7
+145,10
@@
class Sky(problem.Problem):
i[n] += vi[n]
j[n] += vj[n]
i[n] += vi[n]
j[n] += vj[n]
- if fine:
+ result = result[t_to_keep]
+ fine = fine[t_to_keep]
+
+ if fine[-1]:
break
frame_sequences.append(result)
break
frame_sequences.append(result)
@@
-289,7
+301,7
@@
class Sky(problem.Problem):
if __name__ == "__main__":
import time
if __name__ == "__main__":
import time
- sky = Sky(height=6, width=8, speed=
2
, nb_iterations=2)
+ sky = Sky(height=6, width=8, speed=
4
, nb_iterations=2)
start_time = time.perf_counter()
token_sequences = sky.generate_token_sequences(nb=64)
start_time = time.perf_counter()
token_sequences = sky.generate_token_sequences(nb=64)