PyTorch Общее время CUDA

Профилировщик Autograd — удобный инструмент для измерения времени выполнения в PyTorch, как показано ниже:

import torch
import torchvision.models as models

model = models.densenet121(pretrained=True)
x = torch.randn((1, 3, 224, 224), requires_grad=True)

with torch.autograd.profiler.profile(use_cuda=True) as prof:
    model(x)
print(prof) 

Результат выглядит так:

-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------
Name                                        CPU time        CUDA time            Calls        CPU total       CUDA total
-----------------------------------  ---------------  ---------------  ---------------  ---------------  ---------------
conv2d                                    9976.544us       9972.736us                1       9976.544us       9972.736us
convolution                               9958.778us       9958.400us                1       9958.778us       9958.400us
_convolution                              9946.712us       9947.136us                1       9946.712us       9947.136us
contiguous                                   6.692us          6.976us                1          6.692us          6.976us
empty                                       11.927us         12.032us                1         11.927us         12.032us

Которая будет включать много строк. Мои вопросы:

1) Как я могу использовать профилировщик автограда, чтобы получить все время CUDA? (т.е. сумма столбца времени CUDA)

2) Есть ли какое-то решение для прагматического использования? Например, prof[0].CUDA_Time?

См. также:  Почему кеши Xcode такие огромные?
Понравилась статья? Поделиться с друзьями:
IT Шеф
Комментарии: 1
  1. MTMD
    [item.cuda_time for item in prof.function_events]
    

    даст вам список раз CUDA. Измените его в зависимости от ваших потребностей. Чтобы получить сумму времени CUDA, например:

    sum([item.cuda_time for item in prof.function_events])
    

    Однако будьте осторожны, время в списке указано в микросекундах, а время отображается в миллисекундах в print(prof).

Добавить комментарий

;-) :| :x :twisted: :smile: :shock: :sad: :roll: :razz: :oops: :o :mrgreen: :lol: :idea: :grin: :evil: :cry: :cool: :arrow: :???: :?: :!: