histocartography.ml.layers.pna_layer module

PNA: Principal Neighbourhood Aggregation Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic https://arxiv.org/abs/2004.05718

Summary

Classes:

PNALayer

PNATower

class PNALayer(node_dim, out_dim, aggregators: str = 'mean max min std', scalers: str = 'identity amplification attenuation', avg_d: int = 4, dropout: float = 0.0, graph_norm: bool = False, batch_norm: bool = False, towers: int = 1, pretrans_layers: int = 1, posttrans_layers: int = 1, divide_input: bool = True, residual: bool = True, verbose=False)[source]

Bases: torch.nn.modules.module.Module

__init__(node_dim, out_dim, aggregators: str = 'mean max min std', scalers: str = 'identity amplification attenuation', avg_d: int = 4, dropout: float = 0.0, graph_norm: bool = False, batch_norm: bool = False, towers: int = 1, pretrans_layers: int = 1, posttrans_layers: int = 1, divide_input: bool = True, residual: bool = True, verbose=False)[source]

PNA layer constructor.

Parameters
  • node_dim (int) – Input dimension of each node.

  • out_dim (int) – Output dimension of each node.

  • aggregators (str) – Set of aggregation function identifiers. Default to “mean max min std”.

  • scalers (str) – Set of scaling functions identifiers. Default to “identity amplification attenuation”.

  • avg_d (int) – Average degree of nodes in the training set, used by scalers to normalize. Default to 5.

  • dropout (float) – Dropout used. Default to 0.

  • graph_norm (bool) – Whether to use graph normalisation. Default to False.

  • batch_norm (bool) – Whether to use batch normalisation. Default to False.

  • towers – Number of towers to use. Default to 1.

  • pretrans_layers – Number of layers in the transformation before the aggregation. Default to 1.

  • posttrans_layers – Number of layers in the transformation after the aggregation. Default to 1.

  • divide_input – Whether the input features should be split between towers or not. Default to True.

  • residual – Whether to add a residual connection. Default to True.

  • verbose (bool) – Verbosity. Default to False.

forward(g, h)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

set_rlp(with_rlp)[source]
class PNATower(in_dim, out_dim, dropout, graph_norm, batch_norm, aggregators, scalers, avg_d, pretrans_layers, posttrans_layers)[source]

Bases: torch.nn.modules.module.Module

pretrans_edges(edges)[source]
message_func(edges)[source]
reduce_func(nodes)[source]
posttrans_nodes(nodes)[source]
forward(g, h)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Reference

If you use histocartography in your projects, please cite the following:

@inproceedings{pati2021,
    title = {Hierarchical Graph Representations for Digital Pathology},
    author = {Pushpak Pati, Guillaume Jaume, Antonio Foncubierta, Florinda Feroce, Anna Maria Anniciello, Giosuè Scognamiglio, Nadia Brancati, Maryse Fiche, Estelle Dubruc, Daniel Riccio, Maurizio Di Bonito, Giuseppe De Pietro, Gerardo Botti, Jean-Philippe Thiran, Maria Frucci, Orcun Goksel, Maria Gabrani},
    booktitle = {https://arxiv.org/pdf/2102.11057},
    year = {2021}
}