閩侯福州網(wǎng)站建設(shè)招代理最好的推廣方式
關(guān)于混淆矩陣,各位可以在這里了解:混淆矩陣細致理解_夏天是冰紅茶的博客-CSDN博客
上一篇中我們了解了混淆矩陣,并且進行了類定義,那么在這一節(jié)中我們將要對其進行擴展,在多分類中,如何去計算TP,TN,FN,FP。
原理推導
這里以三分類為例,這里來看看TP,TN,FN,FP是怎么分布的。
類別1的標簽:
類別2的標簽:
類別3的標簽:
這樣我們就能知道了混淆矩陣的對角線就是TP
TP = torch.diag(h)
?假正例(FP)是模型錯誤地將負類別樣本分類為正類別的數(shù)量
FP = torch.sum(h, dim=1) - TP
假負例(FN)是模型錯誤地將正類別樣本分類為負類別的數(shù)量
FN = torch.sum(h, dim=0) - TP
最后用總數(shù)減去除了 TP 的其他三個元素之和得到 TN
TN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)
邏輯驗證
這里借用上一篇的例子,假如我們這個混淆矩陣是這樣的:
tensor([[2, 0, 0],
? ? ? ? ? ? [0, 1, 1],
? ? ? ? ? ? [0, 2, 0]])
為了方便講解,這里我們對其進行一個簡單的編號,即0—8:
0 | 1 | 2 |
3 | 4 | 5 |
6 | 7 | 8 |
torch.sum(h, dim=1) 可得 tensor([2., 2., 2.]) , torch.sum(h, dim=0)?可得?tensor([2., 3., 1.]) 。
- ?TP:? ?tensor([2., 1., 0.])?
- ?FP:? ?tensor([0., 1., 2.])?
- ?TN:? ?tensor([4., 2., 3.])?
- ?FN:? ?tensor([0., 2., 1.])
我們先來看看TP的構(gòu)成,對應著矩陣的對角線2,1,0;FP在類別1中占3,6號位,在類別2中占1,7號位,在類別3中占2,5號位,加起來即為0,1,2;TN在類別1中占4,5,7,8號位,在類別2中占邊角位,在類別3中占0,1,3,4號位,加起來即為4,2,3;FN在類別1中占1,2號位,在類別2中占3,5號位,在類別3中占6,7號位,加起來即為0,2,1。
補充類定義
import torch
import numpy as npclass ConfusionMatrix(object):def __init__(self, num_classes):self.num_classes = num_classesself.mat = Nonedef update(self, t, p):n = self.num_classesif self.mat is None:# 創(chuàng)建混淆矩陣self.mat = torch.zeros((n, n), dtype=torch.int64, device=t.device)with torch.no_grad():# 尋找GT中為目標的像素索引k = (t >= 0) & (t < n)# 統(tǒng)計像素真實類別t[k]被預測成類別p[k]的個數(shù)inds = n * t[k].to(torch.int64) + p[k]self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)def reset(self):if self.mat is not None:self.mat.zero_()@propertydef ravel(self):"""計算混淆矩陣的TN, FP, FN, TP"""h = self.mat.float()n = self.num_classesif n == 2:TP, FN, FP, TN = h.flatten()return TP, FN, FP, TNif n > 2:TP = h.diag()FN = h.sum(dim=1) - TPFP = h.sum(dim=0) - TPTN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)return TP, FN, FP, TNdef compute(self):"""主要在eval的時候使用,你可以調(diào)用ravel獲得TN, FP, FN, TP, 進行其他指標的計算計算全局預測準確率(混淆矩陣的對角線為預測正確的個數(shù))計算每個類別的準確率計算每個類別預測與真實目標的iou,IoU = TP / (TP + FP + FN)"""h = self.mat.float()acc_global = torch.diag(h).sum() / h.sum()acc = torch.diag(h) / h.sum(1)iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))return acc_global, acc, iudef __str__(self):acc_global, acc, iu = self.compute()return ('global correct: {:.1f}\n''average row correct: {}\n''IoU: {}\n''mean IoU: {:.1f}').format(acc_global.item() * 100,['{:.1f}'.format(i) for i in (acc * 100).tolist()],['{:.1f}'.format(i) for i in (iu * 100).tolist()],iu.mean().item() * 100)
我在代碼中添加了屬性修飾器,以便我們可以直接的進行調(diào)用,并且也考慮到了二分類與多分類不同的情況。
性能指標
關(guān)于這些指標在網(wǎng)上有很多介紹,這里就不細講了
class ModelIndex():def __init__(self,TP, FN, FP, TN, e=1e-5):self.TN = TNself.FP = FPself.FN = FNself.TP = TPself.e = edef Precision(self):"""精確度衡量了正類別預測的準確性"""return self.TP / (self.TP + self.FP + self.e)def Recall(self):"""召回率衡量了模型對正類別樣本的識別能力"""return self.TP / (self.TP + self.FN + self.e)def IOU(self):"""表示模型預測的區(qū)域與真實區(qū)域之間的重疊程度"""return self.TP / (self.TP + self.FP + self.FN + self.e)def F1Score(self):"""F1分數(shù)是精確度和召回率的調(diào)和平均數(shù)"""p = self.Precision()r = self.Recall()return 2*p*r / (p + r + self.e)def Specificity(self):"""特異性是指模型在負類別樣本中的識別能力"""return self.TN / (self.TN + self.FP + self.e)def Accuracy(self):"""準確度是模型正確分類的樣本數(shù)量與總樣本數(shù)量之比"""return self.TP + self.TN / (self.TP + self.TN + self.FP + self.FN + self.e)def FP_rate(self):"""False Positive Rate,假陽率是模型將負類別樣本錯誤分類為正類別的比例"""return self.FP / (self.FP + self.TN + self.e)def FN_rate(self):"""False Negative Rate,假陰率是模型將正類別樣本錯誤分類為負類別的比例"""return self.FN / (self.FN + self.TP + self.e)def Qualityfactor(self):"""品質(zhì)因子綜合考慮了召回率和特異性"""r = self.Recall()s = self.Specificity()return r+s-1
參考文章:多分類中TP/TN/FP/FN的計算_Hello_Chan的博客-CSDN博客?