Johnson–Lindenstrauss lemma
https://youtu.be/9-Jl0dxWQs8?list=PLZx_FHIHR8AwKD9csfl6Sl_pgCXX19eer&t=1125
THe number of vectors that can be fit into a spaces grows exponentially.
Useful for LLM in storing ideas.
Plotting M>N almost orthogonal vectors in N-dim space
Optimisation process that nudges then towards being perpendicular between 89-91 degrees
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm
# List of vectors in some dimension, with many
# more vectors than there are dimensions
num_vectors = 10000
vector_len = 100
big_matrix = torch.randn(num_vectors, vector_len)
big_matrix /= big_matrix.norm(p=2, dim=1, keepdim=True)
big_matrix.requires_grad_(True)
# Set up an optimization loop to create nearly-perpendicular vectors
optimizer = torch.optim.Adam([big_matrix], lr=0.01)
num_steps = 250
losses = []
dot_diff_cutoff = 0.01
big_id = torch.eye(num_vectors, num_vectors)
for step_num in tqdm(range(num_steps)):
optimizer.zero_grad()
dot_products = big_matrix @ big_matrix.T
# Punish deviation from orthogonality
diff = dot_products - big_id
loss = (diff.abs() - dot_diff_cutoff).relu().sum()
# Extra incentive to keep rows normalized
loss += num_vectors * diff.diag().pow(2).sum()
loss.backward()
optimizer.step()
losses.append(loss.item())
# Plot loss curve
plt.plot(losses)
plt.grid(True)
plt.show()
# Compute angle distribution
dot_products = big_matrix @ big_matrix.T
norms = torch.sqrt(torch.diag(dot_products))
normed_dot_products = dot_products / torch.outer(norms, norms)
angles_degrees = torch.rad2deg(torch.acos(normed_dot_products.detach()))
# Use this to ignore self-orthogonality
self_orthogonality_mask = ~(torch.eye(num_vectors, num_vectors).bool())
plt.hist(angles_degrees[self_orthogonality_mask].numpy().ravel(), bins=1000, range=(0, 180))
plt.grid(True)
plt.show()