Pyro Reference
February 3, 2021
- Built with pytorch backend
- Usually involves specifying a generative process
- Usually either use Stochastic Variational Inference (SVI) optimization methods or Monte Carlo Markov Chain (MCMC) sampling methods
For SVI
- Define a model and a guide (variational distribution)
- guides define where the parameters are to be learnt
Example model and guide code
def model(is_cont_africa, ruggedness, log_gdp):
a = pyro.sample("a", dist.Normal(0., 10.))
mean = a
with pyro.plate("data", len(ruggedness)):
pyro.sample("obs", dist.Normal(mean, 0.05), obs=log_gdp)
def guide(is_cont_africa, ruggedness, log_gdp):
a_loc = pyro.param('a_loc', torch.tensor(0.))
a_scale = pyro.param('a_scale', torch.tensor(1.),
constraint=constraints.positive)
a = pyro.sample("a", dist.Normal(a_loc, a_scale))
mean = a
svi = SVI(model,
guide,
optim.Adam({"lr": .05}),
loss=Trace_ELBO())
For MCMC
from pyro.infer import MCMC, NUTS
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(is_cont_africa, ruggedness, log_gdp)
hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}
For sampling the posterior predictive distribution
from pyro.infer import Predictive
num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)
pyro.sample
indicates the random variable of a distributionpyro.param
indicates a parameterPyroModule
is the equivalent of thetorch.nn.module
Poutine
is the effect handling librarytrace
gets the execution trace. Is a graph data structurereplay
conditions sites in model to values sampled from guide trace