This course introduces the fundamentals of machine learning techniques on meshes. We start with a short introduction of machine learning on regular structures (e.g., images), and discuss their generalization to the irregular mesh structure. Despite several alternative representations exist (e.g., implicits, voxels, point clouds), we focus our discussion on the techniques directly defined on surface triangle meshes, which is one of the most commonly used representations in graphics over the decades. Specifically, we will cover the building blocks of a mesh convolutional neural network, including convolution, pooling, regularization, and loss functions. We will also cover how to apply convolutional networks to classic geometry processing tasks, such as geometric texture synthesis, shape classification, and subdivision. We propose a mesh MNIST dataset for research prototyping and hand-on demos to these learning tasks.
DatasetsThe MNIST dataset contains thousands of handwritten digits. This is commonly used as the Hello World for deep learning on images. We create an analogous 3D mesh MNIST dataset for 3D deep learning in Link.
DemosWe introduce a series of self-contained examples based on open source libraries such as JAX and PyTorch. The purpose of these examples is to demonstrate how to implement a simple machine learning model on meshes.
1. Simple mesh CNN without poolingWe present a basic example on using mesh CNN to classify meshes of "1" and meshes of "2" from our meshMNIST dataset. We will cover the full pipeline of loading a shape, computing input features, define a mesh CNN, and the network training. We recommend readers to use well-optimized functions for e.g. loading training data or network components. But for clarity purposes, we implement an as-simple-as-possible mesh CNN from scratch using JAX without those utility functions.
Let's get started by import the basic libraries that we will need in this tutorial
We represent each shape as a surface triangle mesh and store as a triangle soup. Specifically, the mesh is represented by two matrices:
- a #vertices by 3 matrix of vertex positions
- a #faces by 3 matrix of triangle indices
Finto the vertex list
- a #edges by 2 matrix of (undirected) edge indices
- a #edges by 4 matrix of flap edge indices
flap(4 boundary edges of each two adjacent triangles) into the edge list
- a #edges by n matrix of input features
fEwhere n is the featrue dimension.
compute_inputcomputes the input feature for each edge. Which input feature to compute is a choice to make, here we simply compute the dihedral angle and the 4 edge length ratios for simplicity Since this pre-processing usually takes a while, we often process the entire dataset once and save the information After processing the meshes, we can now define the mesh CNN. Similar to a lot of mesh CNNs (e.g., MeshCNN), our architecture consists of a bunch of convolution layers, followed by a global pooling and a fully connected layers. For simplicity, we omit the mesh pooling processing in this toy example.
In each mesh convolution, the input is a n dimensional function defined on each edge
fEin our case. For other applications, this could be a function defined on each face/vertex. No matter "where" we store the input function, the first step in mesh convolution is to re-index this function into a format for fast convolutions. In this example, we will re-index the #edges by n edge functions
fEinto flap functions
fPfollowing the formula in MeshCNN Equation 2 as Each flap function is a n by 5 matrix, thus
fPis a #edges by n by 5 3D matrix. After that, mesh convolution can be performed easily with 2D convolution, where each convolution filter has size n by 5 and it outputs a scalar value. We implement it from scratch for clarity, but one should switch to a more well-optimized implementation for performance reasons With the structure of mesh convolution in mind, we can now initialize the network parameters and define the forward pass (convolutions, global pooling, fully connected multilayer perceptron) as The rest of the pipeline is the same as e.g. image based classification networks. We need to define the loss function, optimizer, parameter update functions Then start the network training The complete version of this tutorial can be found here. This training code takes one shape at a time and use stochastic gradient descent to optimize the network. One could also accumulate the gradient from a batch of shapes before taking a gradient step. Ideally, one should also use the dataloader so that loading the data won't be the bottleneck of training. One could also use a validation set to early stop the training to avoid overfitting.