-a, b = torch.rand(d1, d2).to(device), torch.rand(d2, d3).to(device)
+for t in [ torch.float32, torch.float16 ]:
+ try:
+ a = torch.rand(d1, d2, device = device, dtype = t)
+ b = torch.rand(d2, d3, device = device, dtype = t)
+ nb_runs = 0
+
+ sync()
+ start_time = time.perf_counter()
+ while time.perf_counter() - start_time < max_duration:
+ c = a @ b
+ nb_runs += 1
+ sync()
+ duration = time.perf_counter() - start_time