Вопросы

Моя точность поезда остается на уровне 10%, когда я добавляю параметр weight_decay в свой оптимизатор в PyTorch. Я использую набор данных CIFAR10 и модель LeNet CNN

Я обучаю набор данных CIFAR10 на модели LeNet CNN. Я использую PyTorch в Google Colab. Код запускается только тогда, когда я использую оптимизатор Adam с единственным параметром model.parameters (). Но когда я меняю оптимизатор или использую параметр weight_decay, точность остается на уровне 10% на протяжении всех эпох. Я не могу понять причину, по которой это происходит.

# CNN Model - LeNet    
class LeNet_ReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_model = nn.Sequential(nn.Conv2d(3,6,5), 
                                       nn.ReLU(),
                                       nn.AvgPool2d(2, stride=2), 
                                       nn.Conv2d(6,16,5), 
                                       nn.ReLU(),
                                       nn.AvgPool2d(2, stride=2))  
        self.fc_model = nn.Sequential(nn.Linear(400, 120),   
                                      nn.ReLU(),
                                      nn.Linear(120,84),  
                                      nn.ReLU(),
                                      nn.Linear(84,10))

    def forward(self, x):
        x = self.cnn_model(x)
        x = x.view(x.size(0), -1)
        x = self.fc_model(x)
        return x

# Importing dataset and creating dataloader
batch_size = 128
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
                                    transform=transforms.ToTensor())
trainloader = utils_data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                    transform=transforms.ToTensor())
testloader = utils_data.DataLoader(testset, batch_size=batch_size, shuffle=False)

# Creating instance of the model
net = LeNet_ReLU()

# Evaluation function
def evaluation(dataloader):
    total, correct = 0, 0
    for data in dataloader:
        inputs, labels = data

        outputs = net(inputs)
        _, pred = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (pred==labels).sum().item()
    return correct/total * 100

# Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(net.parameters(), weight_decay = 0.9)

# Model training
loss_epoch_arr = []
max_epochs = 16

for epoch in range(max_epochs):
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        opt.step()

        opt.zero_grad()


    loss_epoch_arr.append(loss.item())

    print('Epoch: %d/%d, Test acc: %0.2f, Train acc: %0.2f'
    % (epoch,max_epochs, evaluation(testloader), evaluation(trainloader))) 

plt.plot(loss_epoch_arr)

Вы тренировали его без конфигурации снижения веса? Как это работает на тестовых данных? с точки зрения точности   —  person vivek goldsmith    schedule 06.05.2020

Читать:
ThreeJS Загрузить модель GLTF прямо из файла

Я тренировался с конфигурацией weight_decay. Точность оставалась на уровне 10% как для обучающих, так и для тестовых данных для всех эпох. Это не изменилось. Но когда я попробовал без него, точность поезда и теста увеличилась до 66% и 56% соответственно ..   —  person vivek goldsmith    schedule 06.05.2020

weight_decay=0.9 слишком высоко. По сути, это указывает оптимизатору, что малые веса намного важнее, чем низкие значения потерь. Обычное значение weight_decay=0.0005 или в пределах порядка этого значения.   —  person vivek goldsmith    schedule 06.05.2020

Сейчас он работает. Я добавлял слишком большое значение weight_decay. Спасибо.   —  person vivek goldsmith    schedule 07.05.2020

Похожие записи

Vapor 4 PostgreSQL CRUD без HTTP-запросов

admin

Расшифровать с помощью открытого ключа, используя openssl в командной строке

admin

Scala Spark Read from AWS S3 — com.amazonaws.SdkClientException: невозможно загрузить учетные данные из конечной точки службы

admin

Ошибка при запуске `gatsby build` с Kentico Kontent

admin

В связанном списке, почему бы нам не дать имя каждому узлу?

admin

Динамически добавлять новые записи WTForms FieldList из пользовательского интерфейса

admin