5기(210102~)/B팀

ImageFolder [폐렴 분류해보기]

KAU 2021. 2. 18. 22:00

오랜만에 즐거운 machine learning 시간

 

데이터 셋이 필요하기 때문에 

Kaggle에서 X-ray 데이터 셋을 다운로드 받아옵니다.

www.kaggle.com/paultimothymooney/chest-xray-pneumonia

 

Chest X-Ray Images (Pneumonia)

5,863 images, 2 categories

www.kaggle.com


트레인 데이터셋에는 NORMAL과 PNEUMONIA
정상적인 X-ray의 모습
폐렴  데이터

우한 폐렴 데이터도 준비했습니다.


우리는 colab을 사용할것이기 때문에 데이터 셋을 구글 드라이브에 올려주도록 합시다.

 

 

코랩 환경에서 구글 드라이브 데이터를 가져 올 수 있다.
클릭 몇번으로 임포트 완료! 세상 편해졌네요?


실습 

import torchvision
from torchvision import transforms

from torch.utils.data import DataLoader

토치 비전 라이브러리와 데이터 로더 라이브러리를 임포트 해주도록 합시다.

from matplotlib.pyplot import imshow
%matplotlib inline

시각화 라이브러리인 matplot 라이브러리도 임포트해줍시다.

 

데이터 앞까지 명령어를 사용해서 도달해줍시다.

코랩에서도 기본적으로 리눅스 명령어를 사용하여 경로를 이동할 수 있습니다.

 

trans = transforms.Compose([
    transforms.Resize((64,128))
])

train_data = torchvision.datasets.ImageFolder(root='custom_data/origin_data', transform=trans)

64x128로 데이터셋을 변환시켜줍시다.

 

for num, value in enumerate(train_data):
    data, label = value
    print(num, data, label)
    
    if(label == 0):
        data.save('custom_data/train_data/gray/%d_%d.jpeg'%(num, label))
    else:
        data.save('custom_data/train_data/red/%d_%d.jpeg'%(num, label))

라벨을 붙여주는 코드라고 생각하면 된다.

 

 

 

데이터 랜덤으로 뽑아 봤을 때


CNN을 이용해서 학습시켜보도록 합시다

각종 라이브러리를 임포트 시켜주도록 합시다. 

import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.manual_seed(777)
if device =='cuda':
    torch.cuda.manual_seed_all(777)
data_loader = DataLoader(dataset = train_data, batch_size = 8, shuffle = True, num_workers=2)
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3,6,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(6,16,5),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.layer3 = nn.Sequential(
            nn.Linear(16*13*29, 120),
            nn.ReLU(),
            nn.Linear(120,2)
        )
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.shape[0], -1)
        out = self.layer3(out)
        return out

CNN 구조 정의가 궁금하시다면 다음 게시글을 참고해주세요

metar.tistory.com/entry/CNNConvolutional-Neural-Network

 

CNN[Convolutional Neural Network]

convolution이란 무엇인가? •2D Convolution •주어진 filter로 입력 영상에 Convolution하여 출력 영상을 얻어내는 과정 •딥러닝에서는 해당 filter를 하나의 ‘가중치’로 보고 학습시키는 대상이 됨.

metar.tistory.com

optimizer = optim.Adam(net.parameters(), lr=0.00005)
loss_func = nn.CrossEntropyLoss().to(device)

라벨 맥이는중... 너무 데이터셋이 크다.


데이터셋을 400장으로 줄였습니다.
각각 200장에 라벨을 붙여봅시다.

 

드디어 학습 성공.. 
원래 데이터는 삭제해주셔야 합니다.

원래 데이터 삭제하고 라벨된 데이터만 남겨야합니다.

(처음부터 다른 폴더에 저장하면 되는데..)


torch.save(net.state_dict(), "/content/drive/MyDrive/Colab Notebooks/xray/model/model.pth")

new_net = CNN().to(device)
new_net.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/xray/model/model.pth'))

모델 저장 및 불러오기 완료!

print(net.layer1[0])
print(new_net.layer1[0])

print(net.layer1[0].weight[0][0][0])
print(new_net.layer1[0].weight[0][0][0])

net.layer1[0].weight[0] == new_net.layer1[0].weight[0]


테스트 

trans=torchvision.transforms.Compose([
    transforms.Resize((64,128)),
    transforms.ToTensor()
])
test_data = torchvision.datasets.ImageFolder(root='/content/drive/MyDrive/Colab Notebooks/xray/test2', transform=trans)

test_set = DataLoader(dataset = test_data, batch_size = len(test_data))
with torch.no_grad():
    for num, data in enumerate(test_set):
        imgs, label = data
        imgs = imgs.to(device)
        label = label.to(device)
        
        prediction = net(imgs)
        
        correct_prediction = torch.argmax(prediction, 1) == label
        
        accuracy = correct_prediction.float().mean()
        print('Accuracy:', accuracy.item())

처참한 성능.. 왜..?

아마도 간단한 CNN으로 학습시켜서 분류하기 힘든것 아닌지.. 

아니면 데이터셋이 부족한것일 수도 있다. 

(원래 폐렴 데이터가 5000장인데 200장만 사용했습니다)