For my deep learning class I wanted to investigate the ability of transformers to generalize. I thought shortest paths would be an easy synthetic task to study generalization in. Boy was I wrong. I guess the transformers thought that this task was really hard. Probably this just means that Iā€™m bad at deep learning. But Kevin did tell us all the best hyperparameters and it still didnā€™t really work. In protest of this injustice, Iā€™m writing down how the transformers obviously shouldā€™ve been computing shortest paths. Then, just to show the transformers whoā€™s boss, Iā€™m going to implement this, and then Iā€™m going to demonstrate that the loss basin around this solution is nice and inviting. More precisely, Iā€™m hoping to show that if you take the weights from my hard coded answer and perturb them a bit, gradient descent can still find some weights that do really good pretty easily. Singular learning theory folks are supposed to have some fancy way of actually measuring the size of this basin thing ā€” called the learning coefficient. Maybe Iā€™ll look into that some more and see if we can implement that. Iā€™m sure that whatever they do is more principled than my ā€œletā€™s perturb the weights and see if the loss changes a lotā€ approach.

Oh so I guess I have two proposed approaches:

  • change weights a bit, see how much loss changes by
  • change weights some more, see how far we can go and the model can recover the low loss

Q: how do you get a sense of scale? (what counts as a large weight change?) um, idk, compare to the learning rate or something I guess. Yeah, Iā€™m not sure about this one yet.

Anyways, hereā€™s the formulas.

Let be the number of vertices in the graph. Let be a bound on the diameter. The sequence length will be . One token for each vertex, and then one special token to output the answer. We wonā€™t need positional encodings.

Let denote the layer activations for the token ā€˜s ā€œdeep embeddingā€. will live in . The first coordinate of will be an indicator for whether is the answer token. (1=answer token, 0 else). The next coordinates of will be indicators for whether has a neighbor to each vertex (listed in order). The next coordinates of will be indicators (ok they wonā€™t quite be 0/1 valued, but will be close to it) for whether can reach all the other vertices. The last coordinates will be zero. Now, for we put all ones in the NBR / REACH slots and then in the final coordinates we start them out as zero they will store whether 1 can reach 2 in that number of steps.

We will have layers and heads per layer. The Key/Query/Value matrices will be computed as follows: or based on whether can reach . or based on whether is a neighbor of . .

Now compute the scores, and softmax it.

Now, concatenate the heads and feed them through the following MLP:

  • multiple stuff by
  • Sigmoid it
  • subtract
  • multiply by
  • position it correctly

now this output gets added to the residual stream for that token

ok maybe we need one more head this last head enables the ANS token to look at vertex 1 and see if it can reach vertex 2 yet. If it can, then it puts a 1 in the appropriate spot of the embedding.

Finally, we can read off the answer by looking at the embedding vector for the ANS token. oh and um you should not mess with the ANS tokenā€™s embedding while doing the other heads.

EDIT
The above description isnā€™t totally faithful to what I ended up doing. For what I ended up doing, please just see the below code.

"""how-tsp-should-be.ipynb
 
Automatically generated by Colab.
 
Original file is located at
    https://colab.research.google.com/drive/1InE1iW8ARzndPpvqH_9y22s81sOiHxPs
"""
 
from tqdm import tqdm
import torch
import torch.nn as nn
import matplotlib as mpl
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
 
from math import sqrt
from collections import deque
import os
import random
import pickle
import ipdb
 
#  torch.manual_seed(30)
#  random.seed(30)
torch.manual_seed(33)
random.seed(33)
 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# assert device.type == "cuda", "CUDA is not available. Please check your GPU setup."
 
NVTXS = 6
MAXDIST = NVTXS+1
AVGDEG = 2
SEQLEN = NVTXS + 1
HIDDENDIM = 4*NVTXS+2
 
# 0: ANSFLAG
# 1:NVTXS+1 NBRS
# NVTXS+1: 2*NVTXS+1 REACH
# 2*NVTXS+1: 3*NVTXS+1 SELF
# -1 NOTANSFLAG
 
START_REACH = NVTXS+1
START_OUT = 2*NVTXS+1
START_SELF = 3*NVTXS+1
SRC_FLAG_IDX = START_SELF
SOURCE = 1
TARGET = 2
ANS_FLAG_IDX = 0
NOTANS_FLAG_IDX = -1
 
def print_everything(data):
    print("NBRS")
    print(data[0, 1:, 1:1+NVTXS])
    print("REACH")
    print(data[0, 1:, START_REACH:START_REACH+NVTXS])
    print("ANSFLAG")
    print(data[0, :, 0])
    print("MORE FLAGS")
    print(data[0, :, -1])
    print("SELF")
    print(data[0, 1:, START_SELF:START_SELF+NVTXS])
    print("OUT")
    print(data[0, 0, START_OUT:START_OUT+NVTXS])
 
 
def random_graph():
    data = torch.zeros((SEQLEN, HIDDENDIM))
 
    for i in range(1,NVTXS+1):
        data[i, START_SELF-1+i] = 1
 
    adj_list = [set() for _ in range(SEQLEN)]
    indices = [random.randint(1, NVTXS) for _ in range(AVGDEG * NVTXS)]
    for i in range(0, len(indices), 2):
        u = indices[i]
        v = indices[i + 1]
        if u != v:
            data[v,u] = 1
            data[u,v] = 1
            data[v,NVTXS+u] = 1
            data[u,NVTXS+v] = 1
            adj_list[u].add(v)
            adj_list[v].add(u)
 
    data[0, ANS_FLAG_IDX] = 1
    data[1:, NOTANS_FLAG_IDX] = 1
 
    # TODO: this is kind of a hack
    data[0, START_REACH:START_REACH+NVTXS] = 1
    return data, adj_list
 
"""
input: G, represented as an adjacency list
output: distance from SOURCE to TARGET
"""
def SSSP(G):
    dist = [MAXDIST for _ in G]
    dist[SOURCE] = 0
    frontier = deque()
    frontier.append(SOURCE)
    while len(frontier) > 0:
        vtx = frontier.popleft()
        for x in G[vtx]:
            if dist[x] == MAXDIST:
                dist[x] = 1 + dist[vtx]
                frontier.append(x)
                if x == TARGET:
                    return dist[TARGET]
    return MAXDIST
 
def mkbatch(size):
    graphs1 = []
    distance1 = []
 
    for i in range(size):
        data, adj_list = random_graph()
        dist = SSSP(adj_list)
        graphs1.append(data)
        distance1.append(dist)
 
        print(adj_list)
 
    data = torch.stack(graphs1)
    labels = torch.tensor(distance1, dtype=torch.float16)
    return data, labels
 
"""
TODO: WRAP EVERYTHING in nn.Parameter(torch.zeros((1, HIDDENDIM)))
and then do my perturbing parameters experiment
 
TODO:
    USE activation magic to bring everything back to the 0/1 realm instead of possibly being 0/2 valued
"""
 
class SillyTransformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.most_KQVs = []
        for head in range(1,NVTXS+1):
          Q = torch.zeros((2, HIDDENDIM))
          Q[0, START_REACH-1+head] = 1000
          Q[1, NOTANS_FLAG_IDX] = 1
 
          K = torch.zeros((2, HIDDENDIM))
          K[0, head] = 1
          K[1, ANS_FLAG_IDX] = 200
 
          V = torch.zeros((NVTXS,HIDDENDIM))
          for i in range(NVTXS):
              V[i, START_SELF+i] = 1
 
          self.most_KQVs.append((K, Q, V))
 
        self.weird_KQVs = []
        for layer in range(NVTXS):
            K = torch.zeros((3, HIDDENDIM))
            K[0, NOTANS_FLAG_IDX] = -1000
            K[0, SRC_FLAG_IDX] = +1100
            K[1, NOTANS_FLAG_IDX] = -1000
            K[1, NVTXS+TARGET] = +1100
            K[1, ANS_FLAG_IDX] = -1100
            K[2, ANS_FLAG_IDX] = 10
 
            Q = torch.zeros((3, HIDDENDIM))
            Q[:, ANS_FLAG_IDX] = 1
 
            V = torch.zeros((NVTXS, HIDDENDIM))
            V[layer, SRC_FLAG_IDX] = 1
 
            self.weird_KQVs.append((K, Q, V))
 
    def forward(self, src):
      for layer in range(NVTXS):
        allKQVs = [self.weird_KQVs[layer]] + self.most_KQVs
        head_outputs = []
        for (K, Q, V) in allKQVs:
            ksrc = torch.matmul(src, K.unsqueeze(0).transpose(-2, -1))
            qsrc = torch.matmul(src, Q.unsqueeze(0).transpose(-2, -1))
            vsrc = torch.matmul(src, V.unsqueeze(0).transpose(-2, -1))
 
            scores = torch.matmul(qsrc, ksrc.transpose(-2, -1))
            attention_weights = torch.softmax(scores, dim=-1)
            head_output = torch.matmul(attention_weights, vsrc)
            head_outputs.append(head_output)
 
        new_reaches = sum(head_outputs[1:])
        BSZ = new_reaches.shape[0]
 
        nodelta_nbrs = torch.zeros((BSZ, SEQLEN, NVTXS+1))
        morepadlol = torch.zeros((BSZ, SEQLEN, 1+NVTXS))
 
        DIFF = torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)
        src += torch.cat((nodelta_nbrs, new_reaches, head_outputs[0], morepadlol), dim=2)
        src[:, :, START_REACH:START_REACH+NVTXS] = 2*torch.sigmoid(src[:,:, START_REACH:START_REACH+NVTXS]*1000)-1
 
        #  print("SRC")
        #  print_everything(src)
 
      canreach = src[:,0,START_OUT:START_OUT+NVTXS]
      #  __import__('ipdb').set_trace()
      final_output = 1+torch.sum(1-canreach,dim=1)
      return final_output
 
model = SillyTransformer()
model.to(device)
 
data, labels = mkbatch(10)
assert torch.all(model(data) == labels)