Skip to content

speculators.train.checkpointer

Classes:

BaseCheckpointer

BaseCheckpointer(path: Path | str)

Helper class to save and load checkpoints.

Checkpoint file structure: ../path/ 0/ # epoch number model.safetensors optimizer_state_dict.pt scheduler_state_dict.pt (optional) 1/ model.safetensors optimizer_state_dict.pt scheduler_state_dict.pt (optional) ...

Source code in speculators/train/checkpointer.py
def __init__(self, path: Path | str):
    self.path = Path(path)
    self.previous_epoch = self._get_previous_epoch()

    if self.previous_epoch != -1:
        self.prev_path: Path | None = self.path / str(self.previous_epoch)
    else:
        self.prev_path = None