- a = torch.rand(d1, d2, device = device, dtype = t)
- b = torch.rand(d2, d3, device = device, dtype = t)
+ try:
+ a = torch.rand(d1, d2, device = device, dtype = t)
+ b = torch.rand(d2, d3, device = device, dtype = t)
+ nb_runs = 0
+
+ sync()
+ start_time = time.perf_counter()
+ while time.perf_counter() - start_time < max_duration:
+ c = a @ b
+ nb_runs += 1
+ sync()
+ duration = time.perf_counter() - start_time
+
+ nb_flop = float(nb_runs * d1 * d2 * d3 * 2) # 1 multiply-and-add is 2 ops
+ speed = nb_flop / duration