-
Notifications
You must be signed in to change notification settings - Fork 6
24 add cross attention labels text #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- module and config created to do that - mainly attached the TextEmbedder (it aggregates the token embedding to produce a sentence embedding - instead of naive averaging) - rest of the code has been adapted, especially categorical var handling in TextClassificationModel
used as a namespace after, so no converting it throws a bug
- given a parameter, retrieve the attention matrix - compatible with captum attributions - update tests accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Adds optional “label attention” cross-attention to the text classification pipeline so the model can produce label-specific sentence embeddings and (optionally) return label×token attention matrices for explainability.
Changes:
- Introduces
LabelAttentionConfig/LabelAttentionClassifierand integrates label attention intoTextEmbedder. - Updates model forward/predict paths to support returning label-attention matrices and adjusts classifier head output shape when label attention is enabled.
- Extends pipeline tests to cover label-attention-enabled training/prediction.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
torchTextClassifiers/torchTextClassifiers.py |
Wires label-attention config through initialization, updates predict() explainability options, and deserializes label-attention config on load. |
torchTextClassifiers/model/model.py |
Propagates label-attention enablement, adjusts forward pass to optionally return attention matrices, and normalizes embeddings before the head. |
torchTextClassifiers/model/lightning.py |
Minor formatting-only change in validation_step. |
torchTextClassifiers/model/components/text_embedder.py |
Adds label-attention config/module and changes embedder outputs to include sentence embeddings + optional attention matrices. |
torchTextClassifiers/model/components/__init__.py |
Exports LabelAttentionConfig. |
tests/test_pipeline.py |
Adds a label-attention-enabled pipeline test and updates explainability assertions for new return keys. |
Comments suppressed due to low confidence (1)
torchTextClassifiers/model/components/text_embedder.py:209
TextEmbedder._get_sentence_embeddingnow sometimes returns a raw tensor (for aggregation_method 'first'/'last'), butTextEmbedder.forwardunconditionally treats the result as a dict and calls.values(). This will crash when aggregation_method != 'mean'. Make_get_sentence_embeddingreturn a consistent structure (e.g., always a dict withsentence_embeddingandlabel_attention_matrix).
if self.attention_config is not None:
if self.attention_config.aggregation_method is not None: # default is "mean"
if self.attention_config.aggregation_method == "first":
return token_embeddings[:, 0, :]
elif self.attention_config.aggregation_method == "last":
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
return token_embeddings[
torch.arange(token_embeddings.size(0)),
lengths - 1,
:,
]
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@meilame-tayebjee I've opened a new pull request, #61, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@meilame-tayebjee I've opened a new pull request, #62, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
…essages Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
|
@meilame-tayebjee I've opened a new pull request, #64, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 6 out of 7 changed files in this pull request and generated 9 comments.
Comments suppressed due to low confidence (1)
torchTextClassifiers/model/components/text_embedder.py:226
- TextEmbedder._get_sentence_embedding returns a raw Tensor for aggregation_method == 'last', but callers now expect a dict. This path will break whenever aggregation_method is set to 'last'. Return {'sentence_embedding': ..., 'label_attention_matrix': None} instead.
elif self.attention_config.aggregation_method == "last":
lengths = attention_mask.sum(dim=1).clamp(min=1) # last non-pad token index + 1
return token_embeddings[
torch.arange(token_embeddings.size(0)),
lengths - 1,
:,
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # size (B, n_head, n_labels, seq_len) | ||
| attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) | ||
|
|
||
| # Apply mask to attention scores before softmax | ||
| if attention_mask is not None: | ||
| # attn_mask is already in the right shape: (B, 1, 1, T) | ||
| # We need to apply it to scores of shape (B, n_head, n_labels, T) |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When compute_attention_matrix=True and GQA is enabled (n_head != n_kv_head), torch.matmul(q, k.transpose(-2, -1)) will fail because q is (B, n_head, ...) while k is (B, n_kv_head, ...). Expand/repeat k (and the mask) to n_head (e.g., repeat_interleave by n_head//n_kv_head) or compute per-group weights so attention_matrix works under GQA.
| # size (B, n_head, n_labels, seq_len) | |
| attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5) | |
| # Apply mask to attention scores before softmax | |
| if attention_mask is not None: | |
| # attn_mask is already in the right shape: (B, 1, 1, T) | |
| # We need to apply it to scores of shape (B, n_head, n_labels, T) | |
| # q: (B, n_head, n_labels, head_dim) | |
| # k: (B, n_kv_head, seq_len, head_dim) | |
| # When using GQA (n_head != n_kv_head), expand k across heads so shapes match. | |
| k_for_scores = k | |
| if self.enable_gqa: | |
| # Each key/value head is shared by a group of query heads. | |
| # Repeat k along the head dimension to get shape (B, n_head, seq_len, head_dim). | |
| expand_factor = self.n_head // self.n_kv_head | |
| k_for_scores = k_for_scores.repeat_interleave(expand_factor, dim=1) | |
| # size (B, n_head, n_labels, seq_len) | |
| attention_scores = torch.matmul(q, k_for_scores.transpose(-2, -1)) / (self.head_dim**0.5) | |
| # Apply mask to attention scores before softmax | |
| if attention_mask is not None: | |
| # attn_mask is already in the right shape: (B, 1, 1, T) | |
| # It will broadcast over (B, n_head, n_labels, T) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would rather remove GQA for label attention to simplify for now @copilot
|
@meilame-tayebjee I've opened a new pull request, #65, to work on those changes. Once the pull request is ready, I'll request review from you. |
|
@meilame-tayebjee I've opened a new pull request, #66, to work on those changes. Once the pull request is ready, I'll request review from you. |
|
@meilame-tayebjee I've opened a new pull request, #67, to work on those changes. Once the pull request is ready, I'll request review from you. |
|
@meilame-tayebjee I've opened a new pull request, #68, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
@meilame-tayebjee I've opened a new pull request, #69, to work on those changes. Once the pull request is ready, I'll request review from you. |
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
This pull request introduces label attention as an optional feature in the text classification pipeline, allowing the model to generate label-specific sentence embeddings using a cross-attention mechanism. The changes include new configuration classes, updates to the
TextEmbedderand model logic, and new tests to ensure label attention works as intended.Label Attention Mechanism:
LabelAttentionConfigandLabelAttentionClassifierto enable label-specific sentence embeddings using cross-attention, where labels act as queries over token embeddings. [1] [2]TextEmbedderConfigandTextEmbedderto support label attention, including a new output structure and logic to handle label attention matrices. [1] [2] [3] [4] [5] [6] [7]Model and Pipeline Integration:
model.py) to validate and propagate label attention configuration, including enforcing that the classification head outputs a single value when label attention is enabled and updating the number of classes accordingly.Testing Enhancements:
test_label_attention_enabledand corresponding updates to the helper functions. [1] [2] [3] [4] [5] [6] [7]Miscellaneous: