{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nGroupLasso for logistic regression\n==================================\n\nA sample script for group lasso regression\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Setup\n-----\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\nimport numpy as np\n\nfrom group_lasso import LogisticGroupLasso\n\nnp.random.seed(0)\nLogisticGroupLasso.LOG_LOSSES = True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Set dataset parameters\n----------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "group_sizes = [np.random.randint(10, 20) for i in range(50)]\nactive_groups = [np.random.randint(2) for _ in group_sizes]\ngroups = np.concatenate([size * [i] for i, size in enumerate(group_sizes)])\nnum_coeffs = sum(group_sizes)\nnum_datapoints = 10000\nnoise_std = 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Generate data matrix\n--------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "X = np.random.standard_normal((num_datapoints, num_coeffs))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Generate coefficients\n---------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "w = np.concatenate(\n    [\n        np.random.standard_normal(group_size) * is_active\n        for group_size, is_active in zip(group_sizes, active_groups)\n    ]\n)\nw = w.reshape(-1, 1)\ntrue_coefficient_mask = w != 0\nintercept = 2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Generate regression targets\n---------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "y_true = X @ w + intercept\ny = y_true + np.random.randn(*y_true.shape) * noise_std\np = 1 / (1 + np.exp(-y))\np_true = 1 / (1 + np.exp(-y_true))\nc = np.random.binomial(1, p_true)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "View noisy data and compute maximum accuracy\n--------------------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.figure()\nplt.plot(p, p_true, \".\")\nplt.xlabel(\"Noisy probabilities\")\nplt.ylabel(\"Noise-free probabilities\")\n# Use noisy y as true because that is what we would have access\n# to in a real-life setting.\nbest_accuracy = ((p_true > 0.5) == c).mean()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Generate estimator and train it\n-------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "gl = LogisticGroupLasso(\n    groups=groups,\n    group_reg=0.05,\n    l1_reg=0,\n    scale_reg=\"inverse_group_size\",\n    subsampling_scheme=1,\n    supress_warning=True,\n)\n\ngl.fit(X, c)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Extract results and compute performance metrics\n-----------------------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Extract info from estimator\npred_c = gl.predict(X)\nsparsity_mask = gl.sparsity_mask_\nw_hat = gl.coef_\n\n# Compute performance metrics\naccuracy = (pred_c == c).mean()\n\n# Print results\nprint(f\"Number variables: {len(sparsity_mask)}\")\nprint(f\"Number of chosen variables: {sparsity_mask.sum()}\")\nprint(f\"Accuracy: {accuracy}, best possible accuracy = {best_accuracy}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Visualise regression coefficients\n---------------------------------\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "coef = gl.coef_[:, 1] - gl.coef_[:, 0]\nplt.figure()\nplt.plot(w / np.linalg.norm(w), \".\", label=\"True weights\")\nplt.plot(\n    coef / np.linalg.norm(coef), \".\", label=\"Estimated weights\",\n)\n\nplt.figure()\nplt.plot([w.min(), w.max()], [coef.min(), coef.max()], \"gray\")\nplt.scatter(w, coef, s=10)\nplt.ylabel(\"Learned coefficients\")\nplt.xlabel(\"True coefficients\")\n\nplt.figure()\nplt.plot(gl.losses_)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}