Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def run_full_pipeline(tokenizer, sample_text_data, categorical_data, labels, mod
y_val=Y,
training_config=training_config,
)
ttc.load(ttc.save_path) # test load

# Predict with explanations
top_k = 5
Expand Down
3 changes: 3 additions & 0 deletions torchTextClassifiers/model/components/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def __init__(self, text_embedder_config: TextEmbedderConfig):
self.config = text_embedder_config

self.attention_config = text_embedder_config.attention_config
if isinstance(self.attention_config, dict):
self.attention_config = AttentionConfig(**self.attention_config)

if self.attention_config is not None:
self.attention_config.n_embd = text_embedder_config.embedding_dim

Expand Down
2 changes: 1 addition & 1 deletion torchTextClassifiers/model/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
scheduler_interval: Scheduler interval.
"""
super().__init__()
self.save_hyperparameters(ignore=["model", "loss"])
self.save_hyperparameters(ignore=["model"])

self.model = model
self.loss = loss
Expand Down
123 changes: 123 additions & 0 deletions torchTextClassifiers/torchTextClassifiers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import pickle
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union

try:
Expand Down Expand Up @@ -75,6 +77,7 @@ class TrainingConfig:
trainer_params: Optional[dict] = None
optimizer_params: Optional[dict] = None
scheduler_params: Optional[dict] = None
save_path: Optional[str] = "my_ttc"

def to_dict(self) -> Dict[str, Any]:
data = asdict(self)
Expand Down Expand Up @@ -362,6 +365,7 @@ def train(
logger.info(f"Training completed in {end - start:.2f} seconds.")

best_model_path = trainer.checkpoint_callback.best_model_path
self.checkpoint_path = best_model_path

self.lightning_module = TextClassificationModule.load_from_checkpoint(
best_model_path,
Expand All @@ -372,6 +376,9 @@ def train(

self.pytorch_model = self.lightning_module.model.to(self.device)

self.save_path = training_config.save_path
self.save(self.save_path)

self.lightning_module.eval()

def _check_XY(self, X: np.ndarray, Y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -576,6 +583,122 @@ def predict(
"confidence": confidence,
}

def save(self, path: Union[str, Path]) -> None:
"""Save the complete torchTextClassifiers instance to disk.

This saves:
- Model configuration
- Tokenizer state
- PyTorch Lightning checkpoint (if trained)
- All other instance attributes

Args:
path: Directory path where the model will be saved

Example:
>>> ttc = torchTextClassifiers(tokenizer, model_config)
>>> ttc.train(X_train, y_train, training_config)
>>> ttc.save("my_model")
"""
path = Path(path)
path.mkdir(parents=True, exist_ok=True)

# Save the checkpoint if model has been trained
checkpoint_path = None
if hasattr(self, "lightning_module"):
checkpoint_path = path / "model_checkpoint.ckpt"
# Save the current state as a checkpoint
trainer = pl.Trainer()
trainer.strategy.connect(self.lightning_module)
trainer.save_checkpoint(checkpoint_path)

# Prepare metadata to save
metadata = {
"model_config": self.model_config.to_dict(),
"ragged_multilabel": self.ragged_multilabel,
"vocab_size": self.vocab_size,
"embedding_dim": self.embedding_dim,
"categorical_vocabulary_sizes": self.categorical_vocabulary_sizes,
"num_classes": self.num_classes,
"checkpoint_path": str(checkpoint_path) if checkpoint_path else None,
"device": str(self.device) if hasattr(self, "device") else None,
}

# Save metadata
with open(path / "metadata.pkl", "wb") as f:
pickle.dump(metadata, f)

# Save tokenizer
tokenizer_path = path / "tokenizer.pkl"
with open(tokenizer_path, "wb") as f:
pickle.dump(self.tokenizer, f)

logger.info(f"Model saved successfully to {path}")

@classmethod
def load(cls, path: Union[str, Path], device: str = "auto") -> "torchTextClassifiers":
"""Load a torchTextClassifiers instance from disk.

Args:
path: Directory path where the model was saved
device: Device to load the model on ('auto', 'cpu', 'cuda', etc.)

Returns:
Loaded torchTextClassifiers instance

Example:
>>> loaded_ttc = torchTextClassifiers.load("my_model")
>>> predictions = loaded_ttc.predict(X_test)
"""
path = Path(path)

if not path.exists():
raise FileNotFoundError(f"Model directory not found: {path}")

# Load metadata
with open(path / "metadata.pkl", "rb") as f:
metadata = pickle.load(f)

# Load tokenizer
with open(path / "tokenizer.pkl", "rb") as f:
tokenizer = pickle.load(f)

# Reconstruct model_config
model_config = ModelConfig.from_dict(metadata["model_config"])

# Create instance
instance = cls(
tokenizer=tokenizer,
model_config=model_config,
ragged_multilabel=metadata["ragged_multilabel"],
)

# Set device
if device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
instance.device = device

# Load checkpoint if it exists
if metadata["checkpoint_path"]:
checkpoint_path = path / "model_checkpoint.ckpt"
if checkpoint_path.exists():
# Load the checkpoint with weights_only=False since it's our own trusted checkpoint
instance.lightning_module = TextClassificationModule.load_from_checkpoint(
str(checkpoint_path),
model=instance.pytorch_model,
weights_only=False,
)
instance.pytorch_model = instance.lightning_module.model.to(device)
instance.checkpoint_path = str(checkpoint_path)
logger.info(f"Model checkpoint loaded from {checkpoint_path}")
else:
logger.warning(f"Checkpoint file not found at {checkpoint_path}")

logger.info(f"Model loaded successfully from {path}")
return instance

def __repr__(self):
model_type = (
self.lightning_module.__repr__()
Expand Down