모바일

tflite를 사용했을 때 output tensor 모양 확인하기

jinmc 2025. 3. 4. 14:58
반응형

모바일에서 돌아갈수 있는 tflite모델을 돌리는데 있어서 output tensor의 갯수를 맞추는 것도 중요하다고 생각합니다.

 

import tensorflow as tf

# 모델 파일 로드
interpreter = tf.lite.Interpreter(model_path="efficientdet-lite0.tflite")
interpreter.allocate_tensors()

# 입력 및 출력 텐서 정보 가져오기
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("Input Details:", input_details)
print("Output Details:", output_details)

 

그럼 결과가 이런식으로 나오게 됩니다.

 

Input Details: [{'name': 'serving_default_inputs:0', 'index': 0, 
			  'shape': array([  1, 256, 256,   3], dtype=int32), 
              'shape_signature': array([  1, 256, 256,   3], dtype=int32), 
              'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 
              'quantization_parameters': {'scales': array([], dtype=float32), 
              'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 
              'sparsity_parameters': {}}]
Output Details: [{'name': 'StatefulPartitionedCall:0', 'index': 431, 
				  'shape': array([    1, 12276,     4], dtype=int32), 
                  'shape_signature': array([    1, 12276,     4], dtype=int32), 
                  'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 
                  'quantization_parameters': {'scales': array([], dtype=float32), 
                  'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 
                  'sparsity_parameters': {}}, {'name': 'StatefulPartitionedCall:1', 
                  'index': 429, 'shape': array([    1, 12276,    19], dtype=int32), 
                  'shape_signature': array([    1, 12276,    19], dtype=int32), 
                  'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 
                  'quantization_parameters': {'scales': array([], dtype=float32), 
                  'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 
                  'sparsity_parameters': {}}]

 

보면, output shape이 [1, 12276, 19] 이렇게 나오는데, 20이 class output length입니다.

반응형