GNN

Representation

  • Node embedding: Each node embedded as a vector, and the entire graph represented as adjacency matrix and feature matrix (attribute vector)
  • Graph embedding: Entire graph represented as a vector.

Properties

  • Permutation invariance (Graph embedding), : Permuting the node embeddings and attribute vector has no effect on the output. When we want to predict anything about the entire graph, for example classifying a molecule, then we want the model to permutation invariant.
  • Permutation equivariance (Node embedding), : Permuting the node embeddings and attribute vector is equivalent to first applying the function, and then permuting. Or in other words, translation of input features result in an equivalent translation of outputs, or permuting the adjancency matrix means the output of f is permuted in a consistent way.
  • We need to satisfy either of the two: i.e. invariance or equivariance.

Why do we not represent a graph as an MLP? Why do we not represent as a CNN? what are the similarities?

  • Similarity: Locality, Weight Sharing, arbitrary input size
  • Difference: Abstract shape

Message Passing

Take a graph: along with a set of node features , and generate node embeddings .

what can node features be?

Depends on the problem we’re solving. For molecular graphs, this can be information for each atom in the molecule, for social graphs, information can be about each individual member in the graph. For no individual node features, input can still be statistics of the node in the graph, and can even contain some more information about the graph itself. To break permutation equivariance, we can assign positional encoding to each node, say as a one-hot vector encoding.

Aggregate: in each round k, each node aggregates the message (feature description) from its neighbours and update the weights.

  • Initial embedding is set to be features of the node:
  • Differentiable, multiset function. Input being a set designs the graph as Permutation equivariant.
  • Sum, Mean, Max/Min.
  • : Universal approximation of multiset functions.
  • Receptive field of the graph increases with each iteration, as more information further away from the node is aggregated.
  • Messages from the neighbours can encode structural information like degree of the neighbour node, useful in problems like analysing molecular graphs. Or can also encode feature-based information from local neighbourhood of the graph analogous to how CNNs aggregate feature information from spatially-defined patches.

Update: .

  • , where and

Readout: : Outputs a final result after K final iteration. Like pooling in CNNs.

We can also define a graph-level equation for aggregate and update, and we can even batch aggregate and update in one equation using self-loops:

Similarity to MLP with each node’s weight being a vector instead of a scalar, and aggregate and update together forms the linear and pointwise layer. So actually, a GNN can learn much more than a simple MLP.

Every iteration weight of single node gets updated using it's neighbour, but that d-dimensional vector will be saturated after many iterations? Will a single node embedding contain any meaningful information from its neighbourhood?

Over-smoothing: Representation of all the nodes in the graph can become very similar to one another. This makes it impossible to build deeper GNN models.

Formally, Define influence of a node’s input feature on the final layer embedding of all other nodes in the graph . For any pair of nodes u, v in the graph, influence of u on v is quantified using the Jacobian .

Building deeper models can hurt performance of GNN models as with each added layer, information loss about local neighbourhood increases and learned embeddings are over-smoothed.

More on the influence of self-update and deeper models can be found in GRL book and Xu et. al..

Generalisations

Aggregate

  • Normalisation: Mean normalisation or symmetric normalisation .
    • Why does symmetric normalisation work better than mean?

    • Normalisation leads to loss of structural information, as is provably, less powerful than sum aggregation. Normalisation is more useful when feature information is more useful than structural information.
  • Set pooling: Use a universal set function approximator that can approximate any permutation-invariant aggregator function:
  • Janossy Pooling: Permutation sensitive function averaged over permutations
  • Attention: : weighted aggregation. Is used to increase the inductive bias of the model with prior information about importance of neighbours.
  • Multiple attention heads: Compute K distinct attention weights using independent parametrised attention layers. Aggregate all message by projection and concatenation.

Update

  • Skip connections: Counter over-smoothing by directly preserving information from previous rounds of message passing.
    • , where and , and can be learned jointly with other representations.
    • Due to the analogous properties of CNNs, concatenation and skip connection as described in He et. al produces similar results.

Features and Relationships

  • Edge attributes:
  • Multi-relational: aggregation can depend on the relationship between nodes.

Generalised Message Passing

Main improvement over baseline message passing is that during each iteration, the model generates a hidden edge embedding for all edges in the graph, and an overall graph embedding corresponding to the entire graph. This helps differentiate between edge and node level features and entire graph-level features. We can also define different loss functions for different type of embeddings, and tasks.

Approximation Theory

Graph Isomorphisms: Given two graphs , declare whether two graphs are isomorphic. Formally, we say two graphs with adjacency matrix and feature matrix are isomorphic if and only if there exists a permutation matrix P such that and . Or informally, when they have same structure but differ in ordering of nodes in their adjacency matrices.

  • Weisfeiler-Lehman isomorphism test
  • Distinguishing capacity of GNNs

Problems I’m Seeing:

  • A d-dimensional weight vector exists for each node in the graph, so the size scales with . And number of edges will also mean that update function will be hard to compute, but is actually emabarrasingly parallel.
  • How to choose depth K?
  • How do you update the graph structure as you process more? Can we prune or connect more edges and nodes?

More Readings: