找回密码
 立即注册
首页 业界区 安全 模型训练中 平均损失值和平均准确率的深入理解 ...

模型训练中 平均损失值和平均准确率的深入理解

百杲憔 2025-9-28 18:49:47
aver_loss

总损失的计算

对于求平均损失来说  需要先求总损失
而求总损失 就需要求一个批次中的损失
对于一个bs来说 损失的计算是利用
loss=criterion(out,labels)计算得出
而criterion 使用的nn.crossentropy
得出来的损失值 已经是对这一个bs传入的所有样本取过平均值了
所以得出来的loss是当前bs的aver_loss
上面标亮的这段话 是求损失值的关键,也是后面两种方法的基础。
则total_loss+=loss  就计算出总损失了。
对于aver_loss 是可以有两种处理方式的。
方法一:累加“总损失”,最后除以“总样本数”

这是更精确、更标准的方法,也是 PyTorch 官方教程中常见的方式。

  • 循环内的操作:
    Python
    1. running_loss += loss.item() * inputs.size(0)
    复制代码

    • loss.item():这是 PyTorch CrossEntropyLoss 默认返回的一个批次 (batch) 的平均损失
    • inputs.size(0):这是当前批次中的样本数量(也就是 batch_size)。
    • loss.item() * inputs.size(0):用“平均损失”乘以“样本数”,我们得到的实际上是这个批次的“总损失”(即损失值的加和)。
    • running_loss += ...:所以,running_loss 累加的是所有批次的总损失之和,也就是整个 epoch 见过的所有样本的损失总和

  • 循环外的操作:
    Python
    1. epoch_loss = running_loss / dataset_size[phase]
    复制代码

    • 因为 running_loss 是所有样本的损失总和,所以我们理应除以所有样本的总数量 (dataset_size[phase]),来得到最精确的“平均到每个样本的损失”



  • 优点:这种方法可以精确地处理最后一个批次样本数不足的情况(当数据集总数不能被 batch_size 整除时),因为 inputs.size(0) 会自动适应最后一个批次的实际大小。
方法二:累加“平均损失”,最后除以“总批次数”(您提出的方式)

您的这个逻辑也是完全正确的!它代表了另一种计算思路。

  • 要使用您的计算方法,循环内的操作应该是:
    Python
    1. running_loss += loss.item()
    复制代码

    • 这里,我们累加的是每个批次的“平均损失”。running_loss 最终会变成所有批次的平均损失之和

  • 循环外的操作(如您所写):
    Python
    1. aver_loss = running_loss / len(dataloaders[phase])
    复制代码

    • 因为 running_loss 是所有批次平均损失的和,所以我们理应除以总的批次数 (len(dataloaders[phase])),来得到“每个批次的平均损失的平均值”



  • 优点:实现起来非常直观。
  • 微小缺点:当最后一个批次样本数不足时,它在计算最终平均值时,给予了这个不完整的批次的“平均损失”与其他完整批次相同的权重,理论上会引入微小的计算偏差。但在实践中,当数据集很大时,这点偏差几乎可以忽略不计。
accuracy
对于准确率来说,他是在每一个批次(bs)中
使用_,preds=torch.max(outputs,1)
torch.max的使用参考[[torch.max]]  先求出分类的类别。
然后调用torch.sum(preds == labels.data) 求出正确预测的总数。
preds==labels.data 返回的是一个bool数组。
torch.sum则是把bool数组的true视为1 false视为0 求和 最后返回一个整数
实际上返回的是一个整数张量
在每一个epoch训练前定义需定义total_acc=0
在每个批次累加正确的数量到total_acc上
则total_acc也是一个整数张量
再一个轮次所有批次训练完后   total_acc/样本总数(也就是参考损失值计算中的dataset_size[phase])。即为正确率
这里要注意一个点,total_acc是一个整数张量 而样本总数是一个整型变量。
在现代 PyTorch(以及 Python 3)中,除法 / 默认是“真除法” (true division)。当 PyTorch 执行一个整数张量除以一个整数时,它足够“聪明”,知道结果很可能是小数,所以它会自动将结果“提升” (promote)为框架默认的浮点类型,也就是 torch.float32。

来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

相关推荐

前天 00:48

举报

用心讨论,共获提升!
您需要登录后才可以回帖 登录 | 立即注册