- i, j, vi, vj = (
- torch.empty(nb_birds, dtype=torch.int64),
- torch.empty(nb_birds, dtype=torch.int64),
- torch.empty(nb_birds, dtype=torch.int64),
- torch.empty(nb_birds, dtype=torch.int64),
- )
-
- col = torch.randperm(colors.size(0) - 1)[:nb_birds].sort().values + 1
-
- for n in range(nb_birds):
- c = col[n]
-
- while True:
- i[n], j[n] = (
- torch.randint(height, (1,))[0],
- torch.randint(width, (1,))[0],
- )
- vm = torch.randint(4, (1,))[0]
- vi[n], vj[n] = (vm % 2) * 2 - 1, (vm // 2) * 2 - 1
- if (
- i[n] - vi[n] >= 0
- and i[n] - vi[n] < height
- and j[n] - vj[n] >= 0
- and j[n] - vj[n] < width
- and f_start[i[n], j[n]] == 0
- and f_start[i[n] - vi[n], j[n]] == 0
- and f_start[i[n], j[n] - vj[n]] == 0
- ):
- break
-
- f_start[i[n], j[n]] = c
- f_start[i[n] - vi[n], j[n]] = c
- f_start[i[n], j[n] - vj[n]] = c
-
- f_end = f_start.clone()
-
- for l in range(nb_iterations):
- for n in range(nb_birds):
- c = col[n]
- f_end[i[n], j[n]] = 0
- f_end[i[n] - vi[n], j[n]] = 0
- f_end[i[n], j[n] - vj[n]] = 0