Skip to content

speculators.model

Base model classes for the Speculators library.

This module contains the base model classes for speculative decoding implementations in the Speculators library. These classes provide the foundation for creating speculator models that can perform speculative token generation with verifier models for accelerated inference.

The models extend Hugging Face's PreTrainedModel and GenerationMixin to maintain full compatibility with the transformers ecosystem while adding speculative decoding capabilities. They support automatic model registration and discovery, dynamic model loading based on configuration, and flexible verifier attachment.

Classes: SpeculatorModel: Abstract base class for all speculator models with transformers compatibility, automatic registry support, and speculative generation methods

Functions: reload_and_populate_models: Automatically populates the model registry for discovery and instantiation of registered speculator models

Classes:

  • SpeculatorModel

    Abstract base class for all speculator models in the Speculators library.

Functions:

SpeculatorModel

SpeculatorModel(
    config: SpeculatorModelConfig,
    verifier: str | PathLike | PreTrainedModel | None,
    verifier_attachment_mode: Literal[
        "detached", "full", "train_only"
    ]
    | None,
    **kwargs,
)

Bases: ClassRegistryMixin, PreTrainedModel, GenerationMixin

Abstract base class for all speculator models in the Speculators library.

This class provides the foundation for implementing speculative decoding models that can generate candidate tokens to be verified by a base verifier model. It combines the functionality of Hugging Face's PreTrainedModel and GenerationMixin with automatic model registration and discovery capabilities. All concrete speculator model implementations must inherit from this class, register with SpeculatorModel.register(NAME), and implement the abstract forward method.

Example:

# Load a speculator model with automatic class resolution
model = SpeculatorModel.from_pretrained("path/to/speculator")

# Optionally attach a new verifier model
verifier = AutoModel.from_pretrained("path/to/verifier")
model.attach_verifier(verifier)

# Generate with speculative decoding
outputs = model.generate(input_ids, max_length=100)

Initialize a SpeculatorModel instance.

Sets up the basic structure for a speculator model, including configuration storage and optional verifier model attachment. The verifier model is used during speculative decoding to validate the tokens proposed by the speculator.

If no verifier is provided during initialization, it must be attached later using the attach_verifier method before calling generate.

Parameters:

  • config

    (SpeculatorModelConfig) –

    The configuration for the speculator model. Must be a SpeculatorModelConfig instance containing model hyperparameters and speculative decoding settings.

  • verifier

    (str | PathLike | PreTrainedModel | None) –

    The verifier model to attach. This can be a path to a local model directory, a Hugging Face model identifier, or an instance of PreTrainedModel. If provided, the speculator will use this verifier for speculative decoding. If None, the speculator will load the verifier from the config if specified, or it must be attached later using the attach_verifier method.

  • verifier_attachment_mode

    (Literal['detached', 'full', 'train_only'] | None) –

    Optional mode for how the verifier is attached to the speculator. If "detach", any verifier passed in or resolved from the config will not be attached. If "full", the verifier is fully integrated into the speculator's forward pass and generation methods. If "train_only", only the portions of the verifier needed for training are attached, allowing for better resource utilization during training. If None and a verifier is provided, it defaults to "full". If a verifier is not provided and None is found in the config, this parameter is ignored.

  • kwargs

    Additional keyword arguments passed to the parent PreTrainedModel constructor.

Methods:

  • attach_verifier

    Attach a verifier model for the speculator that is used to attach to

  • detach_verifier

    Removes the reference to the attached verifier model and frees up the

  • forward

    Defines the forward pass computation for the speculator model.

  • from_pretrained

    Load a pretrained speculator model from the Hugging Face Hub or local directory.

  • from_training_args

    Create model instance from training arguments.

  • generate

    Generate text using speculative decoding with the attached verifier model.

  • get_trainer_kwargs

    Get algorithm-specific kwargs for training and validation.

  • registered_model_class_from_config

    Looks up the appropriate speculator model class from the registry

  • resolve_verifier

    Resolves the verifier model from a given path or identifier.

  • state_dict

    Overrides the state_dict method from PyTorch to ensure that save pathways

  • verify_training_compatible

    Verify that a model instance is compatible with training infrastructure.

Source code in speculators/model.py
def __init__(
    self,
    config: SpeculatorModelConfig,
    verifier: str | os.PathLike | PreTrainedModel | None,
    verifier_attachment_mode: Literal["detached", "full", "train_only"] | None,
    **kwargs,
):
    """
    Initialize a SpeculatorModel instance.

    Sets up the basic structure for a speculator model, including configuration
    storage and optional verifier model attachment. The verifier model is used
    during speculative decoding to validate the tokens proposed by the speculator.

    If no verifier is provided during initialization, it must be attached later
    using the attach_verifier method before calling generate.

    :param config: The configuration for the speculator model. Must be a
        SpeculatorModelConfig instance containing model hyperparameters and
        speculative decoding settings.
    :param verifier: The verifier model to attach. This can be a path to a local
        model directory, a Hugging Face model identifier, or an instance of
        PreTrainedModel. If provided, the speculator will use this verifier for
        speculative decoding. If None, the speculator will load the verifier from
        the config if specified, or it must be attached later using the
        `attach_verifier` method.
    :param verifier_attachment_mode: Optional mode for how the verifier is
        attached to the speculator. If "detach", any verifier passed in or
        resolved from the config will not be attached.
        If "full", the verifier is fully integrated into the
        speculator's forward pass and generation methods.
        If "train_only", only the portions of the verifier needed for training
        are attached, allowing for better resource utilization during training.
        If None and a verifier is provided, it defaults to "full".
        If a verifier is not provided and None is found in the config,
        this parameter is ignored.
    :param kwargs: Additional keyword arguments passed to the parent
        PreTrainedModel constructor.
    """
    if not config:
        raise ValueError(
            "Config must be provided to initialize a SpeculatorModel. "
            "Use SpeculatorModelConfig to create a valid configuration."
        )

    if not isinstance(config, SpeculatorModelConfig):
        raise TypeError(
            f"Expected config to be an instance of SpeculatorModelConfig, "
            f"got {type(config)} {config}."
        )

    super().__init__(config, **kwargs)
    self.config: SpeculatorModelConfig = config
    self.verifier: PreTrainedModel | None = None
    self.verifier_attachment_mode: Literal["detached", "full", "train_only"] = (
        "detached"
    )

    verifier = verifier or config.speculators_config.verifier.name_or_path
    if verifier is not None and verifier_attachment_mode != "detached":
        self.attach_verifier(verifier, mode=verifier_attachment_mode)

attach_verifier

attach_verifier(
    verifier: str | PathLike | PreTrainedModel,
    mode: Literal["full", "train_only"] | None = None,
)

Attach a verifier model for the speculator that is used to attach to for running inference/training with the speculator and validates the candidate tokens generated by the speculator during the speculative decoding process. It should be compatible with the speculator's configuration in terms of vocabulary, architecture, and tokenization.

Example:

# Load and attach a verifier
verifier = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
speculator.attach_verifier(verifier)

# Now ready for generation
outputs = speculator.generate(input_ids)

Parameters:

  • verifier

    (str | PathLike | PreTrainedModel) –

    The verifier model to attach. This can be a path to a local model directory, a Hugging Face model identifier, or an instance of PreTrainedModel. If a path or identifier is provided, the model will be loaded automatically. If an instance is provided, it will be used directly.

  • mode

    (Literal['full', 'train_only'] | None, default: None ) –

    Optional mode for how the verifier is attached to the speculator. If "full", the verifier is fully integrated into the speculator's forward pass and generation methods. If "train_only", only the portions of the verifier needed for training are attached, allowing for better resource utilization during training. If None, defaults to "full".

Returns:

  • The PreTrainedModel instance for the verifier that was attached.

Source code in speculators/model.py
def attach_verifier(
    self,
    verifier: str | os.PathLike | PreTrainedModel,
    mode: Literal["full", "train_only"] | None = None,
):
    """
    Attach a verifier model for the speculator that is used to attach to
    for running inference/training with the speculator and validates the
    candidate tokens generated by the speculator during the
    speculative decoding process. It should be compatible
    with the speculator's configuration in terms of vocabulary, architecture,
    and tokenization.

    Example:
        ```python
        # Load and attach a verifier
        verifier = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
        speculator.attach_verifier(verifier)

        # Now ready for generation
        outputs = speculator.generate(input_ids)
        ```

    :param verifier: The verifier model to attach. This can be a path to a local
        model directory, a Hugging Face model identifier, or an instance of
        PreTrainedModel. If a path or identifier is provided, the model will be
        loaded automatically. If an instance is provided, it will be used directly.
    :param mode: Optional mode for how the verifier is attached to the speculator.
        If "full", the verifier is fully integrated into the speculator's forward
        pass and generation methods. If "train_only", only the portions of the
        verifier needed for training are attached, allowing for better resource
        utilization during training. If None, defaults to "full".
    :return: The PreTrainedModel instance for the verifier that was attached.
    """
    if self.verifier_attachment_mode != "detached":
        raise RuntimeError(
            "Cannot attach a verifier when the speculator is not in detached mode. "
            "Detach the current verifier first using `detach_verifier()`."
        )

    if mode not in {"full", "train_only", None}:
        raise ValueError(
            f"Invalid verifier_attachment_mode: {mode}. "
            "Must be one of 'full', 'train_only', or None."
        )

    self.verifier_attachment_mode = mode or "full"
    self.verifier = (
        self.resolve_verifier(verifier)
        if self.verifier_attachment_mode == "full"
        else None
    )  # Expect subclasses to handle references if train_only

detach_verifier

detach_verifier()

Removes the reference to the attached verifier model and frees up the associated memory. After calling this method, the speculator will not be able to perform forward passes or generation until a new verifier is attached.

Source code in speculators/model.py
def detach_verifier(self):
    """
    Removes the reference to the attached verifier model and frees up the
    associated memory. After calling this method, the speculator will not
    be able to perform forward passes or generation until a new verifier
    is attached.
    """
    if self.verifier_attachment_mode == "detached":
        raise RuntimeError(
            "Verifier is already detached, cannot be called again until "
            "a new verifier is attached."
        )

    if self.verifier is not None:
        del self.verifier

    self.verifier = None
    self.verifier_attachment_mode = "detached"

forward

forward(*args, **kwargs)

Defines the forward pass computation for the speculator model.

This method must be implemented by all concrete speculator model subclasses. It defines how the model processes inputs to generate candidate tokens or logits specifically for training pipelines.

Use model.generate for generation tasks, which will handle speculative decoding with the attached verifier.

Parameters:

  • args

    Positional arguments for the forward pass, typically including input_ids and potentially attention_mask, position_ids, etc.

  • kwargs

    Keyword arguments for the forward pass, which may include various model-specific parameters and options.

Returns:

  • Model outputs, typically including logits or candidate token sequences, depending on the specific speculator implementation.

Source code in speculators/model.py
def forward(self, *args, **kwargs):
    """
    Defines the forward pass computation for the speculator model.

    This method must be implemented by all concrete speculator model
    subclasses. It defines how the model processes inputs to generate candidate
    tokens or logits specifically for training pipelines.

    Use `model.generate` for generation tasks, which will handle
    speculative decoding with the attached verifier.

    :param args: Positional arguments for the forward pass, typically including
        input_ids and potentially attention_mask, position_ids, etc.
    :param kwargs: Keyword arguments for the forward pass, which may include
        various model-specific parameters and options.
    :return: Model outputs, typically including logits or candidate token
        sequences, depending on the specific speculator implementation.
    """
    raise NotImplementedError(
        "The forward method is only supported on concrete "
        "speculator model subclasses."
    )

from_pretrained classmethod

from_pretrained(
    pretrained_model_name_or_path: str | PathLike | None,
    *model_args,
    verifier: str
    | PathLike
    | PreTrainedModel
    | None = None,
    verifier_attachment_mode: Literal[
        "detached", "full", "train_only"
    ]
    | None = None,
    config: PretrainedConfig | str | PathLike | None = None,
    cache_dir: str | PathLike | None = None,
    ignore_mismatched_sizes: bool = False,
    force_download: bool = False,
    local_files_only: bool = False,
    token: str | bool | None = None,
    revision: str = "main",
    use_safetensors: bool | None = None,
    weights_only: bool = True,
    **kwargs,
) -> SpeculatorModel

Load a pretrained speculator model from the Hugging Face Hub or local directory.

This method automatically resolves the correct speculator model class based on the configuration type and loads the model with the appropriate weights. If called on the base SpeculatorModel class, it will automatically determine and instantiate the correct subclass based on the model configuration.

Example:

# Load with automatic class resolution
model = SpeculatorModel.from_pretrained("RedHatAI/speculator-llama-7b")

# Load from local directory
model = SpeculatorModel.from_pretrained("./my_speculator")

# Load with custom config
config = SpeculatorModelConfig.from_pretrained("RedHatAI/eagle-llama-7b")
model = SpeculatorModel.from_pretrained(
    None, config=config, state_dict=state_dict
)

Parameters:

  • pretrained_model_name_or_path

    (str | PathLike | None) –

    The model identifier on Hugging Face Hub, or path to a local directory containing the model files. Can be None if config is provided as a path.

  • model_args

    Additional positional arguments passed to the model constructor.

  • verifier

    (str | PathLike | PreTrainedModel | None, default: None ) –

    Optional verifier model to attach the speculator to. Can be a path to a local model directory, a Hugging Face model identifier, or an instance of PreTrainedModel. If provided, the speculator will use this verifier for speculative decoding. If None, the speculator will load the verifier from the config if specified, or it must be attached later using the attach_verifier method.

  • verifier_attachment_mode

    (Literal['detached', 'full', 'train_only'] | None, default: None ) –

    Optional mode for how the verifier is attached to the speculator. If "detached", any verifier passed in or resolved from the config will not be ignored. If "full", the verifier is fully integrated into the speculator's forward pass and generation methods. If "train_only", only the portions of the verifier needed for training are attached, allowing for better resource utilization during training. If None and a verifier is provided, it defaults to "full". If a verifier is not provided and None is found in the config, this parameter is ignored.

  • config

    (PretrainedConfig | str | PathLike | None, default: None ) –

    Optional configuration for the model. Can be a SpeculatorModelConfig instance, a path to a config file, or None to load from model directory.

  • cache_dir

    (str | PathLike | None, default: None ) –

    Directory to cache downloaded files. If None, uses default transformers cache directory.

  • ignore_mismatched_sizes

    (bool, default: False ) –

    Whether to ignore size mismatches when loading pretrained weights. Useful for loading models with different architectures.

  • force_download

    (bool, default: False ) –

    Whether to force re-download of model files even if they exist in cache.

  • local_files_only

    (bool, default: False ) –

    Whether to avoid downloading files and only use local cached files. Raises an error if files are not found locally.

  • token

    (str | bool | None, default: None ) –

    Optional authentication token for accessing private models on Hugging Face Hub. Can be a string token or True to use saved token.

  • revision

    (str, default: 'main' ) –

    The specific model revision to load (branch name, tag, or commit hash). Defaults to "main".

  • use_safetensors

    (bool | None, default: None ) –

    Whether to use safetensors format for loading weights. If None, automatically detects the available format.

  • weights_only

    (bool, default: True ) –

    Whether to only load model weights without optimizer states or other training artifacts.

  • kwargs

    Additional keyword arguments passed to the model constructor and loading process.

Returns:

  • SpeculatorModel

    A SpeculatorModel instance of the appropriate subclass, loaded with the pretrained weights and configuration.

Source code in speculators/model.py
@classmethod
def from_pretrained(
    cls: type["SpeculatorModel"],
    pretrained_model_name_or_path: str | os.PathLike | None,
    *model_args,
    verifier: str | os.PathLike | PreTrainedModel | None = None,
    verifier_attachment_mode: Literal["detached", "full", "train_only"]
    | None = None,
    config: PretrainedConfig | str | os.PathLike | None = None,
    cache_dir: str | os.PathLike | None = None,
    ignore_mismatched_sizes: bool = False,
    force_download: bool = False,
    local_files_only: bool = False,
    token: str | bool | None = None,
    revision: str = "main",
    use_safetensors: bool | None = None,
    weights_only: bool = True,
    **kwargs,
) -> "SpeculatorModel":
    """
    Load a pretrained speculator model from the Hugging Face Hub or local directory.

    This method automatically resolves the correct speculator model class based on
    the configuration type and loads the model with the appropriate weights. If
    called on the base SpeculatorModel class, it will automatically determine and
    instantiate the correct subclass based on the model configuration.

    Example:
        ```python
        # Load with automatic class resolution
        model = SpeculatorModel.from_pretrained("RedHatAI/speculator-llama-7b")

        # Load from local directory
        model = SpeculatorModel.from_pretrained("./my_speculator")

        # Load with custom config
        config = SpeculatorModelConfig.from_pretrained("RedHatAI/eagle-llama-7b")
        model = SpeculatorModel.from_pretrained(
            None, config=config, state_dict=state_dict
        )
        ```

    :param pretrained_model_name_or_path: The model identifier on Hugging Face Hub,
        or path to a local directory containing the model files. Can be None if
        config is provided as a path.
    :param model_args: Additional positional arguments passed to the model
        constructor.
    :param verifier: Optional verifier model to attach the speculator to.
        Can be a path to a local model directory, a Hugging Face model identifier,
        or an instance of PreTrainedModel. If provided, the speculator will use this
        verifier for speculative decoding. If None, the speculator will load the
        verifier from the config if specified, or it must be attached later
        using the `attach_verifier` method.
    :param verifier_attachment_mode: Optional mode for how the verifier is
        attached to the speculator. If "detached", any verifier passed in or
        resolved from the config will not be ignored.
        If "full", the verifier is fully integrated into the
        speculator's forward pass and generation methods.
        If "train_only", only the portions of the verifier needed for training
        are attached, allowing for better resource utilization during training.
        If None and a verifier is provided, it defaults to "full".
        If a verifier is not provided and None is found in the config,
        this parameter is ignored.
    :param config: Optional configuration for the model. Can be a
        SpeculatorModelConfig instance, a path to a config file, or None to load
        from model directory.
    :param cache_dir: Directory to cache downloaded files. If None, uses default
        transformers cache directory.
    :param ignore_mismatched_sizes: Whether to ignore size mismatches when loading
        pretrained weights. Useful for loading models with different architectures.
    :param force_download: Whether to force re-download of model files even if
        they exist in cache.
    :param local_files_only: Whether to avoid downloading files and only use local
        cached files. Raises an error if files are not found locally.
    :param token: Optional authentication token for accessing private models on
        Hugging Face Hub. Can be a string token or True to use saved token.
    :param revision: The specific model revision to load (branch name, tag, or
        commit hash). Defaults to "main".
    :param use_safetensors: Whether to use safetensors format for loading weights.
        If None, automatically detects the available format.
    :param weights_only: Whether to only load model weights without optimizer
        states or other training artifacts.
    :param kwargs: Additional keyword arguments passed to the model constructor
        and loading process.
    :return: A SpeculatorModel instance of the appropriate subclass, loaded with
        the pretrained weights and configuration.
    """
    if not config:
        if not pretrained_model_name_or_path:
            raise ValueError(
                "Either `config` or `pretrained_model_name_or_path` must be "
                "provided to load a SpeculatorModel."
            )
        config = cls.config_class.from_pretrained(
            pretrained_model_name_or_path,
            cache_dir=cache_dir,
            force_download=force_download,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
        )

    if not isinstance(config, SpeculatorModelConfig):
        # once conversion is added, need to handle the case where a non speculator
        # config is passed in as a kwarg and auto convert
        raise TypeError(
            f"Expected config to be an instance of SpeculatorModelConfig, "
            f"got {type(config)}."
        )

    if not pretrained_model_name_or_path and not kwargs.get("state_dict"):
        raise ValueError(
            "Either `pretrained_model_name_or_path` or `state_dict` must be "
            "provided to load a SpeculatorModel."
        )

    if cls is SpeculatorModel:
        # generic call to from_pretrained on this class, need to resolve the
        # specific model class to use for loading based on the config and registry
        model_class = cls.registered_model_class_from_config(config)
        return model_class.from_pretrained(
            pretrained_model_name_or_path,
            *model_args,
            verifier=verifier,
            verifier_attachment_mode=verifier_attachment_mode,
            config=config,
            cache_dir=cache_dir,
            ignore_mismatched_sizes=ignore_mismatched_sizes,
            force_download=force_download,
            local_files_only=local_files_only,
            token=token,
            revision=revision,
            use_safetensors=use_safetensors,
            weights_only=weights_only,
            **kwargs,
        )

    return super().from_pretrained(  # type: ignore[misc]
        pretrained_model_name_or_path,
        *model_args,
        verifier=verifier,
        verifier_attachment_mode=verifier_attachment_mode,
        config=config,
        cache_dir=cache_dir,
        ignore_mismatched_sizes=ignore_mismatched_sizes,
        force_download=force_download,
        local_files_only=local_files_only,
        token=token,
        revision=revision,
        use_safetensors=use_safetensors,
        weights_only=weights_only,
        **kwargs,
    )

from_training_args abstractmethod classmethod

from_training_args(
    verifier_config: PretrainedConfig, **kwargs
) -> SpeculatorModel

Create model instance from training arguments.

This factory method is used by the training script to instantiate models from command-line arguments. Each algorithm must implement this to support the training infrastructure.

Args: verifier_config: Configuration from the verifier/base model. **kwargs: Training arguments as keyword arguments. Each algorithm extracts the parameters it needs.

Returns: Initialized model instance ready for training.

Example:

@classmethod
def from_training_args(cls, verifier_config, **kwargs):
    config = MySpeculatorConfig(
        transformer_layer_config=verifier_config,
        num_layers=kwargs['num_layers'],
        ...
    )
    return cls(config=config, t2d=kwargs.get('t2d'), d2t=kwargs.get('d2t'))

Source code in speculators/model.py
@classmethod
@abstractmethod
def from_training_args(
    cls, verifier_config: PretrainedConfig, **kwargs
) -> "SpeculatorModel":
    """Create model instance from training arguments.

    This factory method is used by the training script to instantiate models
    from command-line arguments. Each algorithm must implement this to support
    the training infrastructure.

    Args:
        verifier_config: Configuration from the verifier/base model.
        **kwargs: Training arguments as keyword arguments. Each algorithm
            extracts the parameters it needs.

    Returns:
        Initialized model instance ready for training.

    Example:
        ```python
        @classmethod
        def from_training_args(cls, verifier_config, **kwargs):
            config = MySpeculatorConfig(
                transformer_layer_config=verifier_config,
                num_layers=kwargs['num_layers'],
                ...
            )
            return cls(config=config, t2d=kwargs.get('t2d'), d2t=kwargs.get('d2t'))
        ```
    """
    raise NotImplementedError(
        f"{cls.__name__} must implement from_training_args() classmethod "
        "to support training infrastructure."
    )

generate

generate(
    inputs: Tensor | None = None,
    generation_config: GenerationConfig | None = None,
    logits_processor: LogitsProcessorList | None = None,
    stopping_criteria: StoppingCriteriaList | None = None,
    prefix_allowed_tokens_fn: Callable[
        [int, Tensor], list[int]
    ]
    | None = None,
    synced_gpus: bool | None = None,
    assistant_model: Optional[PreTrainedModel] = None,
    streamer: Optional[BaseStreamer] = None,
    negative_prompt_ids: Tensor | None = None,
    negative_prompt_attention_mask: Tensor | None = None,
    use_model_defaults: bool | None = None,
    custom_generate: str | Callable[..., Any] | None = None,
    **kwargs,
) -> GenerateOutput | torch.LongTensor

Generate text using speculative decoding with the attached verifier model. The method follows the standard transformers generation interface, making it compatible with existing generation workflows while adding speculative decoding capabilities allowing for faster generation.

Parameters:

  • inputs

    (Tensor | None, default: None ) –

    The input token IDs to generate from. Can be None if input_ids are provided in kwargs.

  • generation_config

    (GenerationConfig | None, default: None ) –

    Configuration for generation parameters like max_length, temperature, top_p, etc. If None, uses model defaults.

  • logits_processor

    (LogitsProcessorList | None, default: None ) –

    List of logits processors to apply during generation for tasks like repetition penalty, top-k filtering, etc.

  • stopping_criteria

    (StoppingCriteriaList | None, default: None ) –

    List of stopping criteria to determine when to stop generation (e.g., max length, end-of-sequence tokens).

  • prefix_allowed_tokens_fn

    (Callable[[int, Tensor], list[int]] | None, default: None ) –

    Function to constrain generation to allowed tokens based on the current prefix. Useful for structured generation.

  • synced_gpus

    (bool | None, default: None ) –

    Whether to synchronize GPUs during distributed generation. Relevant for multi-GPU setups.

  • assistant_model

    (Optional[PreTrainedModel], default: None ) –

    An assistant model to use for generation. This parameter maintains compatibility with transformers but may not be used in speculative decoding.

  • streamer

    (Optional[BaseStreamer], default: None ) –

    A streamer to output tokens as they are generated, enabling real-time streaming of the generation process.

  • negative_prompt_ids

    (Tensor | None, default: None ) –

    Token IDs for negative prompting to steer generation away from certain content.

  • negative_prompt_attention_mask

    (Tensor | None, default: None ) –

    Attention mask for negative prompt tokens to properly handle padding.

  • use_model_defaults

    (bool | None, default: None ) –

    Whether to use model-specific default generation parameters instead of transformers defaults.

  • kwargs

    Additional keyword arguments for generation, including input_ids, attention_mask, max_length, etc.

Returns:

  • GenerateOutput | LongTensor

    Generated token sequences as either a GenerateOutput object (containing additional metadata) or a LongTensor of token IDs.

Source code in speculators/model.py
@torch.no_grad()
def generate(
    self,
    inputs: torch.Tensor | None = None,  # noqa: ARG002
    generation_config: GenerationConfig | None = None,  # noqa: ARG002
    logits_processor: LogitsProcessorList | None = None,  # noqa: ARG002
    stopping_criteria: StoppingCriteriaList | None = None,  # noqa: ARG002
    prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]]  # noqa: ARG002
    | None = None,
    synced_gpus: bool | None = None,  # noqa: ARG002
    assistant_model: Optional["PreTrainedModel"] = None,  # type: ignore[override]  # noqa: ARG002
    streamer: Optional["BaseStreamer"] = None,  # noqa: ARG002
    negative_prompt_ids: torch.Tensor | None = None,  # noqa: ARG002
    negative_prompt_attention_mask: torch.Tensor | None = None,  # noqa: ARG002
    use_model_defaults: bool | None = None,  # noqa: ARG002
    custom_generate: str | Callable[..., Any] | None = None,  # noqa: ARG002
    **kwargs,  # noqa: ARG002
) -> GenerateOutput | torch.LongTensor:
    """
    Generate text using speculative decoding with the attached verifier model.
    The method follows the standard transformers generation interface, making it
    compatible with existing generation workflows while adding speculative
    decoding capabilities allowing for faster generation.

    :param inputs: The input token IDs to generate from. Can be None if input_ids
        are provided in kwargs.
    :param generation_config: Configuration for generation parameters like
        max_length, temperature, top_p, etc. If None, uses model defaults.
    :param logits_processor: List of logits processors to apply during generation
        for tasks like repetition penalty, top-k filtering, etc.
    :param stopping_criteria: List of stopping criteria to determine when to
        stop generation (e.g., max length, end-of-sequence tokens).
    :param prefix_allowed_tokens_fn: Function to constrain generation to allowed
        tokens based on the current prefix. Useful for structured generation.
    :param synced_gpus: Whether to synchronize GPUs during distributed generation.
        Relevant for multi-GPU setups.
    :param assistant_model: An assistant model to use for generation. This
        parameter maintains compatibility with transformers but may not be
        used in speculative decoding.
    :param streamer: A streamer to output tokens as they are generated, enabling
        real-time streaming of the generation process.
    :param negative_prompt_ids: Token IDs for negative prompting to steer
        generation away from certain content.
    :param negative_prompt_attention_mask: Attention mask for negative prompt
        tokens to properly handle padding.
    :param use_model_defaults: Whether to use model-specific default generation
        parameters instead of transformers defaults.
    :param kwargs: Additional keyword arguments for generation, including
        input_ids, attention_mask, max_length, etc.
    :return: Generated token sequences as either a GenerateOutput object
        (containing additional metadata) or a LongTensor of token IDs.
    """
    if self.verifier is None:
        raise ValueError(
            "Verifier model is not attached. Please attach a verifier model "
            "before calling generate."
        )

    raise NotImplementedError(
        "The generate method for speculator models is not implemented yet."
    )

get_trainer_kwargs abstractmethod staticmethod

get_trainer_kwargs(**kwargs) -> tuple[dict, dict]

Get algorithm-specific kwargs for training and validation.

This method extracts algorithm-specific parameters from the training arguments and returns separate kwargs dictionaries for training and validation forward passes.

Args: **kwargs: Training arguments containing algorithm-specific parameters.

Returns: Tuple of (train_kwargs, val_kwargs) where: - train_kwargs: Dict passed to model.forward() during training - val_kwargs: Dict passed to model.forward() during validation

Example:

@staticmethod
def get_trainer_kwargs(**kwargs):
    train_kwargs = {
        "num_steps": kwargs["num_steps"],
        "use_special_mode": True,
    }
    val_kwargs = {
        "num_steps": kwargs["num_steps"],
        "use_special_mode": False,
    }
    return train_kwargs, val_kwargs

Source code in speculators/model.py
@staticmethod
@abstractmethod
def get_trainer_kwargs(**kwargs) -> tuple[dict, dict]:
    """Get algorithm-specific kwargs for training and validation.

    This method extracts algorithm-specific parameters from the training
    arguments and returns separate kwargs dictionaries for training and
    validation forward passes.

    Args:
        **kwargs: Training arguments containing algorithm-specific parameters.

    Returns:
        Tuple of (train_kwargs, val_kwargs) where:
            - train_kwargs: Dict passed to model.forward() during training
            - val_kwargs: Dict passed to model.forward() during validation

    Example:
        ```python
        @staticmethod
        def get_trainer_kwargs(**kwargs):
            train_kwargs = {
                "num_steps": kwargs["num_steps"],
                "use_special_mode": True,
            }
            val_kwargs = {
                "num_steps": kwargs["num_steps"],
                "use_special_mode": False,
            }
            return train_kwargs, val_kwargs
        ```
    """
    raise NotImplementedError(
        "Model must implement get_trainer_kwargs() staticmethod "
        "to support training infrastructure."
    )

registered_model_class_from_config classmethod

registered_model_class_from_config(
    config: SpeculatorModelConfig,
) -> type[SpeculatorModel]

Looks up the appropriate speculator model class from the registry based on the configuration type. It matches the config class to the corresponding model class that was registered during auto-discovery or manual registration.

Parameters:

  • config

    (SpeculatorModelConfig) –

    The configuration for which to find the registered model class. Must be an instance of a SpeculatorModelConfig subclass.

Returns:

  • type[SpeculatorModel]

    The registered model class that matches the configuration type.

Source code in speculators/model.py
@classmethod
def registered_model_class_from_config(
    cls, config: SpeculatorModelConfig
) -> type["SpeculatorModel"]:
    """
    Looks up the appropriate speculator model class from the registry
    based on the configuration type. It matches the config class to the
    corresponding model class that was registered during auto-discovery or manual
    registration.

    :param config: The configuration for which to find the registered model class.
        Must be an instance of a SpeculatorModelConfig subclass.
    :return: The registered model class that matches the configuration type.
    """
    if not isinstance(config, SpeculatorModelConfig):
        raise TypeError(
            f"Expected config to be an instance of SpeculatorModelConfig, "
            f"got {type(config)} {config}."
        )

    if type(config) is SpeculatorModelConfig:
        raise TypeError(
            "Received a SpeculatorModelConfig instance but expected a subclass. "
            "Use the specific subclass of SpeculatorModelConfig instead. "
            f"Received: {type(config)} {config}"
        )

    if not cls.registry:
        raise ValueError(
            "No registered model classes found. "
            "Ensure that models are registered with "
            "`SpeculatorModel.register(NAME)` or that auto-discovery is enabled."
        )

    for _, model_class in cls.registry.items():
        model_config_class: type[SpeculatorModelConfig] = model_class.config_class

        if type(config) is model_config_class:
            return model_class

    raise ValueError(
        f"No registered model class found for config type {type(config)}. "
        f"Available registered model classes: {list(cls.registry.keys())}."
    )

resolve_verifier

resolve_verifier(
    verifier: str | PathLike | PreTrainedModel,
) -> PreTrainedModel

Resolves the verifier model from a given path or identifier.

This method loads the verifier model from a specified path or identifier, ensuring it is compatible with the speculator's configuration. If the verifier is already attached, it returns the existing verifier instance.

Parameters:

  • verifier

    (str | PathLike | PreTrainedModel) –

    The verifier model to resolve. Can be a path to a local model directory, a Hugging Face model identifier, or an instance of PreTrainedModel.

Returns:

  • PreTrainedModel

    The resolved PreTrainedModel instance for the verifier.

Source code in speculators/model.py
def resolve_verifier(
    self, verifier: str | os.PathLike | PreTrainedModel
) -> PreTrainedModel:
    """
    Resolves the verifier model from a given path or identifier.

    This method loads the verifier model from a specified path or identifier,
    ensuring it is compatible with the speculator's configuration. If the
    verifier is already attached, it returns the existing verifier instance.

    :param verifier: The verifier model to resolve. Can be a path to a local
        model directory, a Hugging Face model identifier, or an instance of
        PreTrainedModel.
    :return: The resolved PreTrainedModel instance for the verifier.
    """
    if not verifier:
        raise ValueError(
            "Verifier must be provided as a path, identifier, or PreTrainedModel. "
        )

    if not isinstance(verifier, (str, os.PathLike, PreTrainedModel)):
        raise TypeError(
            f"Expected verifier to be a PreTrainedModel, a string path, "
            f"or an os.PathLike object, got {type(verifier)} {verifier}."
        )

    if isinstance(verifier, PreTrainedModel):
        return verifier

    return AutoModelForCausalLM.from_pretrained(verifier)

state_dict

state_dict(
    *,
    destination: dict[str, Any] = None,
    prefix: str = "",
    keep_vars: bool = False,
)

Overrides the state_dict method from PyTorch to ensure that save pathways within Transformers PreTrainedModel do not include the verifier model's parameters. This is important to ensure that the speculator model can be saved and loaded without including the verifier's state, which is expected to be managed separately.

Parameters:

  • destination

    (dict[str, Any], default: None ) –

    Optional dictionary to store the state.

  • prefix

    (str, default: '' ) –

    Optional prefix for parameter names.

  • keep_vars

    (bool, default: False ) –

    Whether to keep Variables in the state_dict.

Returns:

  • A dictionary containing the state of the speculator model, excluding the verifier model's parameters. This dictionary can be used to save the model's state to disk or for further processing.

Source code in speculators/model.py
def state_dict(
    self,
    *,
    destination: dict[str, Any] = None,  # type: ignore[assignment]
    prefix: str = "",
    keep_vars: bool = False,
):
    """
    Overrides the state_dict method from PyTorch to ensure that save pathways
    within Transformers PreTrainedModel do not include the verifier model's
    parameters. This is important to ensure that the speculator model
    can be saved and loaded without including the verifier's state, which
    is expected to be managed separately.

    :param destination: Optional dictionary to store the state.
    :param prefix: Optional prefix for parameter names.
    :param keep_vars: Whether to keep Variables in the state_dict.
    :return: A dictionary containing the state of the speculator model,
        excluding the verifier model's parameters. This dictionary can be used
        to save the model's state to disk or for further processing.
    """
    tmp_verifier = self.verifier
    self.verifier = None
    state = super().state_dict(  # type: ignore[misc]
        destination=destination, prefix=prefix, keep_vars=keep_vars
    )
    self.verifier = tmp_verifier

    return state

verify_training_compatible classmethod

verify_training_compatible(model: SpeculatorModel) -> None

Verify that a model instance is compatible with training infrastructure.

This method validates that the given model is: 1. An instance of SpeculatorModel 2. Registered in the SpeculatorModel registry 3. Has a layers attribute (required for FSDP wrapping)

Args: model: The model instance to verify

Raises: TypeError: If model is not a SpeculatorModel instance ValueError: If model's class is not in the registry AttributeError: If model doesn't have a layers attribute

Source code in speculators/model.py
@classmethod
def verify_training_compatible(cls, model: "SpeculatorModel") -> None:
    """Verify that a model instance is compatible with training infrastructure.

    This method validates that the given model is:
    1. An instance of SpeculatorModel
    2. Registered in the SpeculatorModel registry
    3. Has a `layers` attribute (required for FSDP wrapping)

    Args:
        model: The model instance to verify

    Raises:
        TypeError: If model is not a SpeculatorModel instance
        ValueError: If model's class is not in the registry
        AttributeError: If model doesn't have a `layers` attribute
    """
    if not isinstance(model, SpeculatorModel):
        raise TypeError(
            f"Model must be a SpeculatorModel, got {type(model).__name__}"
        )

    model_class = type(model)
    registry = cls.registry
    if registry is None or model_class not in registry.values():
        raise ValueError(
            f"Model {model_class.__name__} is not registered in "
            f"SpeculatorModel.registry. "
            f"Available models: {list(registry.keys()) if registry else []}"
        )

    if not hasattr(model, "layers"):
        raise AttributeError(
            f"Model {model_class.__name__} must have a 'layers' attribute "
            f"containing decoder layers for FSDP wrapping"
        )

reload_and_populate_models

reload_and_populate_models()

Triggers the automatic discovery and registration of all SpeculatorModel subclasses found in the speculators.models package that have been registered with SpeculatorModel.register(NAME). This enables dynamic model loading and instantiation based on configuration types without requiring explicit imports.

Source code in speculators/model.py
def reload_and_populate_models():
    """
    Triggers the automatic discovery and registration of all
    SpeculatorModel subclasses found in the speculators.models package
    that have been registered with `SpeculatorModel.register(NAME)`. This
    enables dynamic model loading and instantiation based on configuration
    types without requiring explicit imports.
    """
    SpeculatorModel.auto_populate_registry()