Revisiting Multilevel Regression and Poststratification with PyMC
A little over eight years ago, I published a post entitled MRPyMC3 - Multilevel Regression and Poststratification with PyMC3, showing how to perform multilevel regression with poststratification (MRP) in Python with PyMC. I periodically enjoy revisiting old posts after both technology and my understanding of the problem advances. This post revisits MRP with a few notable changes:
- PyMC is now on major version 5 instead of version 3,
- we use nutpie for sampling from the model instead of PyMC's built in sampler, and
- we use Polars instead of pandas.
We will not repeat the previous post's full exposition of MRP and will rather focus on the mechanics of its implementation.
First we import the necessary packages and do a bit of light configuration.
%matplotlib inline
%config InlineBackend.figure_format = "retina"
from itertools import zip_longest
import os
from urllib import request
import us
from zipfile import ZipFile
import arviz as az
from matplotlib import cm, pyplot as plt, ticker
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Normalize
import numpy as np
import nutpie
import polars as pl
import pymc as pm
from pyreadstat import read_dta
from pytensor import tensor as pt
import seaborn as sns
from seaborn import objects as so
from scipy.special import logit
sns.set_style("darkgrid", {"axes.linewidth": 1, "axes.edgecolor": "black"})
Load and transform the data¶
As in the previous post, we follow Jonathan Kastellec's excellent MRP Primer, which focuses on estimating state-level opinions of gay marriage in 2005/2006 from polling data.
First we download and decompress the data.
DATA_PATH = "./data"
DATA_URI = "https://jkastellec.scholar.princeton.edu/sites/g/files/toruqf3871/files/jkastellec/files/mrp_primer_replication_files.zip"
USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36"
if not os.path.isdir(DATA_PATH):
os.mkdir(DATA_PATH)
dest_path = os.path.join(DATA_PATH, os.path.basename(DATA_URI))
if not os.path.exists(dest_path):
opener = request.build_opener()
opener.addheaders = [("User-agent", USER_AGENT)]
request.install_opener(opener)
request.urlretrieve(DATA_URI, dest_path)
with ZipFile(dest_path) as src:
src.extractall(DATA_PATH)
Poll data¶
Next we load and do some light feature engineering on the polling data necessary to build the multilevel model.
UNZIPPED_DIR = "MRP_Primer_Replication_Files"
POLL_PATH = os.path.join(DATA_PATH, UNZIPPED_DIR, "gay_marriage_megapoll.dta")
POLL_COLS = [
"race_wbh",
"age_cat",
"edu_cat",
"female",
"region",
"state",
"poll",
"yes_of_all",
]
NN_COLS = ["race_wbh", "age_cat", "edu_cat"]
def to_zero_indexed(name):
col = pl.col(name)
return col - col.min()
CAT_COLS = ["age_cat", "edu_cat", "race_wbh"]
cat_col_transforms = [to_zero_indexed(name) for name in CAT_COLS]
GENDER = pl.Enum(["Male", "Female"])
POLL = pl.Enum(
[
"ABC 2004Jan15",
"Gall2004Mar05",
"Gall2005Aug22",
"Pew 2004Dec01",
"Pew 2004Feb11",
]
)
RACE = pl.Enum(["White", "Black", "Hispanic"])
REGION = pl.Enum(["dc", "midwest", "northeast", "south", "west"])
STATE = pl.Enum(sorted([state.abbr for state in us.states.STATES] + ["DC"]))
ENUM_CASTS = {
"female": (GENDER, "gender"),
"poll": (POLL, None),
"race_wbh": (RACE, "race"),
"region": (REGION, None),
"state": (STATE, None),
}
def cast_enum_cols(df):
for name, (enum, new_name) in ENUM_CASTS.items():
if name in df.columns:
df = df.with_columns(pl.col(name).cast(enum).alias(new_name or name))
if new_name is not None:
df = df.drop(name)
return df
poll_df = (
pl.from_pandas(read_dta(POLL_PATH)[0])
.select(POLL_COLS)
.drop_nulls(NN_COLS)
.with_columns(*cat_col_transforms)
.group_by(POLL_COLS[:-1])
.agg(pl.sum("yes_of_all"), poll_pop=pl.len().cast(pl.Int32))
.filter(pl.col("state") != "")
.pipe(cast_enum_cols)
)
poll_df
Each row represents a group of people polled:
-
age_catrepresents the age category of that group, -
edu_catrepresents the education category of that group, -
regionrepresents the region of the state in which that group resides, -
staterepresents the state in which that group resides, -
pollrepresents which poll surveyed that group, -
yes_of_allrepresents the number of respondents that indicated support for gay marriage, -
poll_poprepresents the number of respondents, -
genderrepresents the gender of that group, and -
racerepresents the race of that group.
Census data¶
Next we load and do some light feature engineering on the census data necessary for postratification.
CENSUS_PATH = os.path.join(DATA_PATH, UNZIPPED_DIR, "poststratification 2000.dta")
CENSUS_COLS = [
"race_wbh",
"age_cat",
"edu_cat",
"female",
"state",
"_freq",
"region",
]
census_df = (
pl.from_pandas(read_dta(CENSUS_PATH)[0])
.rename(lambda s: s.lstrip("c").lower())
.select(CENSUS_COLS)
.rename({"_freq": "pop"})
.with_columns(*cat_col_transforms)
.pipe(cast_enum_cols)
.sort("state")
)
census_df
Each row represents a group of people with the given age, education, state, gender, and race combination. All columns have the same meaning as in the poll data. The pop column contains the census population of that group.
State data¶
Finally, we load and do some light feature engineering on data about each state, which we will use, in addition to the poll data, to build our multilevel model.
STATE_PATH = os.path.join(DATA_PATH, UNZIPPED_DIR, "state_level_update.dta")
STATE_COLS = ["sstate", "p_evang", "p_mormon", "kerry_04"]
def to_percentage(name):
return pl.col(name) / 100
state_df = (
pl.from_pandas(read_dta(STATE_PATH)[0])
.select(STATE_COLS)
.rename(lambda name: name.replace("ss", "s"))
.with_columns(
to_percentage("p_evang"), to_percentage("p_mormon"), to_percentage("kerry_04")
)
.with_columns(p_relig=pl.col("p_evang") + pl.col("p_mormon"))
.drop("p_evang", "p_mormon")
.pipe(cast_enum_cols)
.join(census_df.group_by("state").agg(pl.first("region")), on="state")
.sort("state")
)
state_df
There is a row for each state, plus an additional one for the District of Columbia, with the following columns:
-
state- the state represented by that row, -
kerry_04- the proportion of the state's voters that voted for John Kerry in the 2004 presidential election, -
p_relig- the proportion of the state's population that identifies as evangelical Christian or Mormon, and -
region- the region of the country the state is in.
Exploratory data analysis¶
Now that we have all the necessary data, we visualize it in order to become familiar with it before building the model.
First we define a few functions that will facilitate intuitive plotting of state-level statistics.
STATE_GRID = [
["AK", None, None, None, None, None, None, None, None, None, "ME"],
[None, None, None, None, None, None, None, None, None, "VT", "NH"],
["WA", "ID", "MT", "ND", "MN", None, "MI", None, "NY", "MA", "RI"],
["OR", "UT", "WY", "ND", "IA", "WI", "OH", "PA", "NJ", "CT"],
["CA", "NV", "CO", "NE", "IL", "IN", "WV", "VA", "MD", "DE"],
[None, "AZ", "NM", "KS", "MO", "KY", "TN", "SC", "NC"],
[None, None, None, "OK", "LA", "AR", "MS", "AL", "GA", None, "DC"],
["HI", None, None, None, "TX", None, None, None, "FL"],
]
def plot_facecolor(df, col, state, *, ax, norm, cmap, default="gray"):
is_state = pl.col("state") == state
if df.select(is_state.any()).item():
color = cmap(norm(df.filter(is_state)[col]))
else:
color = default
ax.set_facecolor(color)
def plot_state_grid(
data, *, col, norm, cmap, default="gray", cbar=True, ax_plotter=plot_facecolor
):
nrows = len(STATE_GRID)
ncols = max(len(row) for row in STATE_GRID)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols)
for state_row, ax_row in zip(STATE_GRID, axes):
for state, ax in zip_longest(state_row, ax_row):
if state is None:
ax.axis("off")
else:
ax.set_xticks([])
ax.set_yticks([])
ax.annotate(state, xy=(0.1, 0.1), xycoords="axes fraction")
ax_plotter(
data, col, state, ax=ax, norm=norm, cmap=cmap, default=default
)
if cbar:
cbar_ax = fig.add_axes([0.925, 0.25, 0.02, 0.5])
ColorbarBase(cbar_ax, cmap=cmap, norm=norm)
else:
cbar_ax = None
return fig, axes, cbar_ax
Poll data¶
With the data and these tools in hand, we begin by plotting the number of poll respondents in each state.
fig, _, _ = plot_state_grid(
poll_df.group_by("state").agg(pl.sum("poll_pop")),
col="poll_pop",
norm=Normalize(0, 750),
cmap=cm.binary,
default="white",
)
fig.suptitle("Number of responses");
We see immediately that some states have very few respondents, and indeed that Alaska and Hawaii had no respondents.
(
poll_df.group_by("state")
.agg(pl.sum("poll_pop"))
.join(state_df, how="right", on="state")
.with_columns(pl.col("poll_pop").fill_null(0))
.sort("poll_pop")
.select("state", "poll_pop")
.head()
)
One of the benefits of MRP is the ability to use census data to predict the support for gay marriage in these states with little to no data based on similar respondents in other states.
Next we visualize the empirical support for gay marriage in each state.
PROB_CMAP = sns.diverging_palette(220, 10, as_cmap=True).reversed()
PROB_FORMATTER = ticker.StrMethodFormatter("{x:.1%}")
PROB_LOCATOR = ticker.MultipleLocator(0.2)
PROB_MIN, PROB_MAX = 0.1, 0.70
PROB_NORM = Normalize(0, PROB_MAX)
def make_prob_ax(ax):
cbar_ax.yaxis.set_major_locator(PROB_LOCATOR)
cbar_ax.yaxis.set_major_formatter(PROB_FORMATTER)
def rate(num, denom):
return pl.sum(num) / pl.sum(denom)
disagg_df = poll_df.group_by("state").agg(rate("yes_of_all", "poll_pop"))
fig, _, cbar_ax = plot_state_grid(
disagg_df,
col="yes_of_all",
norm=PROB_NORM,
cmap=PROB_CMAP,
)
make_prob_ax(cbar_ax)
fig.suptitle("Disaggregation estimate of\nsupport for gay marriage");
Note that the MRP literature, including the Kastellec primer we are following in this post, refer to this quantity as the "disaggregation estimate." We will use that language throughout the rest of this post for consistency. At the end of the post we will see how the MRP estimates of support for gay marriage differ from these disaggregation estimates. The difference between these estimates is largely due to the fact that the polled population is not (and cannot practically be) precisely reflective of the population of each state.
This challenge of representative sampling is further illustrated in the following plot.
(
so.Plot(
poll_df.join(
census_df, on=set(poll_df.columns) & set(census_df.columns), how="right"
)
.group_by("state", "gender", "race")
.agg(pl.col("poll_pop").fill_null(0).sum() == 0)
.group_by("gender", "race")
.agg(pl.sum("poll_pop").alias("state_ct")),
x="race",
y="state_ct",
color="gender",
)
.add(so.Bar(), so.Dodge())
.scale(y=so.Continuous().tick(every=5))
.label(
x="Race",
y="Number of states",
color="Gender",
title="States with no respondents",
)
)
While only two states have zero polled respondents, significantly more states have zero respondents from minority race/gender subpopulations. As in states with zero respondents, a benefit of MRP is the ability to predict support for gay marriage among unsampled subpopulations in these states based on attitudes from similar respondents in other statse.
Modeling¶
With this basic understanding, we now turn to building the multilevel model necessary for MRP.
First we define coordinates that will make our model easier to work with.
n_age = poll_df["age_cat"].n_unique()
n_edu = poll_df["edu_cat"].n_unique()
COORDS = {
"age": np.arange(n_age),
"edu": np.arange(n_edu),
"gender": GENDER.categories.to_numpy(),
"poll": POLL.categories.to_numpy(),
"race": RACE.categories.to_numpy(),
"region": REGION.categories.to_numpy(),
"state": STATE.categories.to_numpy(),
}
The model includes an intercept with a normal prior,
$$\beta_0 \sim N(0, 2.5^2).$$
with pm.Model(coords=COORDS) as model:
β0 = pm.Normal("β0", 0, 2.5)
Most of the terms of the model will be hierarchical normal. For example, the prior for the age effect is
$$ \begin{align} \sigma_{\text{age}} & \sim \text{Half-}N(2.5^2), \\ \beta_{\text{age}} & \sim N(0, \sigma_{\text{age}}^2). \end{align} $$
For sampling efficiency, we actually use an equivalent noncentered parameterization of the above priors. The priors for the effects of education, poll, age-education interaction, and gender-race interaction are defined similarly.
HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)
def noncentered_normal(name, *, dims, μ=0):
Δ = pm.Normal(f"Δ_{name}", 0, 1, dims=dims)
σ = pm.HalfNormal(f"σ_{name}", 2.5 * HALFNORMAL_SCALE)
return pm.Deterministic(name, μ + Δ * σ, dims=dims)
with model:
β_age = noncentered_normal("β_age", dims="age")
β_edu = noncentered_normal("β_edu", dims="edu")
β_poll = noncentered_normal("β_poll", dims="poll")
β_age_edu = noncentered_normal("β_age_edu", dims=("age", "edu"))
β_gender_race = noncentered_normal("β_gender_race", dims=("gender", "race"))
The state-level effects are slightly more complex, combining a hierarchical normal region-level effect with the effect of the proportion of its residents that identify as religious and that of the state's support for John Kerry in 2024, . We have
$$ \begin{align} \sigma_{\text{region}} & \sim \text{Half-}N(2.5^2), \\ \alpha_{\text{region}} & \sim N(0, \sigma_{\text{region}}^2) \\ \alpha_{\text{relig}}, \alpha_{\text{kerry}} & \sim N(0, 2.5^2), \\ \beta_{\text{state}} & = \alpha_{\text{region}} + \alpha_{\text{relig}} \cdot x_{\text{relig}} + \alpha_{\text{kerry}} \cdot x_{\text{kerry}}. \end{align} $$
Here $x_{\text{relig}}$ is the proportion of the state's residents that identify as religious and $x_{\text{kerry}}$ is the proportion that voted for John Kerry in 2024.
kerry_04 = state_df["kerry_04"].to_numpy()
p_relig = state_df["p_relig"].to_numpy()
state_region = state_df["region"].cast(pl.Categorical).to_physical().to_numpy()
with model:
α_region = noncentered_normal("α_region", dims="region")
α_relig = pm.Normal("α_relig", 0, 2.5)
α_kerry = pm.Normal("α_kerry", 0, 2.5)
β_state = (
α_region[state_region] + α_relig * logit(p_relig) + α_kerry * logit(kerry_04)
)
Before using these effects to define the likelihood of the observed polling data, we build a number of data containers that will facilitate the posterior predictive sampling that will be necessary to perform poststratification.
with model:
age = pm.Data("age", poll_df["age_cat"].to_numpy())
edu = pm.Data("edu", poll_df["edu_cat"].to_numpy())
gender = pm.Data(
"gender", poll_df["gender"].to_physical().to_numpy().astype(np.int_)
)
poll = pm.Data("poll", poll_df["poll"].to_physical().to_numpy().astype(np.int_))
pop = pm.Data("pop", poll_df["poll_pop"].to_numpy())
race = pm.Data("race", poll_df["race"].to_physical().to_numpy().astype(np.int_))
state = pm.Data("state", poll_df["state"].to_physical().to_numpy())
yes_of_all = pm.Data("yes_of_all", poll_df["yes_of_all"].to_numpy())
use_poll = pm.Data("use_poll", True)
Finally we define
$$\eta = \beta_0 + \beta_{\text{age}} + \beta_{\text{edu}} + \beta_{\text{state}} + \beta_{\text{age, edu}} + \beta_{\text{gender, race}} + \beta_{\text{poll}}.$$
Note that to support poststratification, we include the ability to turn off the effect of the poll on the model.
The likelihood is binomial with log-odds $\eta$.
def adv_index(tensor, indices, shape):
return tensor.ravel()[pt.ravel_multi_index(indices, shape)]
n_gender = GENDER.categories.shape[0]
n_race = RACE.categories.shape[0]
with model:
η = sum(
[
β0,
β_age[age],
β_edu[edu],
β_state[state],
adv_index(β_age_edu, (age, edu), (n_age, n_edu)),
adv_index(β_gender_race, (gender, race), (n_gender, n_race)),
pt.switch(use_poll, β_poll[poll], 0),
]
)
pm.Binomial("response", pop, pt.sigmoid(η), observed=yes_of_all)
We now sample from the posterior distribution of this model.
SEED = 123456789 # for reproducibility
CHAINS = 8
SAMPLER_KWARGS = {
"seed": SEED,
"chains": CHAINS,
"cores": CHAINS,
"target_accept": 0.99,
}
trace = nutpie.sample(nutpie.compile_pymc_model(model), **SAMPLER_KWARGS)
The Gelman-Rubin statistic, $\hat{R}$, shows no cause for concern.
az.rhat(trace).max().to_array().max()
Poststratification¶
We now sample from the model's posterior predictive distribution in order to perform poststratification. To do so, we set the data containers we created above to the values from the census. This step is what allows us to adjust for the non-representativeness of the poll's respondents.
n_census, _ = census_df.shape
with model:
pm.set_data(
{
"age": census_df["age_cat"].to_numpy(),
"edu": census_df["edu_cat"].to_numpy(),
"gender": census_df["gender"].to_physical().to_numpy().astype(np.int_),
"poll": np.zeros(n_census, dtype=np.int_),
"pop": census_df["pop"].to_numpy(),
"race": census_df["race"].to_physical().to_numpy().astype(np.int_),
"state": census_df["state"].to_physical().to_numpy(),
"yes_of_all": np.zeros(n_census, dtype=np.int_),
"use_poll": False,
}
)
pm.sample_posterior_predictive(trace, extend_inferencedata=True)
We use these posterior predictive samples to construct the poststratification estimates of each state's support for gay marriagne.
poststrat_df = (
census_df.with_columns(
pp_yes_of_all=trace.posterior_predictive["response"]
.mean(dim=("chain", "draw"))
.to_numpy()
)
.group_by("state")
.agg(rate("pp_yes_of_all", "pop").alias("yes_of_all"))
)
Now compare the disaggregation and MRP estimates of support for gay marriage. The disaggregation estimate colors the upper left corner of each state's rectangle, and the MRP estimate colors the lower right corner.
def plot_split_color(dfs, col, state, *, ax, norm, cmap, default="gray"):
ul_df, lr_df = dfs
is_state = pl.col("state") == state
xmin, xmax = ax.get_xlim()
ymin, ymax = ax.get_ylim()
# upper left
if ul_df.select(is_state.any()).item():
ul_color = cmap(norm(ul_df.filter(is_state)[col]))
else:
ul_color = default
ax.add_patch(plt.Polygon([[0, 0], [xmax, ymax], [0, ymax]], color=ul_color))
# lower right
if lr_df.select(is_state.any()).item():
lr_color = cmap(norm(lr_df.filter(is_state)[col]))
else:
lr_color = default
ax.add_patch(plt.Polygon([[0, 0], [xmax, ymax], [xmax, 0]], color=lr_color))
fig, _, cbar_ax = plot_state_grid(
(
poll_df.group_by("state").agg(rate("yes_of_all", "poll_pop")),
poststrat_df,
),
col="yes_of_all",
norm=PROB_NORM,
cmap=PROB_CMAP,
ax_plotter=plot_split_color,
)
make_prob_ax(cbar_ax)
fig.suptitle("Disaggregation and MRP estimates\nof support for gay marriage");
We see that most states support has moved towards the average with the MRP estimate. Also, MRP produces estimates for Alaska and Hawaii based on their census data, whereas disaggregation cannot.
Finally, we produce another comparison of the two estimates state-by-state.
state_pop = census_df.group_by("state").agg(pl.sum("pop")).sort("state")
state_pop_ = state_pop["pop"].to_numpy().squeeze()
poststrat_state_df = (
trace.posterior_predictive["response"]
.rename(response_dim_2="state")
.assign_coords(state=census_df["state"])
.groupby("state")
.sum()
.pipe(lambda x: x / state_pop_)
)
sorted_states = STATE.categories[
poststrat_state_df.mean(dim=("chain", "draw")).argsort().to_numpy()
]
(ax,) = az.plot_forest(
poststrat_state_df,
coords={"state": sorted_states[::-1]},
combined=True,
labeller=az.labels.NoVarLabeller(),
figsize=(6, 12),
)
disagg_scatter_df = (
disagg_df.join(
pl.DataFrame(
{
"state": sorted_states.cast(STATE),
"sort_key": np.arange(sorted_states.shape[0]),
"y": ax.get_yticks(),
}
),
on="state",
)
.join(state_pop, on="state")
.sort("sort_key")
)
pop_norm = Normalize(0, disagg_scatter_df["pop"].max())
ax.scatter(
disagg_scatter_df["yes_of_all"],
disagg_scatter_df["y"],
s=200 * pop_norm(disagg_scatter_df["pop"]),
c="k",
label="Disaggregation\nestimate",
zorder=5,
)
ax.axvline(
(poststrat_state_df.mean(dim=("chain", "draw")) * state_pop_).sum()
/ state_pop_.sum(),
ls="--",
c="k",
label="National average",
)
ax.xaxis.set_major_formatter(PROB_FORMATTER)
ax.set_xlabel("MRP estimate of support for gay marriage")
ax.legend(loc="upper left");
Here the size of the black circles corresponding to the disaggregation estimate correspond to the state's population. We see, by and large, that states with higher population have similar disaggregation and MRP estimates, and states with smaller population have MRP estimates closer to the national (disaggregation) mean than their disaggregation estimates. Notable exceptions to this trend are New York, Florida, and Pennsylvania, among others. These states's deviation from the trend indicates that the respondent population was particularly unrepresentative of the state's demographics.
This post is available as a Jupyter notebook here.
%load_ext watermark
%watermark -n -u -v -iv