
**********************************************
CELL 1: Fit the MCMC Hierarhical Bayes Model
**********************************************


import os, numpy as np, pandas as pd, jax, jax.numpy as jnp
import statsmodels.api as sm
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
import arviz as az
from pathlib import Path

# ------------------------
# paths
# ------------------------
infile = DF_3
outfile = infile.with_name("hb_trace_wpg_only.nc")

# ------------------------
# env and sampler settings
# ------------------------
os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=8"
os.environ["OMP_NUM_THREADS"] = "1"

CHAINS = 6
WARMUP = 1500
DRAWS  = 2500
TARGET_ACCEPT = 0.9
SEED = 1234

# ------------------------
# load and clean
# ------------------------
df = pd.read_csv(infile)
req = ["w","games","z_wpg","mgrID"]
df = df.dropna(subset=req)
df = df[(df.games>0)&(df.w>=0)&(df.w<=df.games)]

mask = (df.w==0)|(df.w==df.games)
if mask.any():
    print(f"Removing {mask.sum()} perfect records")
    df = df.loc[~mask]

df["mgr_idx"], mgr_codes = pd.factorize(df.mgrID, sort=True)
J = len(mgr_codes)

w = df.w.to_numpy(np.int32)
games = df.games.to_numpy(np.int32)
z_wpg = df.z_wpg.to_numpy(np.float64)
idx = df.mgr_idx.to_numpy(np.int32)

print(f"rows={len(df)}, managers={J}")

# ------------------------
# glm start values
# ------------------------
y = (df.w + 0.5) / (df.games + 1)
X = sm.add_constant(df[["z_wpg"]])
glm = sm.GLM(y, X, family=sm.families.Binomial(), var_weights=df.games).fit()
start = glm.params.to_dict()

print(f"GLM starting values: {start}")

def batched_init(chains, J, start):
    return {
        "beta0": jnp.full((chains,), float(start.get("const", 0))),
        "beta_wpg": jnp.full((chains,), float(start.get("z_wpg", 0))),
        "tau": jnp.full((chains,), 0.1),
        "u_raw": jnp.zeros((chains, J)),
    }

init_params = batched_init(CHAINS, J, start)

# ------------------------
# model - z_wpg only
# ------------------------
def model(w, games, z_wpg, idx, J):
    beta0 = numpyro.sample("beta0", dist.Normal(0, 2))
    beta_wpg = numpyro.sample("beta_wpg", dist.Normal(0, 2))
    tau = numpyro.sample("tau", dist.HalfNormal(1))
    
    with numpyro.plate("managers", J):
        u_raw = numpyro.sample("u_raw", dist.Normal(0, 1))
    
    u = numpyro.deterministic("u", tau * u_raw)
    eta = beta0 + beta_wpg * z_wpg + u[idx]
    numpyro.sample("w_obs", dist.Binomial(total_count=games, logits=eta), obs=w)

# ------------------------
# run
# ------------------------
print("\n🚀 Running MCMC with z_wpg only...")
nuts = NUTS(model, target_accept_prob=TARGET_ACCEPT)
mcmc = MCMC(nuts, num_warmup=WARMUP, num_samples=DRAWS, num_chains=CHAINS,
            chain_method="vectorized", progress_bar=True)

rng = jax.random.PRNGKey(SEED)
mcmc.run(rng, w=jnp.array(w), games=jnp.array(games),
         z_wpg=jnp.array(z_wpg), idx=jnp.array(idx), J=J, 
         init_params=init_params)

mcmc.print_summary()

# ------------------------
# save results
# ------------------------
print("\n💾 Saving results...")
idata = az.from_numpyro(mcmc, coords={"manager": mgr_codes},
                        dims={"u": ["manager"], "u_raw": ["manager"]})
az.to_netcdf(idata, outfile)

summ = az.summary(idata, var_names=["beta0", "beta_wpg", "tau"], round_to=4)
summ.to_csv(infile.with_name("hb_summary_wpg_only.csv"))

u = idata.posterior["u"]
u_mean = u.mean(dim=("chain", "draw")).values
u_sd = u.std(dim=("chain", "draw")).values

pd.DataFrame({
    "mgrID": mgr_codes,
    "u_mean": u_mean,
    "u_sd": u_sd
}).to_csv(infile.with_name("hb_mgr_effects_wpg_only.csv"), index=False)

print("✅ done - z_wpg only model complete")









**********************************************
CELL 2: Extract the draws
**********************************************





import arviz as az
import numpy as np
import pandas as pd
from scipy.special import expit

# ========================================
# LOAD TRACE
# ========================================
trace = az.from_netcdf([your path]\hb_trace_wpg_only.nc')

# Quick overview
print(trace)
print("\n" + "="*50)
print("Posterior variables:")
print(list(trace.posterior.data_vars))
print("\n" + "="*50)
print("Dimensions:")
print(trace.posterior.dims)

# ========================================
# CHECK CONVERGENCE
# ========================================
print("\n" + "="*50)
print("CONVERGENCE DIAGNOSTICS")
print("="*50)

# Check convergence of main parameters
print(az.summary(trace, var_names=['beta0', 'beta_wpg', 'tau']))

# Check R-hat for all manager effects
rhat_u = az.rhat(trace, var_names=['u'])
print(f"\nManager effects R-hat range: {rhat_u['u'].min().values:.3f} to {rhat_u['u'].max().values:.3f}")
print(f"Managers with R-hat > 1.01: {(rhat_u['u'] > 1.01).sum().values}")

# ========================================
# EXTRACT DRAWS
# ========================================
print("\n" + "="*50)
print("EXTRACTING POSTERIOR DRAWS")
print("="*50)

# Extract draws in memory
u_draws = trace.posterior['u'].values  # (6, 2500, J)
u_flat = u_draws.reshape(-1, len(trace.posterior.coords['manager']))  # (15000, J)

# Get manager IDs from coordinates
mgr_ids = trace.posterior.coords['manager'].values
print(f"Manager IDs: {mgr_ids[:5]} ... {mgr_ids[-5:]}")
print(f"\nFirst 10 draws for manager 0:")
print(u_flat[:10, 0])

# Convert to probabilities at reference point (z_wpg = 0)
# logit(p) = beta0 + u
beta0_draws = trace.posterior['beta0'].values.flatten()
probs = expit(beta0_draws[:, np.newaxis] + u_flat)

print("\nProbabilities for manager 0 (at reference point):")
print(probs[:10, 0])

# ========================================
# CREATE LONG FORMAT DATASET
# ========================================
print("\n" + "="*50)
print("CREATING LONG FORMAT DATASET")
print("="*50)

n_draws = u_flat.shape[0]
n_managers = u_flat.shape[1]

df_long = pd.DataFrame({
    'draw': np.repeat(np.arange(n_draws), n_managers),
    'mgrID': np.tile(mgr_ids, n_draws),
    'u': u_flat.flatten(),
    'prob': probs.flatten()
})

# Add wins per 162 games
df_long['w162'] = (df_long['prob'] - 0.5) * 162

print(f"Dataset shape: {df_long.shape}")
print(f"\nFirst 10 rows:")
print(df_long.head(10))

# ========================================
# SAVE FILES
# ========================================
print("\n" + "="*50)
print("SAVING FILES")
print("="*50)

# Save  CSV
   
df_long.to_csv(r'[your_path]\manager_posterior_draws_long_wpg_only.csv', index=False)


print("✓ Saved manager_posterior_draws_long_wpg_only.csv")
print(f"\n✓ Complete! {len(df_long):,} rows × {len(df_long.columns)} columns")