|
59 | 59 | "from sklearn.decomposition import PCA\n", |
60 | 60 | "from sklearn.ensemble import IsolationForest, RandomForestClassifier\n", |
61 | 61 | "\n", |
| 62 | + "import shap # Explainable AI tool\n", |
| 63 | + "\n", |
62 | 64 | "import matplotlib.pyplot as plot" |
63 | 65 | ] |
64 | 66 | }, |
|
410 | 412 | " anomaly_label_column: str = 'anomalyLabel',\n", |
411 | 413 | " anomaly_score_column: str = 'anomalyScore',\n", |
412 | 414 | ") -> pd.DataFrame:\n", |
413 | | - " isolation_forest = IsolationForest(n_estimators=200, contamination=0.05, random_state=42)\n", |
| 415 | + " isolation_forest = IsolationForest(n_estimators=200, contamination='auto', random_state=42)\n", |
414 | 416 | " anomaly_score = isolation_forest.fit_predict(prepared_features)\n", |
415 | 417 | "\n", |
416 | | - " original_features[anomaly_label_column] = anomaly_score * -1 # 1 = anomaly, 0 = no anomaly\n", |
| 418 | + " original_features[anomaly_label_column] = (anomaly_score == -1).astype(int) # 1 = anomaly, 0 = normal\n", |
417 | 419 | " original_features[anomaly_score_column] = isolation_forest.decision_function(prepared_features) * -1 # higher = more anomalous\n", |
418 | 420 | " return original_features" |
419 | 421 | ] |
|
440 | 442 | " anomaly_label_column: str = \"anomalyLabel\",\n", |
441 | 443 | " anomaly_score_column: str = \"anomalyScore\"\n", |
442 | 444 | ") -> pd.DataFrame:\n", |
443 | | - " anomalies = anomaly_detected_features[anomaly_detected_features[anomaly_label_column] == -1]\n", |
| 445 | + " anomalies = anomaly_detected_features[anomaly_detected_features[anomaly_label_column] == 1]\n", |
444 | 446 | " return anomalies.sort_values(by=anomaly_score_column, ascending=False).reset_index(drop=True).head(10)" |
445 | 447 | ] |
446 | 448 | }, |
|
456 | 458 | }, |
457 | 459 | { |
458 | 460 | "cell_type": "markdown", |
459 | | - "id": "efa822ca", |
| 461 | + "id": "a3936d79", |
460 | 462 | "metadata": {}, |
461 | 463 | "source": [ |
462 | | - "### 1.4 Plot the 20 most influential features\n", |
463 | | - "\n", |
464 | | - "Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score." |
465 | | - ] |
466 | | - }, |
467 | | - { |
468 | | - "cell_type": "code", |
469 | | - "execution_count": null, |
470 | | - "id": "24427977", |
471 | | - "metadata": {}, |
472 | | - "outputs": [], |
473 | | - "source": [ |
474 | | - "def get_feature_importances(\n", |
475 | | - " anomaly_detected_features: pd.DataFrame, \n", |
476 | | - " prepared_features: numpy_typing.NDArray,\n", |
477 | | - " anomaly_label_column: str = \"anomalyLabel\",\n", |
478 | | - ") -> numpy_typing.NDArray:\n", |
479 | | - " \"\"\"\n", |
480 | | - " Use Random Forest as a proxy model to find out which are the most important features for the anomaly detection model (Isolation Forest).\n", |
481 | | - " This helps to see if embedding components dominate (top 10 filled with them), and then tune accordingly.\n", |
482 | | - " \"\"\"\n", |
483 | | - " # Use IsolationForest labels as a \"pseudo ground truth\"\n", |
484 | | - " y_pseudo = (anomaly_detected_features[anomaly_label_column] == -1).astype(int)\n", |
485 | | - "\n", |
486 | | - " # Fit classifier to match the IF model\n", |
487 | | - " proxy_random_forest = RandomForestClassifier(n_estimators=100, random_state=42)\n", |
488 | | - " proxy_random_forest.fit(prepared_features, y_pseudo)\n", |
489 | | - "\n", |
490 | | - " return proxy_random_forest.feature_importances_" |
491 | | - ] |
492 | | - }, |
493 | | - { |
494 | | - "cell_type": "code", |
495 | | - "execution_count": null, |
496 | | - "id": "97b21d49", |
497 | | - "metadata": {}, |
498 | | - "outputs": [], |
499 | | - "source": [ |
500 | | - "java_package_anomaly_detection_importances = get_feature_importances(java_package_anomaly_detection_features, java_package_anomaly_detection_features_prepared)\n", |
501 | | - "java_package_anomaly_detection_importances_series = pd.Series(java_package_anomaly_detection_importances, index=java_package_anomaly_detection_feature_names).sort_values(ascending=False)\n", |
502 | | - "#display(java_type_anomaly_detection_importances_series.head(10))" |
503 | | - ] |
504 | | - }, |
505 | | - { |
506 | | - "cell_type": "code", |
507 | | - "execution_count": null, |
508 | | - "id": "14d0b03e", |
509 | | - "metadata": {}, |
510 | | - "outputs": [], |
511 | | - "source": [ |
512 | | - "def plot_feature_importances(feature_importances_series: pd.Series, title_prefix: str) -> None:\n", |
513 | | - " feature_importances_series.head(20).plot(\n", |
514 | | - " kind='barh',\n", |
515 | | - " figsize=(10, 6),\n", |
516 | | - " color='skyblue',\n", |
517 | | - " title=f\"{title_prefix}: Top 20 Feature Importances (Random Forest Proxy)\",\n", |
518 | | - " xlabel=\"Importance\"\n", |
519 | | - " )\n", |
520 | | - " plot.gca().invert_yaxis() # Most important feature at the top\n", |
521 | | - " plot.tight_layout()\n", |
522 | | - " plot.show()" |
523 | | - ] |
524 | | - }, |
525 | | - { |
526 | | - "cell_type": "code", |
527 | | - "execution_count": null, |
528 | | - "id": "974a2bae", |
529 | | - "metadata": {}, |
530 | | - "outputs": [], |
531 | | - "source": [ |
532 | | - "plot_feature_importances(java_package_anomaly_detection_importances_series, title_prefix='Java Packages')" |
533 | | - ] |
534 | | - }, |
535 | | - { |
536 | | - "cell_type": "markdown", |
537 | | - "id": "c9dd6246", |
538 | | - "metadata": {}, |
539 | | - "source": [ |
540 | | - "### 1.5. Plot anomalies\n", |
| 464 | + "### 1.4. Plot anomalies\n", |
541 | 465 | "\n", |
542 | 466 | "Plots clustered nodes and highlights anomalies." |
543 | 467 | ] |
544 | 468 | }, |
545 | 469 | { |
546 | 470 | "cell_type": "code", |
547 | 471 | "execution_count": null, |
548 | | - "id": "ab1e76ab", |
| 472 | + "id": "c5604735", |
549 | 473 | "metadata": {}, |
550 | 474 | "outputs": [], |
551 | 475 | "source": [ |
|
640 | 564 | { |
641 | 565 | "cell_type": "code", |
642 | 566 | "execution_count": null, |
643 | | - "id": "aea29887", |
| 567 | + "id": "61ec7904", |
644 | 568 | "metadata": {}, |
645 | 569 | "outputs": [], |
646 | 570 | "source": [ |
647 | 571 | "plot_anomalies(java_package_anomaly_detection_features, title_prefix=\"Java Package Anomalies\")" |
648 | 572 | ] |
649 | 573 | }, |
| 574 | + { |
| 575 | + "cell_type": "markdown", |
| 576 | + "id": "efa822ca", |
| 577 | + "metadata": {}, |
| 578 | + "source": [ |
| 579 | + "### 1.5 Print the 20 most influential features\n", |
| 580 | + "\n", |
| 581 | + "Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score." |
| 582 | + ] |
| 583 | + }, |
| 584 | + { |
| 585 | + "cell_type": "code", |
| 586 | + "execution_count": null, |
| 587 | + "id": "24427977", |
| 588 | + "metadata": {}, |
| 589 | + "outputs": [], |
| 590 | + "source": [ |
| 591 | + "def get_proxy_random_forest(\n", |
| 592 | + " anomaly_detected_features: pd.DataFrame, \n", |
| 593 | + " prepared_features: numpy_typing.NDArray,\n", |
| 594 | + " anomaly_label_column: str = \"anomalyLabel\",\n", |
| 595 | + ") -> RandomForestClassifier:\n", |
| 596 | + " \"\"\"\n", |
| 597 | + " Use Random Forest as a proxy model to find out which are the most important features for the anomaly detection model (Isolation Forest).\n", |
| 598 | + " This helps to see if embedding components dominate (top 10 filled with them), and then tune accordingly.\n", |
| 599 | + " \"\"\"\n", |
| 600 | + " # Use IsolationForest labels as a \"pseudo ground truth\"\n", |
| 601 | + " y_pseudo = anomaly_detected_features[anomaly_label_column]\n", |
| 602 | + "\n", |
| 603 | + " # Fit classifier to match the IF model\n", |
| 604 | + " proxy_random_forest = RandomForestClassifier(n_estimators=100, random_state=42)\n", |
| 605 | + " proxy_random_forest.fit(prepared_features, y_pseudo)\n", |
| 606 | + "\n", |
| 607 | + " return proxy_random_forest" |
| 608 | + ] |
| 609 | + }, |
| 610 | + { |
| 611 | + "cell_type": "code", |
| 612 | + "execution_count": null, |
| 613 | + "id": "97b21d49", |
| 614 | + "metadata": {}, |
| 615 | + "outputs": [], |
| 616 | + "source": [ |
| 617 | + "java_package_proxy_random_forest = get_proxy_random_forest(java_package_anomaly_detection_features, java_package_anomaly_detection_features_prepared)\n", |
| 618 | + "java_package_anomaly_detection_importances = java_package_proxy_random_forest.feature_importances_\n", |
| 619 | + "java_package_anomaly_detection_importances_series = pd.Series(java_package_anomaly_detection_importances, index=java_package_anomaly_detection_feature_names).sort_values(ascending=False)\n", |
| 620 | + "print(java_package_anomaly_detection_importances_series.head(10))" |
| 621 | + ] |
| 622 | + }, |
| 623 | + { |
| 624 | + "cell_type": "raw", |
| 625 | + "id": "14d0b03e", |
| 626 | + "metadata": { |
| 627 | + "vscode": { |
| 628 | + "languageId": "raw" |
| 629 | + } |
| 630 | + }, |
| 631 | + "source": [ |
| 632 | + "# TODO Remove if not used anymore because of a better plot using SHAP\n", |
| 633 | + "def plot_feature_importances(feature_importances_series: pd.Series, title_prefix: str) -> None:\n", |
| 634 | + " feature_importances_series.head(20).plot(\n", |
| 635 | + " kind='barh',\n", |
| 636 | + " figsize=(10, 6),\n", |
| 637 | + " color='skyblue',\n", |
| 638 | + " title=f\"{title_prefix}: Top 20 Feature Importances (Random Forest Proxy)\",\n", |
| 639 | + " xlabel=\"Importance\"\n", |
| 640 | + " )\n", |
| 641 | + " plot.gca().invert_yaxis() # Most important feature at the top\n", |
| 642 | + " plot.tight_layout()\n", |
| 643 | + " plot.show()\n", |
| 644 | + "\n", |
| 645 | + "plot_feature_importances(java_package_anomaly_detection_importances_series, title_prefix='Java Packages')" |
| 646 | + ] |
| 647 | + }, |
| 648 | + { |
| 649 | + "cell_type": "markdown", |
| 650 | + "id": "db03216e", |
| 651 | + "metadata": {}, |
| 652 | + "source": [ |
| 653 | + "### 1.6 Use SHAP to explain the Isolation Forest Model" |
| 654 | + ] |
| 655 | + }, |
| 656 | + { |
| 657 | + "cell_type": "code", |
| 658 | + "execution_count": null, |
| 659 | + "id": "e8c5905d", |
| 660 | + "metadata": {}, |
| 661 | + "outputs": [], |
| 662 | + "source": [ |
| 663 | + "def explain_anomalies_with_shap(\n", |
| 664 | + " random_forest_model: RandomForestClassifier,\n", |
| 665 | + " anomaly_detected_features: pd.DataFrame,\n", |
| 666 | + " prepared_features: numpy_typing.NDArray,\n", |
| 667 | + " feature_names: list[str],\n", |
| 668 | + " title_prefix: str = \"\",\n", |
| 669 | + " anomaly_label_column: str = \"anomalyLabel\",\n", |
| 670 | + ") -> None:\n", |
| 671 | + " \"\"\"\n", |
| 672 | + " Explain anomalies using SHAP values.\n", |
| 673 | + " \"\"\"\n", |
| 674 | + "\n", |
| 675 | + " # Use TreeExplainer for Random Forest\n", |
| 676 | + " explainer = shap.TreeExplainer(random_forest_model)\n", |
| 677 | + " \n", |
| 678 | + " shap_values = explainer.shap_values(prepared_features)\n", |
| 679 | + " print(f\"Input shape: {anomaly_detected_features.shape}\")\n", |
| 680 | + " print(f\"SHAP shape: {np.shape(shap_values)}\")\n", |
| 681 | + "\n", |
| 682 | + " anomaly_rows = anomaly_detected_features[anomaly_label_column] == 1 # Filter anomalies\n", |
| 683 | + " shap.summary_plot(\n", |
| 684 | + " shap_values[anomaly_rows, :, 1], # Class 1 = anomaly\n", |
| 685 | + " prepared_features[anomaly_rows],\n", |
| 686 | + " feature_names=feature_names,\n", |
| 687 | + " plot_type=\"bar\",\n", |
| 688 | + " title=f\"{title_prefix} Anomalies explained using SHAP\",\n", |
| 689 | + " max_display=20,\n", |
| 690 | + " plot_size=(12, 6) # (width, height) in inches\n", |
| 691 | + " )\n", |
| 692 | + "\n", |
| 693 | + " # Create DataFrame of SHAP values for class 1 (anomaly)\n", |
| 694 | + " shap_df = pd.DataFrame(\n", |
| 695 | + " shap_values[:, :, 1], # select SHAP values for class 1\n", |
| 696 | + " columns=feature_names\n", |
| 697 | + " )\n", |
| 698 | + "\n", |
| 699 | + " # Add anomaly label to shap_df\n", |
| 700 | + " shap_df[\"anomalyLabel\"] = anomaly_detected_features[\"anomalyLabel\"].values\n", |
| 701 | + "\n", |
| 702 | + " # Filter to only anomalies using the boolean mask\n", |
| 703 | + " anomaly_shap_df = shap_df[anomaly_rows].drop(columns=[\"anomalyLabel\"])\n", |
| 704 | + "\n", |
| 705 | + " # Get top 3 features per anomaly (by absolute SHAP value)\n", |
| 706 | + " top3_per_anomaly = anomaly_shap_df.apply(\n", |
| 707 | + " lambda row: list(\n", |
| 708 | + " row.abs().sort_values(ascending=False).head(3).index\n", |
| 709 | + " ),\n", |
| 710 | + " axis=1\n", |
| 711 | + " )\n", |
| 712 | + "\n", |
| 713 | + " # Add top 3 influential features to every anomaly row\n", |
| 714 | + " anomaly_detected_features[\"anomalyLabelInfluentialFeatures\"] = None\n", |
| 715 | + " anomaly_detected_features.loc[\n", |
| 716 | + " anomaly_rows, \"anomalyLabelInfluentialFeatures\"\n", |
| 717 | + " ] = top3_per_anomaly.values\n", |
| 718 | + "\n", |
| 719 | + " display(anomaly_detected_features[anomaly_detected_features[\"anomalyLabel\"] == 1].sort_values(by='anomalyScore', ascending=False).head(10))\n" |
| 720 | + ] |
| 721 | + }, |
| 722 | + { |
| 723 | + "cell_type": "code", |
| 724 | + "execution_count": null, |
| 725 | + "id": "7d671e71", |
| 726 | + "metadata": {}, |
| 727 | + "outputs": [], |
| 728 | + "source": [ |
| 729 | + "explain_anomalies_with_shap(\n", |
| 730 | + " random_forest_model=java_package_proxy_random_forest,\n", |
| 731 | + " anomaly_detected_features=java_package_anomaly_detection_features, \n", |
| 732 | + " prepared_features=java_package_anomaly_detection_features_prepared,\n", |
| 733 | + " feature_names=java_package_anomaly_detection_feature_names,\n", |
| 734 | + " title_prefix=\"Java Package\"\n", |
| 735 | + ")" |
| 736 | + ] |
| 737 | + }, |
650 | 738 | { |
651 | 739 | "cell_type": "markdown", |
652 | 740 | "id": "5682bb64", |
|
767 | 855 | "display(get_top_10_anomalies(java_type_anomaly_detection_features))" |
768 | 856 | ] |
769 | 857 | }, |
| 858 | + { |
| 859 | + "cell_type": "markdown", |
| 860 | + "id": "68a00628", |
| 861 | + "metadata": {}, |
| 862 | + "source": [ |
| 863 | + "### 2.4. Plot anomalies\n", |
| 864 | + "\n", |
| 865 | + "Plots clustered nodes and highlights anomalies." |
| 866 | + ] |
| 867 | + }, |
| 868 | + { |
| 869 | + "cell_type": "code", |
| 870 | + "execution_count": null, |
| 871 | + "id": "4ecc9fb4", |
| 872 | + "metadata": {}, |
| 873 | + "outputs": [], |
| 874 | + "source": [ |
| 875 | + "plot_anomalies(java_type_anomaly_detection_features, title_prefix=\"Java Type Anomalies\")" |
| 876 | + ] |
| 877 | + }, |
770 | 878 | { |
771 | 879 | "cell_type": "markdown", |
772 | 880 | "id": "4e565f84", |
773 | 881 | "metadata": {}, |
774 | 882 | "source": [ |
775 | | - "### 2.4 Plot the 20 most influential features\n", |
| 883 | + "### 2.5 Print the 20 most influential features\n", |
776 | 884 | "\n", |
777 | 885 | "Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score." |
778 | 886 | ] |
779 | 887 | }, |
780 | 888 | { |
781 | 889 | "cell_type": "code", |
782 | 890 | "execution_count": null, |
783 | | - "id": "1b97f299", |
| 891 | + "id": "86945e66", |
784 | 892 | "metadata": {}, |
785 | 893 | "outputs": [], |
786 | 894 | "source": [ |
787 | | - "java_type_anomaly_detection_importances = get_feature_importances(java_type_anomaly_detection_features, java_type_anomaly_detection_features_prepared)\n", |
| 895 | + "java_type_proxy_random_forest = get_proxy_random_forest(java_type_anomaly_detection_features, java_type_anomaly_detection_features_prepared)\n", |
| 896 | + "java_type_anomaly_detection_importances = java_type_proxy_random_forest.feature_importances_\n", |
788 | 897 | "java_type_anomaly_detection_importances_series = pd.Series(java_type_anomaly_detection_importances, index=java_type_anomaly_detection_feature_names).sort_values(ascending=False)\n", |
789 | | - "#display(java_type_anomaly_detection_importances_series.head(10))\n", |
790 | | - "\n", |
791 | | - "plot_feature_importances(java_type_anomaly_detection_importances_series, title_prefix='Java Types')" |
| 898 | + "print(java_type_anomaly_detection_importances_series.head(10))" |
792 | 899 | ] |
793 | 900 | }, |
794 | 901 | { |
795 | 902 | "cell_type": "markdown", |
796 | | - "id": "68a00628", |
| 903 | + "id": "b12a0379", |
797 | 904 | "metadata": {}, |
798 | 905 | "source": [ |
799 | | - "### 2.5. Plot anomalies\n", |
800 | | - "\n", |
801 | | - "Plots clustered nodes and highlights anomalies." |
| 906 | + "### 2.6 Use SHAP to explain the Isolation Forest Model" |
802 | 907 | ] |
803 | 908 | }, |
804 | 909 | { |
805 | 910 | "cell_type": "code", |
806 | 911 | "execution_count": null, |
807 | | - "id": "4ecc9fb4", |
| 912 | + "id": "2d4b35c6", |
808 | 913 | "metadata": {}, |
809 | 914 | "outputs": [], |
810 | 915 | "source": [ |
811 | | - "plot_anomalies(java_type_anomaly_detection_features, title_prefix=\"Java Type Anomalies\")" |
| 916 | + "explain_anomalies_with_shap(\n", |
| 917 | + " random_forest_model=java_type_proxy_random_forest,\n", |
| 918 | + " anomaly_detected_features=java_type_anomaly_detection_features, \n", |
| 919 | + " prepared_features=java_type_anomaly_detection_features_prepared,\n", |
| 920 | + " feature_names=java_type_anomaly_detection_feature_names,\n", |
| 921 | + " title_prefix=\"Java Type\"\n", |
| 922 | + ")" |
812 | 923 | ] |
813 | 924 | } |
814 | 925 | ], |
|
0 commit comments