모바일/안드로이드

안드로이드 앱 tflite 모델 로드하던 중 에러 해결(metadata)

jinmc 2022. 8. 23. 17:18
반응형

안드로이드에서 머신러닝을 몇 가지 해 보던 중, 

object detection 어플리케이션을 해 봤는데, 

audio classification 앱은 실행되지 않는 것을 발견했습니다.

오랜 디버깅 끝에, LogCat에서 다음 로그를 발견했습니다.

 

E/AndroidRuntime: FATAL EXCEPTION: main
    Process: com.example.mysoundclassification, PID: 31078
    java.lang.AssertionError: Error occurred when initializing AudioClassifier: Models are assumed to have the ModelMetadata and SubGraphMetadata.

알아보니, model을 asset 폴더에 넣고, 다음과 같이 부르는 것 뿐만 아니라 

var modelPath = "yamnet_metadata.tflite"

val classifier = AudioClassifier.createFromFile(this, modelPath)

model에 metadata를 추가해 줘야 한다고 합니다.

관련된 documentation은 여기에서 볼 수 있습니다.

 

TensorFlwo Lite 모델에 메타데이터 추가

 

여기서 주의할 점은, 이미지 분류, 객체 탐지, nlp, 오디오 분류 등 각각의 메타데이터가 다르다는 점입니다.

다음 링크를 참고하면 좋을 듯 합니다.

 

tensorflow lite metadata writer api

 

저는 오디오 분류를 했기 때문에 다음과 같이 코드를 작성하였습니다.

 

from tflite_support.metadata_writers import audio_classifier
from tflite_support.metadata_writers import metadata_info
from tflite_support.metadata_writers import writer_utils

AudioClassifierWriter = audio_classifier.MetadataWriter
_MODEL_PATH = "yamnet.tflite"
# Task Library expects label files that are in the same format as the one below.
_LABEL_FILE = "yamnet_labels.txt"
# Expected sampling rate of the input audio buffer.
_SAMPLE_RATE = 16000
# Expected number of channels of the input audio buffer. Note, Task library only
# support single channel so far.
_CHANNELS = 1
_SAVE_TO_PATH = "yamnet_metadata.tflite"

# Create the metadata writer.
writer = AudioClassifierWriter.create_for_inference(
    writer_utils.load_file(_MODEL_PATH), _SAMPLE_RATE, _CHANNELS, [_LABEL_FILE])

# Verify the metadata generated by metadata writer.
print(writer.get_metadata_json())

# Populate the metadata into the model.
writer_utils.save_file(writer.populate(), _SAVE_TO_PATH)

Python 을 통해서 metadata를 볼 수도 있습니다.

다음과 같은 스크립트를 통해서 metadata를 볼 수 있을 것 같습니다.

 

import tensorflow as tf
import numpy as np
import zipfile

# Download the model to yamnet-classification.tflite
interpreter = tf.lite.Interpreter('yamnet_metadata.tflite')

input_details = interpreter.get_input_details()
waveform_input_index = input_details[0]['index']
output_details = interpreter.get_output_details()
scores_output_index = output_details[0]['index']

# Input: 0.975 seconds of silence as mono 16 kHz waveform samples.
waveform = np.zeros(int(round(0.975 * 16000)), dtype=np.float32)
print(waveform.shape)  # Should print (15600,)

interpreter.resize_tensor_input(waveform_input_index, [waveform.size], strict=True)
interpreter.allocate_tensors()
interpreter.set_tensor(waveform_input_index, waveform)
interpreter.invoke()
scores = interpreter.get_tensor(scores_output_index)
print(scores.shape)  # Should print (1, 521)

top_class_index = scores.argmax()
labels_file = open("yamnet_labels.txt")
# labels_file = zipfile.ZipFile('yamnet-classification.tflite').open('yamnet_label_list.txt')
labels = [l.strip() for l in labels_file.readlines()]
print(len(labels))  # Should print 521
print(labels[top_class_index])  # Should print 'Silence'.

print()
print(labels)
반응형