CounterNet: End-to-End Training for Prediction-Aware Counterfactual Explanations

Slide Note
Embed
Share

CounterNet presents an innovative framework integrating model training and counterfactual explanation generation efficiently. By training the predictive model and counterfactual generator together, CounterNet ensures improved validity of explanations at a lower cost. This approach enhances convergence and generates counterfactual explanations with higher validity and faster speed, outperforming traditional methods.


Uploaded on Apr 02, 2024 | 0 Views


Download Presentation

Please find below an Image/Link to download the presentation.

The content on the website is provided AS IS for your information and personal use only. It may not be sold, licensed, or shared on other websites without obtaining consent from the author. Download presentation by click this link. If you encounter any issues during the download, it is possible that the publisher has removed the file from their server.

E N D

Presentation Transcript


  1. CounterNet CounterNet: End-to-End Training of Prediction Aware Counterfactual Explanations ACM KDD 2023

  2. Intro to counterfactual explanation Counterfactual explanation explains a model by offering a hypothetical sample x that is very close to an original sample x, but belongs to a different class

  3. Motivation Most counterfactual explanation (CF) generation techniques are post-hoc Additional algorithms to generate CFs Non-parametric Parametric/generative model predictive model model agnostic The process of generating CFs is uninformed by the training of the predictive model leads to misalignment between the prediction and explanations, CF generator does not know the decision boundary Most CF generation techniques are time intensive negatively impacts the runtime

  4. Motivation

  5. Contribution CounterNet is an end-to-end learning framework integrates model training and CF generation into a single pipeline, ensures better cost- invalidity tradeoff Train together predictive model CF generator A feedback system creates valid CFs A block-wise coordinate descent process improve the convergence CounterNet generates CF explanations with 100% validity and low cost of change ( 9.8% improvement to baselines) at 3x speed

  6. Proposed architecture Kept during training, removed during testing One-hot encoded Categorical features to ensure input ? and CF example ? are opposite

  7. Optimization (i) predictive accuracy - the predictor network should output accurate predictions ?? (ii) counterfactual validity- CF examples ? produced by the CF generator network should be valid i.e. must change class ( ??+ ?? =1) (iii) minimizing cost of change - minimal modifications should be required to change input instance ? to CF example ? MSE loss-based training seems to be less sensitive to randomness in initialization, more robust to noise, and less prone to overfitting on a wide variety of learning tasks (as compared to cross-entropy loss).

  8. Optimization conventional way of solving the optimization problem is using gradient descent with backpropagation Leads to 2 issues poor convergence in training proneness to adversarial examples gradient across all three loss objectives fluctuates drastically, which leads to poor convergence optimizing 2 with respect to the predictive weights ?? decreases the robustness of the predictor ? ( ), which leads to an increased vulnerability to adversarial examples

  9. Optimization conventional way of solving the optimization problem is using gradient descent with backpropagation Leads to 2 issues poor convergence in training proneness to adversarial examples propose a block-wise coordinate descent procedure (i) optimizing predictive accuracy ( 1) , (ii) optimizing the validity and proximity of CF generation ( 2 and 3) for each minibatch of ? data points {?(?),?(?)}?, we apply two gradient updates to the network through backpropagation.

  10. Experiment Comparison CounterNet is compared against eight state-of-the-art CF explanation methods non-parametric post-hoc method, generates CF examples by optimizing CF validity and proximity VanillaCF non-parametric methods, optimize for diversity, consistency and uncertainty DiverseCF, ProtoCF, and UncertainCF parametric methods, use generative models (i.e., VAE or GAN) to generate CF examples VAE-CF, CounteRGAN, C-CHVAE, and VCNet post-hoc methods require a trained predictive model as input. For fair comparison, only encoder and the predictor have been optimized for L1 loss (improve accuracy)

  11. Experiment Datasets CounterNet extensively tested using 4 datasets. predict whether an individual s income reaches $50K (Y=1) or not (Y=0) using demographic data Adult uses historical payments to predict the default of payment (Y=1) or not (Y=0) Credit predicts if a homeowner qualifies for a line of credit (Y=1) or not (Y=0) HELOC predicts whether MOOC students drop out (Y=1) or not (Y=0), based on their online learning logs OULAD

  12. Experiment Evaluation Metrics 5 evaluation metrics have been used to validate the performance. the fraction of input data points for which ??+ ?? = 1. High validity is desirable >> proves fidelity / effectiveness Validity defined as the ?1 norm distance between ? and ? divided by the number of features. Low score better. Proximity Sparsity measures the number of feature changes (i.e., ?0 norm) between ? and ? Cost ?1 distance to the ?-nearest neighbor of ? (? = 1). Low manifold distance is desirable as closeness to the training data manifold indicates realistic CF explanations Manifold distance Runtime Time to generate CF examples, not training time

  13. Results Comparison of validity, proximity, sparsity, manifold distance

  14. Results Cost-invalidity tradeoff Runtime (inference time only) comparison Average of all 1- % of invalid CFs

  15. Additional Results 5 ablation studies: 1. CounterNet-BCE:replacing the MSE based L1 and L2 loss with binary cross entropy loss 2. CounterNet-SingleBP: no block wise backpropagation 3. CounterNet-Separate: a separate predictor ?: X Y and CF generator ?: X X , such that ? and ? share no identical components 4. CounterNet-NoPass-? ?? ?: excluding passing ?? 5. CounterNet-Posthoc: train the predictor on the entire training dataset, and optimize CF generator while the trained predictor is frozen.

  16. Additional Results

  17. Additional Results more ablation studies: 1. CounterNet-NoFreeze:without freezing the predictor at the second stage of the optimization 2. Base model: Just a predictor

Related


More Related Content