Explicit, comprehensible covariance matrices in Java

As part of my code readability push in a recent work project (see also here), I’ve just tided up the code responsible for creating, updating, and referring to covariance matrices of patient characteristics and treatment effects. I first searched the web for examples of how covariance matrices had been implemented previously in Java and found the Covariance class in the Apache Commons Mathematics Library which, as it happens, we were already using for its MersenneTwister implementation and a few subclasses of AbstractRealDistribution.

However, looking at the data structures involved, the Covariance class will calculate and return a covariance matrix as an instance implementing the RealMatrix interface (and accepts the raw data as either a two-dimensional array of doubles or as a RealMatrix). Looking at the Covariance source code, the current implementation returns an instance of BlockRealMatrix, which by default is divided into a series of 52 x 52 blocks (so 2,704 double values per block), which are flattened in row major order into single dimensional arrays, which are themselves stored (again in row major order) in an outer array. This is clearly an extremely high performance approach, with three blocks designed to fit into 64Kb of L1 cache (3 × 2,704 × 8 / 1,024 = 63.4Kb), allowing block multiplication to be conducted entirely within the cache (i.e. C[i][j] += A[i][k] × B[k][j]).

This is all well and good, but passing 2D arrays or RealMatrix instances around runs contrary to my current goal of writing readable code. Not only that, but I want to store Pearson product-moment correlation coefficients alongside each covariance parameter so the model can decide later whether or not to covary parameters whose correlation is not significant. So I’d really need two 2D matrices to retrieve both the covariance and the p-value. Given that I’m dealing with relatively small matrices and not doing anything too computationally intensive with them, I decided to write a thin wrapper class around an EnumMap of EnumMaps, in which the Enum for the inner and outer maps consists of elements describing each covaried characteristic and the inner EnumMaps map each Enum element to a small tuple-like nested class that stores the covariance value and p-value in public instance variables.

Since the model uses a few different covariance matrices (including the aforementioned baseline patient characteristics and treatment effects, for example), I wrote a generic class that takes the Enum class in the constructor, such that it can be instantiated as follows:

CovarianceMatrix<PatientCharacteristic> matrix = new CovarianceMatrix<>(PatientCharacteristic.class);

The constructor then creates the 2D “matrix” of the nested, tuple-like ValuePValuePair classes with all covariance and p-values set to 0.0. This is fairly memory inefficient in that we immediately have n2 instances of ValuePValuePair (where n is the number of enumerable fields), but it helped to get this up and running quickly without hitting any NullPointerExceptions or having to perform null checks. Once the data structure’s set up, values can be added to the matrix as follows (with a couple of static final constants added at the top of the class to improve the readability of the code where it counts):

static final PatientCharacteristic HBA1C = PatientCharacteristic.HBA1C;
static final PatientCharacteristic TOTAL_CHOLESTEROL = PatientCharacteristic.TOTAL_CHOLESTEROL;

matrix.setCovarianceAndPValue(HBA1C, TOTAL_CHOLESTEROL, 0.28, 0.0001);

As a quick aside, it’s worth noting that, since the ability to refer to matrix elements numerically has (apparently) been lost, it initially looks as though a lot of strict matrix-like functionality may have been lost as well. However, since the EnumMap of EnumMap approach relies on the getEnumConstants() method (which always returns Enum elements in the order in which they’re specified) to generate the matrix, the numeric index of each row (and column) can be derived from any given Enum element as follows:

int index = java.util.Arrays.asList(PatientCharacteristic.class.getEnumConstants()).indexOf(PatientCharacteristic.HBA1C);

Using this approach, “lossless” conversion from a RealMatrix instance to this data structure and back would be possible (assuming initial knowledge of which rows and columns map to which characteristics), but even without the numeric references, this meets the current requirement, which is simply to use pre-calculated covariance matrices from SAS in Java in a clear and concise manner.

The full EnumMap wrapper class looks like this:

import java.util.EnumMap;

public class CovarianceMatrix<K extends Enum<K>> {
	
	private EnumMap<K, EnumMap<K, ValuePValuePair>> matrix;
	
	private static final class ValuePValuePair {
		public double value;
		public double pValue;

		public ValuePValuePair(double value, double pValue) {
			this.value = value;
			this.pValue = pValue;
		}
	}
	
	public CovarianceMatrix(Class<K> characteristics) {
		
		this.matrix = new EnumMap<>(characteristics);
		
		for (K initialCharacteristic : characteristics.getEnumConstants()) {
			for (K secondCharacteristic : characteristics.getEnumConstants()) {
				if (this.matrix.get(initialCharacteristic) != null) {
					this.matrix.get(initialCharacteristic).put(secondCharacteristic, new ValuePValuePair(0.0, 0.0));
				} else {
					EnumMap<K, ValuePValuePair> newValue = new EnumMap<>(characteristics);
					newValue.put(secondCharacteristic, new ValuePValuePair(0.0, 0.0));
					this.matrix.put(initialCharacteristic, newValue);
				}
			}
		}
	}
	
	public void setCovarianceAndPValue(K row, K column, double covariance, double pValue) {
		this.matrix.get(row).get(column).value = covariance;
		this.matrix.get(row).get(column).pValue = pValue;
	}

	public void setCovariance(K row, K column, double covariance) {
		this.matrix.get(row).get(column).value = covariance;
	}
	
	public void setCovariancePValue(K row, K column, double pValue) {
		this.matrix.get(row).get(column).pValue = pValue;
	}
	
	public double getCovariance(K row, K column) {
		return this.matrix.get(row).get(column).value;
	}
	
	public double getCovariancePValue(K row, K column) {
		return this.matrix.get(row).get(column).pValue;
	}
	
}

The next step is to perform some profiling on this in situ and probably switch over to instantiating the ValuePValuePair instances lazily with the appropriate null checks added in to the getters. But for now, the class is working very well and achieves the goal of making the clinical modelling portions of the code much more readable.

Replicating Excel’s logarithmic curve fitting in R

For a new work project, we’ve just been provided with a Kaplan-Meier curve showing kidney graft survival over 12 months in two groups of patients. As the data is newly generated, we don’t yet have access to the raw data and there are too many patients in the samples to see any discrete steps on the K-M curve, meaning that we can’t extract data on time to individual events. Since we wanted to get a rough idea of how our budget impact analysis (the main focus of the project) might look with these data in place, we traced the data from the K-M curve using WebPlotDigitizer by Ankit Rohatgi.

Looking at the shape of the data we had and knowing that longer-term kidney graft survival (from donors after brain death) typically looks like that in Figure 1, we opted for a simple logarithmic model of the data:

$$S(t) = \beta ln(t) + c$$

Figure 1 Long-term graft survival after first adult kidney-only transplant from donors after brain death, 2000–2012. (NHS Blood and Transplant (NHSBT) Organ Donation and Transplantation Activity Report 2013/14.)
NHSBT Kidney Survival

It’s very easy to generate a logarithmic fit in Excel by selecting the “Add Trendline…” option for a selected series in a chart, selecting the “Logarithmic” option and checking the “Display equation on chart” and “Display R-squared value on chart”. The latter options display the coefficient and intercept values, and the R-squared value (using the Pearson product-moment correlation coefficient), respectively.

However, since we’re ultimately going to be using R for the analysis (when the raw K-M data become available and we use the R survival package to fit a Weibull model, or similar, to the data), I thought it would make sense to also use R to give us our rough logarithmic fit of the data. As one might expect, it’s very straightforward to replicate the results that are produced in Excel using just a few lines of R:

t <- c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)
survival <- c(1, 0.97909, 0.97171, 0.96186, 0.95694, 0.95387, 0.95264, 0.9520, 0.94956, 0.94526, 0.94279, 0.94033, 0.94033)
s_t <- data.frame(t, survival)
log_estimate <- lm(survival~log(t), data=s_t)
summary(log_estimate)

Running this gives the following output, which tallies exactly with the intercept, log(t) parameters and R squared values reported in Excel for the same data:

Residuals:
       Min         1Q     Median         3Q        Max 
-0.0033502 -0.0016374  0.0000542  0.0014951  0.0037662 

Coefficients:
             Estimate Std. Error t value Pr(>|t|)    
(Intercept)  0.996234   0.001636  609.13  < 2e-16 ***
log(t)      -0.022377   0.000868  -25.78 3.46e-11 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 

Residual standard error: 0.002302 on 11 degrees of freedom
Multiple R-squared: 0.9837,	Adjusted R-squared: 0.9822 
F-statistic: 664.6 on 1 and 11 DF,  p-value: 3.458e-11

And if we overlay the resulting curve over the NHSBT data in Figure 1, we can see that the simple logarithmic fit matches the clinical reality pretty well when projected over a five-year time horizon:

Figure 2 NHSBT long-term kidney graft survival with logarithmic fit to K-M data overlaid in black
NHSBT Kidney Survival with Overlay

We can then use this approximation of the K-M curve to derive a survival function that we’ll use in the model until we have access to the full data set. A nice, quick approach that certainly couldn’t be accused of overfitting.

Returning the first instance of a Java class in a collection of instances of its superclass

In an ongoing work project, we have a Java class that models a patient with type 1 diabetes. Since a lot of the risk models that operate on the simulated patient are affected by various medications (antihypertensives, antithrombotics, lipid modifying medications, etc.), we have an ArrayList on the patient that holds instances of a Medication class (or any of its many subclasses) to represent which of these the patient is currently taking:

private List<Medication> riskAdjustingMedications;

This allows us to write things like the fairly English-sounding:

patient.getRiskAdjustingMedications().add(new AntiHypertensiveMedication());

(We actually opted to use the getter for the riskAdjustingMedications ArrayList and the .add() method directly for the very reason that the code is extremely readable when written as above.)

But how can we write something that allows us to establish whether or not the ArrayList already contains an instance of Medication itself or one of its subclasses? For now, we’re using a private method in the Patient class that accomplishes this as generally as possible:

private static <T, E extends T> E returnFirstInstanceOfClassInCollection(Class<E> theClass, Collection<T> arrayList) {
	
	for (T o : arrayList) {
		if (o != null && o.getClass() == theClass) {
			return theClass.cast(o);
		}
	}

	return null;
}

So that’s a generic method with the type parameters T and E (where E extends T) that takes the class of E (in our specific example, Medication.class or any subclass) and a Collection of type T as arguments. If an instance of type Class is found in the Collection, the method returns the first instance, otherwise it returns null.

In the Patient class, we’ve then implemented a public convenience method to support querying for Medication or any of its subclasses:

public <T extends Medication> boolean isTaking(Class<T> medication) {
	return Patient.returnFirstInstanceOfClassInCollection(medication, this.riskAdjustingMedications) == null ? false : true;
}

For maximum readability in our classes that are evaluating complication risk and patient physiology, we can then declare statics for the specific Medication or subclass types we’re using and write code like the following:

static final Class<AntiHypertensiveMedication> ANTI_HYPERTENSIVES = AntiHypertensiveMedication.class;

if (patient.isTaking(ANTI_HYPERTENSIVES)) {
	patient.adjustSystolicBloodPressureBy(-5.0);
}

Eminently readable.