Что эквивалентно tf.nn.softmax_cross_entropy_with_logits в pytorch?

Я пытался воспроизвести код, написанный в тензорном потоке, с помощью pytorch. Я наткнулся на функцию потерь в тензорном потоке softmax_cross_entropy_with_logits. Я искал ее эквивалент в pytorch и нашел torch.nn.MultiLabelSoftMarginLoss, хотя я не совсем уверен, что это правильная функция. Также я не знаю, как измерить точность моей модели, когда я использую эту функцию потерь и без промежуточного слоя в конце сети, вот мой код:


# GRADED FUNCTION: compute_cost 

def compute_cost(Z3, Y):

    loss = torch.nn.MultiLabelSoftMarginLoss()    
    return loss(Z3,Y)


def model(net,X_train, y_train, X_test, y_test, learning_rate = 0.009,
          num_epochs = 100, minibatch_size = 64, print_cost = True):

    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    optimizer.zero_grad()

    total_train_acc=0

    for epoch in range(num_epochs):
        for i, data in enumerate(train_loader, 0):
            running_loss = 0.0

            inputs, labels = data

            inputs, labels = Variable(inputs), Variable(labels)

            Z3 = net(inputs)

            # Cost function
            cost = compute_cost(Z3, labels)

            # Backpropagation: Define the optimizer. 
            # Use an AdamOptimizer that minimizes the cost.

            cost.backward()
            optimizer.step()             

            running_loss += cost.item()

            # Measuring the accuracy of minibatch
            acc = (labels==Z3).sum()
            total_train_acc += acc.item()
            #Print every 10th batch of an epoch
            if epoch%1 == 0:
            print("Cost after epoch {} : 
            {:.3f}".format(epoch,running_loss/len(train_loader)))

См. также:  1 строка для встраивания слов ELECTRA с NLU в Python
Понравилась статья? Поделиться с друзьями:
IT Шеф
Комментарии: 1
  1. Moeinh77

    Используйте torch.nn.CrossEntropyLoss(). Он сочетает в себе softmax и кросс-энтропию. Из документации:

    Этот критерий объединяет nn.LogSoftmax () и nn.NLLLoss () в одном классе.

    Пример:

    # define loss function
    loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
    
    # during training
    for (x, y) in train_loader:
        model.train()
        y_pred = model(x) # your input `torch.FloatTensor`
        loss_val = loss_fn(y_pred, y)
        print(loss_val.item()) # prints numpy value
    
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()
    

    Убедитесь, что типы x и y правильные. Обычно преобразование выполняется так: loss_fn(y_pred.type(torch.FloatTensor), y.type(torch.LongTensor)).

    Для измерения точности вы можете определить пользовательскую функцию:

    def compute_accuracy(y_pred, y):
       if list(y_pred.size()) != list(y.size()):
          raise ValueError('Inputs have different shapes.',
                           list(y_pred.size()), 'and', list(y.size()))
    
      result = [1 if y1==y2 else 0 for y1, y2 in zip(y_pred, y)]
    
      return sum(result) / len(result)
    

    И используйте оба вот так:

    model.train()
    y_pred = model(x)
    
    loss_val = loss_fn(y_pred.type(torch.FloatTensor), y.type(torch.LongTensor))
    _, y_pred = torch.max(y_pred, 1)
    accuracy_val = compute_accuracy(y_pred, y)
    print(loss_val.item()) # print loss value
    print(accuracy_val) # print accuracy value
    # update step e.t.c
    

    Если ваши входные данные имеют горячую кодировку, вы можете преобразовать их в обычную кодировку, прежде чем использовать loss_fn:

    _, targets = y.max(dim=1)
    y_pred = model(x)
    loss_val = loss_fn(y_pred, targets)
    

    tnx для вашего ответа, есть проблема с CrossEntropyLoss, и он все равно не принимает одну горячую закодированную метку, я могу использовать ее с одной горячей кодировкой? person Moeinh77; 08.04.2019

    Да, через минуту добавлю к ответу. person Moeinh77; 08.04.2019

Добавить комментарий

;-) :| :x :twisted: :smile: :shock: :sad: :roll: :razz: :oops: :o :mrgreen: :lol: :idea: :grin: :evil: :cry: :cool: :arrow: :???: :?: :!: