336x280(권장), 300x250(권장), 250x250, 200x200 크기의 광고 코드만 넣을 수 있습니다.
torch.nn.init.normal_(tensor, mean=0.0, std=1.0)
Fills the input Tensor with values drawn from the normal distribution N(mean, std^2)
Parameters
-
tensor – an n-dimensional torch.Tensor
-
mean – the mean of the normal distribution
-
std – the standard deviation of the normal distribution
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import random
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# for reproducibility
random.seed(111)
torch.manual_seed(777)
if device == 'cuda':
torch.cuda.manual_seed_all(777)
# MNIST DATASET
mnist_train = dsets.MNIST(root='MNIST_data/', train=True,
transform=transforms.ToTensor(), download=True)
mnist_test = dsets.MNIST(root='MNIST_data/', train=False,
transform=transforms.ToTensor(), download=True)
# parameters
epochs = 15
batch_size = 100
learning_rate = 0.001
# dataloader
data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size,
shuffle=True, drop_last=True)
# model
linear1 = torch.nn.Linear(28*28, 256, bias=True)
linear2 = torch.nn.Linear(256,256, bias=True)
linear3 = torch.nn.Linear(256,10, bias=True)
relu = torch.nn.ReLU()
torch.nn.init.normal_(linear1.weight)
torch.nn.init.normal_(linear2.weight)
torch.nn.init.normal_(linear3.weight)
# 마지막에 relu 추가하지 않는 이유는 cross-entropy에서 softmax를 사용하기 때문.
model = torch.nn.Sequential(linear1, relu, linear2, relu, linear3).to(device)
loss_function = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
print('start')
for epoch in range(epochs):
avg_loss = 0
batch_count = len(data_loader)
for X, Y in data_loader:
X = X.view(-1, 28*28).to(device)
Y = Y.to(device)
prediction = model(X)
loss = loss_function(prediction, Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_loss += (loss / batch_count)
print(epoch, avg_loss.item())
참고 : https://pytorch.org/docs/stable/nn.init.html
참고 : https://github.com/deeplearningzerotoall/PyTorch
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] MNIST with Batch Normalization (0) | 2019.10.29 |
---|---|
[PyTorch] MNIST with Dropout (0) | 2019.10.28 |
[PyTorch] MNIST Introduction (0) | 2019.10.28 |
[PyTorch] Logistic Regression & Softmax (0) | 2019.10.27 |
[PyTorch] Minibatch Gradient Descent (0) | 2019.10.27 |