336x280(권장), 300x250(권장), 250x250, 200x200 크기의 광고 코드만 넣을 수 있습니다.
Batch Normalization
dropout과 동일하게 model.train( ), model.eval( ) 사용해야 함.
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
device = 'cuda' if torch.cuda.is_available() else 'cpu'
mnist_train = dsets.MNIST(root='MNIST_data/', train=True, download=True,
transform=transforms.ToTensor())
mnist_test = dsets.MNIST(root='MNIST_data/', train=True, download=True,
transform=transforms.ToTensor())
epochs = 15
batch_size = 100
learning_rate = 0.01
data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
shuffle=True, drop_last=True)
bn_linear1 = torch.nn.Linear(28*28, 32, bias=True)
bn_linear2 = torch.nn.Linear(32, 32, bias=True)
bn_linear3 = torch.nn.Linear(32, 10, bias=True)
relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.BatchNorm1d(32)
nn_linear1 = torch.nn.Linear(28*28, 32, bias=True)
nn_linear2 = torch.nn.Linear(32, 32, bias=True)
nn_linear3 = torch.nn.Linear(32, 10, bias=True)
bn_model = torch.nn.Sequential(bn_linear1, bn1, relu,
bn_linear2, bn2, relu,
bn_linear3).to(device)
nn_model = torch.nn.Sequential(nn_linear1, relu,
nn_linear2, relu,
nn_linear3).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
bn_optimizer = torch.optim.Adam(bn_model.parameters(), lr=learning_rate)
nn_optimizer = torch.optim.Adam(nn_model.parameters(), lr=learning_rate)
batch_num = len(data_loader)
for epoch in range(epochs):
bn_model.train()
bn_avg_loss = 0
nn_avg_loss = 0
for X, Y in data_loader:
X = X.view(-1, 28*28).to(device)
Y = Y.to(device)
bn_prediction = bn_model(X)
bn_loss = criterion(bn_prediction, Y)
bn_optimizer.zero_grad()
bn_loss.backward()
bn_optimizer.step()
nn_prediction = nn_model(X)
nn_loss = criterion(nn_prediction, Y)
nn_optimizer.zero_grad()
nn_loss.backward()
nn_optimizer.step()
bn_avg_loss += (bn_loss / batch_num)
nn_avg_loss += (nn_loss / batch_num)
print(epoch+1, bn_avg_loss.item(), nn_avg_loss.item())
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] MNIST CNN (0) | 2019.11.06 |
---|---|
[PyTorch] Convolution (0) | 2019.10.29 |
[PyTorch] MNIST with Dropout (0) | 2019.10.28 |
[PyTorch] MNIST with ReLU and Weight Initialization (0) | 2019.10.28 |
[PyTorch] MNIST Introduction (0) | 2019.10.28 |