That is the claim in the paper. I don't understand how it is supported by measuring results on the validation set.
Figure 3 looks nice but it doesn't say anything on its own. I don't know what's the best way to interpret it. The paper offers some interpretation that convinces you, but not me. Sorry, this kind of work is too fuzzy for me. What happened to good, old-fasion proofs?
Everyone's expectation would be that this is it. The model is overfitted, so it is useless. The model is as good as a hash map, 0 generalization ability.
The paper provides empirical, factual evidence that as you continue training there is still something happening in the model. After the model memorized the whole training dataset and while it still has not received any feedback information from the validation dataset, it starts to figure out how to solve validation dataset.
Mind you, this is not interpretation, this is factual. Long after 100% overfitting, the model is able to keep increasing its accuracy on dataset it has not seen.
It's as we discovered that water can flow upwards.
Grokking was discovered by someone forgetting to turn off their computer.
Nobody knows why. So, nobody is able to make any theoretical deductions about it.
But I agree that fig 3. requires interpretation. By itself it does not say a lot, but similar structures appear in other models like in the one where we discuss elements sequence prediction. To me, the models figure out some underlying structure of the problem, and we are able to interpret that structure.
I tend to look at it from Bayesian perspective. This type of evidence increases my belief that the models are learning what I would call semantics. It's a separate line of evidence from looking at benchmark results. Here we can get a glimpse at how some models may be doing some simple predictions and it does not look like memorization.
Yes, but the researchers get plenty of feedback from the validation set and there's nothing easier for them than to tweak their system to perform well on the validation set. That's overfitting on the validation set by proxy. It's absolutely inevitable when the validation set is visible to the researchers and it's very difficult to guard against because of course a team who has spent maybe a month or two working on a system with a publication deadline looming are not going to just give up on their work once they figure it it doesn't work very well. They're going to tweak it and tweak it and tweak it, until it does what they want it to. They're going to converge -they are going to converge- on some ideal set of hyperparameters that optimises their system's performance on its validation set (or the test set, it doesn't matter what it's called, it matters that it is visible to the authors). They will even find a region of the weight space where it's best to initialise their system to get it to perform well on the validation set. And, of course, if they can't find a way to get good performance out of their system, you and I will never hear about it because nobody ever publishes negative results.
So there are very strong confirmation and survivorship biases at play and it's not surprising to see, like you say, that the system keeps doing better. And that suffices to explain its performance, without the need for any mysterious post-overfitting grokking ability.
But maybe I haven't read the paper that carefully and they do guard against this sort of overfitting-by-proxy? Have you found something like that in the paper? If so, sorry for missing it myself.
It actually still does not suffice. It is just not expected no matter what the authors would be doing.
Just the fact that they managed to get that effect is interesting.
Granted, the phenomenon may be limited in scope. For example, on ImageNet it may require ridiculously long time scales. But maybe there is some underlying reason we can exploit to get to grokking faster.
It's basically all in fig 2.:
- they use 3 random seeds per result
- they show results for 12 different simple algorithmic datasets
- they evaluate 12 different combinations of hyperparameters
- for each hyperparameters combination they use 10+ different ratios of train to validation splits
So they do some 10*12*3*2 = 720 runs.
They conclude that hyperparameters are important. Seems like weight decay is especially important for the grokking phenomenon to happen when model has access to low ratio of training data.
Also, at least 2 other people managed to replicate that results:
https://twitter.com/sea_snell/status/1461344037504380931
https://twitter.com/lieberum_t/status/1480779426535288834
One hypothesis may be that models are just biased to randomly stumble upon wide, flat local minima. And wide, flat local minima generalize well.