CR-Sparse: Hardware accelerated functional algorithms for sparse signal processing in Python using JAX

principles; Matrix completion problems; Matrix factorization problems; Model based / Structured compressive sensing problems; Joint recovery problems from multiple measurement vectors.


Summary
We introduce CR-Sparse, a Python library that enables to efficiently solve a wide variety of sparse representation based signal processing problems. It is a cohesive collection of sublibraries working together. Individual sub-libraries provide functionalities for: wavelets, linear operators, greedy and convex optimization based sparse recovery algorithms, subspace clustering, standard signal processing transforms, and linear algebra subroutines for solving sparse linear systems. It has been built using Google JAX (Bradbury et al., 2018), which enables the same high level Python code to get efficiently compiled on CPU, GPU and TPU architectures using XLA (Abadi et al., 2017). Traditional signal processing exploits the underlying structure in signals by representing them using Fourier or wavelet orthonormal bases. In these representations, most of the signal energy is concentrated in few coefficients allowing greater flexibility in analysis and processing of signals. More flexibility can be achieved by using overcomplete dictionaries (Mallat, 2009) (e.g. unions of orthonormal bases). However, the construction of sparse representations of signals in these overcomplete dictionaries is no longer straightforward and requires use of specialized sparse coding algorithms like orthogonal matching pursuit (Pati et al., 1993) or basis pursuit (Chen et al., 2001). The key idea behind these algorithms is the fact that underdetermined systems Ax = b can be solved efficiently to provide sparse solutions x if the matrix A satisfies specific conditions on its properties like coherence. Compressive sensing takes the same idea in the other direction and contends that signals having sparse representations in suitable bases can be acquired by very few data-independent random measurements y = Φx if the sensing or measurement system Φ satisfies certain conditions like restricted isometry property (Candes, 2008). The same sparse coding algorithms can be tailored for sparse signal recovery from compressed measurements.

Package Overview
The cr.sparse.pursuit package includes greedy and thresholding based solvers for sparse recovery. It includes: OMP, CoSaMP, HTP, IHT, SP algorithms. (provided in cr.sparse.lop package). The cr.sparse.cvx package includes efficient solvers for l1-minimization problems using convex optimization methods. The cr.sparse.sls package provides JAX versions of LSQR, ISTA, FISTA algorithms for solving sparse linear systems. These algorithms can work with unstructured random and dense sensing matrices as well as structured sensing matrices represented as linear operators The cr.sparse.lop package includes a collection of linear operators influenced by PyLops (Ravasi & Vasconcelos, 2019). cr.sparse.wt package includes a JAX version of major functionality from PyWavelets (Lee et al., 2019) making it a first major pure Python wavelets implementation which can work across CPUs, GPUs and TPUs.

Statement of need
Currently, there is no single Package which provides a comprehensive set of tools for solving sparse recovery problems in one place. Individual researchers provide their codes along with their research paper only for the algorithms they have developed. Most of this work is available in the form of MATLAB (MATLAB, 2018) libraries. E.g.: YALL1 is the original MATLAB implementation of the ADMM based sparse recovery algorithms. L1-LS is the original MATLAB implementation of the Truncated Newton Interior Points Method for solving the l1-minimization problem. Sparsify provides the MATLAB implementations of IHT, NIHT, AIHT algorithms. aaren/wavelets is a CWT implementation following (Torrence & Compo, 1998). HTP provides implementation of Hard Thresholding Pursuit in MATLAB. WaveLab is the reference open source wavelet implementation in MATLAB. However, its API has largely been superseded by later libraries. Sparse and Redundant Representations book code (Elad, 2010) provides basic implementations of a number of sparse recovery and related algorithms. Several of these libraries contain key performance critical sub-routines in the form of C/C++ extensions making portability to GPUs harder.
There are some Python libraries which focus on specific areas however they are generally CPU based. E.g., pyCSalgos is a Python implementation of various Compressed Sensing algorithms. spgl1 is a NumPy based implementation of spectral projected gradient for L1 minimization. c-lasso (Simpson et al., 2021) is a Python package for constrained sparse regression and classification. This is also CPU only. PyWavelets is an excellent CPU only wavelets implementation in Python closely following the API of Wavelet toolbox in MATLAB. The performance critical parts have been written entirely in C. There are several attempts to port it on GPU using PyTorch (PyTorch-Wavelet-Toolbox) or Tensorflow (tf-wavelets) backends. PyLops includes GPU support. They have built a backend.py layer to switch explicitly between NumPy and CuPy for GPU support. In contrast, our use of JAX enables us to perform jit compilation with abstracted out end-to-end XLA optimization to multiple backend.
The algorithms in this package have a wide variety of applications. We list a few: image denoising, deblurring, compression, inpainting, impulse noise removal, super-resolution, subspace clustering, dictionary learning, compressive imaging, medical imaging, compressive radar, wireless sensor networks, astrophysical signals, cognitive radio, sparse channel estimation, analog to information conversion, speech recognition, seismology, direction of arrival.

Sparse signal processing problems and available solvers
We provide JAX based implementations for the following algorithms: • cr.sparse.pursuit.omp: Orthogonal Matching Pursuit (OMP) (Davenport & Wakin, 2010;Pati et al., 1993;Tropp, 2004 , 2019)) which provide the forward and adjoint operation functions. These operators can be JIT compiled and used efficiently with the algorithms above. Our 2D and ND operators accept 2D/ND arrays as input and return 2D/ND arrays as output. The operators +, -, @, ** etc. are overridden to provide operator calculus, i.e. ways to combine operators to generate new operators.
As an application area, the library includes an implementation of sparse subspace clustering (SSC) by orthogonal matching pursuit (You et al., 2016) in the cr.sparse.cluster.ssc package. The cr.sparse.cluster.spectral package provides a custom implementation of spectral clustering step of SSC.

Experimental Results
We conducted a number of experiments to benchmark the runtime of CR-Sparse implementations viz. existing reference software in Python or MATLAB. Jupyter notebooks to reproduce these micro-benchmarks are available on the cr-sparse-companion  repository.
All Python based benchmarks have been run on the machine configuration: Intel(R) Xeon ( We see significant though variable gains achieved by CR-Sparse on GPU. We have observed that gain tends to increase for larger problem sizes. GPUs tend to perform better when problem size increases as the matrix/vector products become bigger. vmap and pmap tools provided by JAX can be used to easily parallelize the CR-Sparse algorithms over multiple data and processors.
Following Limitations Some of the limitations in the library come from the underlying JAX library. JAX is relatively new and still hasn't reached 1.0 level maturity. The programming model chosen by JAX places several restrictions on expressing the program logic. For example, JAX does not have support for dynamic or data dependent shapes in their JIT compiler. Thus, any algorithm parameter which determines the size/shape of individual arrays in an algorithm must be statically provided. E.g. for the greedy algorithms like OMP, the sparsity level K must be known in advance and provided as a static parameter to the API as the size of output array depends on K.
The control flow primitives like lax.while_loop, lax.fori_loop etc. in JAX require that the algorithm state flowing between iterations must not change shape and size. This makes coding of algorithms like OMP or SVT (singular value thresholding) very difficult. An incremental QR or Cholesky decomposition based implementation of OMP requires growing algorithm state. We ended up using a standard Python for loop for now but the JIT compiler simply unrolls it and doesn't allow for tolerance based early termination in them.
1D convolutions are slow in JAX on CPU #7961. This affects the performance of DWT/IDWT in cr.sparse.dwt. We are working on exploring ways of making it more efficient while keeping the API intact.
These restrictions necessitate good amount of creativity and a very disciplined coding style so that efficient JIT friendly solvers can be developed.

Future Work
Currently, work is underway to provide a JAX based implementation of TFOCS (Becker et al., 2011) in the dev branch. This will help us increase the coverage to a wider set of problems (like total variation minimization, Dantzig selector, l1-analysis, nuclear norm minimization, etc.). As part of this effort, we are expanding our collection of linear operators and building a set of indicator and projector functions on to convex sets and proximal operators (Parikh & Boyd, 2014). This will enable us to cover other applications such as SSC-L1 (Pourkamali-Anaraki et al., 2020). In future, we intend to increase the coverage in following areas: More recovery algorithms (OLS, Split Bergmann, SPGL1, etc.) and specialized cases (partial known support, ); Bayesian Compressive Sensing; Dictionary learning (K-SVD, MOD, etc.); Subspace clustering; Image denoising, compression, etc. problems using sparse representation principles; Matrix completion problems; Matrix factorization problems; Model based / Structured compressive sensing problems; Joint recovery problems from multiple measurement vectors.