본문 바로가기

ML

[ML] Fashion MNIST 데이터셋 분류모델(CNN) 생성 - Jungyu Ko

문제 정의

Fashion MNIST 데이터셋을 이용하여 분류 모델을 만들어보도록 하겠습니다.

해당 글에서는 단순한 CNN (Convolutional Neural Network)을 사용하여 분류 모델을 만들어 볼 예정입니다.

 

실습 github 라이브러리를 공유합니다.

https://github.com/jungyuko/FashionMNIST


먼저 사용할 라이브러리를 불러옵니다.

import logging		# 코드의 출력 결과를 기록하기 위한 라이브러리
import argparse		# hyper parameter를 조절하기 위한 라이브러리
import random		

import torch					# 신경망 학습을 위한 라이브러리
import torch.nn as nn				# 신경망 학습을 위한 라이브러리						
import torch.optim as optim			# Optimizer를 사용하기 위한 라이브러리
from torchvision import datasets, transforms	# 데이터셋을 다운로드 및 변환하기 위한 라이브러리
import numpy as np				# 리스트 조작을 위한 라이브러리

from matplotlib import pyplot as plt	# 시각화를 위한 라이브러리

# utils 함수를 생성하여 함수 초기화, logging 등을 도와주는 파일입니다.
from utils import AverageMeter, config_logging

추가로 hyper parameter와 GPU 동작 여부를 결정하는 코드를 추가합니다.

# argument 세팅
parser = argparse.ArgumentParser(description='Fashion MNIST Classfication')

parser.add_argument('--batch_size', default=64, type=int,
                    help='Dataset batch_size')
parser.add_argument('--num_epochs', default=50, type=int,
                    help='num epochs')
parser.add_argument('--learning_rate', default=1e-3, type=float,
                    help='learning rate')

parser.add_argument('--comment', type=str, default='')

# args = parser.parse_args('')	# jupyter notebook 환경에서 () 안에 ''를 추가해야 작동합니다.
args = parser.parse_args()	# 단순 python 파일로 실행할 때는 () 안에 ''를 넣지 않습니다.
comment = args.comment

config_logging(comment)
logging.info('args: {}'.format(args))


### GPU 설정
device = torch.device("cuda:0")
logging.info("device: {}".format(device))

hyper parameter 등은 위와 같이 설정을 할 경우, 

학습을 진행하고 세부 파라미터를 조절하여 재 학습을 시킬 때 굉장히 편리합니다.

 

데이터셋

Fashion MNIST 데이터셋

Fashion MNIST 데이터셋은 티셔츠, 샌들, 가방과 같은 10가지의 카테고리의 이미지들의 모음입니다.

기존 손글씨 데이터셋인 MNIST 데이터셋보다 이미지의 구성이 복잡하기 때문에 손글씨 MNIST 데이터셋보다 모델의 성능을 판단하기 쉬울 것입니다.

 

Fashion MNIST 데이터셋의 자료구조

Fashion MNIST 데이터셋은 [T-Shirts, Trouser, Pullover, Dress, Coat, Sandal, Shirt, Sneaker, Bag, Ankle boot]와 같이 10가지의 카테고리로 이루어져 있습니다.

한 장의 이미지는 28x28 픽셀의 이미지로 이루어져 있으며, train 데이터로 60,000장, test 데이터로 10,000장으로 나뉘어 있습니다.

 

데이터셋을 불러옵니다.

### transform 설정
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
                                
root = './MNIST_Fashion'	# 데어터셋을 저장하는 경로설정입니다.

# 훈련 데이터셋을 다운로드합니다.
trainset = datasets.FashionMNIST(root=root, 	
                                 download=True, 
                                 train=True, 	
                                 transform=transform)
# 테스트 데이터셋을 다운로드합니다.
testset = datasets.FashionMNIST(root=root,
                                 download=True,  
                                 train=False, 
                                 transform=transform)

# 모델의 입력으로 넣어줄 train dataloader를 정의합니다.
train_loader = torch.utils.data.DataLoader(trainset, 
                                           batch_size=args.batch_size,
                                           shuffle=True)
# 모델의 입력으로 넣어줄 test dataloader를 정의합니다.
test_loader = torch.utils.data.DataLoader(testset, 
                                          batch_size=args.batch_size, 
                                          shuffle=False)

# train data와 test data의 갯수를 확인합니다.
logging.info("train data length: {}, test data length: {}".format(len(trainset), len(testset)))

torchvision에 Fashion MNIST 데이터셋이 저장되어 있기 때문에 손쉽게 데이터셋을 불러오고 다운로드할 수 있습니다.

모델

해당 글에서는 단순한 DNN 모델을 설계합니다.

[28x28] → Flatten → [784] → ReLU → [512] → ReLU → [256] → ReLU → [128] → ReLU → [64] → ReLU → [10]의 DNN을 설계합니다. 마지막 layer가 10인 이유는 Fashion MNIST 데이터셋이 10개의 카테고리를 가지고 있기 때문입니다.

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
                                                            # batchx1x28x28
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16,       # batchx16x28x28
                      kernel_size=3, stride=1, padding=1),   
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2,2),                             
            
            nn.Conv2d(in_channels=16, out_channels=32,      # batchx32x14x14 
                      kernel_size=3, stride=1, padding=1),   
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2,2),

            nn.Conv2d(in_channels=32, out_channels=64,      # batchx64x7x7    
                      kernel_size=3, stride=1, padding=1),      
            nn.ReLU(),

            nn.Flatten(),
            nn.Linear(64*7*7, 784),
            nn.ReLU(),
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )
    
    def forward(self, x):
        x = self.layer(x)
        
        return x

손실 함수

모델을 학습하기 위한 손실 함수로는 Cross Entropy Loss를 사용합니다.

분류 문제에서는 Cross Entropy Loss가 주로 사용됩니다. 상황에 따라 다른 손실 함수가 사용될 때도 있지만 해당 문제에서는 Cross Entropy Loss를 사용합니다. Cross Entropy와 관련해서는 따로 설명하진 않겠습니다. 

학습

model = CNN()
epochs = args.num_epochs

total_loss = []
for epoch in range(1, epochs+1):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

    logging.info('Train Phase, Epochs: {}'.format(epoch))
    model.train()
    train_loss = AverageMeter()

    for batch_num, data in enumerate(train_loader):
        images, labels = data
        
	# DNN과는 달리 CNN에서는 이미지 원본이 모델의 input으로 들어갑니다.
        images = images.view(-1,1,28,28).to(device)	
        labels = labels.to(device)
     
        output = model(images)
        loss = criterion(output, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss.update(loss.item(), images.shape[0])
        
        if batch_num % 100 == 0:
            logging.info(
                "[{}/{}] # {}/{} loss: {:.4f}".format(epoch, epochs, 
                                                     batch_num, len(train_loader), train_loss.val)
            )
    total_loss.append(train_loss.avg)

 

학습 결과

50 epoch의 loss 결과

loss가 안정적으로 떨어지는 것을 확인할 수 있습니다.

loss가 안정적으로 떨어지고 해당 모델을 통해 test 데이터에서도 작동이 제대로 되는지 확인해야 합니다.

 

검증

수능을 잘 봐야지 좋은 결과를 얻지, 모의고사를 잘 본다고 좋은 대학을 가는 것이 아니듯이, 테스트 데이터를 통해 일반화가 제대로 이뤄질 수 있는지 살펴봐야 합니다.

def output_label(label):
    output_mapping = {
                 0: "T-shirt/Top",
                 1: "Trouser",
                 2: "Pullover",
                 3: "Dress",
                 4: "Coat", 
                 5: "Sandal", 
                 6: "Shirt",
                 7: "Sneaker",
                 8: "Bag",
                 9: "Ankle Boot"
                 }
    input = (label.item() if type(label) == torch.Tensor else label)
    return output_mapping[input]

count = 0
ans = 0

class_correct = [0. for _ in range(10)]
total_correct = [0. for _ in range(10)]

logging.info('Test Phase...')
with torch.no_grad():
    model.eval()

    for batch_num, data in enumerate(test_loader):
        images, labels = data
        
	# DNN과는 달리 CNN에서는 이미지 원본이 모델의 input으로 들어갑니다.
        images = images.view(-1,1,28,28).to(device)
        labels = labels.to(device)

        output = model(images)
        predict = torch.max(output, 1)[1].to(device)
        is_correct = (predict == labels).squeeze()

        for i in range(len(is_correct)):
            label = labels[i]
            
            ans += is_correct[i].item()   
            count += 1
            
            class_correct[label] += is_correct[i].item()
            total_correct[label] += 1

logging.info('Total Accuracy: {:.4f}%'.format((ans/count)*100))

for i in range(10):
    logging.info("Accuracy of class {}: {:.4f}%".format(output_label(i), class_correct[i] * 100 / total_correct[i]))
2022-03-05 17:56:30,704 [INFO ] Total Accuracy: 90.6900%
2022-03-05 17:56:30,705 [INFO ] Accuracy of class T-shirt/Top: 88.5000%
2022-03-05 17:56:30,705 [INFO ] Accuracy of class Trouser: 98.3000%
2022-03-05 17:56:30,706 [INFO ] Accuracy of class Pullover: 87.1000%
2022-03-05 17:56:30,706 [INFO ] Accuracy of class Dress: 90.8000%
2022-03-05 17:56:30,706 [INFO ] Accuracy of class Coat: 91.0000%
2022-03-05 17:56:30,707 [INFO ] Accuracy of class Sandal: 98.0000%
2022-03-05 17:56:30,707 [INFO ] Accuracy of class Shirt: 62.7000%
2022-03-05 17:56:30,707 [INFO ] Accuracy of class Sneaker: 97.5000%
2022-03-05 17:56:30,708 [INFO ] Accuracy of class Bag: 98.4000%
2022-03-05 17:56:30,708 [INFO ] Accuracy of class Ankle Boot: 94.6000%

DNN에 비해 전체적으로 약 2%정도의 성능이 오른 것을 확인할 수 있습니다.

DNN에서는 T-shirt와 shirt의 분류 정확도가 다른 카테고리에 비해 낮았던 것을 확인할 수 있었는데 반해,

CNN에서는 T-shirt는 DNN에 비해 정확하게 분류하는 것을 확인할 수 있으며, Shirt의 분류 정확도가 다른 카테고리에 비해 낮은 것을 확인할 수 있습니다. 

실제 예측 결과와 정답 데이터와의 시각적인 비교를 통해 원인을 분석해봅니다.

 

테스트 데이터셋 시각화

label_tags = {
    0: 'T-Shirt', 
    1: 'Trouser', 
    2: 'Pullover', 
    3: 'Dress', 
    4: 'Coat', 
    5: 'Sandal', 
    6: 'Shirt',
    7: 'Sneaker', 
    8: 'Bag', 
    9: 'Ankle Boot'
}
columns = 6
rows = 6
fig = plt.figure(figsize=(10,10))
 
model.eval()
for i in range(1, columns*rows+1):
    data_idx = np.random.randint(len(testset))
    input_img = testset[data_idx][0].unsqueeze(dim=0).to(device) 
 
    output = model(input_img)
    _, argmax = torch.max(output, 1)
    pred = label_tags[argmax.item()]
    label = label_tags[testset[data_idx][1]]
    
    fig.add_subplot(rows, columns, i)
    if pred == label:
        plt.title(pred)
        cmap = 'Blues'
    else:
        plt.title('Not ' + pred + ' but ' +  label)
        cmap = 'Reds'
    plot_img = testset[data_idx][0][0,:,:]
    plt.imshow(plot_img, cmap=cmap)
    plt.axis('off')
    
plt.show()

모델의 테스트 데이터셋의 시각화 결과

Shirt의 경우, Pullover, Coat와 같은 카테고리에서 혼동하는 것을 확인할 수 있습니다.

사실 Shirt의 경우 Pullover, Coat와 유사한 형태를 띠고 있으며 간단하게 구현된 현재의 모델은 Shirt, Pullover, Coat에서의 특징을 다른 카테고리들에 비해 학습이 제대로 되지 않았다고도 볼 수 있습니다.

실험 결과로 예측해보면, Coat는 Coat로 올바르게 예측하지만 Shirt를 Coat, Pullover로 분류하는 경우들이 많아서 Shirt가 다른 결과에 비해 낮은 결과를 얻었다고 생각할 수 있습니다.


마무리

머신러닝을 위한 단계는 크게 3가지로 나눌 수 있습니다.

  • 데이터셋
  • 모델
  • 손실 함수

위의 3가지가 올바르게 설정되었다면, 모델은 올바르게 학습을 할 수 있으며, test 데이터셋을 통해 모델의 성능을 검증할 수 있습니다. 본문에서는 간단한 CNN을 설계하여 구현했음에도 90.6%라는 좋은 성능을 보이는 것을 확인할 수 있습니다. Fashion MNIST 데이터셋은 머신러닝의 입문자들이 실습해보기 좋은 데이터셋이라고 생각합니다.

 

긴 글 읽어주셔서 감사드리며, 머신러닝 학습에 도움이 되었으면 좋겠습니다.

감사합니다.