今天給大家介紹的是來自中國科學技術大學計算機科學與技術學院的劉淇教授團隊和騰訊美國分公司的Cheekong Lee聯合發表在AAAI 2022上的文章“ProtGNN: Towards Self-Explaining Graph Neural Networks”。 圖神經網路作為一種黑盒模型,無法提供易於人類理解的解釋,這一弊端使其可信度大打折扣。本文提出面向自解釋的圖神經網路模型ProtGNN / ProtGNN+,將原型學習與GNN結合,在進行分類任務的同時為人類提供視覺化解釋。實驗證明,ProtGNN / ProtGNN+具有與現有GNN相當的分類能力,並能準確高效地進行模型的內在解釋。
1 摘要
儘管圖形神經網路(GNNs)的研究近來取得較大進展,但如何解釋GNN的預測仍是一個具有挑戰性的難題。現有的解釋方法主要集中在事後解釋上,無法揭示GNN的原始推理過程,因此,建立具有內在可解釋性的GNN是非常有必要的。本文中,作者提出了原型圖神經網路模型(ProtGNN),它將原型學習與GNN相結合,為GNN的解釋提供了一個新的視角。在ProtGNN中,解釋是由基於案例的推理過程自然產生的,並在分類過程中實際使用。ProtGNN的分類預測是透過將輸入與潛在空間中學習到的一些原型進行相似度比較而得到的;此外,為了獲得更好的可解釋性和更高的效率,在ProtGNN+中加入了一個新的條件子圖取樣模組,以表明輸入圖的哪一部分與每個原型最相似。最後,作者在廣泛的資料集上評估上述模型,並進行了具體的案例研究,結果表明,ProtGNN / ProtGNN+可以提供較好的內在解釋性,同時達到與現有的不可解釋的GNN相當的分類精度。
2 模型介紹
ProtGNN/ProtGNN+的整體架構如下:
2。1 核心模組
模型主要由三個核心模組構成:GNN Encoder,Prototype Layer和Fully Connected Layer。
GNN Encoder:給定輸入圖
,圖編碼層
將
對映為固定長度的圖嵌入
。
Prototype Layer:在原型層
中,模型為每個類別分配預定數量的原型,對於輸入圖
的嵌入向量
,計算
與原型之間的相似性分數:
其中,
為某一類別的第
個原型嵌入,
設定為一較小值(e。g。, 1e-4)以保證結果始終大於0。
Fully Connected Layer:基於原型層
得到的相似性分數,全連線層
透過Softmax函式計算每個類的輸出機率。
2。2 學習目標
為確保模型準確性,作者在訓練集上採用交叉熵作為損失函式:
其中,
為真實類別。
為確保模型的內在解釋性,作者在訓練集中引入3種約束條件:
1。 The Cluster Cost(Clst):每個圖嵌入應該至少接近其所屬類的一個原型:
2。 The Separation Cost (Sep):每個圖嵌入應該遠離不屬於其所屬類的原型:
3。 The Diversity Loss (Div):促進原型多樣性,避免原型之間過於接近:
其中,
為類
對應的原型集,
為設定的餘弦相似度閾值。
最終得到目標函式:
2。3 原型投影
由於原型也是嵌入向量,不能直接解釋,為了更好的解釋和視覺化,作者設計了一個在訓練階段執行的投影程式——將每個原型
投影到與
相同類別中的最近的訓練子圖上:
其中,子圖
通過蒙特卡洛樹搜尋演算法(MCTS)得到。
2。4 條件子圖取樣模組
為提供更好的解釋,作者進一步提出了帶有條件子圖取樣模組的ProtGNN+模型,該模型不僅顯示了相似度得分,還確定了輸入圖的哪一部分子圖與每個原型最相似。
考慮到呈指數增長的時間複雜度以及並行化和泛化的難度,作者未採用MTCS方法,而是提出一種引數化方法進行子圖搜尋。
作者假設解釋圖為Gilbert隨機圖,以保證每條邊的狀態彼此獨立,此時,節點
和
之間的邊
為:
其中,
為
函式,
為引數為
的多層感知器,
為拼接操作。
此時,子圖搜尋的目標函式為:
其中,
為預算正則化,(
為鄰接矩陣,
為規定的最大子圖規格)。
3 實驗
本文提出5種資料集:MUTAG, BBBP, Graph-SST2, Graph-Twitter和BA-Shape。
3。1 與基準模型比較:
將ProtGNN / ProtGNN+與GCN, GAT,以及GIN進行比較,結果表明,ProtGNN / ProtGNN+在分類任務上,具有與基準模型相當的分類能力:
3。2 案例研究
在MUTAG和Graph-SST2資料集上視覺化模型的推理過程,結果表明,ProtGNN / ProtGNN+能夠準確視覺化原型並識別出相似子圖,具有較好的內在可解釋性:
3。3 原型的t-SNE視覺化
利用t-SNE方法對 BBBP 資料集上的圖嵌入和原型嵌入進行降維視覺化,結果表明,原型可以佔據圖嵌入的中心,這驗證了原型學習的有效性:
3。4 效率分析
在BBBP資料集上比較不同模型的執行時間(其中,ProtGNN+*是採用MCTS作為條件子圖取樣方法的模型),結果表明:
1。相比於MCTS,ProtGNN+提出的引數化條件子圖取樣方法能有效降低時間成本;
2。儘管ProtGNN / ProtGNN+與GCN相比具有更大的時間成本(主要是由採用MCTS進行原型投影造成的),但考慮到前兩者提供的內在可解釋性,該時間成本仍是可以接受的。
4 總結
雖然已有多種模型從不同的角度解釋GNN,但它們都不能為GNN提供模型內在的解釋。本文中,作者提出了ProtGNN / ProtGNN+模型,為GNN的解釋提供了一個新的視角:
1。ProtGNN的分類預測是透過將輸入與原型層中的一些學習原型進行相似性比較來獲得的;
2。為了更好的可解釋性和更高的效率,ProtGNN+提出了一種新的條件子圖取樣模組來指示與原型最相似的子圖。
實驗結果表明,ProtGNN / ProtGNN+可以提供人類可接受的分類精度、時間複雜度和人類可理解的推理過程。
文章地址
https://arxiv。org/pdf/2112。00911。pdf
程式碼地址
https://github。com/zaixizhang/ProtGNN
作者 | 王郅巍
稽核 | 付海濤
————- End ————-