2 Comments
User's avatar
Rainbow Roxy's avatar

Hey, great read as always, it’s fascinating how you break this down, but I was wondering if you could elabroate a bit on the actual masking implementation that stopps the model from “peeking” at future tokens, as that always strikes me as the clever bit.

Dr. Ashish Bamania's avatar

Thank you! The masking is performed using a triangular mask applied to the attention scores.

The mask is created using the torch.tril method, which generates a lower triangular matrix with 1s on and below the diagonal, and 0s above.

The positions with 0s (above the diagonal) represent future positions that should be masked out.

The mask is next applied using the masked_fill method to set all future positions (where mask values are 0) to -inf in the attention scores.

Softmax normalization finally converts -inf values at these places to 0. This means setting the attention weights for future tokens to 0.

Hope this makes sense!