Source code for group_lasso.utils

"""
"""
import numpy as np


[docs]def extract_ohe_groups(onehot_encoder): """Extract a vector with group indices from a scikit-learn OneHotEncoder Arguments --------- onehot_encoder : sklearn.preprocessing.OneHotEncoder Returns ------- np.ndarray A group-vector that can be used with the group lasso regularised linear models. """ if not hasattr(onehot_encoder, "categories_"): raise ValueError( "Cannot extract group labels from an unfitted OneHotEncoder instance." ) categories = onehot_encoder.categories_ return np.concatenate( [ group * np.ones_like(category) for group, category in enumerate(categories) ] )