For my deep learning class I wanted to investigate the ability of transformers to generalize. I thought shortest paths would be an easy synthetic task to study generalization in.
Boy was I wrong. I guess the transformers thought that this task was really hard. Probably this just means that Iām bad at deep learning. But Kevin did tell us all the best hyperparameters and it still didnāt really work. In protest of this injustice, Iām writing down how the transformers obviously shouldāve been computing shortest paths. Then, just to show the transformers whoās boss, Iām going to implement this, and then Iām going to demonstrate that the loss basin around this solution is nice and inviting. More precisely, Iām hoping to show that if you take the weights from my hard coded answer and perturb them a bit, gradient descent can still find some weights that do really good pretty easily. Singular learning theory folks are supposed to have some fancy way of actually measuring the size of this basin thing ā called the learning coefficient. Maybe Iāll look into that some more and see if we can implement that. Iām sure that whatever they do is more principled than my āletās perturb the weights and see if the loss changes a lotā approach.
Oh so I guess I have two proposed approaches:
change weights a bit, see how much loss changes by
change weights some more, see how far we can go and the model can recover the low loss
Q: how do you get a sense of scale? (what counts as a large weight change?)
um, idk, compare to the learning rate or something I guess.
Yeah, Iām not sure about this one yet.
Anyways, hereās the formulas.
Let be the number of vertices in the graph. Let be a bound on the diameter. The sequence length will be . One token for each vertex, and then one special token to output the answer.
We wonāt need positional encodings.
Let denote the layer activations for the token ās ādeep embeddingā. will live in . The first coordinate of will be an indicator for whether is the answer token. (1=answer token, 0 else). The next coordinates of will be indicators for whether has a neighbor to each vertex (listed in order). The next coordinates of will be indicators (ok they wonāt quite be 0/1 valued, but will be close to it) for whether can reach all the other vertices.
The last coordinates will be zero.
Now, for we put all ones in the NBR / REACH slots and then in the final coordinates we start them out as zero they will store whether 1 can reach 2 in that number of steps.
We will have layers and heads per layer.
The Key/Query/Value matrices will be computed as follows:
or based on whether can reach .
or based on whether is a neighbor of .
.
Now compute the scores, and softmax it.
Now, concatenate the heads and feed them through the following MLP:
multiple stuff by
Sigmoid it
subtract
multiply by
position it correctly
now this output gets added to the residual stream for that token
ok maybe we need one more head
this last head enables the ANS token to look at vertex 1 and see if it can reach vertex 2 yet.
If it can, then it puts a 1 in the appropriate spot of the embedding.
Finally, we can read off the answer by looking at the embedding vector for the ANS token.
oh and um you should not mess with the ANS tokenās embedding while doing the other heads.
EDIT
The above description isnāt totally faithful to what I ended up doing.
For what I ended up doing, please just see the below code.