---
title: "Choosing Weights and Validating ML"
author: "Maciej Nasinski"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
  %\VignetteIndexEntry{Choosing Weights and Validating ML}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(
  echo = TRUE, message = FALSE, warning = FALSE,
  collapse = TRUE, comment = "#>"
)
```

This vignette is a decision guide for choosing and checking weights in `cat2cat()`.

Read it when you want to answer one of these questions:

- Are naive and frequency weights telling the same story?
- Is ML worth trying at all?
- If ML is used, does it improve on the frequency baseline?
- What should I do when different weight methods disagree?
- How should I handle failed ML predictions?

If you only need the basic two-period workflow, go back to [Get Started](cat2cat.html). If you need multi-period, panel, aggregated, or regression workflows, continue to [Advanced Workflows](cat2cat_advanced.html).

```{r load-data}
library(cat2cat)
library(dplyr)
library(tidyr)
library(e1071)
library(randomForest)

data(occup, package = "cat2cat")
data(trans, package = "cat2cat")

occup_2008 <- occup[occup$year == 2008, ]
occup_2010 <- occup[occup$year == 2010, ]
occup_2012 <- occup[occup$year == 2012, ]
```

## Step 1: Understand the competing weight assumptions

`cat2cat` offers several ways to assign probability weights to replicated observations. Each method encodes a different **distributional assumption** about how ambiguous observations split across candidate categories. When a downstream estimand depends on the mapped category, this is the identifying assumption for that estimand - so always check sensitivity.

**Naive weights** (`wei_naive_c2c`) are always computed. Each replicated observation gets uniform probability $1/k$ where $k$ is the number of candidate categories.

- *Assumption*: All candidates equally likely (maximum entropy / uninformative prior)
- *Requires*: Only the mapping table - no data from either period
- *Use when*: No information favoring any candidate, or as a robustness lower bound

**Frequency-based weights** (`wei_freq_c2c`) are the default. They use category counts from the base period.

- *Assumption*: Ambiguous observations distribute like the base period population
- *Requires*: Observed counts in base period (falls back to naive if all zero)
- *Use when*: Base period is large and representative; ambiguous cases resemble the general population

**ML weights** (`wei_knn_c2c`, `wei_lda_c2c`, `wei_rf_c2c`, `wei_nb_c2c`) use individual features to predict category membership.

- *Assumption*: Features (age, education, etc.) predict true category: $P(j \mid X, g)$
- *Requires*: Training data with both category labels and predictive features
- *Use when*: Features are informative - verify with `cat2cat_ml_run()`

Available ML methods:

- **knn**: k-Nearest Neighbours. A non-parametric method that handles non-linear boundaries. Sensitive to the choice of `k`.
- **lda**: Linear Discriminant Analysis. Fast, interpretable. Assumes multivariate normality and equal covariance.
- **rf**: Random Forest. Handles interactions well. Slower, needs `ntree` tuning.
- **nb**: Naive Bayes via `e1071`. Fast, useful after numeric/logical/factor preprocessing. Assumes conditional independence of features.

ML features must be numeric, logical, or factor columns. Factor columns are
one-hot encoded automatically using levels observed in the training data and
the target period. Character columns are not encoded automatically; convert
them to factors first if they represent categories.

You can run multiple methods at once and compare or combine them:

```{r mixed-ml-weights}
occup_2_mix <- cat2cat(
  data = list(
    old = occup_2008, new = occup_2010,
    cat_var = "code", time_var = "year"
  ),
  mappings = list(trans = trans, direction = "backward"),
  ml = list(
    data = occup_2010,
    cat_var = "code",
    method = c("knn", "rf", "lda", "nb"),
    features = c("age", "sex", "edu", "exp", "parttime", "salary"),
    args = list(k = 10, ntree = 50),
    on_fail = "na"
  )
)
```

Correlations between weight methods:

```{r weight-correlations}
occup_2_mix$old %>%
  select(wei_knn_c2c, wei_rf_c2c, wei_lda_c2c, wei_nb_c2c, wei_freq_c2c, wei_naive_c2c) %>%
    cor(use = "pairwise.complete.obs")
```

### If ML fails on some rows: `on_fail` and `fail_warn`

Sometimes ML probabilities cannot be produced for a subset of replicated rows
(for example incomplete target features or method-specific prediction failures).
`cat2cat()` exposes explicit policy controls in `ml`:

- `on_fail = "freq"` (default): failed ML rows are filled with `wei_freq_c2c`.
- `on_fail = "naive"`: failed ML rows are filled with `wei_naive_c2c`.
- `on_fail = "na"`: failed ML rows are kept as `NA`.
- `on_fail = "error"`: stop immediately when failed rows are detected.
- `fail_warn = TRUE` (default): warn with affected rows/observations per method.
- `fail_warn = FALSE`: suppress these warnings.

Important: this failure accounting is specific to `cat2cat()` and the constructed
weight columns (`wei_*_c2c`). It is different from `cat2cat_ml_run()` "SKIPPED
GROUPS", which reports mapping groups that were not evaluated in holdout
diagnostics (single category, too few observations, or method fit/predict error
for that group).

```{r ml-failure-policy, eval=FALSE}
ml_setup <- list(
  data = bind_rows(occup_2010, occup_2012),
  cat_var = "code",
  method = c("knn", "rf", "lda"),
  features = c("age", "sex", "edu", "exp", "parttime", "salary"),
  args = list(k = 10, ntree = 50),
  on_fail = "freq",   # default policy
  fail_warn = TRUE     # default reporting
)

# strict mode for QA pipelines
ml_strict <- ml_setup
ml_strict$on_fail <- "error"

# diagnostic mode to inspect failures directly
ml_diag <- ml_setup
ml_diag$on_fail <- "na"
ml_diag$fail_warn <- FALSE
```

Ensemble weights with `cross_c2c()` and pruning with `prune_c2c()`:

```{r ensemble-prune}
occup_old_mix <- occup_2_mix$old %>%
  cross_c2c(.) %>%
  prune_c2c(., column = "wei_cross_c2c", method = "nonzero")
```

## Step 2: Check whether conclusions are sensitive to the weight choice

Different weight methods affect regression coefficients when you filter to a specific occupation group and combine both periods. This is the proper sensitivity analysis: subjects from the base period (new, no replication) plus subjects from the target period (old, weighted by probability of belonging to this group).

### Compare weight methods on the same mapped data

Run backward mapping with all ML methods:

```{r sensitivity-result}
result_all <- cat2cat(
  data = list(old = occup_2008, new = occup_2010,
              cat_var = "code", time_var = "year"),
  mappings = list(trans = trans, direction = "backward"),
  ml = list(
    data = occup_2010, cat_var = "code",
    method = c("knn", "rf", "lda", "nb"),
    features = c("age", "sex", "edu", "exp", "parttime", "salary"),
    args = list(k = 10, ntree = 50)
  )
)
```

**Weighted counts per group** - compare how weight methods redistribute observations:
 
```{r sensitivity-counts}
weight_cols <- c("wei_naive_c2c", "wei_freq_c2c", "wei_knn_c2c", "wei_rf_c2c", "wei_lda_c2c", "wei_nb_c2c")

# Pick groups with high replication
top_groups <- result_all$old %>%
  filter(rep_c2c > 1) %>%
  count(g_new_c2c, sort = TRUE) %>%
  head(6) %>%
  pull(g_new_c2c)

# Weighted counts from OLD period (replicated)
old_counts <- lapply(weight_cols, function(wcol) {
  result_all$old %>%
    filter(g_new_c2c %in% top_groups) %>%
    group_by(g_new_c2c) %>%
    summarise(n = sum(.data[[wcol]]), .groups = "drop")
}) %>%
  setNames(gsub("wei_|_c2c", "", weight_cols)) %>%
  bind_rows(.id = "method") %>%
  tidyr::pivot_wider(names_from = method, values_from = n)

# Counts from NEW period (no replication, exact)
new_counts <- result_all$new %>%
  filter(code %in% top_groups) %>%
  count(code, name = "new_period") %>%
  rename(g_new_c2c = code)

# Combine for comparison
left_join(old_counts, new_counts, by = "g_new_c2c")
```

The `new_period` column shows the actual counts in 2010. The other columns show how the 2008 observations are redistributed under each weight method. `naive` assigns uniform probability (1/n candidates), `freq` uses base period frequencies, and ML methods (`knn`, `rf`, `lda`, `nb`) use predicted probabilities.

Pick a specific group for regression analysis:

```{r sensitivity-weights}
# New-period counts per category (no replication, so plain tally)
new_counts_all <- result_all$new %>%
  count(code, name = "n_new") %>%
  rename(g_new_c2c = code)

# Old-period weighted counts, joined to new-period counts
group_sizes <- result_all$old %>%
  group_by(g_new_c2c) %>%
  summarise(n_old = sum(wei_freq_c2c), .groups = "drop") %>%
  left_join(new_counts_all, by = "g_new_c2c") %>%
  filter(n_old >= 10, n_new >= 10) %>%
  arrange(desc(n_old))

# Pick a group for regression analysis
target_group <- group_sizes$g_new_c2c[1]
cat("Analysing occupation group:", target_group, "\n")
```

**Regression within a single occupation group** - combine both periods and compare coefficients:

```{r sensitivity-group-reg}
# Subset old period to target group (with weights)
old_subset <- result_all$old %>%
  filter(g_new_c2c == target_group)

# Subset new period to target group (no replication, weight = 1)
new_subset <- result_all$new %>%
  filter(code == target_group) %>%
  mutate(
    wei_naive_c2c = 1, wei_freq_c2c = 1, wei_knn_c2c = 1,
    wei_rf_c2c = 1, wei_lda_c2c = 1, wei_nb_c2c = 1
  )

# Combine both periods
d <- bind_rows(old_subset, new_subset)

# Compare all regression coefficients across weight methods
f <- I(log(salary)) ~ age + sex + factor(edu) + exp + parttime

coefs <- sapply(weight_cols, function(wcol) {
  d$w <- d$multiplier * d[[wcol]]
  coef(lm(f, data = d, weights = w))
})
colnames(coefs) <- gsub("wei_|_c2c", "", weight_cols)
round(coefs, 4)
```

All coefficients can vary because weight methods change which old-period subjects contribute to this occupation group.

### Compare pruning strategies only after comparing full weights

> **Note**: Pruning discards probability information and should be used only after analysis with full weights. 
Prefer `prune_c2c(method = "nonzero")` to remove impossible candidates while preserving the probability distribution. 
More aggressive pruning (`highest1`) is appropriate only for descriptive tables or when you need exactly one category per observation.

```{r sensitivity-pruning}
# Compare regression coefficients under different pruning strategies
prune_methods <- c("nonzero", "highest", "highest1")

prune_coefs <- sapply(prune_methods, function(pm) {
  old_pruned <- result_all$old %>%
    prune_c2c(method = pm) %>%
    filter(g_new_c2c == target_group)
  
  d <- bind_rows(old_pruned, new_subset)
  d$w <- d$multiplier * d$wei_freq_c2c
  coef(lm(f, data = d, weights = w))
})
round(prune_coefs, 4)
```

### Compare ensemble compositions when no single method dominates

`cross_c2c()` creates a weighted average of multiple weight columns. Vary the mix:

```{r sensitivity-ensemble}
configs <- list(
  equal      = c(1, 1) / 2,
  freq_heavy = c(3, 1) / 4,
  ml_heavy   = c(1, 3) / 4
)

ens_coefs <- sapply(names(configs), function(nm) {
  old_ens <- result_all$old %>%
    cross_c2c(c("wei_freq_c2c", "wei_knn_c2c"), configs[[nm]]) %>%
    filter(g_new_c2c == target_group)
  
  new_ens <- new_subset %>% mutate(wei_cross_c2c = 1)
  d <- bind_rows(old_ens, new_ens)
  d$w <- d$multiplier * d$wei_cross_c2c
  coef(lm(f, data = d, weights = w))
})
round(ens_coefs, 4)
```

When regression coefficients are stable across weight methods, pruning strategies, and ensemble compositions, report with confidence. When they diverge, the mapping introduces uncertainty - report the range or investigate the source.

## Step 3: Validate whether ML actually improves on simpler baselines

The `ml` argument in `cat2cat()` adds ML-based probability weights, but ML is not guaranteed to improve over simpler baselines. `cat2cat_ml_run()` provides per-group holdout (single train/test split) diagnostics to answer this question *before* committing to a method.

### What `cat2cat_ml_run()` is doing

For each mapping group (set of candidate categories linked by the transition table) `cat2cat_ml_run()`:

1. Collects all observations from `ml$data` whose category belongs to the group.
2. Randomly splits them into training (`1 - test_prop`) and test (`test_prop`) sets.
3. Computes two baselines on the test set:
   - **naive** - accuracy of a random guess ($1 / k$ where $k$ is the number of candidate categories).
   - **freq** - accuracy of always predicting the most frequent training-set category.
4. Trains each specified ML method on the training set and records test-set model performance.

Groups with fewer than 5 observations or only one candidate category are skipped.
Also note that `cat2cat_ml_run()` does not use `on_fail`; it is a diagnostic
tool and reports skipped groups instead of applying row-level fallback weights.

### Minimal validation workflow

```{r cv-basic}
cv_knn <- cat2cat_ml_run(
  mappings = list(trans = trans, direction = "backward"),
  ml = list(
    data = bind_rows(occup_2010, occup_2012),
    cat_var = "code",
    method = "knn",
    features = c("age", "sex", "edu", "exp", "parttime", "salary"),
    args = list(k = 10)
  )
)
print(cv_knn)
```

The `print()` summary reports:

- **ACCURACY** - average held-out classification accuracy across non-skipped groups. `naive (1/k)` is the random-guess baseline, `freq` is the majority-class baseline, and each ML line reports top-class accuracy for that method.
- **BRIER SCORE** - average full-vector probability error across non-skipped groups. Lower is better. This matters because `cat2cat` ultimately uses probability weights, not just hard classifications.
- **MEAN P(TRUE CLASS)** - average probability assigned to the true category. Higher is better. This is often the most directly relevant metric for `cat2cat`, because it measures the quality of the probability weights themselves.
- **ACCURACY: ML vs BASELINES** - the share of groups in which the ML method beats `naive` or beats `freq` on accuracy. This is a win-rate summary, not an average accuracy gap.
- **SKIPPED GROUPS** - the percentage of mapping groups for which that ML method has no reported result because the group had only one candidate category, fewer than 5 observations, or the model could not be fit for that group.

So for output like:

- `knn > naive: 87.7%`
- `knn > freq: 18.0%`
- `knn: accuracy = 0.5108` vs `freq (most common): 0.5366`

the right reading is: kNN clearly beats the naive baseline, but it does **not** beat the frequency baseline on top-class accuracy overall. In that case, `wei_freq_c2c` remains the default choice if your only goal is classification accuracy.

At the same time, if kNN has a slightly lower Brier score and a higher mean P(true class) than `freq`, then it may still be producing better-calibrated probability weights even though its top prediction is less often correct. That distinction matters in `cat2cat`, because the mapped weights are probabilities distributed across candidate categories rather than single-class assignments.

### Compare multiple ML methods in one run

```{r cv-multiple}
cv_all <- cat2cat_ml_run(
  mappings = list(trans = trans, direction = "backward"),
  ml = list(
    data = bind_rows(occup_2010, occup_2012),
    cat_var = "code",
    method = c("knn", "lda", "rf", "nb"),
    features = c("age", "sex", "edu", "exp", "parttime", "salary"),
    args = list(k = 10, ntree = 50)
  )
)
print(cv_all)
```

Interpretation tip for mixed outputs:

- It is possible for a method to have 0 failed rows in `cat2cat()` but a non-zero
  skipped-group rate in `cat2cat_ml_run()`.
- This is not a contradiction: the first is row-level weight construction, the
  second is group-level holdout evaluation coverage.

### Inspect per-group diagnostics when methods disagree

The returned object is a named list. Each element corresponds to one mapping group:

```{r cv-inspect}
# Pick a group with multiple candidates
group_names <- names(cv_all)
example_group <- group_names[
  which(vapply(cv_all, function(g) !is.na(g$freq) && g$naive < 1, logical(1)))[1]
]
cv_all[[example_group]]
```

Each group entry contains the group-level diagnostics behind the printed summary:

- `$naive` - $1/k$ random-guess accuracy for that group.
- `$freq` - majority-class accuracy for that group.
- `$acc` - named numeric vector with ML accuracy by method.
- `$naive_brier` and `$freq_brier` - baseline Brier scores.
- `$brier` - named numeric vector with ML Brier scores by method.
- `$naive_mean_prob` and `$freq_mean_prob` - baseline mean P(true class).
- `$mean_prob` - named numeric vector with ML mean P(true class) by method.

### Decision rules for interpreting the output

**Understanding model performance in context**: This is **multi-class classification** - each mapping group can have 3-10+ candidate categories. 
A naive random guess yields only ~18% accuracy (1/k where k is the number of candidates). 
Achieving 50%+ is substantial improvement over random - do not compare these numbers to binary classification benchmarks where 80%+ is typical. 
The key question is whether ML beats the *frequency* baseline, not whether it reaches some absolute threshold.

| Scenario | Recommendation |
|----------|---------------|
| ML model performance >> freq across most groups | ML weights add genuine signal; use them |
| ML model performance $\approx$ freq | ML is no better than frequency; prefer `wei_freq_c2c` (simpler, faster) |
| ML model performance < freq for many groups | ML is adding noise; do **not** use ML weights |
| High skipped-group rate (>20%) | Features may have too many missing values, groups are too small, or method fitting is unstable |

Because the train/test split is random, results vary between runs. For more stable estimates, pool more data into `ml$data` (e.g. multiple survey waves) or run `cat2cat_ml_run()` several times and average the summaries.

> **Caveat**: high `cat2cat_ml_run()` model performance means the model discriminates well *within* mapping groups. 
It does not validate the mapping table itself. A perfect model with a wrong transition table will still produce wrong results.
