Ошибка с полосатым срезом в tensorflow lite

У меня проблема с tenorflow-lite. Я получаю такую ​​ошибку:

Тип INT32 (2) не поддерживается. Узлу STRIDED_SLICE (номер 2) не удалось вызвать со статусом 1

Что я сделал:

Я обучил модель с данными MNIST.

  model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

Я преобразовал модель, используя квантование только целых чисел. Однако, когда я вызываю модель, она выдает эту ошибку.

Я просматривал striced_slice.cc и обнаружил следующее:

      switch (output->type) {
        case kTfLiteFloat32:
          reference_ops::StridedSlice(op_params,
                                      tflite::micro::GetTensorShape(input),
                                      tflite::micro::GetTensorData<float>(input),
                                      tflite::micro::GetTensorShape(output),
                                      tflite::micro::GetTensorData<float>(output));
          break;
        case kTfLiteUInt8:
          reference_ops::StridedSlice(
              op_params, tflite::micro::GetTensorShape(input),
              tflite::micro::GetTensorData<uint8_t>(input),
              tflite::micro::GetTensorShape(output),
              tflite::micro::GetTensorData<uint8_t>(output));
          break;
        case kTfLiteInt8:
          reference_ops::StridedSlice(op_params,
                                      tflite::micro::GetTensorShape(input),
                                      tflite::micro::GetTensorData<int8_t>(input),
                                      tflite::micro::GetTensorShape(output),
                                      tflite::micro::GetTensorData<int8_t>(output));
          break;
        default:
          TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
                             TfLiteTypeGetName(input->type), input->type);

Так что поддержки int32 нет. Я не совсем уверен, как справиться с такой проблемой. Есть ли способ изменить поведение на этом узле? Должен ли я сделать что-то другое на этапе квантования?

Что я сделал:

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

tflite_model = converter.convert()
open("model_int8.tflite", "wb").write(tflite_model)

PD: Я работаю с tenorflow-lite для использования в stm32.

Заранее спасибо.

Подскажите пожалуйста модель tflite?   —  person r142431    schedule 16.11.2020

wetransfer.com/downloads/ здесь модель tflite   —  person r142431    schedule 16.11.2020

Где произошла ошибка int32 и используете ли вы TFLite или TFLite для микроконтроллеров? Как @AlexK. упомянул, что полное целочисленное квантование использует int8, а не int32, поэтому предполагается, что вы не видите int32 в коде   —  person r142431    schedule 17.11.2020

Проверил вашу модель tflite, есть проблема. Форма ветки — ›полосатый ломтик -› упаковка кажется мне странной, и я не вижу соответствующей ее части в вашей модели keras. Похоже, вы не установили форму для изменения формы на постоянное значение. Возможно, попробуйте использовать изменение формы в качестве первого слоя, указав input_shape   —  person r142431    schedule 17.11.2020

См. также:  Переименование файла Git. История доступна в cmd-строке, но не в интерфейсе github?

Вы можете самостоятельно проверить модель tflite с помощью такого инструмента, как Netron. Для модели квантования full-int8 допустима только int8 input-output-ops   —  person r142431    schedule 17.11.2020

Привет @Tiezhen. Я использую tflite для микроконтроллеров. Я новичок в этой теме. Я только что нашел этот колаб, и я получил его. Итак, я проверяю модель в NetronApp и вижу, что есть несколько слоев с int32 в качестве входных данных. Итак, вы сказали, что все слои должны быть в int8, не так ли?   —  person r142431    schedule 17.11.2020

@ r142431 другой вариант, попробуйте явно указать пакет в Keras, скажем, с 1.   —  person r142431    schedule 19.11.2020

Понравилась статья? Поделиться с друзьями:
IT Шеф
Комментарии: 2
  1. r142431

    Когда вы выполняете полностью целочисленное квантование, ваши входы и выходы должны быть длиной 1 байт (в вашем случае int8). Укажите в качестве входных данных значения int8, и вы сможете вызвать свою модель.

    Извините, но я не понимаю. Я использую ввод длиной в 1 байт. Кажется, проблема в другом слое. person r142431; 16.11.2020

    Если модель полностью квантована, мы не найдем внутри другого типа, кроме int8, даже если используются 16-битные веса. Таким образом, единственная проблема, связанная с типом, может быть при загрузке объектов ввода и вывода. Пожалуйста, внимательно проверьте свои объекты ввода и вывода. person r142431; 16.11.2020

  2. r142431

    Я решил эту проблему, просто добавив поддержку INT32 в striced_slice.cc.

    case kTfLiteFloat32:
      reference_ops::StridedSlice(op_params,
                                  tflite::micro::GetTensorShape(input),
                                  tflite::micro::GetTensorData<float>(input),
                                  tflite::micro::GetTensorShape(output),
                                  tflite::micro::GetTensorData<float>(output));
      break;
    case kTfLiteUInt8:
      reference_ops::StridedSlice(
          op_params, tflite::micro::GetTensorShape(input),
          tflite::micro::GetTensorData<uint8_t>(input),
          tflite::micro::GetTensorShape(output),
          tflite::micro::GetTensorData<uint8_t>(output));
      break;
    case kTfLiteInt8:
      reference_ops::StridedSlice(op_params,
                                  tflite::micro::GetTensorShape(input),
                                  tflite::micro::GetTensorData<int8_t>(input),
                                  tflite::micro::GetTensorShape(output),
                                  tflite::micro::GetTensorData<int8_t>(output));
      break;
    case kTfLiteInt32:
      reference_ops::StridedSlice(op_params,
                                  tflite::micro::GetTensorShape(input),
                                  tflite::micro::GetTensorData<int8_t>(input),
                                  tflite::micro::GetTensorShape(output),
                                  tflite::micro::GetTensorData<int8_t>(output));
      break;
    default:
      TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
                         TfLiteTypeGetName(input->type), input->type);
      return kTfLiteError;
    

    Я скопировал кейс kTfLiteInt8 и создал кейс kTfLiteInt32.

    Идея в том, что, поскольку я знаю, что входные данные на самом деле относятся к типу int8, я просто приводил их к int8.

    Я тестировал его на микроконтроллере ESP32. Я не проводил полный набор тестов, но с несколькими образцами все заработало, как ожидалось.

    Это обходной путь. Настоящее исправление должно быть выполнено в конвертере, где он полностью квантует модель с помощью TFLITE_BUILTINS_INT8. Каким-то образом он квантует типы float32, но int32 остается в одном из слоев.

Добавить комментарий

;-) :| :x :twisted: :smile: :shock: :sad: :roll: :razz: :oops: :o :mrgreen: :lol: :idea: :grin: :evil: :cry: :cool: :arrow: :???: :?: :!: