卷积层输出:
$o= \lfloor \frac{n+2p-f}{s} \rfloor+ 1$
池化层输出:
$o= \frac{n+2p-f}{s}+ 1$
$n代表图片大小,p代表填充,f代表卷积核,s代表步长,o代表输出图片大小$
池化输出大小=[(输入大小-卷积核(过滤器)大小)/步长]+1
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
device = ('cuda' if torch.cuda.is_available() else 'cpu')
#获取数据
train_data = torchvision.datasets.MNIST(
root='../data/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=False)
test_data = torchvision.datasets.MNIST(
root='../data/',
train=False,
transform=torchvision.transforms.ToTensor(),
download=False)
#对数据进行分批次训练
batch_size = 64
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
# 定义模型
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, stride=1, padding= 2), #1*28*28-- 6*28*28
# 设置 padding=2 使得结果为 28*28 (28+2*2-5)/1+ 1
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), # 6*14*14
)
self.layer2 = nn.Sequential(
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding= 0), # 16*10*10
nn.ReLU(),
nn.MaxPool2d(kernel_size=2), # 16*5*5
)
self.layer3 = nn.Sequential(
nn.Conv2d(16, 120, kernel_size=5, stride=1, padding=0), #120*1
nn.ReLU()
)
self.fc1 = nn.Sequential(
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x
def evaluate_accuracy(data, model):
"""
计算测试集训练效果
"""
acc_sum, n = 0.0, 0
model.eval()
with torch.no_grad():
for x, y in data:
x, y = x.to(device), y.to(device)
acc_sum += (model(x).argmax(1)== y).float().sum().item() #计算正确的个数
n += y.shape[0] #计算全部数据个数
return acc_sum/ n
#定义损失函数以及优化函数
model = LeNet().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(data, model, loss_fn, optimizer):
size = len(data.dataset)
model.train()
for batch, (x,y) in enumerate(data):
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch %100 == 0:
loss, current = loss.item(), (batch+ 1)* len(x) #loss为:tensor(127.4510, device='cuda:0', grad_fn=<DivBackward1>)所以通过item()去得到他们的具体数值
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(data, model, loss_fn):
size = len(data.dataset)
num_batches = len(data)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for x, y in data:
x, y = x.to(device), y.to(device)
pred = model(x)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
epochs = 21
for t in range(epochs):
if epochs/10 == 0:
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1 ------------------------------- loss: 2.272300 [ 64/60000] loss: 2.269048 [ 6464/60000] loss: 2.278383 [12864/60000] loss: 2.273706 [19264/60000] loss: 2.274111 [25664/60000] loss: 2.270103 [32064/60000] loss: 2.254435 [38464/60000] loss: 2.277729 [44864/60000] loss: 2.270250 [51264/60000] loss: 2.255527 [57664/60000] Test Error: Accuracy: 30.1%, Avg loss: 2.258554 Epoch 2 ------------------------------- loss: 2.257055 [ 64/60000] loss: 2.248479 [ 6464/60000] loss: 2.262434 [12864/60000] loss: 2.252451 [19264/60000] loss: 2.254298 [25664/60000] loss: 2.245104 [32064/60000] loss: 2.219618 [38464/60000] loss: 2.255465 [44864/60000] loss: 2.239263 [51264/60000] loss: 2.216094 [57664/60000] Test Error: Accuracy: 38.5%, Avg loss: 2.220405 Epoch 3 ------------------------------- loss: 2.220529 [ 64/60000] loss: 2.198015 [ 6464/60000] loss: 2.226407 [12864/60000] loss: 2.196418 [19264/60000] loss: 2.201220 [25664/60000] loss: 2.175020 [32064/60000] loss: 2.122310 [38464/60000] loss: 2.191850 [44864/60000] loss: 2.142621 [51264/60000] loss: 2.095963 [57664/60000] Test Error: Accuracy: 44.2%, Avg loss: 2.098208 Epoch 4 ------------------------------- loss: 2.103936 [ 64/60000] loss: 2.031757 [ 6464/60000] loss: 2.105584 [12864/60000] loss: 2.008843 [19264/60000] loss: 2.016967 [25664/60000] loss: 1.924333 [32064/60000] loss: 1.748101 [38464/60000] loss: 1.972561 [44864/60000] loss: 1.804982 [51264/60000] loss: 1.649397 [57664/60000] Test Error: Accuracy: 54.8%, Avg loss: 1.677555 Epoch 5 ------------------------------- loss: 1.742907 [ 64/60000] loss: 1.542034 [ 6464/60000] loss: 1.671715 [12864/60000] loss: 1.490320 [19264/60000] loss: 1.450628 [25664/60000] loss: 1.320244 [32064/60000] loss: 1.038140 [38464/60000] loss: 1.442216 [44864/60000] loss: 1.164547 [51264/60000] loss: 1.009380 [57664/60000] Test Error: Accuracy: 70.4%, Avg loss: 1.018695 Epoch 6 ------------------------------- loss: 1.148516 [ 64/60000] loss: 0.897734 [ 6464/60000] loss: 0.964151 [12864/60000] loss: 0.832320 [19264/60000] loss: 0.885078 [25664/60000] loss: 0.752632 [32064/60000] loss: 0.629204 [38464/60000] loss: 0.919603 [44864/60000] loss: 0.766702 [51264/60000] loss: 0.731564 [57664/60000] Test Error: Accuracy: 79.8%, Avg loss: 0.670018 Epoch 7 ------------------------------- loss: 0.817475 [ 64/60000] loss: 0.596555 [ 6464/60000] loss: 0.593916 [12864/60000] loss: 0.590386 [19264/60000] loss: 0.677329 [25664/60000] loss: 0.536581 [32064/60000] loss: 0.469469 [38464/60000] loss: 0.676546 [44864/60000] loss: 0.615217 [51264/60000] loss: 0.625991 [57664/60000] Test Error: Accuracy: 84.1%, Avg loss: 0.528252 Epoch 8 ------------------------------- loss: 0.682580 [ 64/60000] loss: 0.482965 [ 6464/60000] loss: 0.444592 [12864/60000] loss: 0.502215 [19264/60000] loss: 0.565583 [25664/60000] loss: 0.454061 [32064/60000] loss: 0.375289 [38464/60000] loss: 0.562539 [44864/60000] loss: 0.538687 [51264/60000] loss: 0.565120 [57664/60000] Test Error: Accuracy: 86.5%, Avg loss: 0.452470 Epoch 9 ------------------------------- loss: 0.595269 [ 64/60000] loss: 0.420472 [ 6464/60000] loss: 0.369769 [12864/60000] loss: 0.452824 [19264/60000] loss: 0.485208 [25664/60000] loss: 0.413044 [32064/60000] loss: 0.316019 [38464/60000] loss: 0.505066 [44864/60000] loss: 0.483565 [51264/60000] loss: 0.524046 [57664/60000] Test Error: Accuracy: 87.9%, Avg loss: 0.403727 Epoch 10 ------------------------------- loss: 0.527021 [ 64/60000] loss: 0.376387 [ 6464/60000] loss: 0.322572 [12864/60000] loss: 0.420881 [19264/60000] loss: 0.419712 [25664/60000] loss: 0.387603 [32064/60000] loss: 0.278192 [38464/60000] loss: 0.469517 [44864/60000] loss: 0.439173 [51264/60000] loss: 0.490684 [57664/60000] Test Error: Accuracy: 89.2%, Avg loss: 0.367917 Epoch 11 ------------------------------- loss: 0.470405 [ 64/60000] loss: 0.342959 [ 6464/60000] loss: 0.287529 [12864/60000] loss: 0.399545 [19264/60000] loss: 0.365872 [25664/60000] loss: 0.368633 [32064/60000] loss: 0.253037 [38464/60000] loss: 0.446631 [44864/60000] loss: 0.399979 [51264/60000] loss: 0.463988 [57664/60000] Test Error: Accuracy: 90.0%, Avg loss: 0.340112 Epoch 12 ------------------------------- loss: 0.425155 [ 64/60000] loss: 0.320384 [ 6464/60000] loss: 0.260997 [12864/60000] loss: 0.385833 [19264/60000] loss: 0.325431 [25664/60000] loss: 0.353819 [32064/60000] loss: 0.232314 [38464/60000] loss: 0.432795 [44864/60000] loss: 0.367541 [51264/60000] loss: 0.440752 [57664/60000] Test Error: Accuracy: 90.5%, Avg loss: 0.317997 Epoch 13 ------------------------------- loss: 0.384871 [ 64/60000] loss: 0.303614 [ 6464/60000] loss: 0.238483 [12864/60000] loss: 0.375285 [19264/60000] loss: 0.294723 [25664/60000] loss: 0.339454 [32064/60000] loss: 0.216060 [38464/60000] loss: 0.420999 [44864/60000] loss: 0.339942 [51264/60000] loss: 0.419056 [57664/60000] Test Error: Accuracy: 90.9%, Avg loss: 0.299733 Epoch 14 ------------------------------- loss: 0.352212 [ 64/60000] loss: 0.293606 [ 6464/60000] loss: 0.219635 [12864/60000] loss: 0.366890 [19264/60000] loss: 0.269858 [25664/60000] loss: 0.325129 [32064/60000] loss: 0.201280 [38464/60000] loss: 0.413026 [44864/60000] loss: 0.315552 [51264/60000] loss: 0.399812 [57664/60000] Test Error: Accuracy: 91.4%, Avg loss: 0.284142 Epoch 15 ------------------------------- loss: 0.324788 [ 64/60000] loss: 0.286184 [ 6464/60000] loss: 0.206210 [12864/60000] loss: 0.359954 [19264/60000] loss: 0.252210 [25664/60000] loss: 0.310987 [32064/60000] loss: 0.189725 [38464/60000] loss: 0.404570 [44864/60000] loss: 0.295089 [51264/60000] loss: 0.381713 [57664/60000] Test Error: Accuracy: 91.8%, Avg loss: 0.270155 Epoch 16 ------------------------------- loss: 0.299573 [ 64/60000] loss: 0.280513 [ 6464/60000] loss: 0.195589 [12864/60000] loss: 0.352637 [19264/60000] loss: 0.236667 [25664/60000] loss: 0.297527 [32064/60000] loss: 0.178814 [38464/60000] loss: 0.395350 [44864/60000] loss: 0.276259 [51264/60000] loss: 0.364834 [57664/60000] Test Error: Accuracy: 92.2%, Avg loss: 0.257590 Epoch 17 ------------------------------- loss: 0.276401 [ 64/60000] loss: 0.274002 [ 6464/60000] loss: 0.187289 [12864/60000] loss: 0.345724 [19264/60000] loss: 0.222929 [25664/60000] loss: 0.285335 [32064/60000] loss: 0.167951 [38464/60000] loss: 0.385595 [44864/60000] loss: 0.259681 [51264/60000] loss: 0.348546 [57664/60000] Test Error: Accuracy: 92.5%, Avg loss: 0.246200 Epoch 18 ------------------------------- loss: 0.254310 [ 64/60000] loss: 0.267883 [ 6464/60000] loss: 0.180744 [12864/60000] loss: 0.337245 [19264/60000] loss: 0.210713 [25664/60000] loss: 0.274544 [32064/60000] loss: 0.157922 [38464/60000] loss: 0.376551 [44864/60000] loss: 0.244212 [51264/60000] loss: 0.333477 [57664/60000] Test Error: Accuracy: 92.9%, Avg loss: 0.235646 Epoch 19 ------------------------------- loss: 0.234088 [ 64/60000] loss: 0.262315 [ 6464/60000] loss: 0.175059 [12864/60000] loss: 0.328604 [19264/60000] loss: 0.199100 [25664/60000] loss: 0.265312 [32064/60000] loss: 0.149548 [38464/60000] loss: 0.366871 [44864/60000] loss: 0.231003 [51264/60000] loss: 0.320089 [57664/60000] Test Error: Accuracy: 93.3%, Avg loss: 0.225980 Epoch 20 ------------------------------- loss: 0.215904 [ 64/60000] loss: 0.256772 [ 6464/60000] loss: 0.169321 [12864/60000] loss: 0.320245 [19264/60000] loss: 0.187060 [25664/60000] loss: 0.256086 [32064/60000] loss: 0.142840 [38464/60000] loss: 0.357499 [44864/60000] loss: 0.218770 [51264/60000] loss: 0.307769 [57664/60000] Test Error: Accuracy: 93.6%, Avg loss: 0.216964 Epoch 21 ------------------------------- loss: 0.200837 [ 64/60000] loss: 0.252707 [ 6464/60000] loss: 0.163498 [12864/60000] loss: 0.312334 [19264/60000] loss: 0.176415 [25664/60000] loss: 0.248623 [32064/60000] loss: 0.136781 [38464/60000] loss: 0.347503 [44864/60000] loss: 0.208322 [51264/60000] loss: 0.297342 [57664/60000] Test Error: Accuracy: 93.7%, Avg loss: 0.208597 Epoch 22 ------------------------------- loss: 0.187000 [ 64/60000] loss: 0.248547 [ 6464/60000] loss: 0.158774 [12864/60000] loss: 0.304266 [19264/60000] loss: 0.165540 [25664/60000] loss: 0.241203 [32064/60000] loss: 0.132309 [38464/60000] loss: 0.337949 [44864/60000] loss: 0.199489 [51264/60000] loss: 0.288586 [57664/60000] Test Error: Accuracy: 94.0%, Avg loss: 0.200504 Epoch 23 ------------------------------- loss: 0.175113 [ 64/60000] loss: 0.244488 [ 6464/60000] loss: 0.153888 [12864/60000] loss: 0.296739 [19264/60000] loss: 0.156701 [25664/60000] loss: 0.233408 [32064/60000] loss: 0.127569 [38464/60000] loss: 0.327803 [44864/60000] loss: 0.192073 [51264/60000] loss: 0.280420 [57664/60000] Test Error: Accuracy: 94.2%, Avg loss: 0.193056 Epoch 24 ------------------------------- loss: 0.164669 [ 64/60000] loss: 0.238974 [ 6464/60000] loss: 0.149055 [12864/60000] loss: 0.290120 [19264/60000] loss: 0.146750 [25664/60000] loss: 0.226248 [32064/60000] loss: 0.123039 [38464/60000] loss: 0.318942 [44864/60000] loss: 0.186323 [51264/60000] loss: 0.274440 [57664/60000] Test Error: Accuracy: 94.4%, Avg loss: 0.186250 Epoch 25 ------------------------------- loss: 0.155306 [ 64/60000] loss: 0.234572 [ 6464/60000] loss: 0.143423 [12864/60000] loss: 0.284805 [19264/60000] loss: 0.136558 [25664/60000] loss: 0.218994 [32064/60000] loss: 0.119276 [38464/60000] loss: 0.311303 [44864/60000] loss: 0.181491 [51264/60000] loss: 0.268996 [57664/60000] Test Error: Accuracy: 94.6%, Avg loss: 0.179679 Epoch 26 ------------------------------- loss: 0.147111 [ 64/60000] loss: 0.229742 [ 6464/60000] loss: 0.138638 [12864/60000] loss: 0.279645 [19264/60000] loss: 0.127543 [25664/60000] loss: 0.212860 [32064/60000] loss: 0.115741 [38464/60000] loss: 0.303846 [44864/60000] loss: 0.177719 [51264/60000] loss: 0.264207 [57664/60000] Test Error: Accuracy: 94.7%, Avg loss: 0.173709 Epoch 27 ------------------------------- loss: 0.139985 [ 64/60000] loss: 0.224278 [ 6464/60000] loss: 0.134438 [12864/60000] loss: 0.275243 [19264/60000] loss: 0.119117 [25664/60000]
--------------------------------------------------------------------------- KeyboardInterrupt Traceback (most recent call last) ~\AppData\Local\Temp\ipykernel_7180\2134615688.py in <module> 3 if epochs%10 == 0: 4 print(f"Epoch {t+1}\n-------------------------------") ----> 5 train(train_dataloader, model, loss_fn, optimizer) 6 test(test_dataloader, model, loss_fn) 7 print("Done!") ~\AppData\Local\Temp\ipykernel_7180\2395823617.py in train(data, model, loss_fn, optimizer) 9 x, y = x.to(device), y.to(device) 10 ---> 11 pred = model(x) 12 loss = loss_fn(pred, y) 13 e:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(*input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], [] ~\AppData\Local\Temp\ipykernel_7180\2835041471.py in forward(self, x) 26 x = self.layer1(x) 27 x = self.layer2(x) ---> 28 x = self.layer3(x) 29 x = x.view(x.size(0), -1) 30 x = self.fc1(x) e:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(*input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], [] e:\Anaconda\lib\site-packages\torch\nn\modules\container.py in forward(self, input) 202 def forward(self, input): 203 for module in self: --> 204 input = module(input) 205 return input 206 e:\Anaconda\lib\site-packages\torch\nn\modules\module.py in _call_impl(self, *input, **kwargs) 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(*input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], [] e:\Anaconda\lib\site-packages\torch\nn\modules\conv.py in forward(self, input) 461 462 def forward(self, input: Tensor) -> Tensor: --> 463 return self._conv_forward(input, self.weight, self.bias) 464 465 class Conv3d(_ConvNd): e:\Anaconda\lib\site-packages\torch\nn\modules\conv.py in _conv_forward(self, input, weight, bias) 457 weight, bias, self.stride, 458 _pair(0), self.dilation, self.groups) --> 459 return F.conv2d(input, weight, bias, self.stride, 460 self.padding, self.dilation, self.groups) 461 KeyboardInterrupt:
def test(data, model):
model.eval()
correct = 0
with torch.no_grad():
for x, y in data:
x, y = x.to(device), y.to(device)
pred = model(x)
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
correct /= len(data.dataset)
return correct
epochs = 21
# 没有对数据进行优化,可以选择去对数据进行优化,比如正则化、添加droup out、添加resnet连接
for epoch in range(epochs):
print('Epoch:{}\n'.format(epoch+1))
for batch, (x,y) in enumerate(train_dataloader):
size = len(train_dataloader.dataset)
model.train()
x, y = x.to(device), y.to(device)
pred = model(x)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch %100 == 0:
loss, current = loss.item(), (batch+ 1)* len(x)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
correct = test(test_dataloader, model)
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}% \n")
Epoch:1 loss: 0.098646 [ 64/60000] loss: 0.157770 [ 6464/60000] loss: 0.106539 [12864/60000] loss: 0.215023 [19264/60000] loss: 0.100605 [25664/60000] loss: 0.150523 [32064/60000] loss: 0.102860 [38464/60000] loss: 0.228336 [44864/60000] loss: 0.174099 [51264/60000] loss: 0.196144 [57664/60000] Test Error: Accuracy: 95.9% Epoch:2 loss: 0.107013 [ 64/60000] loss: 0.154821 [ 6464/60000] loss: 0.105823 [12864/60000] loss: 0.208457 [19264/60000] loss: 0.096202 [25664/60000] loss: 0.145667 [32064/60000] loss: 0.101881 [38464/60000] loss: 0.221606 [44864/60000] loss: 0.172308 [51264/60000] loss: 0.194549 [57664/60000] Test Error: Accuracy: 96.0% Epoch:3 loss: 0.102882 [ 64/60000] loss: 0.150609 [ 6464/60000] loss: 0.103246 [12864/60000] loss: 0.202459 [19264/60000] loss: 0.090362 [25664/60000] loss: 0.141385 [32064/60000] loss: 0.101247 [38464/60000] loss: 0.215352 [44864/60000] loss: 0.170169 [51264/60000] loss: 0.192590 [57664/60000] Test Error: Accuracy: 96.1% Epoch:4 loss: 0.099215 [ 64/60000] loss: 0.146702 [ 6464/60000] loss: 0.100990 [12864/60000] loss: 0.196955 [19264/60000] loss: 0.085184 [25664/60000] loss: 0.137365 [32064/60000] loss: 0.100910 [38464/60000] loss: 0.210799 [44864/60000] loss: 0.167222 [51264/60000] loss: 0.190774 [57664/60000] Test Error: Accuracy: 96.2% Epoch:5 loss: 0.095846 [ 64/60000] loss: 0.143141 [ 6464/60000] loss: 0.098452 [12864/60000] loss: 0.191900 [19264/60000] loss: 0.080353 [25664/60000] loss: 0.133666 [32064/60000] loss: 0.100417 [38464/60000] loss: 0.206443 [44864/60000] loss: 0.165251 [51264/60000] loss: 0.189182 [57664/60000] Test Error: Accuracy: 96.4% Epoch:6 loss: 0.092713 [ 64/60000] loss: 0.139563 [ 6464/60000] loss: 0.096374 [12864/60000] loss: 0.187604 [19264/60000] loss: 0.076391 [25664/60000] loss: 0.130577 [32064/60000] loss: 0.100084 [38464/60000] loss: 0.201824 [44864/60000] loss: 0.163594 [51264/60000] loss: 0.187552 [57664/60000] Test Error: Accuracy: 96.5% Epoch:7 loss: 0.090275 [ 64/60000] loss: 0.136202 [ 6464/60000] loss: 0.094576 [12864/60000] loss: 0.183031 [19264/60000] loss: 0.072579 [25664/60000] loss: 0.127552 [32064/60000] loss: 0.100191 [38464/60000] loss: 0.197527 [44864/60000] loss: 0.161685 [51264/60000] loss: 0.185743 [57664/60000] Test Error: Accuracy: 96.5% Epoch:8 loss: 0.087929 [ 64/60000] loss: 0.132803 [ 6464/60000] loss: 0.093051 [12864/60000] loss: 0.178486 [19264/60000] loss: 0.069132 [25664/60000] loss: 0.125270 [32064/60000] loss: 0.100164 [38464/60000] loss: 0.193437 [44864/60000] loss: 0.159913 [51264/60000] loss: 0.183776 [57664/60000] Test Error: Accuracy: 96.6% Epoch:9 loss: 0.085961 [ 64/60000] loss: 0.129178 [ 6464/60000] loss: 0.091663 [12864/60000] loss: 0.174406 [19264/60000] loss: 0.066205 [25664/60000] loss: 0.122413 [32064/60000] loss: 0.100105 [38464/60000] loss: 0.190159 [44864/60000] loss: 0.158249 [51264/60000] loss: 0.182137 [57664/60000] Test Error: Accuracy: 96.6% Epoch:10 loss: 0.083964 [ 64/60000] loss: 0.125733 [ 6464/60000] loss: 0.090878 [12864/60000] loss: 0.169320 [19264/60000] loss: 0.063556 [25664/60000] loss: 0.120058 [32064/60000] loss: 0.099898 [38464/60000] loss: 0.186033 [44864/60000] loss: 0.156695 [51264/60000] loss: 0.180757 [57664/60000] Test Error: Accuracy: 96.7% Epoch:11 loss: 0.082257 [ 64/60000] loss: 0.122557 [ 6464/60000] loss: 0.090394 [12864/60000] loss: 0.165458 [19264/60000] loss: 0.061349 [25664/60000] loss: 0.117119 [32064/60000] loss: 0.100087 [38464/60000] loss: 0.182522 [44864/60000] loss: 0.156071 [51264/60000] loss: 0.179523 [57664/60000] Test Error: Accuracy: 96.8% Epoch:12 loss: 0.080549 [ 64/60000] loss: 0.119308 [ 6464/60000] loss: 0.089058 [12864/60000] loss: 0.161788 [19264/60000] loss: 0.059212 [25664/60000] loss: 0.114778 [32064/60000] loss: 0.099970 [38464/60000] loss: 0.179842 [44864/60000] loss: 0.154772 [51264/60000] loss: 0.178012 [57664/60000] Test Error: Accuracy: 96.9% Epoch:13 loss: 0.078873 [ 64/60000] loss: 0.116339 [ 6464/60000] loss: 0.088101 [12864/60000] loss: 0.157712 [19264/60000] loss: 0.057071 [25664/60000] loss: 0.112539 [32064/60000] loss: 0.099875 [38464/60000] loss: 0.177357 [44864/60000] loss: 0.153125 [51264/60000] loss: 0.176277 [57664/60000] Test Error: Accuracy: 96.9% Epoch:14 loss: 0.077309 [ 64/60000] loss: 0.113174 [ 6464/60000] loss: 0.087097 [12864/60000] loss: 0.154341 [19264/60000] loss: 0.054654 [25664/60000] loss: 0.110167 [32064/60000] loss: 0.099950 [38464/60000] loss: 0.174374 [44864/60000] loss: 0.151898 [51264/60000] loss: 0.174753 [57664/60000] Test Error: Accuracy: 96.9% Epoch:15 loss: 0.075497 [ 64/60000] loss: 0.110737 [ 6464/60000] loss: 0.085666 [12864/60000] loss: 0.151679 [19264/60000] loss: 0.052687 [25664/60000] loss: 0.107632 [32064/60000] loss: 0.100248 [38464/60000] loss: 0.171724 [44864/60000] loss: 0.150230 [51264/60000] loss: 0.173451 [57664/60000] Test Error: Accuracy: 97.0% Epoch:16 loss: 0.073662 [ 64/60000] loss: 0.107864 [ 6464/60000] loss: 0.084569 [12864/60000] loss: 0.148988 [19264/60000] loss: 0.050488 [25664/60000] loss: 0.105394 [32064/60000] loss: 0.100237 [38464/60000] loss: 0.168985 [44864/60000] loss: 0.148702 [51264/60000] loss: 0.172232 [57664/60000] Test Error: Accuracy: 97.0% Epoch:17 loss: 0.072101 [ 64/60000] loss: 0.106040 [ 6464/60000] loss: 0.083778 [12864/60000] loss: 0.146514 [19264/60000] loss: 0.048645 [25664/60000] loss: 0.103209 [32064/60000] loss: 0.100347 [38464/60000] loss: 0.166806 [44864/60000] loss: 0.147567 [51264/60000] loss: 0.171285 [57664/60000] Test Error: Accuracy: 97.0% Epoch:18 loss: 0.070447 [ 64/60000] loss: 0.104067 [ 6464/60000] loss: 0.082554 [12864/60000] loss: 0.143509 [19264/60000] loss: 0.046810 [25664/60000] loss: 0.101432 [32064/60000] loss: 0.100408 [38464/60000] loss: 0.164900 [44864/60000] loss: 0.145843 [51264/60000] loss: 0.170779 [57664/60000] Test Error: Accuracy: 97.1% Epoch:19 loss: 0.068664 [ 64/60000] loss: 0.101738 [ 6464/60000] loss: 0.081756 [12864/60000] loss: 0.141499 [19264/60000] loss: 0.045196 [25664/60000] loss: 0.099498 [32064/60000] loss: 0.100201 [38464/60000] loss: 0.162691 [44864/60000] loss: 0.144729 [51264/60000] loss: 0.169254 [57664/60000] Test Error: Accuracy: 97.2% Epoch:20 loss: 0.067405 [ 64/60000] loss: 0.100164 [ 6464/60000] loss: 0.081304 [12864/60000] loss: 0.139226 [19264/60000] loss: 0.043736 [25664/60000] loss: 0.097499 [32064/60000] loss: 0.100292 [38464/60000] loss: 0.160949 [44864/60000] loss: 0.143649 [51264/60000] loss: 0.168554 [57664/60000] Test Error: Accuracy: 97.2% Epoch:21 loss: 0.065803 [ 64/60000] loss: 0.097873 [ 6464/60000] loss: 0.080800 [12864/60000] loss: 0.137347 [19264/60000] loss: 0.042340 [25664/60000] loss: 0.095061 [32064/60000] loss: 0.100351 [38464/60000] loss: 0.158768 [44864/60000] loss: 0.141980 [51264/60000] loss: 0.166981 [57664/60000] Test Error: Accuracy: 97.2%
torch.save(model.state_dict(), "./LeNet.pth")
model = LeNet().to(device)
model.load_state_dict(torch.load("./LeNet.pth"))
<All keys matched successfully>
model.eval()
x, y = test_data[180][0].view(1,1,28,28), test_data[180][1]
with torch.no_grad():
x = x.to(device)
pred = model(x)
predicted, actual = pred[0].argmax(0), y
print(f'Predicted: "{predicted}", Actual: "{actual}"')
Predicted: "1", Actual: "1"