{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": false, "tags": [ "hide-cell" ] }, "source": [ "# Neural Networks\n", "\n", "Neural networks being a very powerful class of models, especially in cases where the learning of representations from low-level information (such as pixels, audio samples or text) is key, sensAI provides many useful abstractions for dealing with this class of models, facilitating data handling, learning and evaluation.\n", "\n", "sensAI mainly provides abstractions for PyTorch, but there is also rudimentary support for TensorFlow." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2024-08-13T22:11:29.915658Z", "iopub.status.busy": "2024-08-13T22:11:29.915454Z", "iopub.status.idle": "2024-08-13T22:11:29.930959Z", "shell.execute_reply": "2024-08-13T22:11:29.930330Z" }, "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2024-08-13T22:11:29.933962Z", "iopub.status.busy": "2024-08-13T22:11:29.933575Z", "iopub.status.idle": "2024-08-13T22:11:31.040239Z", "shell.execute_reply": "2024-08-13T22:11:31.039583Z" }, "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "%%capture\n", "import sys; sys.path.extend([\"../src\", \"..\"])\n", "import sensai\n", "import pandas as pd\n", "import numpy as np\n", "from typing import *\n", "import config\n", "import warnings\n", "import functools\n", "\n", "cfg = config.get_config()\n", "warnings.filterwarnings(\"ignore\")\n", "sensai.util.logging.configure()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Image Classification\n", "\n", "As an example use case, let us solve the classification problem of classifying digits in pixel images from the MNIST dataset. Images are greyscale (no colour information) and 28x28 pixels in size." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2024-08-13T22:11:31.043884Z", "iopub.status.busy": "2024-08-13T22:11:31.043577Z", "iopub.status.idle": "2024-08-13T22:11:33.260745Z", "shell.execute_reply": "2024-08-13T22:11:33.260043Z" } }, "outputs": [], "source": [ "mnist_df = pd.read_csv(cfg.datafile_path(\"mnist_train.csv.zip\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The data frame contains one column for every pixel, each pixel being represented by an 8-bit integer (0 to 255)." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2024-08-13T22:11:33.264505Z", "iopub.status.busy": "2024-08-13T22:11:33.264042Z", "iopub.status.idle": "2024-08-13T22:11:33.293802Z", "shell.execute_reply": "2024-08-13T22:11:33.293054Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", " | label | \n", "1x1 | \n", "1x2 | \n", "1x3 | \n", "1x4 | \n", "1x5 | \n", "1x6 | \n", "1x7 | \n", "1x8 | \n", "1x9 | \n", "... | \n", "28x19 | \n", "28x20 | \n", "28x21 | \n", "28x22 | \n", "28x23 | \n", "28x24 | \n", "28x25 | \n", "28x26 | \n", "28x27 | \n", "28x28 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "5 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
2 | \n", "4 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
3 | \n", "1 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
4 | \n", "9 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "... | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "0 | \n", "
5 rows × 785 columns
\n", "\n", " | accuracy | \n", "balancedAccuracy | \n", "
---|---|---|
model_name | \n", "\n", " | \n", " |
MLP | \n", "0.962250 | \n", "0.961897 | \n", "
CNN | \n", "0.978333 | \n", "0.978435 | \n", "
CNN' | \n", "0.977167 | \n", "0.977261 | \n", "
RandomForest | \n", "0.946667 | \n", "0.945917 | \n", "