Deep Learning/Computer Vision

augmentation을 통해서 data imbalance 맞춰주기

jinmc 2023. 11. 6. 16:58
반응형

학습할 때, 가장 중요한 것 중 하나가 data balance를 맞춰 주는 것일 겁니다. 만약 두 개의 클래스가 한 개의 클래스는 100장, 한 개의 클래스는 1000장이 있다고 하면,  두 개의 클래스의 밸런스를 맞춰주기 위해서는 1000장을 100장으로 맞춰주기보다는 100장을 1000장으로 맞춰주는게  나을 것입니다.

 

import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
import numpy as np
from PIL import Image, ImageEnhance
​
# Directory containing the images to augment
input_dir = 'class1'
​
# Directory to save the augmented images
output_dir = 'class1_augmented'
​
# Ensure the output directory exists
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
​
# Initialize the image data generator with brightness and color space transformations
datagen = ImageDataGenerator(
    brightness_range=[0.8, 1.2],  # Vary brightness between 80% and 120% of the original image
    channel_shift_range=20,       # Random channel shifts for color transformations
    horizontal_flip=True         # Randomly flip images horizontally
)
​
# Custom preprocessing function to adjust contrast
def random_contrast(image, intensity_range=(0.8, 1.2)):
    enhancer = ImageEnhance.Contrast(image)
    intensity = np.random.uniform(intensity_range[0], intensity_range[1])
    image_enhanced = enhancer.enhance(intensity)
    return image_enhanced
​
# The augmentation process
total_images = 1000  # Total images desired
generated_count = 0  # Counter for generated images
​
while generated_count < total_images:
    for filename in os.listdir(input_dir):
        if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
            img_path = os.path.join(input_dir, filename)
            img = Image.open(img_path)
            img = load_img(img_path)  # Load image
​
            # Apply custom contrast function
            img = random_contrast(img)

            # Save the augmented image
            augmented_image_path = os.path.join(output_dir, f'aug_{generated_count}.jpg')
            img.save(augmented_image_path)
            generated_count += 1
            if generated_count >= total_images:
                break
​
print(f"Image augmentation complete. Generated {generated_count} images.")

 

물론 augmentation 에는 여러가지가 있지만, 일단 horizontal flip과 채도 명도만 조절하는 정도의 augmentation만 하였습니다. keras library의 datagenerator를 사용하려고 해 보았지만, 한 장의 이미지만 계속 augment되는 현상이 생겨서 뺐습니다. 

반응형