Skip to content

Conversation

@meilame-tayebjee
Copy link
Member

This pull request introduces label attention as an optional feature in the text classification pipeline, allowing the model to generate label-specific sentence embeddings using a cross-attention mechanism. The changes include new configuration classes, updates to the TextEmbedder and model logic, and new tests to ensure label attention works as intended.

Label Attention Mechanism:

  • Added LabelAttentionConfig and LabelAttentionClassifier to enable label-specific sentence embeddings using cross-attention, where labels act as queries over token embeddings. [1] [2]
  • Updated TextEmbedderConfig and TextEmbedder to support label attention, including a new output structure and logic to handle label attention matrices. [1] [2] [3] [4] [5] [6] [7]

Model and Pipeline Integration:

  • Modified the main model (model.py) to validate and propagate label attention configuration, including enforcing that the classification head outputs a single value when label attention is enabled and updating the number of classes accordingly.
  • Adjusted the forward method and related pipeline logic to support returning label attention matrices and to handle the new embedding shapes.

Testing Enhancements:

  • Updated and extended the test pipeline to support label attention, including a new test function test_label_attention_enabled and corresponding updates to the helper functions. [1] [2] [3] [4] [5] [6] [7]

Miscellaneous:

  • Minor code improvements, such as import adjustments and docstring updates to reflect the new model name and features. [1] [2] [3] [4] [5]

- module and config created to do that
- mainly attached the TextEmbedder (it aggregates the token embedding to produce a sentence embedding - instead of naive averaging)
- rest of the code has been adapted, especially categorical var handling in TextClassificationModel
used as a namespace after, so no converting it throws a bug
- given a parameter, retrieve the attention matrix
- compatible with captum attributions
- update tests accordingly
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds optional “label attention” cross-attention to the text classification pipeline so the model can produce label-specific sentence embeddings and (optionally) return label×token attention matrices for explainability.

Changes:

  • Introduces LabelAttentionConfig / LabelAttentionClassifier and integrates label attention into TextEmbedder.
  • Updates model forward/predict paths to support returning label-attention matrices and adjusts classifier head output shape when label attention is enabled.
  • Extends pipeline tests to cover label-attention-enabled training/prediction.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
torchTextClassifiers/torchTextClassifiers.py Wires label-attention config through initialization, updates predict() explainability options, and deserializes label-attention config on load.
torchTextClassifiers/model/model.py Propagates label-attention enablement, adjusts forward pass to optionally return attention matrices, and normalizes embeddings before the head.
torchTextClassifiers/model/lightning.py Minor formatting-only change in validation_step.
torchTextClassifiers/model/components/text_embedder.py Adds label-attention config/module and changes embedder outputs to include sentence embeddings + optional attention matrices.
torchTextClassifiers/model/components/__init__.py Exports LabelAttentionConfig.
tests/test_pipeline.py Adds a label-attention-enabled pipeline test and updates explainability assertions for new return keys.
Comments suppressed due to low confidence (1)

torchTextClassifiers/model/components/text_embedder.py:209

  • TextEmbedder._get_sentence_embedding now sometimes returns a raw tensor (for aggregation_method 'first'/'last'), but TextEmbedder.forward unconditionally treats the result as a dict and calls .values(). This will crash when aggregation_method != 'mean'. Make _get_sentence_embedding return a consistent structure (e.g., always a dict with sentence_embedding and label_attention_matrix).
        if self.attention_config is not None:
            if self.attention_config.aggregation_method is not None:  # default is "mean"
                if self.attention_config.aggregation_method == "first":
                    return token_embeddings[:, 0, :]
                elif self.attention_config.aggregation_method == "last":
                    lengths = attention_mask.sum(dim=1).clamp(min=1)  # last non-pad token index + 1
                    return token_embeddings[
                        torch.arange(token_embeddings.size(0)),
                        lengths - 1,
                        :,
                    ]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

meilame-tayebjee and others added 2 commits January 26, 2026 18:24
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link

Copilot AI commented Jan 26, 2026

@meilame-tayebjee I've opened a new pull request, #61, to work on those changes. Once the pull request is ready, I'll request review from you.

meilame-tayebjee and others added 3 commits January 26, 2026 18:31
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link

Copilot AI commented Jan 26, 2026

@meilame-tayebjee I've opened a new pull request, #62, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI and others added 3 commits January 27, 2026 10:22
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
…essages

Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Copy link

Copilot AI commented Jan 27, 2026

@meilame-tayebjee I've opened a new pull request, #64, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI and others added 2 commits January 27, 2026 10:55
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 7 changed files in this pull request and generated 9 comments.

Comments suppressed due to low confidence (1)

torchTextClassifiers/model/components/text_embedder.py:226

  • TextEmbedder._get_sentence_embedding returns a raw Tensor for aggregation_method == 'last', but callers now expect a dict. This path will break whenever aggregation_method is set to 'last'. Return {'sentence_embedding': ..., 'label_attention_matrix': None} instead.
                elif self.attention_config.aggregation_method == "last":
                    lengths = attention_mask.sum(dim=1).clamp(min=1)  # last non-pad token index + 1
                    return token_embeddings[
                        torch.arange(token_embeddings.size(0)),
                        lengths - 1,
                        :,

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +384 to +390
# size (B, n_head, n_labels, seq_len)
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)

# Apply mask to attention scores before softmax
if attention_mask is not None:
# attn_mask is already in the right shape: (B, 1, 1, T)
# We need to apply it to scores of shape (B, n_head, n_labels, T)
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When compute_attention_matrix=True and GQA is enabled (n_head != n_kv_head), torch.matmul(q, k.transpose(-2, -1)) will fail because q is (B, n_head, ...) while k is (B, n_kv_head, ...). Expand/repeat k (and the mask) to n_head (e.g., repeat_interleave by n_head//n_kv_head) or compute per-group weights so attention_matrix works under GQA.

Suggested change
# size (B, n_head, n_labels, seq_len)
attention_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim**0.5)
# Apply mask to attention scores before softmax
if attention_mask is not None:
# attn_mask is already in the right shape: (B, 1, 1, T)
# We need to apply it to scores of shape (B, n_head, n_labels, T)
# q: (B, n_head, n_labels, head_dim)
# k: (B, n_kv_head, seq_len, head_dim)
# When using GQA (n_head != n_kv_head), expand k across heads so shapes match.
k_for_scores = k
if self.enable_gqa:
# Each key/value head is shared by a group of query heads.
# Repeat k along the head dimension to get shape (B, n_head, seq_len, head_dim).
expand_factor = self.n_head // self.n_kv_head
k_for_scores = k_for_scores.repeat_interleave(expand_factor, dim=1)
# size (B, n_head, n_labels, seq_len)
attention_scores = torch.matmul(q, k_for_scores.transpose(-2, -1)) / (self.head_dim**0.5)
# Apply mask to attention scores before softmax
if attention_mask is not None:
# attn_mask is already in the right shape: (B, 1, 1, T)
# It will broadcast over (B, n_head, n_labels, T)

Copilot uses AI. Check for mistakes.
Copy link
Member Author

@meilame-tayebjee meilame-tayebjee Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would rather remove GQA for label attention to simplify for now @copilot

Copy link

Copilot AI commented Jan 27, 2026

@meilame-tayebjee I've opened a new pull request, #65, to work on those changes. Once the pull request is ready, I'll request review from you.

Copy link

Copilot AI commented Jan 27, 2026

@meilame-tayebjee I've opened a new pull request, #66, to work on those changes. Once the pull request is ready, I'll request review from you.

Copy link

Copilot AI commented Jan 27, 2026

@meilame-tayebjee I've opened a new pull request, #67, to work on those changes. Once the pull request is ready, I'll request review from you.

Copy link

Copilot AI commented Jan 27, 2026

@meilame-tayebjee I've opened a new pull request, #68, to work on those changes. Once the pull request is ready, I'll request review from you.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Copy link

Copilot AI commented Jan 27, 2026

@meilame-tayebjee I've opened a new pull request, #69, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI and others added 6 commits January 27, 2026 11:31
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Co-authored-by: meilame-tayebjee <114609737+meilame-tayebjee@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add Cross Attention Labels / Text

2 participants