## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>",
  fig.width = 10,
  fig.height = 5.5,
  warning = FALSE,
  message = FALSE
)

## ----load-packages------------------------------------------------------------
library(HHBayes)
library(dplyr)
library(rstan)
library(ggpubr)

## ----setup-dates--------------------------------------------------------------
study_start <- "2024-07-01"
study_end   <- "2025-06-30"

cat("Study duration:",
    as.integer(as.Date(study_end) - as.Date(study_start)) + 1, "days\n")

## ----setup-surveillance-------------------------------------------------------
dates_weekly <- seq(from = as.Date(study_start),
                    to = as.Date(study_end), by = "week")

surveillance_data <- data.frame(
  date  = dates_weekly,
  cases = 0.1 +
    100 * exp(-0.0002 * (as.numeric(dates_weekly - mean(dates_weekly)))^2) +
    abs(rnorm(length(dates_weekly), mean = 0, sd = 10))
)

plot(surveillance_data$date, surveillance_data$cases,
     type = "l", lwd = 2, col = "steelblue",
     xlab = "Date", ylab = "Weekly cases",
     main = "Synthetic Surveillance Curve")

## ----setup-contact------------------------------------------------------------
role_mixing_weights <- matrix(c(
# Target:  Infant Toddler Adult Elderly
           0.0,   0.5,    1.0,  0.5,    # Source: Infant
           0.5,   0.9,    0.7,  0.5,    # Source: Toddler
           1.0,   0.7,    0.6,  0.7,    # Source: Adult
           0.5,   0.5,    0.7,  0.0     # Source: Elderly
), nrow = 4, byrow = TRUE,
   dimnames = list(
     c("infant", "toddler", "adult", "elderly"),
     c("infant", "toddler", "adult", "elderly")))

role_mixing_weights

## ----setup-household----------------------------------------------------------
household_profile <- list(
  prob_adults   = c(0, 0, 1),        # P(0, 1, 2 adults) — always 2 parents
  prob_infant   = 1.0,               # P(infant present) — always 1 infant
  prob_siblings = c(0, 0.8, 0.2),    # P(0, 1, 2 toddlers) — 80% one, 20% two
  prob_elderly  = c(0.7, 0.1, 0.2)   # P(0, 1, 2 elderly) — 70% none
)

## ----setup-household-examples, eval = FALSE-----------------------------------
# # Nuclear Western family
# nuclear <- list(
#   prob_adults   = c(0.05, 0.15, 0.80),
#   prob_infant   = 0.5,
#   prob_siblings = c(0.40, 0.45, 0.15),
#   prob_elderly  = c(0.95, 0.04, 0.01)
# )
# 
# # Multi-generational Asian family
# multigenerational <- list(
#   prob_adults   = c(0, 0, 1),
#   prob_infant   = 1.0,
#   prob_siblings = c(0.05, 0.30, 0.65),
#   prob_elderly  = c(0.20, 0.50, 0.30)
# )

## ----setup-intervention-------------------------------------------------------
sim_config <- list(
  list(
    name      = "vacc_status",
    efficacy  = 0,             # Set to 0 for baseline (no effect)
    effect_on = "both",
    coverage  = list(
      infant  = 0,
      toddler = 0,
      adult   = 0,
      elderly = 0
    )
  )
)

## ----setup-intervention-real, eval = FALSE------------------------------------
# # Real vaccination scenario
# vacc_config <- list(
#   list(
#     name      = "vacc_status",
#     efficacy  = 0.5,           # 50% reduction in susceptibility and infectivity
#     effect_on = "both",
#     coverage  = list(
#       infant  = 0.00,          # Not eligible
#       toddler = 0.30,
#       adult   = 0.80,
#       elderly = 0.90
#     )
#   ),
#   # You can add multiple interventions — they stack multiplicatively
#   list(
#     name      = "masked",
#     efficacy  = 0.3,
#     effect_on = "both",
#     coverage  = list(infant = 0, toddler = 0.1, adult = 0.7, elderly = 0.6)
#   )
# )

## ----simulate-----------------------------------------------------------------
sim_res <- simulate_multiple_households_comm(
  n_households    = 50,
  viral_testing   = "viral load",
  model_type      = "ODE",
  infectious_shape = 10,
  infectious_scale = 1,
  waning_shape    = 6,
  waning_scale    = 10,
  surveillance_interval = 4,
  start_date      = study_start,
  end_date        = study_end,
  surveillance_df = surveillance_data,
  covariates_config      = sim_config,
  household_profile_list = household_profile,
  role_mixing_matrix     = role_mixing_weights,
  seed = 123
)

## ----inspect-hh---------------------------------------------------------------
head(sim_res$hh_df)

## ----inspect-hh-str-----------------------------------------------------------
cat("Columns:", paste(names(sim_res$hh_df), collapse = ", "), "\n")
cat("Total person-episodes:", nrow(sim_res$hh_df), "\n")
cat("Unique people:", nrow(distinct(sim_res$hh_df, hh_id, person_id)), "\n")

## ----inspect-diag-------------------------------------------------------------
head(sim_res$diagnostic_df)

## ----inspect-diag-str---------------------------------------------------------
cat("Columns:", paste(names(sim_res$diagnostic_df), collapse = ", "), "\n")
cat("Total test records:", nrow(sim_res$diagnostic_df), "\n")
cat("Positive tests:", sum(sim_res$diagnostic_df$test_result), "\n")

## ----attack-rates-------------------------------------------------------------
rates <- summarize_attack_rates(sim_res)

## ----ar-overall---------------------------------------------------------------
print(rates$primary_overall)

## ----ar-by-role---------------------------------------------------------------
print(rates$primary_by_role)

## ----ar-reinf-----------------------------------------------------------------
print(rates$reinf_overall)
print(rates$reinf_by_role)

## ----epi-curve, fig.width=10, fig.height=6------------------------------------
my_plot <- plot_epidemic_curve(
  sim_res,
  surveillance_data,
  start_date_str = study_start,
  bin_width = 7   # Weekly bins
)
print(my_plot)

## ----stan-join----------------------------------------------------------------
# Extract a 1-row-per-person covariate table
person_covariates <- sim_res$hh_df %>%
  dplyr::select(hh_id, person_id, vacc_status) %>%
  dplyr::distinct()

# Merge into diagnostic data
df_for_stan <- sim_res$diagnostic_df %>%
  dplyr::left_join(person_covariates, by = c("hh_id", "person_id"))

cat("Rows in df_for_stan:", nrow(df_for_stan), "\n")
cat("Columns:", paste(names(df_for_stan), collapse = ", "), "\n")

## ----stan-priors--------------------------------------------------------------
my_priors <- list(
  beta1      = list(dist = "normal",    params = c(-5, 1)),
  beta2      = list(dist = "normal",    params = c(-5, 1)),
  alpha      = list(dist = "normal",    params = c(-4, 1)),
  covariates = list(dist = "normal",    params = c(0, 2)),
  gen_shape  = list(dist = "lognormal", params = c(1.5, 0.5)),
  gen_rate   = list(dist = "lognormal", params = c(0.0, 0.5)),
  ct50       = list(dist = "normal",    params = c(3, 1)),
  slope      = list(dist = "lognormal", params = c(0.4, 0.5))
)

## ----stan-vl-params-----------------------------------------------------------
VL_params_list <- list(
  adult   = list(v_p = 4.14, t_p = 5.09, lambda_g = 2.31, lambda_d = 2.71),
  infant  = list(v_p = 5.84, t_p = 4.09, lambda_g = 2.82, lambda_d = 1.01),
  toddler = list(v_p = 5.84, t_p = 4.09, lambda_g = 2.82, lambda_d = 1.01),
  elderly = list(v_p = 2.95, t_p = 5.10, lambda_g = 3.15, lambda_d = 0.87)
)

## ----stan-prepare-------------------------------------------------------------
stan_input <- prepare_stan_data(
  df_clean          = df_for_stan,
  surveillance_df   = surveillance_data,
  study_start_date  = as.Date(study_start),
  study_end_date    = as.Date(study_end),
  use_vl_data       = 1,
  model_type        = "ODE",
  ODE_params_list   = NULL,
  covariates_susceptibility = NULL,  # No covariates in this baseline
  covariates_infectivity    = NULL,
  imputation_params = VL_params_list,
  priors            = my_priors,
  role_mixing_matrix = role_mixing_weights
)

cat("Stan data prepared successfully.\n")
cat("Number of individuals (N):", stan_input$N, "\n")
cat("Number of time steps (T):", stan_input$T, "\n")
cat("Number of households (H):", stan_input$H, "\n")
cat("Number of roles (R):", stan_input$R, "\n")

## ----fit-model, eval = FALSE--------------------------------------------------
# options(mc.cores = parallel::detectCores())
# 
# fit <- fit_household_model(
#   stan_input,
#   pars    = c("log_phi_by_role_raw",
#               "log_kappa_by_role_raw",
#               "log_beta1",
#               "log_beta2",
#               "log_alpha_comm",
#               "g_curve_est",
#               "V_term_calc"),
#   include = FALSE,    # Exclude these internal parameters from saved output
#   iter    = 1000,     # For testing; use 2000+ for real analysis
#   warmup  = 500,      # For testing; use 1000+ for real analysis
#   chains  = 1         # For testing; use 4 for real analysis
# )

## ----fit-model-real, eval = FALSE---------------------------------------------
# fit <- fit_household_model(
#   stan_input,
#   pars    = c("log_phi_by_role_raw", "log_kappa_by_role_raw",
#               "log_beta1", "log_beta2", "log_alpha_comm",
#               "g_curve_est", "V_term_calc"),
#   include = FALSE,
#   iter    = 2000,
#   warmup  = 1000,
#   chains  = 4
# )

## ----fit-summary, eval = FALSE------------------------------------------------
# print(fit, probs = c(0.025, 0.5, 0.975))

## ----fit-diagnostics, eval = FALSE--------------------------------------------
# # Quick check
# fit_summary <- rstan::summary(fit)$summary
# cat("Max Rhat:", max(fit_summary[, "Rhat"], na.rm = TRUE), "\n")
# cat("Min n_eff:", min(fit_summary[, "n_eff"], na.rm = TRUE), "\n")
# 
# # Trace plots (requires bayesplot)
# if (requireNamespace("bayesplot", quietly = TRUE)) {
#   bayesplot::mcmc_trace(fit, pars = c("beta1", "beta2", "alpha_comm"))
# }

## ----plot-posteriors, eval = FALSE--------------------------------------------
# p_post <- plot_posterior_distributions(fit)
# print(p_post)

## ----plot-covariates, eval = FALSE--------------------------------------------
# plot_covariate_effects(fit, stan_input)

## ----plot-chains, eval = FALSE------------------------------------------------
# chains <- reconstruct_transmission_chains(
#   fit,
#   stan_input,
#   min_prob_threshold = 0.001   # Keep links with >= 0.1% probability
# )
# 
# head(chains)

## ----plot-timeline, eval = FALSE----------------------------------------------
# p_hh <- plot_household_timeline(
#   chains,
#   stan_input,
#   target_hh_id = 11     # Plot household #11
# )
# print(p_hh)

## ----plot-timeline-filtered, eval = FALSE-------------------------------------
# p_hh_filtered <- plot_household_timeline(
#   chains, stan_input,
#   target_hh_id = 11,
#   prob_cutoff  = 0.05,
#   plot_width   = 11,
#   plot_height  = 7
# )
# print(p_hh_filtered)

## ----full-script, eval = FALSE------------------------------------------------
# library(HHBayes)
# library(dplyr)
# library(rstan)
# library(ggpubr)
# 
# # ==============================================================================
# # 1. SETUP
# # ==============================================================================
# 
# study_start <- "2024-07-01"
# study_end   <- "2025-06-30"
# 
# # Surveillance data
# dates_weekly <- seq(as.Date(study_start), as.Date(study_end), by = "week")
# surveillance_data <- data.frame(
#   date  = dates_weekly,
#   cases = 0.1 + 100 * exp(-0.0002 * (as.numeric(dates_weekly -
#     mean(dates_weekly)))^2) + abs(rnorm(length(dates_weekly), 0, 10))
# )
# 
# # Contact matrix
# role_mixing_weights <- matrix(c(
#   0.0, 0.5, 1.0, 0.5,
#   0.5, 0.9, 0.7, 0.5,
#   1.0, 0.7, 0.6, 0.7,
#   0.5, 0.5, 0.7, 0.0
# ), nrow = 4, byrow = TRUE,
# dimnames = list(
#   c("infant", "toddler", "adult", "elderly"),
#   c("infant", "toddler", "adult", "elderly")))
# 
# # Household profile
# household_profile <- list(
#   prob_adults   = c(0, 0, 1),
#   prob_infant   = 1.0,
#   prob_siblings = c(0, 0.8, 0.2),
#   prob_elderly  = c(0.7, 0.1, 0.2)
# )
# 
# # Intervention (baseline: no effect)
# sim_config <- list(
#   list(name = "vacc_status", efficacy = 0, effect_on = "both",
#        coverage = list(infant = 0, toddler = 0, adult = 0, elderly = 0))
# )
# 
# # ==============================================================================
# # 2. SIMULATE
# # ==============================================================================
# 
# sim_res <- simulate_multiple_households_comm(
#   n_households = 50, viral_testing = "viral load", model_type = "ODE",
#   infectious_shape = 10, infectious_scale = 1,
#   waning_shape = 6, waning_scale = 10,
#   surveillance_interval = 4,
#   start_date = study_start, end_date = study_end,
#   surveillance_df = surveillance_data,
#   covariates_config = sim_config,
#   household_profile_list = household_profile,
#   role_mixing_matrix = role_mixing_weights,
#   seed = 123
# )
# 
# rates <- summarize_attack_rates(sim_res)
# print(rates$primary_by_role)
# 
# plot_epidemic_curve(sim_res, surveillance_data,
#                     start_date_str = study_start, bin_width = 7)
# 
# # ==============================================================================
# # 3. PREPARE FOR STAN
# # ==============================================================================
# 
# person_covariates <- sim_res$hh_df %>%
#   select(hh_id, person_id, vacc_status) %>% distinct()
# 
# df_for_stan <- sim_res$diagnostic_df %>%
#   left_join(person_covariates, by = c("hh_id", "person_id"))
# 
# my_priors <- list(
#   beta1 = list(dist = "normal", params = c(-5, 1)),
#   beta2 = list(dist = "normal", params = c(-5, 1)),
#   alpha = list(dist = "normal", params = c(-4, 1)),
#   covariates = list(dist = "normal", params = c(0, 2)),
#   gen_shape = list(dist = "lognormal", params = c(1.5, 0.5)),
#   gen_rate  = list(dist = "lognormal", params = c(0.0, 0.5)),
#   ct50  = list(dist = "normal", params = c(3, 1)),
#   slope = list(dist = "lognormal", params = c(0.4, 0.5))
# )
# 
# VL_params_list <- list(
#   adult   = list(v_p = 4.14, t_p = 5.09, lambda_g = 2.31, lambda_d = 2.71),
#   infant  = list(v_p = 5.84, t_p = 4.09, lambda_g = 2.82, lambda_d = 1.01),
#   toddler = list(v_p = 5.84, t_p = 4.09, lambda_g = 2.82, lambda_d = 1.01),
#   elderly = list(v_p = 2.95, t_p = 5.10, lambda_g = 3.15, lambda_d = 0.87)
# )
# 
# stan_input <- prepare_stan_data(
#   df_clean = df_for_stan, surveillance_df = surveillance_data,
#   study_start_date = as.Date(study_start),
#   study_end_date = as.Date(study_end),
#   use_vl_data = 1, model_type = "ODE",
#   imputation_params = VL_params_list, priors = my_priors,
#   role_mixing_matrix = role_mixing_weights
# )
# 
# # ==============================================================================
# # 4. FIT MODEL
# # ==============================================================================
# 
# options(mc.cores = parallel::detectCores())
# 
# fit <- fit_household_model(
#   stan_input,
#   pars = c("log_phi_by_role_raw", "log_kappa_by_role_raw",
#            "log_beta1", "log_beta2", "log_alpha_comm",
#            "g_curve_est", "V_term_calc"),
#   include = FALSE,
#   iter = 1000, warmup = 500, chains = 1
# )
# 
# # ==============================================================================
# # 5. RESULTS
# # ==============================================================================
# 
# print(fit, probs = c(0.025, 0.5, 0.975))
# plot_posterior_distributions(fit)
# 
# chains <- reconstruct_transmission_chains(fit, stan_input,
#                                            min_prob_threshold = 0.001)
# plot_household_timeline(chains, stan_input, target_hh_id = 11)

## ----session-info-------------------------------------------------------------
sessionInfo()

