## ----include=FALSE------------------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 7,
  fig.height = 5,
  error = TRUE 
)

## -----------------------------------------------------------------------------
library(fairGATE)
library(dplyr)
library(readxl)

# Loading the UCI Adult Dataset

data("adult_ready_small", package = "fairGATE")
adult_data <- adult_ready_small

adult <- adult_data %>%
  mutate(
    across(where(is.character), ~ trimws(.x)),
    income = as.integer(income)
  )

## -----------------------------------------------------------------------------

# Dropping unwanted cols (i.e. numeric cols and those with high multicolinearity)
cols_to_drop <- c("subjectid", "Row.names")

#  Ensure to perform other preprocessing steps such as one-hot endoing etc

# Fully prepared data goes here

prepared <- fairGATE::prepare_data(
  data          = adult,
  outcome_var   = "income",
  group_var     = "sex",
  cols_to_remove= cols_to_drop
)

## ----include = FALSE----------------------------------------------------------
# --- Safety block: clean any non-finite or zero-variance columns before training ---

X <- prepared$X

fix_col <- function(x) {
  x[!is.finite(x)] <- NA
  if (all(is.na(x))) return(rep(0, length(x)))
  x[is.na(x)] <- stats::median(x, na.rm = TRUE)
  x
}

# Replace any non-finite values
bad <- colSums(!is.finite(X)) > 0
if (any(bad)) X[, bad] <- apply(X[, bad, drop = FALSE], 2, fix_col)

# Drop zero-variance columns (these can cause NaNs on scaling)
zv <- apply(X, 2, function(v) sd(v, na.rm = TRUE) == 0)
if (any(zv)) X <- X[, !zv, drop = FALSE]

# Update prepared object
prepared$X <- X
prepared$feature_names <- colnames(X)

# Quick sanity check
stopifnot(sum(!is.finite(prepared$X)) == 0, ncol(prepared$X) > 0)

## ----train-demo, results='hide', message=FALSE, warning=FALSE-----------------
# Train a small Gated Neural Network
trained_model <- fairGATE::train_gnn(
  prepared_data = prepared,
  run_tuning    = FALSE,     # skip tuning for speed
  best_params   = list(
    lr = 0.01,
    hidden_dim = 16,
    dropout_rate = 0.1,
    lambda = 0.0,
    temperature = 1.0
  ),
  num_repeats   = 2,         # very short repeated split
  epochs        = 20,        # fast CRAN-safe runtime
  verbose       = FALSE
)

## -----------------------------------------------------------------------------

# Run basic analysis
basic_analyses <- analyse_gnn_results(
  gnn_results = trained_model,
  prepared_data = prepared
)

# --- View all plots from the basic analysis ---
cat("## ROC Curve\n")
print(basic_analyses$roc_plot)

cat("\n## Calibration Plot\n")
print(basic_analyses$calibration_plot)

cat("\n## Gate Weight Distribution\n")
print(basic_analyses$gate_density_plot)

cat("\n## Gate Entropy Distribution\n")
print(basic_analyses$entropy_density_plot)

## -----------------------------------------------------------------------------
exp_res <- analyse_experts(
  gnn_results     = trained_model,   # from train_gnn()
  prepared_data   = prepared,        # from prepare_data()
  top_n_features  = 15,              # number of top features to visualise
  verbose         = TRUE
)

# View the main objects returned
names(exp_res)
#> [1] "all_weights"          "means_by_group_wide" 
#> [3] "pairwise_differences" "difference_plot"     
#> [5] "multi_group_plot"     "top_features_multi"

# View first few feature importances
head(exp_res$means_by_group_wide)

# Example: view one pairwise difference table
names(exp_res$pairwise_differences)
#> [1] "Female_vs_Male"
head(exp_res$pairwise_differences[[1]])

# Visualise feature specialisation
if (!is.null(exp_res$difference_plot)) print(exp_res$difference_plot)
if (!is.null(exp_res$multi_group_plot)) print(exp_res$multi_group_plot)

## -----------------------------------------------------------------------------
# Generate and print the Sankey plot
p <- plot_sankey(
  prepared_data  = prepared,       # from prepare_data()
  gnn_results    = trained_model,  # from train_gnn()
  expert_results = exp_res,        # from analyse_experts()
  verbose        = TRUE
)

print(p)


## ----f360_export, eval = FALSE, message = FALSE-------------------------------
# export_f360_csv(
#   gnn_results       = trained_model,   # from train_gnn()
#   prepared_data     = prepared,        # from prepare_data()
#   path              = "outputs/fairness360_input.csv",
#   include_gate_cols = TRUE,            # include expert routing probabilities
#   threshold         = 0.5,             # classification threshold for binary outcome
#   verbose           = TRUE
# )

