monad.ui.module.load_from_foundation_model
monad.ui.module.load_from_foundation_modelmonad.ui.load_from_foundation_model(checkpoint_path, downstream_task, target_fn,
with_head=False,
pl_logger=None,
predictions_to_include_fn=None,
predictions_to_exclude_fn=None,
split=None,
**kwargs)
Initialize a trainer using a foundation model checkpoint. This ensures correct adaptation of the network for downstream tasks.
from monad.ui.module import load_from_foundation_model
trainer = load_from_foundation_model(
checkpoint_path="<path/to/store/pretrain/artifacts>",
downstream_task=...,
target_fn=...,
)| Parameters |
|---|
checkpoint_path: str | pathlib.Path
Directory where the foundation model checkpoint artifacts are stored.
downstream_task: Task
Specifies the downstream task type — i.e., BinaryClassificationTask, MulticlassClassificationTask, MultilabelClassificationTask, RegressionTask, RecommendationTask, OneHotRecommendationTask.
target_fn: Callable[[ Events, Events, Attributes, dict[str, float]], np.ndarray | tuple[np.ndarray, np.ndarray] | Sketch| tuple[Sketch, Sketch] | None]
User-defined function that computes training targets for the specified downstream task.
with_head: bool
Default: False.
Whether to use the last layer from the foundation model. May improve the quality of recommendation tasks.
pl_logger: Optional[pytorch_lightning.loggers.Logger]
Default: None.
A PyTorch Lightning logger instance.
predictions_to_include_fn: Optional[Callable[[Events ,Attributes , dict[str, float]], numpy.typing.NDArray[np.str_] | list[str] | None]]
Default: None.
A function that returns items/classes the predictions should be narrowed to for each entity.
Mutually exclusive with predictions_to_exclude_fn.
predictions_to_exclude_fn: Optional[Callable[[Events ,Attributes , dict[str, float]], numpy.typing.NDArray[np.str_] | list[str] | None]]
Default: None.
A function that returns items/classes that should be excluded from the predictions for each entity.
Mutually exclusive with predictions_to_include_fn.
split: [TimeSplitOverride|EntitySplitOverride]
Default: None.
Optional override for the split configuration set during previous run.
kwargs: Any
Default: dict.
Data configuration parameters to change.
| Returns |
|---|
