histocartography.interpretability.graph_pruning_explainer module

Summary

Classes:

ExplainerModel

GraphPruningExplainer

class GraphPruningExplainer(entropy_loss_weight: float = 1.0, size_loss_weight: float = 0.05, ce_loss_weight: float = 10.0, node_thresh: float = 0.05, mask_init_strategy: str = 'normal', mask_activation: str = 'sigmoid', num_epochs: int = 500, lr: float = 0.01, weight_decay: float = 0.0005, **kwargs)[source]

Bases: histocartography.interpretability.base_explainer.BaseExplainer

__init__(entropy_loss_weight: float = 1.0, size_loss_weight: float = 0.05, ce_loss_weight: float = 10.0, node_thresh: float = 0.05, mask_init_strategy: str = 'normal', mask_activation: str = 'sigmoid', num_epochs: int = 500, lr: float = 0.01, weight_decay: float = 0.0005, **kwargs)None[source]

Graph Pruning Explainer (GNNExplainer) constructor

Parameters
  • entropy_loss_weight (float) – how much weight to put on the element-wise entropy loss term. Default to 1.0.

  • size_loss_weight (float) – how much weight to put on the mask size. Default to 0.05.

  • float) (ce_loss_weight) – how much weight to put on the cross- entropy loss term. Default to 10.0.

  • node_thresh (float) – Threshold value to set deactivate node. Default to 0.05.

  • mask_init_strategy (str) – Initialization strategy for the mask. Default to “normal” (ie all 1’s).

  • mask_activation (str) – Mask activation function. Default to “sigmoid”.

  • num_epochs (int) – Number of epochs used for training the mask. Default to 500.

  • lr (float) – Learning rate. Default to 0.01.

  • weight_decay (float) – Weight decay. Default to 5e-4.

class ExplainerModel(model: torch.nn.modules.module.Module, adj: None._VariableFunctions.tensor, x: None._VariableFunctions.tensor, init_probs: None._VariableFunctions.tensor, model_params: dict, train_params: dict, use_sigmoid: bool = True)[source]

Bases: torch.nn.modules.module.Module

__init__(model: torch.nn.modules.module.Module, adj: None._VariableFunctions.tensor, x: None._VariableFunctions.tensor, init_probs: None._VariableFunctions.tensor, model_params: dict, train_params: dict, use_sigmoid: bool = True)[source]

Explainer constructor.

Parameters
  • model (nn.Module) – Torch model.

  • adj (torch.tensor) – Adjacency matrix.

  • x (torch.tensor) – Node features.

  • (torch.tensor (init_probs) – ): Prediction on the whole graph.

  • model_params (dict) – Model params for learning mask.

  • train_params (dict) – Training params for learning mask.

  • use_sigmoid (bool) – Default to True.

static sigmoid(x, t=1)[source]
forward()[source]

Forward pass.

distillation_loss(inner_logits)[source]

Compute distillation loss.

loss(pred: None._VariableFunctions.tensor)[source]

Compute new overall loss given current prediction. :param pred: Prediction made by current model. :type pred: torch.tensor

Reference

If you use histocartography in your projects, please cite the following:

@inproceedings{pati2021,
    title = {Hierarchical Graph Representations for Digital Pathology},
    author = {Pushpak Pati, Guillaume Jaume, Antonio Foncubierta, Florinda Feroce, Anna Maria Anniciello, Giosuè Scognamiglio, Nadia Brancati, Maryse Fiche, Estelle Dubruc, Daniel Riccio, Maurizio Di Bonito, Giuseppe De Pietro, Gerardo Botti, Jean-Philippe Thiran, Maria Frucci, Orcun Goksel, Maria Gabrani},
    booktitle = {https://arxiv.org/pdf/2102.11057},
    year = {2021}
}