Pyro Reference

Pyro Reference

February 3, 2021

  • Built with pytorch backend
  • Usually involves specifying a generative process
  • Usually either use Stochastic Variational Inference (SVI) optimization methods or Monte Carlo Markov Chain (MCMC) sampling methods

For SVI

  • Define a model and a guide (variational distribution)
  • guides define where the parameters are to be learnt

Example model and guide code

def model(is_cont_africa, ruggedness, log_gdp):
    a = pyro.sample("a", dist.Normal(0., 10.))
    mean = a 
    with pyro.plate("data", len(ruggedness)):
        pyro.sample("obs", dist.Normal(mean, 0.05), obs=log_gdp)

def guide(is_cont_africa, ruggedness, log_gdp):
    a_loc = pyro.param('a_loc', torch.tensor(0.))
    a_scale = pyro.param('a_scale', torch.tensor(1.),
                         constraint=constraints.positive)
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    mean = a 

svi = SVI(model,
          guide,
          optim.Adam({"lr": .05}),
          loss=Trace_ELBO())

For MCMC

from pyro.infer import MCMC, NUTS

nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(is_cont_africa, ruggedness, log_gdp)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

For sampling the posterior predictive distribution

from pyro.infer import Predictive
num_samples = 1000
predictive = Predictive(model, guide=guide, num_samples=num_samples)
  • pyro.sample indicates the random variable of a distribution

  • pyro.param indicates a parameter

  • PyroModule is the equivalent of the torch.nn.module

  • Poutine is the effect handling library

    • trace gets the execution trace. Is a graph data structure
    • replay conditions sites in model to values sampled from guide trace