Perplexity

1

What is Perplexity in LLM?

Perplexity is a common evaluation metric used in Natural Language Processing (NLP) and Language Models (LLMs) to measure how well a model predicts a sample. It reflects the uncertainty or confidence of the model in its predictions. Lower perplexity indicates better performance, as it implies the model assigns higher probabilities to the correct tokens.

Definition

Perplexity (PPLPPL) is defined as:

PPL=2−1N∑i=1Nlog⁡2P(xi)PPL = 2^{-\frac{1}{N} \sum_{i=1}^N \log_2 P(x_i)}

Where:

  • NN: Total number of tokens in the dataset or sequence.

  • P(xi)P(x_i): Probability assigned by the model to the ithi^{th} token in the sequence.

Alternatively, in terms of cross-entropy loss:

PPL=eCross-EntropyPPL = e^{\text{Cross-Entropy}}

Intuition

  • A perplexity of kk means the model is as uncertain as if it were choosing uniformly from kk options.

  • Lower perplexity means the model is more certain and assigns higher probabilities to the actual tokens, indicating better language understanding or prediction.

Use in LLMs

In the context of Large Language Models (LLMs), perplexity is used to:

  1. Evaluate the Model's Quality: It quantifies the model's predictive performance on a test dataset.

  2. Compare Models: A model with a lower perplexity on the same dataset is generally considered better at language understanding or generation.

  3. Monitor Training: Perplexity is monitored during training to ensure the model is improving.

Limitations

  • Doesn't Capture Real-World Usage: Perplexity focuses on token-level probabilities and may not correlate directly with human-perceived fluency or utility in downstream tasks.

  • Dataset Dependency: Perplexity is sensitive to the dataset used for evaluation, and its interpretation can vary depending on tokenization or the domain of the data.

Example

If a model assigns probabilities P(x1),P(x2),…,P(xN)P(x_1), P(x_2), \ldots, P(x_N) to tokens in a sentence, the perplexity reflects how well the model "perplexes" or predicts the next tokens. A perfect model (one that always assigns a probability of 1 to the correct token) would have a perplexity of 1. A model that assigns uniform probabilities would have higher perplexity.

Practical Insights

  • For modern LLMs, perplexity on large datasets like Wikipedia or books is used as a benchmark.

  • Very low perplexity may not always translate to practical task performance, so it's often combined with other metrics.


2

Write a code in scratch for Perplexity

This code snippet computes the perplexity of a model's predictions given a batch of logits (raw scores output by the model) and the corresponding target tokens (ground truth labels). Here's a detailed explanation of each part:


1. Function Overview

This function calculates the perplexity, a measure of how well the model predicts the given targets. It takes:

  • logits: A 3D tensor of shape (batch_size, seq_length, vocab_size) representing the model's raw predictions (before applying softmax) for each token in the sequence.

  • target: A 2D tensor of shape (batch_size, seq_length) containing the ground truth token indices.


2. Convert Logits to Log Probabilities

  • Purpose: Convert the raw logits to log probabilities. The softmax function converts logits into probabilities, and the log_softmax computes the log of these probabilities to avoid numerical instability.

  • Output Shape: Same as logits(batch_size, seq_length, vocab_size).


3. Gather Log Probabilities for Target Tokens

  • Purpose: Extract the log probabilities of the correct target tokens.

    • target.unsqueeze(-1): Adds an extra dimension to target to match the last dimension of log_probs.

      • Original target: (batch_size, seq_length)

      • After unsqueeze: (batch_size, seq_length, 1)

    • log_probs.gather(dim=-1, index=...): For each position in the sequence, it picks the log probability corresponding to the correct target token.

    • squeeze(-1): Removes the last singleton dimension to return a 2D tensor of shape (batch_size, seq_length).


4. Calculate Negative Log Likelihood

  • Purpose: Negative log likelihood (NLL) measures the penalty for predicting the correct tokens with low probability. This is the main component for calculating perplexity.


5. Compute Mean NLL Across Tokens

  • Purpose: Average the NLL values across all tokens in the batch. This gives a scalar representing the overall performance of the model.


6. Calculate Perplexity

  • Purpose: Perplexity is computed as the exponential of the mean negative log likelihood. This step converts the log domain back to the probability domain.

  • Interpretation: A lower perplexity value indicates better model performance, as it means the model predicts the correct tokens with higher confidence.


7. Return Perplexity

  • Converts the scalar tensor perplexity to a Python float for easier usage.


Example Usage

The example demonstrates how to use the function with simulated data.

  1. Simulated logits:

This creates a random tensor with:

  • batch_size = 2: Two sequences.

  • seq_length = 4: Each sequence contains four tokens.

  • vocab_size = 10: The model predicts from a vocabulary of size 10.

  1. Simulated target:

Ground truth token indices for the two sequences.

  1. Calculate Perplexity:

The function processes the logits and target to compute the perplexity score.

  1. Print Result:


Key Points:

  • Logits: Raw model outputs (unscaled scores for each token in the vocabulary).

  • Target: Ground truth token indices.

  • Perplexity: Exponential of the average negative log likelihood. Lower perplexity means the model is better at predicting the targets.

  • Usage: Evaluate language models' quality in tasks like text generation or language understanding.


3

what is logit?

A logit is the raw output of a machine learning model before any activation function, such as softmax or sigmoid, is applied. It represents the unnormalized scores assigned by the model for each class in a classification task.

In mathematical terms:

  • For a single instance in a classification problem, a model's output might be a vector of logits, one for each class. These logits can take any real value (−∞,∞-\infty, \infty).

  • The logits are typically passed through an activation function like softmax (for multiclass classification) or sigmoid (for binary classification) to convert them into probabilities.

Characteristics of Logits

  • Unnormalized: Logits do not sum to 1 or fall within a bounded range like probabilities.

  • Linear Outputs: They are directly produced by the model's final layer before activation, often as a linear combination of learned weights and input features.

  • Raw Scores: Logits can be positive or negative and are not constrained like probabilities.

Example: Softmax and Logits

Consider a classification model with 3 classes. The output logits might look like this:

logits=[2.5,0.3,−1.2]\text{logits} = [2.5, 0.3, -1.2]

To convert logits to probabilities, apply the softmax function:

P(yi)=elogiti∑j=13elogitjP(y_i) = \frac{e^{\text{logit}_i}}{\sum_{j=1}^3 e^{\text{logit}_j}}

This gives probabilities like:

probabilities=[0.84,0.14,0.02]\text{probabilities} = [0.84, 0.14, 0.02]

Why Use Logits?

  1. Mathematical Simplicity: Logits are easier for optimization algorithms (like gradient descent) to work with compared to probabilities.

  2. Activation Functions: Logits are transformed into probabilities only when needed, such as during inference or loss computation.

In the Context of LLMs

In language models (LLMs), logits are the output of the model for each token in the vocabulary. For example:

  • In a vocabulary of size 10,00010,000, the model outputs a vector of 10,00010,000 logits for each token position in a sequence.

  • These logits are then passed through a softmax function to compute probabilities, which determine the likelihood of each word/token in the vocabulary.

Summary

  • Logit: Raw, unnormalized output scores of a model.

  • Purpose: They are intermediate values used to calculate probabilities.

  • Context in LLMs: Logits represent the model's confidence in predicting the next token in the vocabulary before normalization.


4


5


6


7


8


9


10


11


12


13


14


15


16


17


18


19


20


21


22


23


24


25


26


27


28


29


30


31


32


33


34


35


36


37


38


39


40


41


42


43


44


45


46


47


48


49


50

Last updated