Использование нескольких функций потерь в pytorch

Я работал над задачей восстановления изображения и рассматривал несколько функций потерь. В моих планах было рассмотреть 3 маршрута:

1: использовать несколько потерь для мониторинга, но использовать только несколько для самого обучения 2: из тех функций потерь, которые используются для тренировки, мне нужно было присвоить каждой вес — в настоящее время я указываю вес. Я бы хотел сделать этот параметр адаптивным. 3: Если между тренировками — если я наблюдаю насыщение, я бы хотел изменить функцию потерь. или его компоненты. В настоящее время я рассматривал возможность повторного обучения сети (если при первом обучении модель была насыщена) так, чтобы она тренировалась с определенной функцией потерь для первых, скажем, M эпох, после которых я изменяю потери.

  1. За исключением последнего случая, я разработал код, который вычисляет эти потери, но я не уверен, будет ли он работать. — т.е. будет ли это обратное распространение? (код приведен ниже)

  2. возможно ли присвоить веса адаптивно при использовании комбинации функций потерь — то есть можем ли мы обучить сеть так, чтобы эти веса также были изучены?

  3. может ли эта реализация использоваться для вышеупомянутого случая 3 изменения функций потерь

Извините, если что-то здесь не ясно или неверно. Пожалуйста, дайте мне знать, если мне нужно улучшить вопрос. (Я новичок в PyTorch)

criterion = _criterion
#--training
prediction = model(input)
loss = criterion(prediction, target)
loss.backward()



class _criterion(nn.Module):

    def __init__(self, model_type="CNN"):

        super(_criterion).__init__()    

        self.model_type = model_type

        

    def forward(self, pred, ref):

        loss_1 = lambda x,y : nn.MSELoss(size_average=False)(x,y)       

        loss_2 = lambda x,y : nn.L1Loss(size_average=False)(x,y)        

        loss_3 = lambda x,y : nn.SmoothL1Loss(size_average=False)(x,y)  

        loss_4 = lambda x,y : L1_Charbonnier_loss_()(x,y)     #user-defined         


        if opt.loss_function_order == 1:

            loss_function_1 = get_loss_function(opt.loss_function_1)

            loss = lambda x,y: 1*loss_function_1(x,y)  

        
        elif opt.loss_function_order == 2:

            loss_function_1 = get_loss_function(opt.loss_function_1)

            loss_function_2 = get_loss_function(opt.loss_function_2)

            weight_1 = opt.loss_function_1_weight

            weight_2 = opt.loss_function_2_weight

            loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y)

        elif opt.loss_function_order == 3:

            loss_function_1 = get_loss_function(opt.loss_function_1)

            loss_function_2 = get_loss_function(opt.loss_function_2)

            loss_function_3 = get_loss_function(opt.loss_function_3)

        

            weight_1 = opt.loss_function_1_weight

            weight_2 = opt.loss_function_2_weight

            weight_3 = opt.loss_function_3_weight

        

            loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) +  weight_3*loss_function_3(x,y)    

        elif opt.loss_function_order == 4:

            loss_function_1 = get_loss_function(opt.loss_function_1)

            loss_function_2 = get_loss_function(opt.loss_function_2)

            loss_function_3 = get_loss_function(opt.loss_function_3)

            loss_function_4 = get_loss_function(opt.loss_function_4)

                

            weight_1 = opt.loss_function_1_weight

            weight_2 = opt.loss_function_2_weight

            weight_3 = opt.loss_function_3_weight

            weight_4 = opt.loss_function_4_weight     

           

            loss = lambda x,y: weight_1*loss_function_1(x,y) + weight_2*loss_function_2(x,y) +  weight_3*loss_function_3(x,y)  +  weight_4*loss_function_4(x,y)       

        else:

            raise Exception("_criterion : unable to interpret loss_function_order")

        return loss(ref,pred), loss_1(ref,pred), loss_2(ref,pred), loss_3(ref,pred), loss_4(ref,pred)



def get_loss_function(loss):    

    if loss == "MSE":

        criterion = nn.MSELoss(size_average=False)

    elif loss == "MAE":

        criterion = nn.L1Loss(size_average=False) 

    elif loss == "Smooth-L1":

        criterion = nn.SmoothL1Loss(size_average=False) 

    elif loss == "Charbonnier":

        criterion = L1_Charbonnier_loss_()
    else:

        raise Exception("not implemented")
    return criterion


class L1_Charbonnier_loss_(nn.Module):

    def __init__(self):

        super(L1_Charbonnier_loss_, self).__init__()

        self.eps = 1e-6 

    def forward(self, X, Y):

        diff = torch.add(X, -Y) 

        error = self.eps*((torch.sqrt(1+((diff * diff)/self.eps)))-1)

        loss = torch.sum(error) 

        return loss

См. также:  Ошибка Selenium (не связана с DevTools)
Понравилась статья? Поделиться с друзьями:
IT Шеф
Комментарии: 1
  1. AMA2403

    Тогда как я понимаю ваш вопрос. Ваша функция вычисления ошибок будет выполнять обратное распространение, но вы должны быть осторожны при использовании функций ошибок, поскольку они работают по-разному для каждой ситуации.
    Что касается весов, вам нужно сохранить веса этой сети, а затем снова загрузить их в другую, используя обучение передачи pytorch, чтобы вы могли использовать веса других исполнений.
    Вот ссылка на то, как использовать обучающую передачу от pythorch.

Добавить комментарий

;-) :| :x :twisted: :smile: :shock: :sad: :roll: :razz: :oops: :o :mrgreen: :lol: :idea: :grin: :evil: :cry: :cool: :arrow: :???: :?: :!: