PyTorch — одна из переменных, необходимых для вычисления градиента, была изменена операцией на месте.

Я использую метод градиента политики в PyTorch. Хотел переместить сетевое обновление в цикл и оно перестало работать. Я все еще новичок в PyTorch, так что извините, если объяснение очевидно.

Вот оригинальный код, который работает:

self.policy.optimizer.zero_grad()
G = T.tensor(G, dtype=T.float).to(self.policy.device) 

loss = 0
for g, logprob in zip(G, self.action_memory):
    loss += -g * logprob
                                 
loss.backward()
self.policy.optimizer.step()

И после изменения:

G = T.tensor(G, dtype=T.float).to(self.policy.device) 

loss = 0
for g, logprob in zip(G, self.action_memory):
    loss = -g * logprob
    self.policy.optimizer.zero_grad()
                                 
    loss.backward()
    self.policy.optimizer.step()

Я получаю сообщение об ошибке:

File "g:\VScode_projects\pytorch_shenanigans\policy_gradient.py", line 86, in learn
    loss.backward()
  File "G:\Anaconda3\envs\pytorch_env\lib\site-packages\torch\tensor.py", line 185, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "G:\Anaconda3\envs\pytorch_env\lib\site-packages\torch\autograd\__init__.py", line 127, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [128, 4]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Я читал, что эта ошибка RuntimeError часто связана с необходимостью клонирования чего-либо, потому что мы используем тот же тензор для вычислить себя, но я не могу понять, что в моем случае не так .

См. также:  Сохраненное состояние Vuex не удаляет состояние после закрытия вкладки
Понравилась статья? Поделиться с друзьями:
IT Шеф
Комментарии: 1
  1. graboonie

    Эта строка loss += -g * logprob — это то, что в вашем случае не так.

    Измените его на это:

    loss = loss + (-g * logprob)
    

    И да, они разные. Они выполняют одни и те же операции, но по-разному.

    Но код с этой строчкой работает. Другой проблемный фрагмент ниже. person graboonie; 13.02.2021

Добавить комментарий

;-) :| :x :twisted: :smile: :shock: :sad: :roll: :razz: :oops: :o :mrgreen: :lol: :idea: :grin: :evil: :cry: :cool: :arrow: :???: :?: :!: