Deep Learning/Computer Vision
keras를 이용한 image classification 구현하고 저장하기 (mobilenet, tflite)
jinmc
2023. 6. 20. 08:44
반응형
안녕하세요
저번 포스트에서 tflite model maker를 통해서 안드로이드에서 tflite model maker를 통해 모델을 만드는 법을 포스팅하였습니다.
2023.06.02 - [모바일/안드로이드] - 안드로이드에서 Image classification 모델 만들기
하지만 tflite_model_maker의 경우 여러가지 제약상황이 많았습니다.
그 중 하나는, tflite_model_maker로 나온 결과는 uint8의 데이터타입으로 나온다는 점입니다.
딱히 이를 고칠 수 있는 방법이 없는 것 같습니다.
하지만 keras api를 사용하면 훨씬 쉽고 커스텀하기 편하게 만들 수 있습니다.
예를들어 이 코드를 봅시다.
import numpy as np
from tensorflow.keras.applications.mobilenet import MobileNet, preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
model = MobileNet(weights='imagenet')
img_path = 'img2.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
print('Predicted:', decode_predictions(preds, top=3)[0])
model.save('my_model') # 폴더 형태로 저장
model.save('my_model.h5') # h5 file 형태로 저장
그 이후 load_model 을 이용해서 모델을 로드할 수 있습니다.
load_model = keras.models.load_model('my_model.h5')
그 담에 model을 convert 하는 코드입니다.
import tensorflow as tf
saved_model_dir = "my_model"
# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()
# Save the model.
with open('my_model.tflite', 'wb') as f:
f.write(tflite_model)
또한, 다음을 통해 output의 datatype에 대해서 알아볼 수 있습니다.
import tensorflow as tf
interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()
# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))
# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))
를 했을 때 다음과 같이 나옴을 볼 수 있습니다.
1 input(s):
[ 1 224 224 3] <class 'numpy.float32'>
1 output(s):
[1 1000] <class 'numpy.float32'>
reference :
https://keras.io/api/applications/mobilenet/
https://www.tensorflow.org/guide/keras/save_and_serialize?hl=ko
https://www.tensorflow.org/lite/models/convert/convert_models
https://firebase.google.com/docs/ml/android/use-custom-models#java
반응형