This idea came into existance after reading the article written by Jay Alammar on RETRO model The Illustrated Retrieval Transformer. Its a small but intuitively written post about RETRO model and how it work. Its highly recommend that you read it before continuing this post.
Basic idea of retro model can be summarized into 5 steps
In the original RETRO paper the authors have used BERT+DATBASE to retrieve additonal information, that is, step:1 - step:3 to train a RETRO model. My idea is to combine the step:1 through step:3 and approximate it into a probabilistic model. Instead of retriving similar/neigbour chunks from database we sample from this probabilistic model to get similar/neigbouring low dimensional latent vectors. We then use this latent vectors to train a RETRO like MLP model. The idea is, the probabilistic model will act like a memory reservoir for storing data and the MLP will fetch additional data from this reservoir to predict an output. The assumption is that instead of using a single input the additonal retrived neigbouring vectors will help the model to give more generalized/accurate output without using a large database.
The orginal RETRO paper was trained in textual data. Because of contstraint in resources we opted to use MNIST image data to train this memory retriving plus classification model. In this project a Variational AutoEncoder (VAE) was chosen as the probabilistic model,other probabilistic model should also work given we can sample some latent vectors. The VAE model was not built by myself, I have used Jackson Kang's VAE Model with slight modifications in latent diminesion.
First the VAE was pre-trained with MNIST
data.From this pre-trained VAE model; for
an image we can fetch its mean and variance
from encoder. Using this mean and variance
we can sample 'n' latent vectors(ours n=10)
using the below formula, here N is the
Normal Distribution. For calculating latent
vectors; mean vector was multiplied with an
alpha (N(1,0.5)), this was done to add
stochasticity inorder to sample from
neigbouring vectors. The 'n' latent
vectors generated were used as an input to
train a MLP classifer. We have used latent
vectors to train the MLP instead of the
generated images from VAE due of its lesser
number of dimensions. The lower dimensional
latent vector encodes the important
features and it also helps in faster
training and inference. Another
hypothetical reason for choosing the latent
vector is that we beileve brain encodes and
stores informations in smaller dimensions.
Everyday, Everytime human brain is
bombarded with signals from various analog
sensory inputs such eye, ear, skin etc.
Human brain is a finite storarge machine
and it will unlikely store all these higher
dimensonal inputs from these sensors as raw
data. Its more advantageous to store
memories in lower dimensions and retrieve
those memories probablistically.
For a batch size of 1, this model was trained in pytorch with SGD as optimizer and CrossEntropy as loss function.I have trained it for 3 epochs using MNIST full data and the last results are.
Project Github Page: https://github.com/Gananath/probablistic_memory_retrieval