From 1d4cb6c7240e28a3aba990d895bf10c934ce7f6a Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Wed, 14 Jan 2026 18:43:25 +0000 Subject: [PATCH] feat!(save&load): enable save and load for ttc object - implies pickling tokenizer, metada and using ckpt for lightning - adapt test --- tests/test_pipeline.py | 1 + .../model/components/text_embedder.py | 3 + torchTextClassifiers/model/lightning.py | 2 +- torchTextClassifiers/torchTextClassifiers.py | 123 ++++++++++++++++++ 4 files changed, 128 insertions(+), 1 deletion(-) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 27da44b..56dff6c 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -159,6 +159,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod y_val=Y, training_config=training_config, ) + ttc.load(ttc.save_path) # test load # Predict with explanations top_k = 5 diff --git a/torchTextClassifiers/model/components/text_embedder.py b/torchTextClassifiers/model/components/text_embedder.py index e9703b7..b317c91 100644 --- a/torchTextClassifiers/model/components/text_embedder.py +++ b/torchTextClassifiers/model/components/text_embedder.py @@ -23,6 +23,9 @@ def __init__(self, text_embedder_config: TextEmbedderConfig): self.config = text_embedder_config self.attention_config = text_embedder_config.attention_config + if isinstance(self.attention_config, dict): + self.attention_config = AttentionConfig(**self.attention_config) + if self.attention_config is not None: self.attention_config.n_embd = text_embedder_config.embedding_dim diff --git a/torchTextClassifiers/model/lightning.py b/torchTextClassifiers/model/lightning.py index ac94eff..1ebc697 100644 --- a/torchTextClassifiers/model/lightning.py +++ b/torchTextClassifiers/model/lightning.py @@ -36,7 +36,7 @@ def __init__( scheduler_interval: Scheduler interval. """ super().__init__() - self.save_hyperparameters(ignore=["model", "loss"]) + self.save_hyperparameters(ignore=["model"]) self.model = model self.loss = loss diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 7b3d9d2..79ce301 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -1,6 +1,8 @@ import logging +import pickle import time from dataclasses import asdict, dataclass, field +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type, Union try: @@ -75,6 +77,7 @@ class TrainingConfig: trainer_params: Optional[dict] = None optimizer_params: Optional[dict] = None scheduler_params: Optional[dict] = None + save_path: Optional[str] = "my_ttc" def to_dict(self) -> Dict[str, Any]: data = asdict(self) @@ -362,6 +365,7 @@ def train( logger.info(f"Training completed in {end - start:.2f} seconds.") best_model_path = trainer.checkpoint_callback.best_model_path + self.checkpoint_path = best_model_path self.lightning_module = TextClassificationModule.load_from_checkpoint( best_model_path, @@ -372,6 +376,9 @@ def train( self.pytorch_model = self.lightning_module.model.to(self.device) + self.save_path = training_config.save_path + self.save(self.save_path) + self.lightning_module.eval() def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: @@ -576,6 +583,122 @@ def predict( "confidence": confidence, } + def save(self, path: Union[str, Path]) -> None: + """Save the complete torchTextClassifiers instance to disk. + + This saves: + - Model configuration + - Tokenizer state + - PyTorch Lightning checkpoint (if trained) + - All other instance attributes + + Args: + path: Directory path where the model will be saved + + Example: + >>> ttc = torchTextClassifiers(tokenizer, model_config) + >>> ttc.train(X_train, y_train, training_config) + >>> ttc.save("my_model") + """ + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + # Save the checkpoint if model has been trained + checkpoint_path = None + if hasattr(self, "lightning_module"): + checkpoint_path = path / "model_checkpoint.ckpt" + # Save the current state as a checkpoint + trainer = pl.Trainer() + trainer.strategy.connect(self.lightning_module) + trainer.save_checkpoint(checkpoint_path) + + # Prepare metadata to save + metadata = { + "model_config": self.model_config.to_dict(), + "ragged_multilabel": self.ragged_multilabel, + "vocab_size": self.vocab_size, + "embedding_dim": self.embedding_dim, + "categorical_vocabulary_sizes": self.categorical_vocabulary_sizes, + "num_classes": self.num_classes, + "checkpoint_path": str(checkpoint_path) if checkpoint_path else None, + "device": str(self.device) if hasattr(self, "device") else None, + } + + # Save metadata + with open(path / "metadata.pkl", "wb") as f: + pickle.dump(metadata, f) + + # Save tokenizer + tokenizer_path = path / "tokenizer.pkl" + with open(tokenizer_path, "wb") as f: + pickle.dump(self.tokenizer, f) + + logger.info(f"Model saved successfully to {path}") + + @classmethod + def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers": + """Load a torchTextClassifiers instance from disk. + + Args: + path: Directory path where the model was saved + device: Device to load the model on ('auto', 'cpu', 'cuda', etc.) + + Returns: + Loaded torchTextClassifiers instance + + Example: + >>> loaded_ttc = torchTextClassifiers.load("my_model") + >>> predictions = loaded_ttc.predict(X_test) + """ + path = Path(path) + + if not path.exists(): + raise FileNotFoundError(f"Model directory not found: {path}") + + # Load metadata + with open(path / "metadata.pkl", "rb") as f: + metadata = pickle.load(f) + + # Load tokenizer + with open(path / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + + # Reconstruct model_config + model_config = ModelConfig.from_dict(metadata["model_config"]) + + # Create instance + instance = cls( + tokenizer=tokenizer, + model_config=model_config, + ragged_multilabel=metadata["ragged_multilabel"], + ) + + # Set device + if device == "auto": + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(device) + instance.device = device + + # Load checkpoint if it exists + if metadata["checkpoint_path"]: + checkpoint_path = path / "model_checkpoint.ckpt" + if checkpoint_path.exists(): + # Load the checkpoint with weights_only=False since it's our own trusted checkpoint + instance.lightning_module = TextClassificationModule.load_from_checkpoint( + str(checkpoint_path), + model=instance.pytorch_model, + weights_only=False, + ) + instance.pytorch_model = instance.lightning_module.model.to(device) + instance.checkpoint_path = str(checkpoint_path) + logger.info(f"Model checkpoint loaded from {checkpoint_path}") + else: + logger.warning(f"Checkpoint file not found at {checkpoint_path}") + + logger.info(f"Model loaded successfully from {path}") + return instance + def __repr__(self): model_type = ( self.lightning_module.__repr__()