본문으로 바로가기

[PyTorch] MNIST Introduction

category AI/PyTorch 2019. 10. 28. 11:56
336x280(권장), 300x250(권장), 250x250, 200x200 크기의 광고 코드만 넣을 수 있습니다.

 

 

 

 

 

MNIST : 손글씨 숫자 분류. 이미지 크기는 28*28임.

 

import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import random
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.get_device_name('cuda') else 'cpu'

random.seed(777)
torch.manual_seed(777)
if device == 'cuda':
    torch.cuda.manual_seed(777)

# parameters 설정
epochs = 15
batch_size = 100

# torchvision.datasets 에는 수많은 데이터들이 존재.
# train = True (train dataset임) / transforms.ToTensor() : 이미지 형태를 tensor 형태로 바꿈.
mnist_train = dsets.MNIST(root='MNIST_data/', 
                          train=True, 
                          transform=transforms.ToTensor(), 
                          download=True)
mnist_test = dsets.MNIST(root='MNIST_data/', 
                         train=False, 
                         transform=transforms.ToTensor(), 
                         download=True)


# DataLoader
# drop_last = False (마지막에 데이터가 부족하거나하면 버리지 않음.)
data_loader = torch.utils.data.DataLoader(dataset=mnist_train, 
                                          batch_size=batch_size, 
                                          shuffle=True,
                                          drop_last=False)


# MNIST 이미지 형태를 (batch * 784) 형태로 바꿔서 사용함.
linear = torch.nn.Linear(28*28, 10, bias=True).to(device)


criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(linear.parameters(), lr = 0.1)


for epoch in range(epochs):
    avg_loss = 0
    total_batch = len(data_loader)
    
    for X, Y in data_loader:
        X = X.view(-1, 28*28).to(device)
        Y = Y.to(device)
        
        optimizer.zero_grad()
        prediction = linear(X)
        loss = criterion(prediction, Y)
        loss.backward()
        optimizer.step()
        
        avg_loss += (loss / total_batch)
        
    print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.9f}'.format(avg_loss))
    
print('training finished')    

 

 

# gradient update를 하지않겠다는 의미.
with torch.no_grad():

    # test dataset의 데이터 형태를 (batch*784)로 바꿔줌.
    X_test = mnist_test.test_data.view(-1, 28*28).float().to(device)
    Y_test = mnist_test.test_labels.to(device)
    
    prediction = linear(X_test)
    
    # 각 배치별로 가장 높은 가능성의 숫자 클래스를 뽑아줌.
    predicted_classes = torch.argmax(prediction, 1)
    correct_count = (predicted_classes == Y_test)
    
    # 맞는 개수의 평균을 내면 정확도가 나옴.
    accuracy = correct_count.float().mean()
    print(accuracy.item())
    
    
    # 하나의 그림만 뽑아서 보기 위해 랜덤으로 인덱스 설정.
    r = random.randint(1, len(Y_test)-1)
    single_x = mnist_test.test_data[r:r+1].view(-1, 28*28).float().to(device)
    single_y = mnist_test.test_labels[r:r+1].to(device)
    
    prediction = linear(single_x)
    print(single_y.item())
    print(torch.argmax(prediction,1).item())

    plt.imshow(mnist_test.test_data[r:r+1].view(28,28), cmap='Greys', interpolation='nearest')
    plt.show()