!!!本次模拟训练的时长在没有下载的基础上且使用cuda加速的情况下是4min多。 (需要保证体验的话,需要使用cuda或者mps进行加速,且提前下载数据集)
在上一篇实验中,我们看到 MLP 在 MNIST 手写数字识别上可以达到接近 97.5% 的准确率。这说明 MLP 具备了较强的拟合能力,只要样本量足够并且合理调节超参数,就能在像手写体这样结构相对简单的数据集上取得非常不错的表现。
但这里也埋下了一个重要的问题:这样的高准确率是否意味着 MLP 在其他任务上也能同样泛化?
MNIST 数据集本身有几个特点:
1:图片分辨率低(28×28 灰度),输入维度相对较小;
2:样本居中、背景干净、噪声少,模式相对统一;
3:任务目标简单:10 类数字分类,类间差异明显。
这就解释了为什么 MLP 在 MNIST 上可以轻松达到很高的准确率——因为它不需要建模复杂的局部结构,只要把像素展平后学习全局模式,就足以区分数字。然而,如果我们把手写图片做一些简单的扰动,比如:
a:微小平移(数字偏移几像素);b:轻微旋转(±15°);c:添加噪声(模糊或随机点);d:换成其他的数据源,接下来我们通过设计实验来观察MLP在其他物体识别泛化能力怎样。
[code]import torchimport torch.nn as nnimport torch.optim as optimimport torchvisionimport torchvision.transforms as transformsfrom torchvision.transforms import functional as TFimport numpy as npimport matplotlib.pyplot as pltfrom matplotlib import rcParamsimport matplotlib.font_manager as fmimport pandas as pd# 尝试多个中文字体(按优先级)chinese_fonts = [ 'PingFang SC', # macOS 默认 'Heiti TC', # macOS 黑体 'STHeiti', # 华文黑体 'Arial Unicode MS', # 支持中文的Arial 'SimHei', # 黑体 'Microsoft YaHei', # 微软雅黑]# 查找可用的中文字体available_fonts = [f.name for f in fm.fontManager.ttflist]font_found = Nonefor font in chinese_fonts: if font in available_fonts: font_found = font breakif font_found: rcParams['font.sans-serif'] = [font_found] rcParams['axes.unicode_minus'] = False print(f"Using font: {font_found}")else: print("Warning: No Chinese font found, using English labels")# 设置设备device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(f"Using device: {device}\n")# ==================== MLP模型定义 ====================class MLP(nn.Module): def __init__(self, input_size=784, hidden_sizes=[512, 256], num_classes=10): super().__init__() layers = [] # 构建网络 prev_size = input_size for hidden_size in hidden_sizes: layers.append(nn.Linear(prev_size, hidden_size)) layers.append(nn.ReLU()) layers.append(nn.Dropout(0.2)) prev_size = hidden_size layers.append(nn.Linear(prev_size, num_classes)) self.network = nn.Sequential(*layers) def forward(self, x): x = x.view(x.size(0), -1) # 展平 return self.network(x)# ==================== 训练和评估函数 ====================def train_model(model, train_loader, criterion, optimizer, epochs=10, device=device): """训练模型""" model = model.to(device) model.train() for epoch in range(epochs): running_loss = 0.0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if (epoch + 1) % 2 == 0: print(f" Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")def evaluate_model(model, test_loader, device=device): """评估模型准确率""" model.eval() correct, total = 0, 0 with torch.no_grad(): for images, labels in test_loader: images, labels = images.to(device), labels.to(device) outputs = model(images) _, predicted = torch.max(outputs, 1) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total return accuracy# ==================== 数据增强和扰动 ====================class AddGaussianNoise: """添加高斯噪声""" def __init__(self, mean=0., std=0.1): self.mean = mean self.std = std def __call__(self, tensor): return tensor + torch.randn(tensor.size()) * self.std + self.meanclass RandomShift: """随机平移""" def __init__(self, shift_range=4): self.shift_range = shift_range def __call__(self, img): shift_x = np.random.randint(-self.shift_range, self.shift_range + 1) shift_y = np.random.randint(-self.shift_range, self.shift_range + 1) return TF.affine(img, angle=0, translate=(shift_x, shift_y), scale=1.0, shear=0)# ==================== 实验1: 基准测试 ====================def experiment_baseline(): """基准实验:原始MNIST""" print("="*60) print("实验1: 基准测试 - 原始MNIST") print("="*60) # 准备数据 transform = transforms.Compose([transforms.ToTensor()]) train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0) # 训练模型 model = MLP(input_size=784, hidden_sizes=[512, 256], num_classes=10) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) print("开始训练...") train_model(model, train_loader, criterion, optimizer, epochs=10) # 评估 accuracy = evaluate_model(model, test_loader) print(f"✅ 基准准确率: {accuracy:.2f}%\n") return model, accuracy# ==================== 实验2: 平移扰动 ====================def experiment_translation(trained_model): """实验2: 测试平移不变性""" print("="*60) print("实验2: 平移扰动测试") print("="*60) shift_ranges = [0, 2, 4, 6, 8] accuracies = [] for shift in shift_ranges: if shift == 0: transform = transforms.Compose([transforms.ToTensor()]) else: transform = transforms.Compose([ RandomShift(shift_range=shift), transforms.ToTensor() ]) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0) accuracy = evaluate_model(trained_model, test_loader) accuracies.append(accuracy) print(f" 平移范围 ±{shift}px: {accuracy:.2f}%") print() return shift_ranges, accuracies# ==================== 实验3: 旋转扰动 ====================def experiment_rotation(trained_model): """实验3: 测试旋转不变性""" print("="*60) print("实验3: 旋转扰动测试") print("="*60) rotation_angles = [0, 5, 10, 15, 20, 30] accuracies = [] for angle in rotation_angles: if angle == 0: transform = transforms.Compose([transforms.ToTensor()]) else: transform = transforms.Compose([ transforms.RandomRotation(degrees=(angle, angle)), transforms.ToTensor() ]) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0) accuracy = evaluate_model(trained_model, test_loader) accuracies.append(accuracy) print(f" 旋转角度 {angle}°: {accuracy:.2f}%") print() return rotation_angles, accuracies# ==================== 实验4: 噪声扰动 ====================def experiment_noise(trained_model): """实验4: 测试噪声鲁棒性""" print("="*60) print("实验4: 噪声扰动测试") print("="*60) noise_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5] accuracies = [] for noise_std in noise_levels: if noise_std == 0.0: transform = transforms.Compose([transforms.ToTensor()]) else: transform = transforms.Compose([ transforms.ToTensor(), AddGaussianNoise(mean=0., std=noise_std) ]) test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0) accuracy = evaluate_model(trained_model, test_loader) accuracies.append(accuracy) print(f" 噪声标准差 {noise_std:.1f}: {accuracy:.2f}%") print() return noise_levels, accuracies# ==================== 实验5: 组合扰动 ====================def experiment_combined(trained_model): """实验5: 组合扰动测试""" print("="*60) print("实验5: 组合扰动测试") print("="*60) test_cases = [ ("原始", transforms.Compose([transforms.ToTensor()])), ("平移+旋转", transforms.Compose([ RandomShift(shift_range=4), transforms.RandomRotation(degrees=10), transforms.ToTensor() ])), ("平移+噪声", transforms.Compose([ RandomShift(shift_range=4), transforms.ToTensor(), AddGaussianNoise(std=0.2) ])), ("旋转+噪声", transforms.Compose([ transforms.RandomRotation(degrees=10), transforms.ToTensor(), AddGaussianNoise(std=0.2) ])), ("全部扰动", transforms.Compose([ RandomShift(shift_range=4), transforms.RandomRotation(degrees=10), transforms.ToTensor(), AddGaussianNoise(std=0.2) ])) ] case_names = [] accuracies = [] for name, transform in test_cases: test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0) accuracy = evaluate_model(trained_model, test_loader) case_names.append(name) accuracies.append(accuracy) print(f" {name}: {accuracy:.2f}%") print() return case_names, accuracies# ==================== 实验6: Fashion-MNIST ====================def experiment_fashion_mnist(): """实验6: Fashion-MNIST数据集""" print("="*60) print("实验6: Fashion-MNIST 泛化测试") print("="*60) # 准备数据 transform = transforms.Compose([transforms.ToTensor()]) train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True) test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0) # 训练模型 model = MLP(input_size=784, hidden_sizes=[512, 256], num_classes=10) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) print("开始训练...") train_model(model, train_loader, criterion, optimizer, epochs=10) # 评估 accuracy = evaluate_model(model, test_loader) print(f"✅ Fashion-MNIST准确率: {accuracy:.2f}%\n") return accuracy# ==================== 实验7: CIFAR-10 ====================def experiment_cifar10(): """实验7: CIFAR-10数据集(彩色图像)""" print("="*60) print("实验7: CIFAR-10 泛化测试") print("="*60) # 准备数据 transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True) test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0) # 训练模型(输入是32x32x3=3072维) model = MLP(input_size=3072, hidden_sizes=[1024, 512, 256], num_classes=10) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) print("开始训练...") train_model(model, train_loader, criterion, optimizer, epochs=15) # 评估 accuracy = evaluate_model(model, test_loader) print(f"✅ CIFAR-10准确率: {accuracy:.2f}%\n") return accuracy# ==================== 运行所有实验 ====================print("
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作! |