모바일/안드로이드

안드로이드에서 소리 분류 custom data를 학습시키기 (tflite, yamnet)

jinmc 2022. 8. 30. 15:37
반응형

안드로이드에서 tflite를 이용해서 모델을 이용한 소리 분류 앱을 만드는 튜토리얼을 진행해 보았습니다.

대부분의 코드는 이 튜토리얼 참조 하였습니다.

 

참조:  선행학습된 커스텀 오디오 분류 모델 빌드

 

튜토리얼에서는 새 데이터를 사용하여서 5종류의 새 울음소리를 학습하였습니다.

 

기본적인 앱은 앞선 튜토리얼과 상당히 비슷합니다. 오디오 분류를 위한 기본 앱 만들기

MainActivity 안에 모든 코드가 있으며, 다음과 같습니다.

 

class MainActivity : AppCompatActivity() {
    var TAG = "MainActivity"

    // TODO 2.1: defines the model to be used
     var modelPath = "this_model.tflite"

    // TODO 2.2: defining the minimum threshold
     var probabilityThreshold: Float = 0.3f

    lateinit var textView: TextView

    override fun onCreate(savedInstanceState: Bundle?) {
        super.onCreate(savedInstanceState)
        setContentView(R.layout.activity_main)

        val REQUEST_RECORD_AUDIO = 1337
        requestPermissions(arrayOf(Manifest.permission.RECORD_AUDIO), REQUEST_RECORD_AUDIO)

        textView = findViewById<TextView>(R.id.output)
        val recorderSpecsTextView = findViewById<TextView>(R.id.textViewAudioRecorderSpecs)

        // TODO 2.3: Loading the model from the assets folder
         val classifier = AudioClassifier.createFromFile(this, modelPath)

        // TODO 3.1: Creating an audio recorder
         val tensor = classifier.createInputTensorAudio()

        // TODO 3.2: showing the audio recorder specification
         val format = classifier.requiredTensorAudioFormat
         val recorderSpecs = "Number Of Channels: ${format.channels}\n" +
                "Sample Rate: ${format.sampleRate}"
         recorderSpecsTextView.text = recorderSpecs

        // TODO 3.3: Creating
         val record = classifier.createAudioRecord()
         record.startRecording()


        Timer().scheduleAtFixedRate(object: TimerTask() {
            override fun run() {
                // TODO 4.1: Classifing audio data
                val numberOfSamples = tensor.load(record)
                val output = classifier.classify(tensor)

                var filteredModelOutput = output[0].categories.filter {
                  it.label.contains("Bird") && it.score > .3
                }


                if (filteredModelOutput.isNotEmpty()) {
                    Log.i("Yamnet", "Bird sound detected!")
                        filteredModelOutput = output[1].categories.filter {
                            it.score > probabilityThreshold
                        }
                    }
                // TODO 4.3: Creating a multiline string with the filtered results
                val outputStr =
                    filteredModelOutput.sortedBy { -it.score }
                        .joinToString(separator = "\n") { "${it.label} -> ${it.score}" }

                // TODO 4.4: Updating the UI
                if (outputStr.isNotEmpty())
                    runOnUiThread {
                        textView.text = outputStr
                    }
            }
        },1, 2000)

    }
}

 

새 모델에는 2개의 출력이 있다고 하는데, 

첫 번째로는 사용한 기본 모델의 일반적인 원본 출력 (YAMNET) 이고,

두 번째로는 훈련에 사용한 새에만 해당하는 보조 출력이라고 합니다 (어떤 새인지를 알려줌)

 

이렇게 함으로서, 다른 카테고리에 나오는 false positive를 거를 수 있습니다.

문제는, yamnet에 있는 category에 없는 custom data 를 학습시키고, inference할 때의 일입니다.

이런 문제에 대해서는 yamnet에 대해서 많이 알고 있어야 할 것으로 생각됩니다.

반응형