Source code for histocartography.ml.layers.dense_gin_layer

"""
Implementation of a Dense GIN (Graph Isomorphism Network) layer. This implementation should be used
when the input graph(s) can only be represented as an adjacency (typically when dealing with dense
adjacency matrices).

Original paper:
    - How Powerful are Graph Neural Networks: https://arxiv.org/abs/1810.00826
    - Author's public implementation: https://github.com/weihua916/powerful-gnns
"""

import torch
import torch.nn.functional as F
import dgl
from torch import nn

from .mlp import MLP


[docs]class DenseGINLayer(nn.Module):
[docs] def __init__( self, node_dim: int, out_dim: int, act: str = 'relu', agg_type: str = 'mean', hidden_dim: int = 32, batch_norm: bool = True, graph_norm: bool = False, with_lrp: bool = False, dropout: float = 0., verbose: bool = False) -> None: """ Dense GIN Layer constructor Args: node_dim (int): Input dimension of each node. out_dim (int): Output dimension of each node. act (str): Activation function of the update function. agg_type (str): Aggregation function. Defaults to 'mean'. hidden_dim (int): Hidden dimension of the GIN MLP. Defaults to 32. batch_norm (bool): If we should use batch normalization. Defaults to True. graph_norm (bool): If we should use graph normalization. Defaults to False. with_lrp (bool): If we should use LRP. Defaults to False. dropout (float): If we should use dropout. Defaults to 0. verbose (bool): Verbosity. Defaults to False. """ super().__init__() if verbose: print('Creating dense GIN layer:') self.out_dim = out_dim self.add_self = True self.mean = agg_type == 'mean' self.mlp = MLP( node_dim, hidden_dim, out_dim, 2, act, batch_norm, verbose=verbose)
[docs] def forward(self, adj, h): """ Forward-pass of a Dense GIN layer. :param g: DGLGraph object. Node features in GNN_NODE_FEAT_IN_KEY :return: updated node features """ if isinstance(adj, dgl.DGLGraph): adj = dgl.unbatch(adj) assert( len(adj) == 1), "Batch size must be equal to 1 for processing Dense GIN Layers" adj = adj[0].adjacency_matrix().to_dense().unsqueeze( dim=0).to( h.device) if self.mean: degree = adj.sum(1, keepdim=True) degree[degree == 0.] = 1. adj = adj / degree if self.add_self: adj = adj.float() + torch.eye(adj.size(1)).to(adj.device) # adjust h dim if len(h.shape) < 3: h = h.unsqueeze(dim=0) h_k_N = torch.matmul(adj, h) bs, n_nodes, dim = h_k_N.shape h_k_N = h_k_N.view(bs * n_nodes, dim) h_k = self.mlp(h_k_N) h_k = h_k.view(bs, n_nodes, self.out_dim) h_k = F.relu(h_k).squeeze() return h_k