policytree: Policy learning via doubly robust empirical welfare maximization over trees

There has recently been a considerable amount of work on statistical methodology for policy learning, including Manski (2004), Zhao, Zeng, Rush, & Kosorok (2012), Swaminathan & Joachims (2015), Kitagawa & Tetenov (2018), van der Laan & Luedtke (2015), Luedtke & van der Laan (2016), Mbakop & Tabord-Meehan (2016), Athey & Wager (2017), Kallus & Zhou (2018), and Zhou, Athey, & Wager (2018). In particular, Kitagawa & Tetenov (2018) show that if we only consider policies π restricted to a class Π with finite VC dimension and have access to data from a randomized trial with n samples, then an empirical welfare maximization algorithm achieves regret that scales as √ VC(Π)/n. Athey & Wager (2017) extend this result to observational studies via doubly robust scoring, and Zhou et al. (2018) further consider the case with multiple treatment choices (in particular, the regret will depend on the tree depth, feature space, and number of actions).


Summary
The problem of learning treatment assignment policies from randomized or observational data arises in many fields. For example, in personalized medicine, we seek to map patient observables (like age, gender, heart pressure, etc.) to a treatment choice using a data-driven rule.
The package policytree for R (R Core Team, 2020) implements the multi-action doubly robust approach of Zhou et al. (2018) in the case where we want to learn policies π that belong to the class Π of depth-k decision trees. In order to use policytree, the user starts by specifying a set of doubly robust scores for policy evaluation; the software then carries out globally optimal weighted search over decision trees.
It is well known that finding an optimal tree of arbitrary depth is NP-hard. However, if we restrict our attention to trees of depth k, then the problem can be solved in polynomial time. Here, we implement the global optimization via an exhaustive (unconstrained) tree search that runs in O(P k N k (log N + D) + P N log N ) time, where N is the number of individuals, P the number of characteristics observed for each individual and D is the number of available treatment choices (see details below). If an individual's characteristics only takes on a few discrete values, the runtime can be reduced by a factor of N k . Additionally, an optional approximation parameter lets the user control how many splits to consider.  (Bloom et al., 1997). The reward matrix contains two outcomes: not assigning treatment (action 1), and assigning treatment, a job training program (action 2). The covariate matrix contains two variables: a candidate's previous annual earnings in $1,000 and years of education. Note: the optional package DiagrammeR is needed to plot trees.
Our package is integrated with the R package grf of Athey, Tibshirani, & Wager (2019), allowing for a simple workflow that uses random forests to estimate the nuisance components required to automatically form the doubly robust scores. We also generalize the causal_for est function from grf to multiple treatment effects with a one vs all encoding described in Zhou et al. (2018). The following simulation example illustrates this workflow in a setting with D = 3 actions; here, we write covariates with X, outcomes as Y, and actions as W. Figure  1 shows a tree similarly grown on a dataset considered by Kitagawa & Tetenov (2018).

Appendix: Details on tree search
The pseudocode for the tree search is outlined in Algorithm 1 and Algorithm 2. At a high level, in the main recursive case for k >= 2, the algorithm maintains the data structure sorted_sets to quickly obtain the sort order of points along all dimensions P for a given split. For each of the P × (N − 1) possible splits, for each dimension j all points on the right side are stored in set R (j). All points on the left side are stored in set L (j). For each split candidate, the point is moved from the right set to the left set for all dimensions. This proceeds recursively to enumerate the reward in all possible splits.
The O(P N log N ) term arises from the fixed amortized cost of creating the global sort order once for every sample along all P dimensions. The remaining O(P k N k (log N + D)) term is obtained by inductively calculating the runtime for increasing depths k.
Algorithm 1: Exact tree search. In the implementation, parents with identical actions in both leaves are pruned. It also features an optional approximation parameter than controls the number of splits to consider. The recursion base case is both at a leaf node (k = 0) as well as at the parent of a leaf (k = 1) where one can jointly compute the best action in each leaf in O(N P D) by a dynamic programming style algorithm). Peripheral functions are outlined at the end 1 function tree_search (sorted_sets, Γ, k); Input : P -vector sorted_sets, N × D score matrix Γ, tree depth k Output: The optimal tree, a structure with (left node, right node, total reward, action) 2 if k = 0 then Recursive Case We propose the time complexity for k >= 1 (1 or more splits) to be O(P k N k (log N + D)). This is satisfied for base case 2 above. For the recursive case, there are P N possible split points. For every single split along along every dimension we remove a sample from a Binary Search Tree and add to another; this takes O(log N ) time, and we do this for each of the P dimensions, leading to time (P log N ). Further, for each split, we recursively call tree_search for depth k − 1, in general there are m 1 and m 2 points in each split at the top level such that N = m 1 + m 2 . Assuming the recursive expression, the amount of work done for each split is then Note that, Similarly, Further, since m 1 + m 2 = N, m 1 , m 2 , N > 0.
Combining, the amount of work in each split is upper bounded by O(P k−1 N k−1 (log N + d)).
Since we have P N splits, this leads to a running time of O(P N (P k−1 N k−1 (log N + d)) = O(P k N k (log N + d)).
Algorithm: Peripheral functions for Algorithm 1 1 function create_sorted_sets (X); Input : N × P covariate matrix X Output: A length P vector, the jth vector containing all N samples sorted along dimension j 2 result ← vector(P); 3 for j=1:P do 4 result(j) ← binary_search_tree(j); 5 for i=1:N do 6 result(j).insert(x i ); 7 end 8 end 9 return result; 10 function create_empty_sorted_sets (); Input : P Number of dimensions Output: A length P vector, the jth vector is empty, but to be sorted along dimension j 11 result ← vector(P); 12 for j=1:P do 13 result(j) ← binary_search_tree(j); 14 end 15 return result;