Test-Driven Development for Scientific Programming, Part 1: Introduction

Test-driven development (TDD) is a way of writing code via creating tests, and using the tests to do design on your system. My best summary in a sentence is that TDD exists to focus the difficulty of designing programs into smaller, controllable pieces, in the form of tests.

If you’re looking for a more solid definition of TDD, I invite you to watch this video of Gary Bernhardt explaining it very succinctly. The essence of his definition is: - you write a failing test - write the minimal amount of code to cause the test to pass - refactor (restructure the code for optimization, style, or design purposes)

Gary Bernhardt uses an example of a small piece of a webapp to illustrate this process, but it looks very little like anything we’d see in most scientific applications. So I’d like to demonstrate briefly what this would look like if you are trying to write a numerical program using Python.

Let’s say we want to write a function in Python that adds an arbitrary number of numpy arrays together. The first thing we need to do is write a test file that looks like this:

import numpy as np


def test_array_add():
    arr_1 = np.ones(5)
    arr_2 = 2*np.ones(5)

    assert arr_add(arr_1, arr_2) == 3*np.ones(5)


test_array_add()

Let’s break this test down a bit: we define a function that calls the function we want to build, arr_add, by feeding it two test arrays that look like

[1,1,1,1,1]
[2,2,2,2,2]

and asserts we get the output

[3,3,3,3,3]

The first thing we get when we run this test is the error

Traceback (most recent call last):
  File "test.py", line 10, in <module>
    test_array_add()
  File "test.py", line 8, in test_array_add
    assert arr_add(arr_1, arr_2) == 3*np.ones(5)
NameError: name 'arr_add' is not defined

shell returned 1

This is because there is no function called arr_add yet; as Gary Bernhardt says, we are “programming by wishful thinking”. That is, we are just writing a test for the code as we want it to exist: I want a function called arr_add that takes in the arrays it operates on as arguments, and returns the proper sum. So now we just cause the function to exist, and write the function in our test file:

import numpy as np


def arr_add():
    pass


def test_array_add():
    arr_1 = np.ones(5)
    arr_2 = 2*np.ones(5)

    assert arr_add(arr_1, arr_2) == 3*np.ones(5)


test_array_add()

Here we’ve just added an empty function definition that doesn’t do anything (the pass statement in Python is basically syntax for no-op). If we run it now we get

Traceback (most recent call last):
  File "test.py", line 14, in <module>
    test_array_add()
  File "test.py", line 12, in test_array_add
    assert arr_add(arr_1, arr_2) == 3*np.ones(5)
TypeError: arr_add() takes 0 positional arguments but 2 were given

shell returned 1

So now we’re on to another error message. Here we get another trivial problem with our function, it needs to be able to take arguments. That’s alright, we can fix that too:

import numpy as np


def arr_add(arr_1, arr_2):
    pass


def test_array_add():
    arr_1 = np.ones(5)
    arr_2 = 2*np.ones(5)

    assert arr_add(arr_1, arr_2) == 3*np.ones(5)


test_array_add()

Running it now gives

Press ENTER or type command to continue
Traceback (most recent call last):
  File "test.py", line 14, in <module>
    test_array_add()
  File "test.py", line 12, in test_array_add
    assert arr_add(arr_1, arr_2) == 3*np.ones(5)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

shell returned 1

This time the error might be less obvious if you aren’t used to making assertions about numpy arrays, but the problem here is that == doesn’t work on numpy arrays by itself, instead we have to specify whether we want to evaluate if any of the elements of the two arrays are the same, or if all are the same. We want all (it would be very sad to let our test pass if it only manages to get some of the values right), so we just need to use the numpy.all function:

import numpy as np


def arr_add(arr_1, arr_2):
    pass


def test_array_add():
    arr_1 = np.ones(5)
    arr_2 = 2*np.ones(5)

    assert np.all(arr_add(arr_1, arr_2) == 3*np.ones(5))


test_array_add()

This returns our first AssertionError message:

Traceback (most recent call last):
  File "test.py", line 14, in <module>
    test_array_add()
  File "test.py", line 12, in test_array_add
    assert np.all(arr_add(arr_1, arr_2) == 3*np.ones(5))
    AssertionError

What has happened is that our empty function is returning None, a constant in Python which is returned by default when a function has no explicit return statement, which our arr_add function does not.

Now we get to a kind of weird part of TDD. We are supposed to write the simplest code we possibly can to fix the error message, and technically we can just make our function return an array of 3s, like this:

def arr_add(arr_1, arr_2):
    return 3*np.ones(5)

This is obviously kind of a stupid thing to do, since it obviously won’t work for anything but our particular test case, but it will make the test pass. Running the code now returns nothing, since the assertion is now true. Why do this? In practice the main reason I would ever hard-code my new program to pass a test like this is to delay having to think about a general implementation. The way we get to the general implementation is to write another test:

import numpy as np


def arr_add(arr_1, arr_2):
    return 3*np.ones(5)


def test_array_add_two_inputs():
    arr_1 = np.ones(5)
    arr_2 = 2*np.ones(5)

    assert np.all(arr_add(arr_1, arr_2) == 3*np.ones(5))


def test_array_add_many_inputs():
    arr_1 = np.ones(5)
    arr_2 = 2*np.ones(5)
    arr_3 = 3*np.ones(5)

    assert np.all(arr_add(arr_1, arr_2, arr_3) == 6*np.ones(5))


test_array_add_two_inputs()
test_array_add_many_inputs()

What we’ve done here is to change the name of our first test a bit, specifying it tests for two input arrays, and then add a new test which tests the function for 3 arrays. Now if we run it we get

Press ENTER or type command to continue
Traceback (most recent call last):
  File "test.py", line 24, in <module>
    test_array_add_many_inputs()
  File "test.py", line 20, in test_array_add_many_inputs
    assert np.all(arr_add(arr_1, arr_2, arr_3) == 6*np.ones(5))
TypeError: arr_add() takes 2 positional arguments but 3 were given

shell returned 1

Now the first test passes but our second test fails, telling us arr_add doesn’t even allow 3 arguments! Now we are forced to generalize in two ways: first, we need to allow 3 arguments, but also not require 3 arguments, as the first test will fail if we do. In general, we’d like to be able to pass arr_add any number of arrays. One way we can do this (and the standard way in Python) is to use a starred expression. If we write a function like this:

def func(*args):
    do_stuff_with_args

func now has in its context a tuple of arguments called args. We can access these arguments inside func via args[0], args[1], etc. So we want to change arr_add to be

def arr_add(*arrs):
    return something

Now we can generalize further by adding the arrays correctly. We could do this with a loop, but looping over numpy arrays is almost never a good idea for performance reasons, so we will instead try this: convert the tuple of arrays to a numpy array via np.asarray(arrs), then we will use the sum method on this array, over axis 0. If you aren’t familiar with numpy arrays, when we create an array of the 1-dimensional arrays we have passed into add_arr, we get something like this:

[[1,1,1,1,1],
 [2,2,2,2,2],
 [3,3,3,3,3]]

You can think of this like a matrix or like a nested list. axis 0 refers to either the columns (viewing it as a matrix) or the first layer of nesting (viewing as a nested list). What we want is to sum the individual nested arrays together, i.e.

[1,1,1,1,1] + [2,2,2,2,2] + [3,3,3,3,3]

which is exactly the same as summing along the columns. So this is why we specify the sum to be along axis=0. Now our final code looks like

import numpy as np


def arr_add(*arrs):
    return np.asarray(arrs).sum(axis=0)


def test_array_add_two_inputs():
    arr_1 = np.ones(5)
    arr_2 = 2*np.ones(5)

    assert np.all(arr_add(arr_1, arr_2) == 3*np.ones(5))


def test_array_add_many_inputs():
    arr_1 = np.ones(5)
    arr_2 = 2*np.ones(5)
    arr_3 = 3*np.ones(5)

    assert np.all(arr_add(arr_1, arr_2, arr_3) == 6*np.ones(5))


test_array_add_two_inputs()
test_array_add_many_inputs()

and both tests now pass.

A few key points to note:

The reason we performed this ridiculous maneuver is to help us focus on writing the tests, and sketching out the design for the function, before worrying about the general algorithm. This is one of the reasons TDD can be a powerful tool; the process lets you design a system by how you will use it, then error-by-error, bring that system into existence, while you focus on just fixing the error message as simply as possible. You let your tests design the system, then finally, once the tests force you to do so, you write the general algorithm or method into the code.

This is a very coarse overview of TDD, one which is well supplemented by the talk I linked to above by Gary Bernhardt. In fact, if you are interested in learning more about TDD and other software engineering practices, Gary Bernhardt’s screencasts at Destroy All Software are a great resource. He goes into many concrete and less trivial examples, and discusses more subtle decisions you have to make when writing real code. All of what I know about TDD comes from Destroy All Software, and my own experience using TDD for scientific code. The screencasts do cost money, but for students and people in other situations where they might be unable to afford the cost, he offers a way to get access for free.

In part 2, I will write down some recommendations for how to write more complex tests in Python, and some software to help make running tests easier.


About me · CV · Research · Photography · Programming · main