Я пытался воспроизвести код, написанный в тензорном потоке, с помощью pytorch. Я наткнулся на функцию потерь в тензорном потоке softmax_cross_entropy_with_logits. Я искал ее эквивалент в pytorch и нашел torch.nn.MultiLabelSoftMarginLoss, хотя я не совсем уверен, что это правильная функция. Также я не знаю, как измерить точность моей модели, когда я использую эту функцию потерь и без промежуточного слоя в конце сети, вот мой код:
# GRADED FUNCTION: compute_cost
def compute_cost(Z3, Y):
loss = torch.nn.MultiLabelSoftMarginLoss()
return loss(Z3,Y)
def model(net,X_train, y_train, X_test, y_test, learning_rate = 0.009,
num_epochs = 100, minibatch_size = 64, print_cost = True):
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
optimizer.zero_grad()
total_train_acc=0
for epoch in range(num_epochs):
for i, data in enumerate(train_loader, 0):
running_loss = 0.0
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
Z3 = net(inputs)
# Cost function
cost = compute_cost(Z3, labels)
# Backpropagation: Define the optimizer.
# Use an AdamOptimizer that minimizes the cost.
cost.backward()
optimizer.step()
running_loss += cost.item()
# Measuring the accuracy of minibatch
acc = (labels==Z3).sum()
total_train_acc += acc.item()
#Print every 10th batch of an epoch
if epoch%1 == 0:
print("Cost after epoch {} :
{:.3f}".format(epoch,running_loss/len(train_loader)))
Используйте
torch.nn.CrossEntropyLoss()
. Он сочетает в себе softmax и кросс-энтропию. Из документации:Пример:
Убедитесь, что типы
x
иy
правильные. Обычно преобразование выполняется так:loss_fn(y_pred.type(torch.FloatTensor), y.type(torch.LongTensor))
.Для измерения точности вы можете определить пользовательскую функцию:
И используйте оба вот так:
Если ваши входные данные имеют горячую кодировку, вы можете преобразовать их в обычную кодировку, прежде чем использовать
loss_fn
:tnx для вашего ответа, есть проблема с CrossEntropyLoss, и он все равно не принимает одну горячую закодированную метку, я могу использовать ее с одной горячей кодировкой? — person Moeinh77; 08.04.2019
Да, через минуту добавлю к ответу. — person Moeinh77; 08.04.2019