모바일/안드로이드
안드로이드에서 소리 분류 custom data를 학습시키기 (tflite, yamnet)
jinmc
2022. 8. 30. 15:37
반응형
안드로이드에서 tflite를 이용해서 모델을 이용한 소리 분류 앱을 만드는 튜토리얼을 진행해 보았습니다.
대부분의 코드는 이 튜토리얼 참조 하였습니다.
튜토리얼에서는 새 데이터를 사용하여서 5종류의 새 울음소리를 학습하였습니다.
기본적인 앱은 앞선 튜토리얼과 상당히 비슷합니다. 오디오 분류를 위한 기본 앱 만들기
MainActivity 안에 모든 코드가 있으며, 다음과 같습니다.
class MainActivity : AppCompatActivity() {
var TAG = "MainActivity"
// TODO 2.1: defines the model to be used
var modelPath = "this_model.tflite"
// TODO 2.2: defining the minimum threshold
var probabilityThreshold: Float = 0.3f
lateinit var textView: TextView
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
val REQUEST_RECORD_AUDIO = 1337
requestPermissions(arrayOf(Manifest.permission.RECORD_AUDIO), REQUEST_RECORD_AUDIO)
textView = findViewById<TextView>(R.id.output)
val recorderSpecsTextView = findViewById<TextView>(R.id.textViewAudioRecorderSpecs)
// TODO 2.3: Loading the model from the assets folder
val classifier = AudioClassifier.createFromFile(this, modelPath)
// TODO 3.1: Creating an audio recorder
val tensor = classifier.createInputTensorAudio()
// TODO 3.2: showing the audio recorder specification
val format = classifier.requiredTensorAudioFormat
val recorderSpecs = "Number Of Channels: ${format.channels}\n" +
"Sample Rate: ${format.sampleRate}"
recorderSpecsTextView.text = recorderSpecs
// TODO 3.3: Creating
val record = classifier.createAudioRecord()
record.startRecording()
Timer().scheduleAtFixedRate(object: TimerTask() {
override fun run() {
// TODO 4.1: Classifing audio data
val numberOfSamples = tensor.load(record)
val output = classifier.classify(tensor)
var filteredModelOutput = output[0].categories.filter {
it.label.contains("Bird") && it.score > .3
}
if (filteredModelOutput.isNotEmpty()) {
Log.i("Yamnet", "Bird sound detected!")
filteredModelOutput = output[1].categories.filter {
it.score > probabilityThreshold
}
}
// TODO 4.3: Creating a multiline string with the filtered results
val outputStr =
filteredModelOutput.sortedBy { -it.score }
.joinToString(separator = "\n") { "${it.label} -> ${it.score}" }
// TODO 4.4: Updating the UI
if (outputStr.isNotEmpty())
runOnUiThread {
textView.text = outputStr
}
}
},1, 2000)
}
}
새 모델에는 2개의 출력이 있다고 하는데,
첫 번째로는 사용한 기본 모델의 일반적인 원본 출력 (YAMNET) 이고,
두 번째로는 훈련에 사용한 새에만 해당하는 보조 출력이라고 합니다 (어떤 새인지를 알려줌)
이렇게 함으로서, 다른 카테고리에 나오는 false positive를 거를 수 있습니다.
문제는, yamnet에 있는 category에 없는 custom data 를 학습시키고, inference할 때의 일입니다.
이런 문제에 대해서는 yamnet에 대해서 많이 알고 있어야 할 것으로 생각됩니다.
반응형