- categories: Data Science, Method, Statistics, Probability
Definition:
Softmax regression, also known as multinomial logistic regression, is a generalization of logistic regression to handle multiclass classification. It models the probability distribution over classes for a given input by using the Softmax Function in the output layer.
Model:
For classes, the probability of an input belonging to class is given by:
where:
- : Feature vector (may include a bias term).
- : Parameter matrix where is the weight vector for class .
- : Target class label.
Key Properties:
-
Probabilities:
The softmax output forms a valid probability distribution:
-
Linear Decision Boundaries:
Softmax regression assumes linear decision boundaries between classes. -
Interpretability:
Each can be interpreted as the contribution of features to the likelihood of class .
Loss Function:
Softmax regression is trained by minimizing the cross-entropy loss:
where:
- : Number of training examples.
- : Indicator function (1 if the true label of is , 0 otherwise).
Gradient Descent for Optimization:
The gradient of the loss function with respect to is:
This gradient is used in optimization algorithms such as stochastic gradient descent (SGD).
Steps for Training Softmax Regression:
- Initialize (e.g., small random values).
- Compute the softmax probabilities for each class using:
- Compute the cross-entropy loss.
- Compute gradients and update using an optimization method like gradient descent.
- Iterate until convergence.
Example:
Consider 3 classes (), an input feature vector , and parameter vectors:
-
Compute logits (linear combinations):
-
Apply the softmax function:
-
Predict the class with the highest probability.
Applications:
- Multiclass classification problems: text classification, image recognition, etc.
- Neural networks: Softmax is often used in the output layer of multiclass models.
Limitations:
- Assumes linear decision boundaries; may underperform on complex datasets.
- Requires sufficient data to avoid overfitting, especially with many classes.