Explicit, comprehensible covariance matrices in Java

As part of my code readability push in a recent work project (see also here), I’ve just tided up the code responsible for creating, updating, and referring to covariance matrices of patient characteristics and treatment effects. I first searched the web for examples of how covariance matrices had been implemented previously in Java and found the Covariance class in the Apache Commons Mathematics Library which, as it happens, we were already using for its MersenneTwister implementation and a few subclasses of AbstractRealDistribution.

However, looking at the data structures involved, the Covariance class will calculate and return a covariance matrix as an instance implementing the RealMatrix interface (and accepts the raw data as either a two-dimensional array of doubles or as a RealMatrix). Looking at the Covariance source code, the current implementation returns an instance of BlockRealMatrix, which by default is divided into a series of 52 x 52 blocks (so 2,704 double values per block), which are flattened in row major order into single dimensional arrays, which are themselves stored (again in row major order) in an outer array. This is clearly an extremely high performance approach, with three blocks designed to fit into 64Kb of L1 cache (3 × 2,704 × 8 / 1,024 = 63.4Kb), allowing block multiplication to be conducted entirely within the cache (i.e. C[i][j] += A[i][k] × B[k][j]).

This is all well and good, but passing 2D arrays or RealMatrix instances around runs contrary to my current goal of writing readable code. Not only that, but I want to store Pearson product-moment correlation coefficients alongside each covariance parameter so the model can decide later whether or not to covary parameters whose correlation is not significant. So I’d really need two 2D matrices to retrieve both the covariance and the p-value. Given that I’m dealing with relatively small matrices and not doing anything too computationally intensive with them, I decided to write a thin wrapper class around an EnumMap of EnumMaps, in which the Enum for the inner and outer maps consists of elements describing each covaried characteristic and the inner EnumMaps map each Enum element to a small tuple-like nested class that stores the covariance value and p-value in public instance variables.

Since the model uses a few different covariance matrices (including the aforementioned baseline patient characteristics and treatment effects, for example), I wrote a generic class that takes the Enum class in the constructor, such that it can be instantiated as follows:

CovarianceMatrix<PatientCharacteristic> matrix = new CovarianceMatrix<>(PatientCharacteristic.class);

The constructor then creates the 2D “matrix” of the nested, tuple-like ValuePValuePair classes with all covariance and p-values set to 0.0. This is fairly memory inefficient in that we immediately have n2 instances of ValuePValuePair (where n is the number of enumerable fields), but it helped to get this up and running quickly without hitting any NullPointerExceptions or having to perform null checks. Once the data structure’s set up, values can be added to the matrix as follows (with a couple of static final constants added at the top of the class to improve the readability of the code where it counts):

static final PatientCharacteristic HBA1C = PatientCharacteristic.HBA1C;
static final PatientCharacteristic TOTAL_CHOLESTEROL = PatientCharacteristic.TOTAL_CHOLESTEROL;

matrix.setCovarianceAndPValue(HBA1C, TOTAL_CHOLESTEROL, 0.28, 0.0001);

As a quick aside, it’s worth noting that, since the ability to refer to matrix elements numerically has (apparently) been lost, it initially looks as though a lot of strict matrix-like functionality may have been lost as well. However, since the EnumMap of EnumMap approach relies on the getEnumConstants() method (which always returns Enum elements in the order in which they’re specified) to generate the matrix, the numeric index of each row (and column) can be derived from any given Enum element as follows:

int index = java.util.Arrays.asList(PatientCharacteristic.class.getEnumConstants()).indexOf(PatientCharacteristic.HBA1C);

Using this approach, “lossless” conversion from a RealMatrix instance to this data structure and back would be possible (assuming initial knowledge of which rows and columns map to which characteristics), but even without the numeric references, this meets the current requirement, which is simply to use pre-calculated covariance matrices from SAS in Java in a clear and concise manner.

The full EnumMap wrapper class looks like this:

import java.util.EnumMap;

public class CovarianceMatrix<K extends Enum<K>> {
	
	private EnumMap<K, EnumMap<K, ValuePValuePair>> matrix;
	
	private static final class ValuePValuePair {
		public double value;
		public double pValue;

		public ValuePValuePair(double value, double pValue) {
			this.value = value;
			this.pValue = pValue;
		}
	}
	
	public CovarianceMatrix(Class<K> characteristics) {
		
		this.matrix = new EnumMap<>(characteristics);
		
		for (K initialCharacteristic : characteristics.getEnumConstants()) {
			for (K secondCharacteristic : characteristics.getEnumConstants()) {
				if (this.matrix.get(initialCharacteristic) != null) {
					this.matrix.get(initialCharacteristic).put(secondCharacteristic, new ValuePValuePair(0.0, 0.0));
				} else {
					EnumMap<K, ValuePValuePair> newValue = new EnumMap<>(characteristics);
					newValue.put(secondCharacteristic, new ValuePValuePair(0.0, 0.0));
					this.matrix.put(initialCharacteristic, newValue);
				}
			}
		}
	}
	
	public void setCovarianceAndPValue(K row, K column, double covariance, double pValue) {
		this.matrix.get(row).get(column).value = covariance;
		this.matrix.get(row).get(column).pValue = pValue;
	}

	public void setCovariance(K row, K column, double covariance) {
		this.matrix.get(row).get(column).value = covariance;
	}
	
	public void setCovariancePValue(K row, K column, double pValue) {
		this.matrix.get(row).get(column).pValue = pValue;
	}
	
	public double getCovariance(K row, K column) {
		return this.matrix.get(row).get(column).value;
	}
	
	public double getCovariancePValue(K row, K column) {
		return this.matrix.get(row).get(column).pValue;
	}
	
}

The next step is to perform some profiling on this in situ and probably switch over to instantiating the ValuePValuePair instances lazily with the appropriate null checks added in to the getters. But for now, the class is working very well and achieves the goal of making the clinical modelling portions of the code much more readable.

Leave a Reply

Your email address will not be published. Required fields are marked *