3.5기(200104~)/3팀

3.5기 3팀 파이토치 MNIST (CNN)

KAU-Deeperent 2020. 2. 21. 20:26

2020/02/21

3.5기 3팀 최웅준,송근영,김정민

장소: 능곡역 지노스 까페

합성곱을 이용한 신경망을 구성하여 Mnist를 학습하였다.

28 x 28 사이즈의 이미지셋으로 총 60000장을 라이브러리 'torchvision'에서제공해준다.

모델구조 

구글에 있는 MNIST 모델을 참조하였습니다.

 

 

import torch

import torch.nn as nn

import torchvision.datasets as dsets

import torchvision.transforms as transforms

import matplotlib.pyplot as plt

import random

 

위와 같은 라이브러리를 import 하였습니다.

 

class MnistClassifier(nn.Module) :

  def __init__(self):

    super(MnistClassifier,self).__init__()

    self.conv1 = nn.Conv2d(1,6,3)

    self.pool1 = nn.MaxPool2d(2)

 

    self.conv2 = nn.Conv2d(6,16,3)

    self.pool2 = nn.MaxPool2d(2)

 

    self.Fc1 = nn.Linear(5*5*16,128)

    self.Fc2 =  nn.Linear(128,10)

 

    self.conv_model = nn.Sequential(

        self.conv1,

        nn.ReLU(),

        self.pool1,

        self.conv2,

        nn.ReLU(),

        self.pool2,

 

    )

    self.fc_model = nn.Sequential(

        self.Fc1,

        nn.ReLU(),

        self.Fc2

    )

    #Fc3

 

  def forward(self,x) :

    mnist_model = self.conv_model(x)

    dim = 1

    for d in mnist_model.size()[1:] :

      dim = dim *d

    mnist_model = mnist_model.view(-1,dim)

    mnist_model = self.fc_model(mnist_model)

    return torch.nn.functional.softmax(mnist_model,dim=1)

 

Mnist CNN network 클래스를 정의한 것입니다.

nn.Module을 상속받으면서 네트워크를 만들길래 그방법을 차용했습니다.

nn.Sequential 이라는 함수를 사용하면 자동적으로 layer를 이어서 모델로 만들어줍니다. 

cnn output이 차원이 3이므로 FC 에 input으로 주기 위해서 차원을 1로 변경해줍니다.

1~10의 숫자를 분류하는것이 이 네트워크의 목적이기 때문에 softmax를 사용해줍니다.

 

training_epoch = 10

batch_size = 100

 

mnist_train = dsets.MNIST(root='MNIST_data/',

                          train=True,

                          transform=transforms.ToTensor(),

                          download=True)

 

mnist_test = dsets.MNIST(root='MNIST_data/',

                         train=False,

                         transform=transforms.ToTensor(),

                         download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist_train,

                                          batch_size=batch_size,

                                          shuffle=True,

                                          drop_last=True)

 

DataLoader라는 클래스를 사용했는데  이는 모델 제작자가 data를 배치만큼 쪼개서 input으로 주고 shuffle해주는 과정을 DataLoader라는 클래스가 대신해주는 것입니다. 안써도 상관 없지만 , 편의를 위해서 사용해줍니다.

 

criterion = nn.CrossEntropyLoss()

# backpropagation method

learning_rate = 1e-3

mnist_train_model = MnistClassifier()

optimizer = torch.optim.Adam(mnist_train_model.parameters(), lr=learning_rate)

trn_loss_list=[]

for epoch in range(1,training_epoch) :

  trn_loss=0.0

  for i, x in enumerate(data_loader) :

    data ,label = x

    model_output = mnist_train_model(data)

    



    #gradient 초기화

    optimizer.zero_grad()

    #loss 함수

    loss = criterion(model_output,label)

    #back_propagation

    loss.backward()

    #weight_update

    optimizer.step()

    trn_loss += loss.item()

    # del (memory issue)

    del loss

    del model_output

 

    print("epoch: {}/{} | step: {}/{} | trn loss: {:.4f} ".format(

                epoch+1, training_epoch, i+1, batch_size, trn_loss / 100

            ))            

            

    trn_loss_list.append(trn_loss/100)

    trn_loss = 0.0

 

loss 함수로는 cross_entropy , optimizer 는 Adam을 사용했습니다.

weight update 는 mini gradient descent 방법으로 update했습니다.

 

test_loader = torch.utils.data.DataLoader(dataset=mnist_test,

                                          batch_size=batch_size,

                                          shuffle=True,

                                          drop_last=True)

 

correct = 0.0

with torch.no_grad():

  for i, x in enumerate(test_loader) :

    data, label  = x

    test_output = mnist_train_model(data)

    prediction = test_output.data.max(1)[1]

    correct += prediction.eq(label.data).sum()

print('Test set: Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))

 

Test 정확도는 98.27% 나왔습니다.

'3.5기(200104~) > 3팀' 카테고리의 다른 글

3.5기 3팀 파이토치를 이용하여 와인 분류하기  (0) 2020.02.21
3.5기 3팀 Selective search  (0) 2020.02.15
3.5기 3팀 Sliding Window  (0) 2020.02.07
3.5기 3팀 Inception(GoogLeNet)  (0) 2020.01.31
3.5기 3팀 스터디  (0) 2020.01.23