5 from setuptools import setup
6 from torch.utils.cpp_extension import BuildExtension, CppExtension
9 std::vector<torch::Tensor> greedy_lines_allocation(torch::Tensor load_start, float decay, torch::Tensor line_requests) {
10 auto nb_lines = load_start.size(1);
11 auto batch_size = line_requests.size(0);
12 auto nb_heads = line_requests.size(1);
13 auto T = line_requests.size(2);
15 auto load_start_a = load_start.accessor<float,2>();
16 auto line_requests_a = line_requests.accessor<float,3>();
18 auto load = torch::empty({batch_size, nb_lines, T});
19 auto load_a = load.accessor<float,3>();
21 auto allocation_result = torch::empty({batch_size,nb_heads,T},torch::TensorOptions().dtype(torch::kInt64));
22 auto allocation_result_a = allocation_result.accessor<long,3>();
24 for(int n = 0; n < batch_size; n++) {
25 for(int t = 0; t < T; t++) {
26 for(int l = 0; l < nb_lines; l++) {
28 load[n][l][t] = decay * load_start_a[n][l];
30 load[n][l][t] = decay * load[n][l][t-1];
33 for(int h = 0; h < nb_heads; h++) {
34 if(line_requests_a[n][h][t] > 0) {
36 for(int l = 0; l < nb_lines; l++) {
37 if(l == 0 || load_a[n][l][t]<load_a[n][l_lowest_load][t]) l_lowest_load=l;
39 if(load_a[n][l_lowest_load][t] < line_requests_a[n][h][t]) {
40 allocation_result_a[n][h][t] = l_lowest_load;
41 load_a[n][l_lowest_load][t] = line_requests_a[n][h][t];
43 allocation_result_a[n][h][t] = -1;
46 allocation_result_a[n][h][t] = -1;
52 return {allocation_result,load};
56 ######################################################################
58 allocator_module = torch.utils.cpp_extension.load_inline(
59 name="allocator_module",
60 cpp_sources=[cpp_source],
61 functions=["greedy_lines_allocation"],
62 build_directory="/tmp/",
66 lines_allocation = allocator_module.greedy_lines_allocation
68 ######################################################################
70 if __name__ == "__main__":
71 N, H, L, T = 1, 1, 3, 20
73 load_start = torch.rand(N, L)
74 requests = (2 * torch.rand(N, H, T) - 1).clamp(min=0)
76 print("load_start", load_start)
78 print("requests", requests)
80 alloc, load = lines_allocation(load_start, 0.99, requests)