Deep Learning/Computer Vision

keras를 이용한 image classification 구현하고 저장하기 (mobilenet, tflite)

jinmc 2023. 6. 20. 08:44
반응형

안녕하세요

 

저번 포스트에서 tflite model maker를 통해서 안드로이드에서 tflite model maker를 통해 모델을 만드는 법을 포스팅하였습니다.

2023.06.02 - [모바일/안드로이드] - 안드로이드에서 Image classification 모델 만들기

 

하지만 tflite_model_maker의 경우 여러가지 제약상황이 많았습니다.

그 중 하나는, tflite_model_maker로 나온 결과는 uint8의 데이터타입으로 나온다는 점입니다.

딱히 이를 고칠 수 있는 방법이 없는 것 같습니다.

 

하지만 keras api를 사용하면 훨씬 쉽고 커스텀하기 편하게 만들 수 있습니다.

 

예를들어 이 코드를 봅시다.

 

import numpy as np
from tensorflow.keras.applications.mobilenet import MobileNet, preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image
 
model = MobileNet(weights='imagenet')
 
img_path = 'img2.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
 
preds = model.predict(x)
print('Predicted:', decode_predictions(preds, top=3)[0])

model.save('my_model') #  폴더 형태로 저장

model.save('my_model.h5') # h5 file 형태로 저장

 

그 이후 load_model 을 이용해서 모델을 로드할 수 있습니다.

load_model = keras.models.load_model('my_model.h5')

그 담에 model을 convert 하는 코드입니다.

 

import tensorflow as tf

saved_model_dir = "my_model"

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) # path to the SavedModel directory
tflite_model = converter.convert()

# Save the model.
with open('my_model.tflite', 'wb') as f:
  f.write(tflite_model)

또한, 다음을 통해 output의 datatype에 대해서 알아볼 수 있습니다.

 

import tensorflow as tf

interpreter = tf.lite.Interpreter(model_path="your_model.tflite")
interpreter.allocate_tensors()

# Print input shape and type
inputs = interpreter.get_input_details()
print('{} input(s):'.format(len(inputs)))
for i in range(0, len(inputs)):
    print('{} {}'.format(inputs[i]['shape'], inputs[i]['dtype']))

# Print output shape and type
outputs = interpreter.get_output_details()
print('\n{} output(s):'.format(len(outputs)))
for i in range(0, len(outputs)):
    print('{} {}'.format(outputs[i]['shape'], outputs[i]['dtype']))

를 했을 때 다음과 같이 나옴을 볼 수 있습니다.

1 input(s):
[  1 224 224   3] <class 'numpy.float32'>

1 output(s):
[1 1000] <class 'numpy.float32'>

 

 

reference :

https://keras.io/api/applications/mobilenet/

https://www.tensorflow.org/guide/keras/save_and_serialize?hl=ko 

https://www.tensorflow.org/lite/models/convert/convert_models

https://firebase.google.com/docs/ml/android/use-custom-models#java

반응형