API Reference

load_from_foundation_model

monad.ui.module.load_from_foundation_model

monad.ui.load_from_foundation_model(checkpoint_path, downstream_task, target_fn, with_head=False, pl_logger=None, predictions_to_include_fn=None, predictions_to_exclude_fn=None, split=None, **kwargs)

Initialize a trainer using a foundation model checkpoint. This ensures correct adaptation of the network for downstream tasks.

from monad.ui.module import load_from_foundation_model

trainer = load_from_foundation_model(
        checkpoint_path="<path/to/store/pretrain/artifacts>",
        downstream_task=...,
        target_fn=...,
    )
Parameters

checkpoint_path: str | pathlib.Path
Directory where the foundation model checkpoint artifacts are stored.


downstream_task: Task
Specifies the downstream task type — i.e., BinaryClassificationTask, MulticlassClassificationTask, MultilabelClassificationTask, RegressionTask, RecommendationTask, OneHotRecommendationTask.


target_fn: Callable[[ Events, Events, Attributes, dict[str, float]], np.ndarray | tuple[np.ndarray, np.ndarray] | Sketch| tuple[Sketch, Sketch] | None]
User-defined function that computes training targets for the specified downstream task.


with_head: bool
Default: False.
Whether to use the last layer from the foundation model. May improve the quality of recommendation tasks.


pl_logger: Optional[pytorch_lightning.loggers.Logger]
Default: None.
A PyTorch Lightning logger instance.


predictions_to_include_fn: Optional[Callable[[Events ,Attributes , dict[str, float]], numpy.typing.NDArray[np.str_] | list[str] | None]]
Default: None.
A function that returns items/classes the predictions should be narrowed to for each entity. Mutually exclusive with predictions_to_exclude_fn.


predictions_to_exclude_fn: Optional[Callable[[Events ,Attributes , dict[str, float]], numpy.typing.NDArray[np.str_] | list[str] | None]]
Default: None.
A function that returns items/classes that should be excluded from the predictions for each entity. Mutually exclusive with predictions_to_include_fn.


split: [TimeSplitOverride|EntitySplitOverride]
Default: None.
Optional override for the split configuration set during previous run.


kwargs: Any
Default: dict.
Data configuration parameters to change.

Returns

MonadModule