336x280(권장), 300x250(권장), 250x250, 200x200 크기의 광고 코드만 넣을 수 있습니다.
[Logistic]
활성화 함수 : sigmoid
loss function : binary cross entropy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
torch.manual_seed(1)
class BinaryClassifier(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
return self.sigmoid(self.linear(x))
x_data = [[1,2],[2,3],[3,1],[4,3],[5,3],[6,2]]
y_data = [[0],[0],[0],[1],[1],[1]]
model = BinaryClassifier()
optimizer = optim.SGD(model.parameters(), lr = 1)
x_train = torch.FloatTensor(x_data)
y_train = torch.FloatTensor(y_data)
for epoch in range(1001):
prediction = model(x_train)
loss = F.binary_cross_entropy(prediction, y_train)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(loss.item())
[softmax & cross-entropy]
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
x_train = [[1, 2, 1, 1],
[2, 1, 3, 2],
[3, 1, 3, 4],
[4, 1, 5, 5],
[1, 7, 5, 5],
[1, 2, 5, 6],
[1, 6, 6, 6],
[1, 7, 7, 7]]
y_train = [2, 2, 2, 1, 1, 1, 0, 0]
x_train = torch.FloatTensor(x_train)
y_train = torch.LongTensor(y_train)
class softmaxClassifierModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(4,3)
def forward(self, x):
return self.linear(x)
model = softmaxClassifierModel()
optimizer = optim.SGD(model.parameters(), lr = 0.1)
nb_epochs = 1000
for epoch in range(nb_epochs + 1):
prediction = model(x_train)
# softmax 연산과 one-hot encoding 생략
loss = F.cross_entropy(prediction, y_train)
# cost로 H(x) 개선
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 100번마다 로그 출력
if epoch % 100 == 0:
print('Epoch {:4d}/{} Cost: {:.6f}'.format(
epoch, nb_epochs, loss.item()
))
출처 : https://youtu.be/HgPWRqtg254
'AI > PyTorch' 카테고리의 다른 글
[PyTorch] MNIST with Dropout (0) | 2019.10.28 |
---|---|
[PyTorch] MNIST with ReLU and Weight Initialization (0) | 2019.10.28 |
[PyTorch] MNIST Introduction (0) | 2019.10.28 |
[PyTorch] Minibatch Gradient Descent (0) | 2019.10.27 |
[PyTorch] Linear Regression Model (0) | 2019.10.27 |