Skip to main content

Fun with the proto package: building an MCMC sampler for Bayesian regression

The proto package is my latest favourite R goodie. It brings prototype-based programming to the R language - a style of programming that lets you do many of the things you can do with classes, but with a lot less up-front work. Louis Kates and Thomas Petzoldt provide an excellent introduction to using proto in the package vignette.

As a learning exercise I concocted the example below involving Bayesian logistic regression. It was inspired by an article on Matt Shotwell's blog about using R environments rather than lists to store the state of a Markov Chain Monte Carlo sampler. Here I use proto to create a parent class-like object (or trait in proto-ese) to contain the regression functions and create child objects to hold both data and results for individual analyses.

First here's an example session...

# Make up some data with a continuous predictor and binary response
nrec <- 500
x <- rnorm(nrec)
y <- rbinom(nrec, 1, plogis(2 - 4*x))

# Predictor matrix with a col of 1s for intercept
pred <- matrix(c(rep(1, nrec), x), ncol=2)
colnames(pred) <- c("intercept", "X")

# Load the proto package
library(proto)

# Use the Logistic parent object to create a child object which will
# hold the data and run the regression (the $ operator references
# functions and data within a proto object)
lr <- Logistic$new(pred, y)
lr$run(5000, 1000)

# lr now contains both data and results
str(lr)

proto object
$ cov : num [1:2, 1:2] 0.05 -0.0667 -0.0667 0.1621
..- attr(*, "dimnames")=List of 2
$ prior.cov: num [1:2, 1:2] 100 0 0 100
$ prior.mu : num [1:2] 0 0
$ beta : num [1:5000, 1:2] 2.09 2.09 2.09 2.21 2.21 ...
..- attr(*, "dimnames")=List of 2
$ adapt : num 1000
$ y : num [1:500] 0 1 1 1 1 1 1 1 1 1 ...
$ x : num [1:500, 1:2] 1 1 1 1 1 1 1 1 1 1 ...
..- attr(*, "dimnames")=List of 2
parent: proto object

# Use the Logistic summary function to tabulate and plot results
lr$summary()

From 5000 samples after 1000 iterations burning in
intercept X
Min. :1.420 Min. :-5.296
1st Qu.:1.840 1st Qu.:-3.915
Median :2.000 Median :-3.668
Mean :1.994 Mean :-3.693
3rd Qu.:2.128 3rd Qu.:-3.455
Max. :2.744 Max. :-2.437



And here's the code for the Logistic trait...

Logistic <- proto()

Logistic$new <- function(., x, y) {
# Creates a child object to hold data and results
#
# x - a design matrix (ie. predictors
# y - a binary reponse vector

proto(., x=x, y=y)
}

Logistic$run <- function(., niter, adapt=1000) {
# Perform the regression by running the MCMC
# sampler
#
# niter - number of iterations to sample
# adapt - number of prior iterations to run
# for the 'burning in' period

require(mvtnorm)

# Set up variables used by the sampler
.$adapt <- adapt
total.iter <- niter + adapt
.$beta <- matrix(0, nrow=total.iter, ncol=ncol(.$x))
.$prior.mu <- rep(0, ncol(.$x))
.$prior.cov <- diag(100, ncol(.$x))
.$cov <- diag(ncol(.$x))

# Run the sampler
b <- rep(0, ncol(.$x))
for (i in 1:total.iter) {
b <- .$update(i, b)
.$beta[i,] <- b
}

# Trim the results matrix to remove the burn-in
# period
if (.$adapt > 0) {
.$beta <- .$beta[(.$adapt + 1):total.iter,]
}
}

Logistic$update <- function(., it, beta.old) {
# Perform a single iteration of the MCMC sampler using
# Metropolis-Hastings algorithm.
# Adapted from code by Brian Neelon published at:
# http://www.duke.edu/~neelo003/r/
#
# it - iteration number
# beta.old - vector of coefficient values from
# the previous iteration

# Update the coefficient covariance if we are far
# enough through the sampling
if (.$adapt > 0 & it > 2 * .$adapt) {
.$cov <- cov(.$beta[(it - .$adapt):(it - 1),])
}

# generate proposed new coefficient values
beta.new <- c(beta.old + rmvnorm(1, sigma=.$cov))

# calculate prior and current probabilities and log-likelihood
if (it == 1) {
.$..log.prior.old <- dmvnorm(beta.old, .$prior.mu, .$prior.cov, log=TRUE)
.$..probs.old <- plogis(.$x %*% beta.old)
.$..LL.old <- sum(log(ifelse(.$y, .$..probs.old, 1 - .$..probs.old)))
}
log.prior.new <- dmvnorm(beta.new, .$prior.mu, .$prior.cov, log=TRUE)
probs.new <- plogis(.$x %*% beta.new)
LL.new <- sum(log(ifelse(.$y, probs.new, 1-probs.new)))

# Metropolis-Hastings acceptance ratio (log scale)
ratio <- LL.new + log.prior.new - .$..LL.old - .$..log.prior.old

if (log(runif(1)) < ratio) {
.$..log.prior.old <- log.prior.new
.$..probs.old <- probs.new
.$..LL.old <- LL.new
return(beta.new)
} else {
return(beta.old)
}
}

Logistic$summary <- function(., show.plot=TRUE) {
# Summarize the results

cat("From", nrow(.$beta), "samples after", .$adapt, "iterations burning in\n")
base::print(base::summary(.$beta))

if (show.plot) {
par(mfrow=c(1, ncol(.$beta)))
for (i in 1:ncol(.$beta)) {
plot(density(.$beta[,i]), main=colnames(.$beta)[i])
}
}
}


Now that's probably not the greatest design in the world, but it only took me a few minutes to put it together and it's incredibly easy to modify or extend. Try it !

Thanks to Brian Neelon for making his MCMC logistic regression code available (http://www.duke.edu/~neelo003/r/).

Comments

Popular posts from this blog

Fitting an ellipse to point data

Some time ago I wrote an R function to fit an ellipse to point data, using an algorithm developed by Radim Halíř and Jan Flusser1 in Matlab, and posted it to the r-help list. The implementation was a bit hacky, returning odd results for some data. A couple of days ago, an email arrived from John Minter asking for a pointer to the original code. I replied with a link and mentioned that I'd be interested to know if John made any improvements to the code. About ten minutes later, John emailed again with a much improved version ! Not only is it more reliable, but also more efficient. So with many thanks to John, here is the improved code: fit.ellipse <- function (x, y = NULL) { # from: # http://r.789695.n4.nabble.com/Fitting-a-half-ellipse-curve-tp2719037p2720560.html # # Least squares fitting of an ellipse to point data # using the algorithm described in: # Radim Halir & Jan Flusser. 1998. # Numerically stable direct least squares fitting of ellipses. …

packcircles version 0.2.0 released

Version 0.2.0 of the packcircles package has just been published on CRAN. This package provides functions to find non-overlapping arrangements of circles in bounded and unbounded areas. The package how has a new circleProgressiveLayout function. It uses an efficient deterministic algorithm to arrange circles by consecutively placing each one externally tangent to two previously placed circles while avoiding overlaps. It was adapted from a version written in C by Peter Menzel who lent his support to creating this R/Rcpp version. The underlying algorithm is described in the paper: Visualization of large hierarchical data by circle packing by Weixin Wang, Hui Wang, Guozhong Dai, and Hongan Wang. Published in Proceedings of the SIGCHI Conference on Human Factors in Computing Systems, 2006, pp. 517-520 (Available from ACM). Here is a small example of the new function, taken from the package vignette: library(packcircles) library(ggplot2) t <- theme_bw() + theme(panel.grid = el…

Build an application plus a separate library uber-jar using Maven

I've been working on a small Java application with a colleague to simulate animal movements and look at the efficiency of different survey methods. It uses the GeoTools library to support map projections and shapefile output. GeoTools is great but comes at a cost in terms of size: the jar for our little application alone is less than 50kb but bundling it with GeoTools and its dependencies blows that out to 20Mb.

The application code has been changing on a daily basis as we explore ideas, add features and fix bugs. Working with my colleague at a distance, over a fairly feeble internet connection, I wanted to package the static libraries and the volatile application into separate jars so that he only needed to download the former once (another option would have been for my colleague to set up a local Maven repository but for various reasons this was impractical).

A slight complication with bundling GeoTools modules into a single jar (aka uber-jar) is that individual modules make ext…