from abc import ABC, abstractmethod
from enum import Enum
from typing import Any
import anndata as ad
import h5py
import numpy as np
from numba import types
from scipy import sparse as py_sparse
from illico.utils.sparse.csc import CSCMatrix
from illico.utils.sparse.csr import CSRMatrix
[docs]
class Test(Enum):
OVO = "ovo"
OVR = "ovr"
[docs]
class DispatcherRegistry(dict):
[docs]
def register(self, test: Test, data_format: KernelDataFormat):
test = Test(test)
data_format = KernelDataFormat(data_format)
def decorator(obj):
key = (test, data_format)
self[key] = obj
return obj
return decorator
[docs]
def get(self, test: Test, data_format: KernelDataFormat):
key = (Test(test), KernelDataFormat(data_format))
try:
return self[key]
except KeyError as e:
raise KeyError(f"No dispatcher registered for test {test} and data format {data_format}.") from e
[docs]
class DataHandlerRegistry(dict):
[docs]
def register(self, data_format):
def decorator(obj):
self[data_format] = obj
return obj
return decorator
[docs]
def get(self, key):
try:
return self[type(key)](key)
except KeyError as e:
raise KeyError(f"Support for data type {type(key)} is not implemented.") from e
# How to fetch data from disk, if data is backed or lazy-loaded
data_handler_registry = DataHandlerRegistry()
# Which dispatcher to use depending on data format and test type
dispatcher_registry = DispatcherRegistry()
[docs]
class DataHandler(ABC):
def __init__(self, data):
self.data = data
[docs]
@abstractmethod
def fetch(self, *args, **kwargs) -> tuple:
"""Fetch data from disk if needed."""
pass
[docs]
@abstractmethod
def to_nb(self, *args, **kwargs) -> Any:
"""Convert data to numba-compatible format."""
pass
[docs]
class InRAMDataHandler(DataHandler):
[docs]
def fetch(self, lb: int, ub: int) -> tuple:
"""If the data is already in RAM, let the kernels do optimized slicing."""
return self.data, (lb, ub)
[docs]
@data_handler_registry.register(np.ndarray)
class DenseDataHandler(InRAMDataHandler):
[docs]
@classmethod
def to_nb(cls, X: np.ndarray) -> np.ndarray:
assert isinstance(X, np.ndarray)
return X
[docs]
@data_handler_registry.register(py_sparse._csr.csr_matrix)
class CSRDataHandler(InRAMDataHandler):
[docs]
@classmethod
def to_nb(cls, X: py_sparse.csr_matrix) -> CSRMatrix:
assert isinstance(X, py_sparse.csr.csr_matrix)
return CSRMatrix(X.data, X.indices, X.indptr, X.shape)
[docs]
@data_handler_registry.register(py_sparse._csc.csc_matrix)
class CSCDataHandler(InRAMDataHandler):
[docs]
@classmethod
def to_nb(cls, X: py_sparse.csc_matrix) -> CSCMatrix:
assert isinstance(X, py_sparse.csc.csc_matrix)
return CSCMatrix(X.data, X.indices, X.indptr, X.shape)
[docs]
@data_handler_registry.register(h5py.Dataset)
class H5pyDatasetDataHandler(DenseDataHandler):
[docs]
def fetch(self, lb: int, ub: int) -> tuple:
return self.data[:, lb:ub], (0, ub - lb)
[docs]
@data_handler_registry.register(ad._core.sparse_dataset._CSCDataset)
class H5pyBackedCSCDataHandler(CSCDataHandler):
[docs]
@classmethod
def to_nb(cls, X: py_sparse.csc_matrix) -> CSCMatrix:
assert isinstance(X, py_sparse.csc.csc_matrix)
return CSCMatrix(X.data, X.indices, X.indptr, X.shape)
[docs]
def fetch(self, lb: int, ub: int) -> tuple:
return self.data[:, lb:ub], (0, ub - lb)
# Import kernel modules to trigger decorator registration
# These imports must come after the registry definitions above
from illico.ovo import ( # noqa: E402, F401
csc_ovo_mwu_kernel_over_contiguous_col_chunk,
csr_ovo_mwu_kernel_over_contiguous_col_chunk,
dense_ovo_mwu_kernel_over_contiguous_col_chunk,
)
from illico.ovr import ( # noqa: E402, F401
csc_ovr_mwu_kernel_over_contiguous_col_chunk,
csr_ovr_mwu_kernel_over_contiguous_col_chunk,
dense_ovr_mwu_kernel_over_contiguous_col_chunk,
)