histocartography.interpretability.graph_pruning_explainer module¶
Summary¶
Classes:
- class GraphPruningExplainer(entropy_loss_weight: float = 1.0, size_loss_weight: float = 0.05, ce_loss_weight: float = 10.0, node_thresh: float = 0.05, mask_init_strategy: str = 'normal', mask_activation: str = 'sigmoid', num_epochs: int = 500, lr: float = 0.01, weight_decay: float = 0.0005, **kwargs)[source]¶
Bases:
histocartography.interpretability.base_explainer.BaseExplainer
- __init__(entropy_loss_weight: float = 1.0, size_loss_weight: float = 0.05, ce_loss_weight: float = 10.0, node_thresh: float = 0.05, mask_init_strategy: str = 'normal', mask_activation: str = 'sigmoid', num_epochs: int = 500, lr: float = 0.01, weight_decay: float = 0.0005, **kwargs) → None[source]¶
Graph Pruning Explainer (GNNExplainer) constructor
- Parameters
entropy_loss_weight (float) – how much weight to put on the element-wise entropy loss term. Default to 1.0.
size_loss_weight (float) – how much weight to put on the mask size. Default to 0.05.
float) (ce_loss_weight) – how much weight to put on the cross- entropy loss term. Default to 10.0.
node_thresh (float) – Threshold value to set deactivate node. Default to 0.05.
mask_init_strategy (str) – Initialization strategy for the mask. Default to “normal” (ie all 1’s).
mask_activation (str) – Mask activation function. Default to “sigmoid”.
num_epochs (int) – Number of epochs used for training the mask. Default to 500.
lr (float) – Learning rate. Default to 0.01.
weight_decay (float) – Weight decay. Default to 5e-4.
- class ExplainerModel(model: torch.nn.modules.module.Module, adj: None._VariableFunctions.tensor, x: None._VariableFunctions.tensor, init_probs: None._VariableFunctions.tensor, model_params: dict, train_params: dict, use_sigmoid: bool = True)[source]¶
Bases:
torch.nn.modules.module.Module
- __init__(model: torch.nn.modules.module.Module, adj: None._VariableFunctions.tensor, x: None._VariableFunctions.tensor, init_probs: None._VariableFunctions.tensor, model_params: dict, train_params: dict, use_sigmoid: bool = True)[source]¶
Explainer constructor.
- Parameters
model (nn.Module) – Torch model.
adj (torch.tensor) – Adjacency matrix.
x (torch.tensor) – Node features.
(torch.tensor (init_probs) – ): Prediction on the whole graph.
model_params (dict) – Model params for learning mask.
train_params (dict) – Training params for learning mask.
use_sigmoid (bool) – Default to True.
Reference¶
If you use histocartography in your projects, please cite the following:
@inproceedings{pati2021,
title = {Hierarchical Graph Representations for Digital Pathology},
author = {Pushpak Pati, Guillaume Jaume, Antonio Foncubierta, Florinda Feroce, Anna Maria Anniciello, Giosuè Scognamiglio, Nadia Brancati, Maryse Fiche, Estelle Dubruc, Daniel Riccio, Maurizio Di Bonito, Giuseppe De Pietro, Gerardo Botti, Jean-Philippe Thiran, Maria Frucci, Orcun Goksel, Maria Gabrani},
booktitle = {https://arxiv.org/pdf/2102.11057},
year = {2021}
}