Usage
python
from explainableai import XAIWrapper
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
from sklearn.neural_network import MLPClassifier
Initialize your models
models = {
'Random Forest': RandomForestClassifier(n_estimators=100, random_state=42),
'Logistic Regression': LogisticRegression(max_iter=1000),
'XGBoost': XGBClassifier(n_estimators=100, random_state=42),
'Neural Network': MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=1000, random_state=42)
}
Create XAIWrapper instance
xai = XAIWrapper()
Fit the models and run analysis
xai.fit(models, X_train, y_train)
results = xai.analyze()
Generate report
xai.generate_report()
Make and explain predictions
prediction, probabilities, explanation = xai.explain_prediction(input_data)
Breaking Changes
- The fit method now requires a dictionary of models instead of a single model.
- Some visualization function signatures have been updated to accommodate multiple models.
We encourage users to update to this version for access to these new features and improvements. As always, please report any issues or suggestions through our GitHub issue tracker.
updates
ExplainableAI
ExplainableAI is a Python package that provides a comprehensive suite of tools for explainable machine learning. It wraps around popular machine learning models and offers various techniques to interpret and explain their predictions.
Features
1. **Model Agnostic**: Works with any scikit-learn compatible model.
2. **Automated EDA**: Performs exploratory data analysis on the input dataset.
3. **Feature Importance**: Calculates and visualizes feature importance.
4. **SHAP Values**: Computes SHAP (SHapley Additive exPlanations) values for in-depth feature impact analysis.
5. **Partial Dependence Plots**: Generates partial dependence plots for top features.
6. **Learning Curve**: Plots learning curves to assess model performance with varying training set sizes.
7. **ROC and Precision-Recall Curves**: For classification tasks, generates ROC and Precision-Recall curves.
8. **Correlation Heatmap**: Visualizes feature correlations.
9. **Cross-Validation**: Performs k-fold cross-validation.
10. **LLM-powered Explanations**: Utilizes Google's Gemini model to provide natural language explanations of model results and individual predictions.
11. **PDF Report Generation**: Automatically generates a comprehensive PDF report of all analyses.
12. **Interactive Predictions**: Allows users to input data and receive explained predictions.
Implementation Details
Core Components
1. **XAIWrapper**: The main class that encapsulates all functionality.
- Handles data preprocessing, model fitting, and various analyses.
- Integrates all explainability techniques.
2. **ReportGenerator**: Generates PDF reports using ReportLab.
- Creates structured reports with text, tables, and visualizations.
3. **Visualization Module**: Contains functions for creating various plots and visualizations.
- Uses matplotlib and seaborn for static visualizations.
4. **Model Evaluation**: Includes functions for assessing model performance.
- Computes metrics like accuracy, F1-score, MSE, R2, etc.
5. **Feature Analysis**: Implements feature importance and SHAP value calculations.
6. **LLM Integration**: Uses Google's Gemini model for natural language explanations.
- Interprets model results and individual predictions.
Key Files
- `core.py`: Contains the XAIWrapper class.
- `report_generation.py`: Implements the ReportGenerator class.
- `visualizations.py`: Houses all visualization functions.
- `model_evaluation.py`: Contains model evaluation metrics.
- `feature_analysis.py`: Implements feature importance and SHAP calculations.
- `llm_explanations.py`: Handles integration with the Gemini model.
Workflow
1. **Data Preprocessing**:
- Handles categorical and numerical features.
- Performs imputation and scaling.
2. **Model Fitting**:
- Fits the provided model to the preprocessed data.
3. **Analysis**:
- Calculates feature importance.
- Generates various visualizations.
- Computes SHAP values.
- Performs cross-validation.
4. **LLM Explanation**:
- Sends analysis results to Gemini for interpretation.
- Generates natural language explanations.
5. **Report Generation**:
- Compiles all analyses and visualizations into a PDF report.
6. **Interactive Predictions**:
- Allows users to input data for new predictions.
- Provides explanations for individual predictions.
Installation
python
pip install explainableai
Usage
python
from explainableai import XAIWrapper
from sklearn.ensemble import RandomForestClassifier
Initialize your model
model = RandomForestClassifier()
Create XAIWrapper instance
xai = XAIWrapper()
Fit the model and run analysis
xai.fit(model, X_train, y_train)
results = xai.analyze()
Generate report
xai.generate_report()
Make and explain predictions
prediction, probabilities, explanation = xai.explain_prediction(input_data)