Case study 3: online statistics#
A cautionary tale.#
In many computational modelling procedures you will need an updated estimate of statistics as the code executes. For example, you may need to track a mean or a standard deviation of a performance measure in a multi-stage algorithm or as a simulation model of a healthcare system executes.
As we have seen numpy
provides highly efficient functions for calculating a mean or standard deviation based on data held in an array. I’m always tempted to make use of these built in procedures. They are indeed fast and incredibly easy to use. The downside is that you waste computation via repeated iteration over an array. The other option, that requires more careful thought (due to floating point issues), is a running estimate of your statistics. In general, I’ve implemented such procedures in standard python. Let’s look at an example where we compare recalculation using a numpy
function with a running (sometimes called an ‘online’) calculation of the mean and standard deviation in standard python.
We will first refactor AttendanceSummary
from Statistical procedures to an OnlineSummary
class to include an update()
function. It will accept a np.ndarray
that recalculates the sample mean and standard deviation using a numpy
on the full data set. The function test_complete_recalculation
iteratively calls update
using more data each time. For simplicities sake we will reuse the data contained within ed_data
.
Wait a minute!
This chapter is about scientific coding in numpy
, but this case study is demonstrating that standard python is more efficient! Well not quite. The overall theme of this part of the book is that code is a first class citizen in health data science. You should always think about the design of your code in any algorithms or models you implement. This case study is demonstrating that there may be instances where a numpy
solution is not the most efficient.
Imports#
import numpy as np
Data#
file_name = 'data/minor_illness_ed_attends.csv'
ed_data = np.loadtxt(file_name, skiprows=1, delimiter=',')
print(ed_data.shape)
(74,)
numpy
solution#
class OnlineSummary:
def __init__(self, data=None, decimal_places=2):
"""
Track online statistics of mean and standard deviation.
Params:
-------
data: np.ndarray, optional (default = None)
Contains an initial data sample.
decimal_places: int, optional (default=2)
Summary decimal places.
"""
if isinstance(data, np.ndarray):
self.n = len(data)
self.mean = data.mean()
self.std = data.std(ddof=1)
else:
self.n = 0
self.mean = None
self.std = None
self.dp = decimal_places
def update(self, data):
'''
Update the mean and standard deviation using complete recalculation.
Params:
------
data: np.ndarray
Vector of data
'''
self.n = len(data)
# update the mean and std. Easy!
self.mean = data.mean()
self.std = data.std(ddof=1)
def __str__(self):
to_print = f'Mean:\t{self.mean:.{self.dp}f}' \
+ f'\nStdev:\t{self.std:.{self.dp}f}' \
+ f'\nn:\t{self.n}' \
return to_print
def test_complete_recalculation(data, start=2):
summary = OnlineSummary(data[:start])
for i in range(start, len(data)+1):
summary.update(data[:i])
return summary
summary = test_complete_recalculation(ed_data)
print(summary)
Mean: 2.92
Stdev: 0.71
n: 74
%timeit summary = test_complete_recalculation(ed_data)
1.15 ms ± 21.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
You should find the numpy
implementation fairly efficient clocking in at around 1.5ms on average. But can we do better in standard python by computing an online mean and standard deviation?
Online sample mean and variance#
To do this we will use Welford’s algorithm for computing a running sample mean and standard deviation. This is a robust, accurate and old(ish) approach (1960s) that I first read about in Donald Knuth’s art of computer programming vol 2. (just to be clear I learnt how to do this in 2008 not 1960!). To implement it we need to refactor update
. Note that we will need a fair bit more code than our simple numpy
solution.
The algorithm is given in a recursive format. For our purposes here, you can just think of that as tracking the mean and standard deviation as attributes of a class that we iteratively update with a new \(x\).
The first thing you need to do is handle the first observation encountered.
Then on each subsequent call you update \(M\) and \(S\) making use of the previous values. Note that \(M\) has a relatively simple interpretation: its the sample mean. However, \(S\) is not the standard deviation. Its actually the sum of squares of differences from the current mean. We will look at how to update that first and then I’ll show you the equation for converting to the standard deviation.
If the equations are confusing you can think of \(M_n\) as the updated_mean
and \(M_{n-1}\) as the previous_mean
.
Once the update is complete it is then relatively trivial to calculate the standard deviation \(\sigma_n\). Note that we don’t necessarily need to track the standard deviation just \(S_n\). We can inexpensively calculate \(\sigma_n\) when it is needed.
The code listing below modifies OnlineSummary
to make use of Welford’s algorithm. Note that std
is now a property that calculates the standard deviation on the fly using \(S_n\)
class OnlineSummary:
def __init__(self, data=None, decimal_places=2):
"""
Returns mean, stdev and 5/95 percentiles of ed data
Params:
-------
data: np.ndarray, optional (default = None)
Contains an initial data sample.
decimal_places: int, optional (default=2)
Summary decimal places.
"""
self.n = 0
self.mean = None
self._sq = None
if isinstance(data, np.ndarray):
for x in data:
self.update(x)
self.dp = decimal_places
@property
def variance(self):
return self._sq / (self.n - 1)
@property
def std(self):
return np.sqrt(self.variance)
def update(self, x):
'''
Running update of mean and variance implemented using Welford's
algorithm (1962).
See Knuth. D `The Art of Computer Programming` Vol 2. 2nd ed. Page 216.
Params:
------
x: float
A new observation
'''
self.n += 1
# we need to do more work ourselves for online stats!
# init values
if self.n == 1:
self.mean = x
self._sq = 0
else:
# compute the updated mean
updated_mean = self.mean + ((x - self.mean) / self.n)
# update the sum of squares
self._sq += (x - self.mean) * (x - updated_mean)
# update the tracked mean
self.mean = updated_mean
def __str__(self):
to_print = f'Mean:\t{self.mean:.{self.dp}f}' \
+ f'\nStdev:\t{self.std:.{self.dp}f}' \
+ f'\nn:\t{self.n}' \
return to_print
def test_online_calculation(data, start=1):
summary = OnlineSummary()
for observation in data:
summary.update(observation)
return summary
summary = test_online_calculation(ed_data)
print(summary)
Mean: 2.92
Stdev: 0.71
n: 74
%timeit summary = test_online_calculation(ed_data)
31.1 μs ± 294 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Summing up#
Crickey nothing beats a good algorithm! You should find that you are now working in microseconds (µs) as opposed to milliseconds. 1µs = 1000ms. On my machine the test_online_calculation
executes in ~45 µs on average while test_complete_recalculation
takes ~1500 µs. So we are finding a speed up of ~97%. That gap will continue to grow as the number of samples \(n\) increases. The result is explained because our second implementation has a constant time for execution (and constant number of computational steps) while the time complexity of the numpy
call depends on the size of the array. That’s a lesson well worth remembering when developing code for scientific applications requiring performant code.