Skip to content

speculators.data_generation.custom_worker

Custom worker extension for hidden states capture.

Classes:

HiddenStatesWorkerExtension

Worker extension that adds hidden states capture functionality.

This extension hooks into VLLM's Worker initialization by being specified in ParallelConfig.worker_extension_cls. It patches the model's forward pass to intercept and capture intermediate layer hidden states during inference.

Key behaviors: - Only captures on tensor parallel (TP) rank 0 to avoid duplicate data when using tensor parallelism. All TP ranks compute the same hidden states, so capturing from rank 0 is sufficient. - Stores captured states in GPU memory during batch processing as lists of tensors, concatenating them only when retrieved via _get_captured_states(). - Supports pipeline parallelism by handling IntermediateTensors correctly.

Attributes: _layer_ids: Frozenset of layer indices for O(1) lookup during capture _captured_states: Accumulated hidden states per layer (GPU tensors) model_runner: Reference to the VLLM model runner