Diff Coverage

Diff: origin/master...HEAD, staged and unstaged changes

Source File Diff Coverage (%) Missing Lines
hyper_parallel/infer/__init__.py 0.0% 17-18,25-26,33,44
hyper_parallel/infer/generation.py 0.0% 16-17,19-20,22,27-28,39,41-43,46,57,66-72,76-80,83-87,91,97-101,104-105,108,110-112,115-119,122,124-126,129-133,136,138-151,154-157,160,166-174,180-184,187,189-195,197,201-203,206,208-221,226,228-237,240,245-250,258-259,262,266-270,273,280-286,292-294,302,306-309,313-314,317,323-327,330,332-333,339,342,350-354,356,359-369,372,376-381,384,390-391,396,405,412-414,419-423,428-429,452,458,463,469-470,482-486,489,495-496,500-502,505,507-508,513-514,517,521-522,525,528,535-539,542,549,553-554,560,573,578,584,596-597,605-610,612-617,623-625,630-633,635-638,640,648-649
hyper_parallel/infer/kv_cache.py 0.0% 16-17,19,21,24,26-31,34,36-43,46-47,50-54,56-57,59,62,68-78,87,94-96,98-104,111,114,117-118,120-121,123,127,129,131,133-139,141,143-164,168,170-172,174-176,178-180,185,188-196,198,200-202,207-208,210,216-222,224,230-236,241,246-251,256-258,263-268,270,272-273,275-277,279-283
hyper_parallel/infer/mixin.py 0.0% 17,20,23,25
hyper_parallel/infer/sampler.py 0.0% 16-17,19,22-27,30,32-34,37,43-48,51,57-69,72,78-89,92,98,103-109
hyper_parallel/infer/utils.py 0.0% 16-17,19-20,23-24,27-42,44-51,53-62,64,66-75,77,79-85,87-92,94,96-99,102-105,109,111-112,114-119,122,127-131,134-137,140,142-152,157-167,170,172-182,187-191,195-197,200,202-203,206,212-219,222,228-238,241,247-251,254,258-265,268,273-275,280
hyper_parallel/infer/__init__.py
13
14
15
16
17
18
19
20
21
22
# limitations under the License.
# ============================================================================
"""Autoregressive generation utilities."""

from hyper_parallel.infer.generation import generate
from hyper_parallel.infer.kv_cache import (
    ContextParallelKVCache,
    KVCache,
    SequenceShardInfo,
    get_sequence_shard_info,
21
22
23
24
25
26
27
28
29
30
    SequenceShardInfo,
    get_sequence_shard_info,
    shard_past_key_values,
)
from hyper_parallel.infer.mixin import GenerateMixin
from hyper_parallel.infer.sampler import (
    apply_repetition_penalty,
    greedy_sample,
    sample_next_token,
    top_k_sample,
29
30
31
32
33
34
35
36
37
    sample_next_token,
    top_k_sample,
    top_p_sample,
)
from hyper_parallel.infer.utils import (
    GenerationConfig,
    apply_logits_processors,
    build_causal_mask,
    build_position_ids,
40
41
42
43
44
45
46
47
48
    prepare_logits_for_sampling,
    should_stop_generation,
)

__all__ = [
    "GenerationConfig",
    "GenerateMixin",
    "KVCache",
    "ContextParallelKVCache",
hyper_parallel/infer/generation.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Prefill + decode generation loop."""
import inspect
from typing import Optional

import torch
import torch.distributed as dist

from hyper_parallel.infer.kv_cache import (
    ContextParallelKVCache,
    KVCache,
    detach_and_validate_past_key_values,
)
from hyper_parallel.infer.sampler import sample_next_token
from hyper_parallel.infer.utils import (
    GenerationConfig,
    append_attention_mask,
    apply_logits_processors,
    build_causal_mask,
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    should_stop_generation,
)


def _get_output(outputs, name: str):
    """Read an output field from dict-like or object-like model outputs."""
    if isinstance(outputs, dict):
        return outputs.get(name)
    return getattr(outputs, name, None)


def _model_forward(
    model,
    input_ids: torch.Tensor,
    position_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
53
54
55
56
57
58
59
60
61
    sequence_shard_info=None,
    global_seq_len: Optional[int] = None,
):
    """Call model.forward with only the keyword arguments it accepts."""
    kwargs = {
        "input_ids": input_ids,
        "position_ids": position_ids,
        "attention_mask": attention_mask,
        "past_key_values": past_key_values,
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        "use_cache": use_cache,
        "sequence_shard_info": sequence_shard_info,
        "global_seq_len": global_seq_len,
    }
    forward = getattr(model, "forward", model)
    try:
        signature = inspect.signature(forward)
    except (TypeError, ValueError):
        return forward(**kwargs)
    parameters = signature.parameters
    accepts_kwargs = any(
        param.kind == inspect.Parameter.VAR_KEYWORD
        for param in parameters.values()
    )
    if not accepts_kwargs:
        for name in list(kwargs):
            if name not in parameters:
                kwargs.pop(name)
    return forward(**kwargs)


def _resolve_context_parallel_rank_world(config: GenerationConfig) -> tuple[int, int]:
    if config.context_parallel_rank is not None:
        return config.context_parallel_rank, config.context_parallel_world_size
    if not dist.is_available() or not dist.is_initialized():
        raise ValueError(
            "context_parallel_cache requires initialized torch.distributed "
            "or explicit context_parallel_rank/context_parallel_world_size",
        )
    return (
        dist.get_rank(group=config.context_process_group),
        dist.get_world_size(group=config.context_process_group),
    )
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        dist.get_world_size(group=config.context_process_group),
    )


def _init_cache(config: GenerationConfig) -> KVCache:
    if not config.context_parallel_cache:
        return KVCache()
    rank, world_size = _resolve_context_parallel_rank_world(config)
    return ContextParallelKVCache(rank=rank, world_size=world_size)


def _cache_shard_info(cache: KVCache):
    return cache.shard_info if isinstance(cache, ContextParallelKVCache) else None


def _cache_seq_len(past_key_values) -> Optional[int]:
    """Resolve cached sequence length from tuple or opaque HF-style cache."""
    if past_key_values is None:
        return None
    if hasattr(past_key_values, "get_seq_length") and not isinstance(
        past_key_values, (list, tuple),
    ):
        return int(past_key_values.get_seq_length())
    values = detach_and_validate_past_key_values(past_key_values)
    if not values:
        return 0
    return int(values[0][0].shape[-2])


def _cache_batch_size(past_key_values) -> Optional[int]:
    """Resolve cache batch size when cache tensors are inspectable."""
    if past_key_values is None:
        return None
    if hasattr(past_key_values, "get_seq_length") and not isinstance(
        past_key_values, (list, tuple),
    ):
        return None
    values = detach_and_validate_past_key_values(past_key_values)
    if not values:
        return None
    return int(values[0][0].shape[0])


def _resolve_prefix_length(config: GenerationConfig) -> int:
    """Validate and resolve reusable prefix cache length."""
    if config.prefix_past_key_values is None:
        return 0
    candidates = []
    if config.prefix_cache_length is not None:
        candidates.append(int(config.prefix_cache_length))
    if config.prefix_attention_mask is not None:
        candidates.append(int(config.prefix_attention_mask.shape[-1]))
    if config.prefix_sequence_shard_info is not None:
        candidates.append(int(config.prefix_sequence_shard_info.global_seq_len))
    seq_len = _cache_seq_len(config.prefix_past_key_values)
    if seq_len is not None and config.prefix_sequence_shard_info is None:
        candidates.append(seq_len)
    if not candidates:
        raise ValueError(
            "prefix_past_key_values requires prefix_cache_length for opaque caches",
        )
    prefix_len = candidates[0]
    if any(length != prefix_len for length in candidates):
        raise ValueError("prefix cache length metadata is inconsistent")
    return prefix_len


def _prepare_prefix_attention_mask(
    config: GenerationConfig,
    input_ids: torch.Tensor,
    device,
) -> tuple[Optional[torch.Tensor], int]:
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    input_ids: torch.Tensor,
    device,
) -> tuple[Optional[torch.Tensor], int]:
    """Prepare a 2-D attention mask for reusable prefix cache."""
    prefix_len = _resolve_prefix_length(config)
    if prefix_len == 0:
        return None, 0
    cache_batch_size = _cache_batch_size(config.prefix_past_key_values)
    if cache_batch_size is not None and cache_batch_size != input_ids.size(0):
        raise ValueError("prefix cache batch size must match input_ids batch size")
    prefix_attention_mask = config.prefix_attention_mask
    if prefix_attention_mask is None:
        return torch.ones(
            input_ids.size(0),
            prefix_len,
            device=device,
            dtype=torch.long,
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
            prefix_len,
            device=device,
            dtype=torch.long,
        ), prefix_len
    if prefix_attention_mask.ndim != 2:
        raise ValueError("prefix_attention_mask must have shape (batch, prefix_seq)")
    if prefix_attention_mask.shape != (input_ids.size(0), prefix_len):
        raise ValueError("prefix_attention_mask batch/sequence length mismatch")
    return prefix_attention_mask.to(device=device), prefix_len


def _init_cache_with_prefix(config: GenerationConfig) -> tuple[KVCache, int]:
    """Create the generation cache and preload prefix cache when present."""
    cache = _init_cache(config)
    prefix_len = _resolve_prefix_length(config)
    if prefix_len == 0:
        return cache, prefix_len
    if isinstance(cache, ContextParallelKVCache):
        if config.prefix_sequence_shard_info is None:
            cache.update_full(config.prefix_past_key_values)
        else:
            cache.update_local(
                config.prefix_past_key_values,
                config.prefix_sequence_shard_info,
            )
        return cache, prefix_len
    cache.update(config.prefix_past_key_values)
    return cache, prefix_len


def _update_cache(cache: KVCache, outputs) -> None:
    """Update normal or context-parallel KV cache from model outputs."""
    past_key_values = _get_output(outputs, "past_key_values")
    if not isinstance(cache, ContextParallelKVCache):
        cache.update(past_key_values)
        return
    if past_key_values is None:
        return
    sequence_shard_info = _get_output(outputs, "sequence_shard_info")
    if sequence_shard_info is not None:
        cache.update_local(past_key_values, sequence_shard_info)
        return
    if cache.is_empty:
        cache.update_full(past_key_values)
        return
    raise ValueError(
        "context-parallel cached decode requires model output sequence_shard_info",
    )


def _resolve_mask_dtype(model, config: GenerationConfig) -> torch.dtype:
    """Choose additive-mask dtype from config or model floating state."""
    if config.mask_dtype is not None:
        return config.mask_dtype
    for iterator_name in ("parameters", "buffers"):
        iterator = getattr(model, iterator_name, None)
        if iterator is None:
            continue
        for tensor in iterator():
            if tensor.is_floating_point():
                return tensor.dtype
    return torch.float32


def _build_decode_key_mask(
    attention_mask: Optional[torch.Tensor],
    dtype: torch.dtype,
) -> Optional[torch.Tensor]:
    """Build additive key padding mask for one-token cached decode."""
    if attention_mask is None:
        return None
    if attention_mask.ndim != 2:
        raise ValueError("attention_mask must have shape (batch, seq)")
    batch_size, seq_len = attention_mask.shape
    mask = torch.zeros(
        batch_size,
        1,
        1,
        seq_len,
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        seq_len,
        device=attention_mask.device,
        dtype=dtype,
    )
    padding = attention_mask == 0
    return mask.masked_fill(padding.view(batch_size, 1, 1, seq_len), float("-inf"))


def _combined_attention_mask(
    prefix_attention_mask: Optional[torch.Tensor],
    attention_mask: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
    if prefix_attention_mask is None:
        return attention_mask
    if attention_mask is None:
        return prefix_attention_mask
    return torch.cat([prefix_attention_mask, attention_mask], dim=-1)


def _build_prefill_mask(
    input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    prefix_attention_mask: Optional[torch.Tensor],
    dtype: torch.dtype,
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
    prefix_attention_mask: Optional[torch.Tensor],
    dtype: torch.dtype,
) -> torch.Tensor:
    """Build the prefill causal mask, including optional prefix keys."""
    if prefix_attention_mask is None:
        return build_causal_mask(input_ids, attention_mask, dtype=dtype)
    batch_size, query_len = input_ids.shape
    prefix_len = prefix_attention_mask.shape[-1]
    device = input_ids.device
    if attention_mask is None:
        attention_mask = torch.ones(
            batch_size,
            query_len,
            device=device,
            dtype=prefix_attention_mask.dtype,
288
289
290
291
292
293
294
295
296
297
298
            query_len,
            device=device,
            dtype=prefix_attention_mask.dtype,
        )
    if attention_mask.shape != input_ids.shape:
        raise ValueError("attention_mask must match input_ids shape")
    mask = torch.zeros(
        batch_size,
        1,
        query_len,
        prefix_len + query_len,
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
        prefix_len + query_len,
        device=device,
        dtype=dtype,
    )
    causal = torch.triu(
        torch.full((query_len, query_len), float("-inf"), device=device, dtype=dtype),
        diagonal=1,
    )
    mask[:, :, :, prefix_len:] = causal.view(1, 1, query_len, query_len)
    current_padding = attention_mask == 0
    current_key_padding = torch.cat([prefix_attention_mask == 0, current_padding], dim=-1)
    mask = mask.masked_fill(
        current_key_padding.view(batch_size, 1, 1, prefix_len + query_len),
        float("-inf"),
    )
    mask = mask.masked_fill(current_padding.view(batch_size, 1, query_len, 1), 0.0)
    return mask


def _build_prefill_position_ids(
    input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    prefix_attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
    attention_mask: Optional[torch.Tensor],
    prefix_attention_mask: Optional[torch.Tensor],
) -> torch.Tensor:
    """Build position ids for prefill with optional prefix offset."""
    position_ids = build_position_ids(input_ids, attention_mask)
    if prefix_attention_mask is None:
        return position_ids
    prefix_lengths = prefix_attention_mask.long().sum(dim=-1).view(-1, 1)
    return position_ids + prefix_lengths


def _prompt_lengths(input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor]):
    """Count valid prompt tokens per batch row."""
    if attention_mask is None:
        return torch.full(
            (input_ids.size(0),),
            input_ids.size(1),
            device=input_ids.device,
            dtype=torch.long,
335
336
337
338
339
340
341
342
343
344
345
346
            input_ids.size(1),
            device=input_ids.device,
            dtype=torch.long,
        )
    return attention_mask.long().sum(dim=-1)


def _finalize_sequences(
    sequences: torch.Tensor,
    initial_attention_mask: Optional[torch.Tensor],
    prompt_lengths: torch.Tensor,
    generated_counts: torch.Tensor,
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
    generated_counts: torch.Tensor,
    pad_token_id: int,
) -> torch.Tensor:
    """Strip left padding and right-pad finalized generated sequences."""
    rows = []
    max_len = 0
    for batch_idx in range(sequences.size(0)):
        if initial_attention_mask is None:
            start = 0
        else:
            starts = torch.nonzero(
                initial_attention_mask[batch_idx].bool(), as_tuple=False,
            )
            if starts.numel() == 0:
                raise ValueError("attention_mask row must contain at least one valid token")
            start = int(starts[0].item())
        total_len = int(prompt_lengths[batch_idx].item() + generated_counts[batch_idx].item())
        row = sequences[batch_idx, start:start + total_len]
        rows.append(row)
        max_len = max(max_len, row.numel())
    output = sequences.new_full((len(rows), max_len), pad_token_id)
    for idx, row in enumerate(rows):
        output[idx, :row.numel()] = row
    return output


def _validate_generate_inputs(
    input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
) -> None:
    if input_ids.ndim != 2:
        raise ValueError("input_ids must have shape (batch, seq)")
    if attention_mask is not None and attention_mask.shape != input_ids.shape:
        raise ValueError("attention_mask must match input_ids shape")
    if attention_mask is not None and torch.any(attention_mask.long().sum(dim=-1) == 0):
        raise ValueError("attention_mask rows must contain at least one valid token")


def _finalize_zero_new_tokens(
    input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    config: GenerationConfig,
) -> torch.Tensor:
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
    attention_mask: Optional[torch.Tensor],
    config: GenerationConfig,
) -> torch.Tensor:
    """Finalize left-padded prompts when no new tokens are requested."""
    prompt_lengths = _prompt_lengths(input_ids, attention_mask)
    generated_counts = torch.zeros(
        input_ids.size(0),
        device=input_ids.device,
        dtype=torch.long,
    )
    return _finalize_sequences(
        input_ids.clone(),
        initial_attention_mask=attention_mask,
        prompt_lengths=prompt_lengths,
        generated_counts=generated_counts,
401
402
403
404
405
406
407
408
409
        pad_token_id=config.pad_token_id,
    )


def _prepare_generation_context(
    model,
    input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    config: GenerationConfig,
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    attention_mask: Optional[torch.Tensor],
    config: GenerationConfig,
):
    """Create the mutable generation context used by the decode loop."""
    mask_dtype = _resolve_mask_dtype(model, config)
    sequences = input_ids.clone()
    prefix_attention_mask, prefix_len = _prepare_prefix_attention_mask(
        config,
        input_ids,
        input_ids.device,
    )
    current_attention_mask = attention_mask.clone() if attention_mask is not None else None
    if prefix_attention_mask is not None and current_attention_mask is None:
        current_attention_mask = torch.ones_like(sequences, dtype=torch.long)
    prompt_lengths = _prompt_lengths(input_ids, current_attention_mask)
    prefix_valid_lengths = (
        prefix_attention_mask.long().sum(dim=-1)
        if prefix_attention_mask is not None
        else torch.zeros(input_ids.size(0), device=input_ids.device, dtype=torch.long)
    )
    cache, prefix_len = _init_cache_with_prefix(config)
    return {
        "mask_dtype": mask_dtype,
        "sequences": sequences,
        "prefix_attention_mask": prefix_attention_mask,
        "current_attention_mask": current_attention_mask,
448
449
450
451
452
453
454
455
456
        "prefix_len": prefix_len,
    }


def _prefill(
    model,
    config: GenerationConfig,
    context: dict,
):
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    config: GenerationConfig,
    context: dict,
):
    """Run the initial full-prompt forward pass."""
    position_ids = _build_prefill_position_ids(
        context["sequences"],
        context["current_attention_mask"],
        context["prefix_attention_mask"],
    )
    attention_mask = _build_prefill_mask(
        context["sequences"],
        context["current_attention_mask"],
        context["prefix_attention_mask"],
        dtype=context["mask_dtype"],
465
466
467
468
469
470
471
472
473
474
        context["current_attention_mask"],
        context["prefix_attention_mask"],
        dtype=context["mask_dtype"],
    )
    cache = context["cache"]
    return _model_forward(
        model,
        input_ids=context["sequences"],
        position_ids=position_ids,
        attention_mask=attention_mask,
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        global_seq_len=context["prefix_len"] + context["sequences"].shape[-1],
    )


def _required_logits(outputs) -> torch.Tensor:
    logits = _get_output(outputs, "logits")
    if logits is None:
        raise ValueError("model output must contain logits")
    return logits


def _finalize_prefill_outputs(
    config: GenerationConfig,
    context: dict,
    outputs,
) -> tuple[torch.Tensor, bool]:
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
    context: dict,
    outputs,
) -> tuple[torch.Tensor, bool]:
    """Validate prefill output and decide whether cached decode can be used."""
    logits = _required_logits(outputs)
    if (
        config.prefix_past_key_values is not None
        and _get_output(outputs, "past_key_values") is None
    ):
        raise ValueError("prefix_past_key_values requires model to return past_key_values")
    _update_cache(context["cache"], outputs)
    return logits, config.use_cache and not context["cache"].is_empty


def _append_next_token(context: dict, next_tokens: torch.Tensor, config: GenerationConfig):
    """Append sampled tokens and advance per-row generation metadata."""
    if config.eos_token_id is not None:
        next_tokens = torch.where(
            context["unfinished"].view(-1, 1),
            next_tokens,
            torch.full_like(next_tokens, config.pad_token_id),
        )
    context["sequences"] = torch.cat([context["sequences"], next_tokens], dim=-1)
    context["generated_counts"] = (
        context["generated_counts"] + context["unfinished"].long()
    )
    context["current_attention_mask"] = append_attention_mask(
        context["current_attention_mask"],
        next_tokens,
    )
    if config.eos_token_id is not None:
        context["unfinished"] = (
            context["unfinished"] & (next_tokens.squeeze(-1) != config.eos_token_id)
        )
    return next_tokens


def _should_finish_generation(
    context: dict,
    logits: torch.Tensor,
    config: GenerationConfig,
    step: int,
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
    config: GenerationConfig,
    step: int,
) -> bool:
    """Check EOS, custom stopping criteria, and max token limit."""
    if config.eos_token_id is not None and not context["unfinished"].any():
        return True
    if should_stop_generation(context["sequences"], logits, config):
        return True
    return step == config.max_new_tokens - 1


def _decode(
    model,
    context: dict,
    next_tokens: torch.Tensor,
    use_cached_decode: bool,
545
546
547
548
549
550
551
552
553
554
555
556
557
558
    next_tokens: torch.Tensor,
    use_cached_decode: bool,
):
    """Run one cached or no-cache decode step."""
    model_attention_mask = _combined_attention_mask(
        context["prefix_attention_mask"],
        context["current_attention_mask"],
    )
    if use_cached_decode:
        decode_pos = (
            context["prefix_valid_lengths"]
            + context["prompt_lengths"]
            + context["generated_counts"]
            - 1
556
557
558
559
560
561
562
563
564
            + context["prompt_lengths"]
            + context["generated_counts"]
            - 1
        )
        return _model_forward(
            model,
            input_ids=next_tokens,
            position_ids=decode_pos.view(-1, 1),
            attention_mask=_build_decode_key_mask(
569
570
571
572
573
574
575
576
577
578
579
580
581
582
            use_cache=True,
            sequence_shard_info=_cache_shard_info(context["cache"]),
            global_seq_len=context["prefix_len"] + context["sequences"].shape[-1],
        )
    decode_pos = _build_prefill_position_ids(
        context["sequences"],
        context["current_attention_mask"],
        context["prefix_attention_mask"],
    )
    decode_mask = _build_prefill_mask(
        context["sequences"],
        context["current_attention_mask"],
        context["prefix_attention_mask"],
        dtype=context["mask_dtype"],
580
581
582
583
584
585
586
587
588
        context["current_attention_mask"],
        context["prefix_attention_mask"],
        dtype=context["mask_dtype"],
    )
    return _model_forward(
        model,
        input_ids=context["sequences"],
        position_ids=decode_pos,
        attention_mask=decode_mask,
592
593
594
595
596
597
598
599
600
601
        global_seq_len=context["prefix_len"] + context["sequences"].shape[-1],
    )


@torch.no_grad()
def generate(
    model,
    input_ids: torch.Tensor,
    generation_config: Optional[GenerationConfig] = None,
    attention_mask: Optional[torch.Tensor] = None,
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
    attention_mask: Optional[torch.Tensor] = None,
    **kwargs,
) -> torch.Tensor:
    """Generate token ids from a causal language model."""
    if kwargs:
        raise TypeError(f"Unexpected generate kwargs: {sorted(kwargs)}")
    config = generation_config or GenerationConfig()
    _validate_generate_inputs(input_ids, attention_mask)
    if config.max_new_tokens == 0:
        return _finalize_zero_new_tokens(input_ids, attention_mask, config)

    was_training = getattr(model, "training", False)
    model.eval()
    try:
        context = _prepare_generation_context(model, input_ids, attention_mask, config)
        outputs = _prefill(model, config, context)
        logits, use_cached_decode = _finalize_prefill_outputs(
            config,
            context,
            outputs,
        )
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
            context,
            outputs,
        )

        for step in range(config.max_new_tokens):
            next_logits = prepare_logits_for_sampling(logits[:, -1, :], config)
            next_logits = apply_logits_processors(
                context["sequences"],
                next_logits,
                config,
            )
            next_tokens = sample_next_token(next_logits, context["sequences"], config)
            next_tokens = _append_next_token(context, next_tokens, config)
            if _should_finish_generation(context, next_logits, config, step):
                break

            outputs = _decode(model, context, next_tokens, use_cached_decode)
            if use_cached_decode:
                _update_cache(context["cache"], outputs)
            logits = _required_logits(outputs)

        return _finalize_sequences(
            context["sequences"],
            initial_attention_mask=context["initial_attention_mask"],
            prompt_lengths=context["prompt_lengths"],
            generated_counts=context["generated_counts"],
644
645
646
647
648
649
            generated_counts=context["generated_counts"],
            pad_token_id=config.pad_token_id,
        )
    finally:
        if was_training:
            model.train()
hyper_parallel/infer/kv_cache.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""KV cache container for generation."""
from dataclasses import dataclass
from typing import Any, Iterable, List, Optional, Tuple

import torch

PastKeyValues = List[Tuple[torch.Tensor, torch.Tensor]]


def _validate_pair_shapes(key: torch.Tensor, value: torch.Tensor) -> None:
    """Validate one key/value cache tensor pair."""
    if not isinstance(key, torch.Tensor) or not isinstance(value, torch.Tensor):
        raise ValueError("key and value must be tensors")
    if key.ndim != 4 or value.ndim != 4:
        raise ValueError("key and value must have shape (batch, heads, seq, dim)")
    if key.shape != value.shape:
        raise ValueError("key and value batch/heads/seq/dim dimensions must match")


def detach_and_validate_past_key_values(past_key_values: Iterable) -> PastKeyValues:
    """Return detached tuple KV tensors after validating their shapes."""
    values = []
    for item in past_key_values:
        if not isinstance(item, (tuple, list)) or len(item) != 2:
            raise ValueError("each cache entry must be a (key, value) pair")
        key, value = item
        _validate_pair_shapes(key, value)
        values.append((key.detach(), value.detach()))
    return values


@dataclass(frozen=True)
class SequenceShardInfo:
    """Sequence range held by one context-parallel rank."""

    rank: int
    world_size: int
    start: int
    end: int
    global_seq_len: int

    @property
    def local_seq_len(self) -> int:
        """Return the sequence length stored by this rank."""
        return self.end - self.start


def get_sequence_shard_info(
    global_seq_len: int,
    rank: int,
    world_size: int,
) -> SequenceShardInfo:
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    rank: int,
    world_size: int,
) -> SequenceShardInfo:
    """Return the contiguous sequence shard range for a CP rank."""
    if global_seq_len < 0:
        raise ValueError("global_seq_len must be >= 0")
    if world_size <= 0:
        raise ValueError("world_size must be > 0")
    if rank < 0 or rank >= world_size:
        raise ValueError("rank must be in [0, world_size)")
    base = global_seq_len // world_size
    remainder = global_seq_len % world_size
    start = rank * base + min(rank, remainder)
    end = start + base + (1 if rank < remainder else 0)
    return SequenceShardInfo(
        rank=rank,
        world_size=world_size,
        start=start,
        end=end,
83
84
85
86
87
88
89
90
91
        global_seq_len=global_seq_len,
    )


def shard_past_key_values(
    past_key_values: Iterable,
    rank: int,
    world_size: int,
    global_seq_len: Optional[int] = None,
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
    world_size: int,
    global_seq_len: Optional[int] = None,
) -> Tuple[PastKeyValues, SequenceShardInfo]:
    """Shard full past key values on the sequence dimension for CP cache."""
    values = detach_and_validate_past_key_values(past_key_values)
    if not values:
        seq_len = 0 if global_seq_len is None else global_seq_len
    else:
        seq_len = values[0][0].shape[-2]
    if global_seq_len is None:
        global_seq_len = seq_len
    if seq_len != global_seq_len:
        raise ValueError("global_seq_len must match full cache sequence length")
    shard_info = get_sequence_shard_info(global_seq_len, rank, world_size)
    sharded = [
        (
            key.narrow(-2, shard_info.start, shard_info.local_seq_len).contiguous(),
            value.narrow(-2, shard_info.start, shard_info.local_seq_len).contiguous(),
        )
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
            value.narrow(-2, shard_info.start, shard_info.local_seq_len).contiguous(),
        )
        for key, value in values
    ]
    return sharded, shard_info


class KVCache:
    """Stores per-layer key/value tensors."""

    def __init__(self):
        self.past_key_values: Optional[Any] = None

    @property
    def is_empty(self) -> bool:
        """Check whether no usable KV cache is stored."""
        return self.past_key_values is None or (
            isinstance(self.past_key_values, list) and len(self.past_key_values) == 0
        )

    def clear(self) -> None:
        """Drop all cached tensors."""
        self.past_key_values = None

    def update(self, past_key_values: Optional[Iterable]) -> None:
        """Replace the cache with detached past key values."""
        if past_key_values is None:
            return
        if self._is_opaque_cache(past_key_values):
            self.past_key_values = past_key_values
            return
        values = self._detach_and_validate(past_key_values)
        self.past_key_values = None if not values else values

    def merge(self, past_key_values: Optional[Iterable]) -> None:
        """Append incremental key/value tensors on the sequence dimension."""
        if past_key_values is None:
            return
        if self._is_opaque_cache(past_key_values):
            self.past_key_values = past_key_values
            return
        new_values = self._detach_and_validate(past_key_values)
        if not new_values:
            return
        if self.past_key_values is None:
            self.past_key_values = new_values
            return
        if len(self.past_key_values) != len(new_values):
            raise ValueError("past_key_values layer count mismatch")
        merged = []
        for (old_k, old_v), (new_k, new_v) in zip(self.past_key_values, new_values):
            self._validate_pair_shapes(old_k, old_v)
            self._validate_pair_shapes(new_k, new_v)
            if old_k.shape[:-2] != new_k.shape[:-2] or old_k.shape[-1] != new_k.shape[-1]:
                raise ValueError("key cache shape mismatch")
            if old_v.shape[:-2] != new_v.shape[:-2] or old_v.shape[-1] != new_v.shape[-1]:
                raise ValueError("value cache shape mismatch")
            merged.append((
                torch.cat([old_k, new_k], dim=-2),
                torch.cat([old_v, new_v], dim=-2),
            ))
        self.past_key_values = merged

    @classmethod
    def _detach_and_validate(cls, past_key_values: Iterable) -> PastKeyValues:
        return detach_and_validate_past_key_values(past_key_values)

    @staticmethod
    def _validate_pair_shapes(key: torch.Tensor, value: torch.Tensor) -> None:
        _validate_pair_shapes(key, value)

    @staticmethod
    def _is_opaque_cache(past_key_values) -> bool:
        return hasattr(past_key_values, "get_seq_length") and not isinstance(
            past_key_values, (list, tuple),
        )


class ContextParallelKVCache(KVCache):
    """Stores a local sequence shard of generation KV cache."""

    def __init__(self, rank: int, world_size: int):
        super().__init__()
        if world_size <= 0:
            raise ValueError("world_size must be > 0")
        if rank < 0 or rank >= world_size:
            raise ValueError("rank must be in [0, world_size)")
        self.rank = rank
        self.world_size = world_size
        self.shard_info = get_sequence_shard_info(0, rank, world_size)

    def update_full(self, past_key_values: Optional[Iterable]) -> None:
        """Shard full prefill K/V cache and store only this rank's sequence slice."""
        if past_key_values is None:
            return
        sharded, shard_info = shard_past_key_values(
            past_key_values,
            rank=self.rank,
            world_size=self.world_size,
        )
        self.past_key_values = sharded
        self.shard_info = shard_info

    def update_local(
        self,
        past_key_values: Optional[Iterable],
        shard_info: SequenceShardInfo,
    ) -> None:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        past_key_values: Optional[Iterable],
        shard_info: SequenceShardInfo,
    ) -> None:
        """Store K/V tensors that are already local to this CP rank."""
        if past_key_values is None:
            return
        self._validate_shard_info(shard_info)
        values = self._detach_and_validate(past_key_values)
        self._validate_local_seq_len(values, shard_info.local_seq_len)
        self.past_key_values = values
        self.shard_info = shard_info

    def merge_local(
        self,
        past_key_values: Optional[Iterable],
        global_seq_len: Optional[int] = None,
    ) -> None:
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
        past_key_values: Optional[Iterable],
        global_seq_len: Optional[int] = None,
    ) -> None:
        """Append local incremental K/V tensors and advance global sequence metadata."""
        if past_key_values is None:
            return
        new_values = self._detach_and_validate(past_key_values)
        if self.past_key_values is None:
            if global_seq_len is None and self.world_size > 1:
                raise ValueError("global_seq_len is required for initial CP local cache")
            inferred_global = (
                self.shard_info.global_seq_len + new_values[0][0].shape[-2]
                if global_seq_len is None and new_values
                else global_seq_len
            )
            shard_info = get_sequence_shard_info(
                0 if inferred_global is None else inferred_global,
                self.rank,
                self.world_size,
            )
            self._validate_local_seq_len(new_values, shard_info.local_seq_len)
            self.past_key_values = new_values
            self.shard_info = shard_info
            return
        old_local_seq_len = self.shard_info.local_seq_len
        next_global_seq_len = (
            self.shard_info.global_seq_len + new_values[0][0].shape[-2]
            if global_seq_len is None and new_values
            else global_seq_len
        )
        if next_global_seq_len is None:
            raise ValueError("global_seq_len is required for empty incremental cache")
        shard_info = get_sequence_shard_info(
            next_global_seq_len,
            self.rank,
            self.world_size,
        )
        expected_growth = shard_info.local_seq_len - old_local_seq_len
        actual_growth = new_values[0][0].shape[-2] if new_values else 0
        if expected_growth != actual_growth:
            raise ValueError("local cache growth does not match CP shard metadata")
        super().merge(new_values)
        self.shard_info = shard_info

    def clear(self) -> None:
        """Drop all cached tensors and reset CP sequence metadata."""
        super().clear()
        self.shard_info = get_sequence_shard_info(0, self.rank, self.world_size)

    def _validate_shard_info(self, shard_info: SequenceShardInfo) -> None:
        if shard_info.rank != self.rank or shard_info.world_size != self.world_size:
            raise ValueError("shard_info does not match this CP cache")

    @staticmethod
    def _validate_local_seq_len(values: PastKeyValues, local_seq_len: int) -> None:
        for key, value in values:
            if key.shape[-2] != local_seq_len or value.shape[-2] != local_seq_len:
                raise ValueError("local cache sequence length does not match shard_info")
hyper_parallel/infer/mixin.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# limitations under the License.
# ============================================================================
"""Mixin exposing a model.generate style API."""

from hyper_parallel.infer.generation import generate


class GenerateMixin:
    """Mixin that forwards to :func:`hyper_parallel.infer.generate`."""

    def generate(self, input_ids, generation_config=None, attention_mask=None, **kwargs):
        """Generate token ids with the common HyperParallel generation loop."""
        return generate(
            self,
            input_ids=input_ids,
            generation_config=generation_config,
            attention_mask=attention_mask,
hyper_parallel/infer/sampler.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Sampling helpers for autoregressive generation."""
import torch
from torch.nn import functional as F

from hyper_parallel.infer.utils import GenerationConfig


def _filter_top_k(logits: torch.Tensor, top_k: int) -> torch.Tensor:
    if top_k <= 0 or top_k >= logits.size(-1):
        return logits
    values, _ = torch.topk(logits, k=top_k, dim=-1)
    threshold = values[:, -1:].contiguous()
    return logits.masked_fill(logits < threshold, float("-inf"))


def greedy_sample(logits: torch.Tensor) -> torch.Tensor:
    """Select the highest-logit token for each batch item."""
    if logits.ndim != 2:
        raise ValueError("logits must have shape (batch, vocab)")
    return logits.argmax(dim=-1, keepdim=True)


def top_k_sample(
    logits: torch.Tensor,
    top_k: int,
    temperature: float = 1.0,
) -> torch.Tensor:
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    top_k: int,
    temperature: float = 1.0,
) -> torch.Tensor:
    """Sample from the top-k logits."""
    if logits.ndim != 2:
        raise ValueError("logits must have shape (batch, vocab)")
    filtered = _filter_top_k(logits, top_k)
    probs = F.softmax(filtered / temperature, dim=-1)
    sampled = torch.multinomial(probs, num_samples=1)
    return sampled


def top_p_sample(
    logits: torch.Tensor,
    top_p: float,
    temperature: float = 1.0,
) -> torch.Tensor:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    top_p: float,
    temperature: float = 1.0,
) -> torch.Tensor:
    """Sample from the nucleus token set."""
    if logits.ndim != 2:
        raise ValueError("logits must have shape (batch, vocab)")
    if top_p >= 1.0:
        probs = F.softmax(logits / temperature, dim=-1)
        return torch.multinomial(probs, num_samples=1)
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    sorted_probs = F.softmax(sorted_logits / temperature, dim=-1)
    cumulative = sorted_probs.cumsum(dim=-1)
    remove = cumulative - sorted_probs > top_p
    filtered = sorted_logits.masked_fill(remove, float("-inf"))
    probs = F.softmax(filtered / temperature, dim=-1)
    sampled = torch.multinomial(probs, num_samples=1)
    return sorted_indices.gather(dim=-1, index=sampled)


def apply_repetition_penalty(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
    penalty: float,
) -> torch.Tensor:
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
    input_ids: torch.Tensor,
    penalty: float,
) -> torch.Tensor:
    """Apply per-item repetition penalty to seen token ids."""
    if penalty == 1.0:
        return logits
    if logits.ndim != 2 or input_ids.ndim != 2:
        raise ValueError("logits and input_ids must be 2-D tensors")
    adjusted = logits.clone()
    vocab_size = logits.size(-1)
    valid = (input_ids >= 0) & (input_ids < vocab_size)
    seen_mask = torch.zeros_like(adjusted, dtype=torch.bool)
    token_ids = input_ids.to(dtype=torch.long).clamp(min=0, max=max(vocab_size - 1, 0))
    seen_mask.scatter_(dim=1, index=token_ids, src=valid)
    penalized = torch.where(adjusted < 0, adjusted * penalty, adjusted / penalty)
    return torch.where(seen_mask, penalized, adjusted)


def sample_next_token(
    logits: torch.Tensor,
    input_ids: torch.Tensor,
    config: GenerationConfig,
) -> torch.Tensor:
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
    input_ids: torch.Tensor,
    config: GenerationConfig,
) -> torch.Tensor:
    """Apply repetition penalty and select the next token."""
    logits = apply_repetition_penalty(
        logits,
        input_ids=input_ids,
        penalty=config.repetition_penalty,
    )
    if not config.do_sample:
        return greedy_sample(logits)
    logits = _filter_top_k(logits, config.top_k)
    if config.top_p < 1.0:
        return top_p_sample(logits, top_p=config.top_p, temperature=config.temperature)
    probs = F.softmax(logits / config.temperature, dim=-1)
    return torch.multinomial(probs, num_samples=1)
hyper_parallel/infer/utils.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generation configuration and mask helpers."""
from dataclasses import dataclass
from typing import Any, Callable, List, Optional

import torch
import torch.distributed as dist


@dataclass
class GenerationConfig:
    """Runtime options for autoregressive generation."""

    max_new_tokens: int = 256
    temperature: float = 1.0
    top_k: int = 50
    top_p: float = 1.0
    do_sample: bool = False
    eos_token_id: Optional[int] = 2
    pad_token_id: int = 0
    repetition_penalty: float = 1.0
    use_cache: bool = True
    prefix_past_key_values: Optional[Any] = None
    prefix_attention_mask: Optional[torch.Tensor] = None
    prefix_sequence_shard_info: Optional[Any] = None
    prefix_cache_length: Optional[int] = None
    context_parallel_cache: bool = False
    context_parallel_rank: Optional[int] = None
    context_parallel_world_size: Optional[int] = None
    # context_logits_rank is local to context_process_group.
    context_logits_rank: Optional[Any] = None
    context_process_group: Optional[Any] = None
    gather_logits: bool = False
    logits_process_group: Optional[Any] = None
    logits_gather_dim: int = -1
    mask_dtype: Optional[torch.dtype] = None
    logits_processor: Optional[List[Callable]] = None
    stopping_criteria: Optional[List[Callable]] = None

    def __post_init__(self):
        self._validate_sampling()
        self._validate_prefix()
        self._validate_context_parallel()
        if self.logits_gather_dim >= 0:
            raise ValueError("logits_gather_dim must be negative")
        if self.mask_dtype is not None and not isinstance(self.mask_dtype, torch.dtype):
            raise ValueError("mask_dtype must be a torch.dtype")
        self._validate_callables(self.logits_processor, "logits_processor")
        self._validate_callables(self.stopping_criteria, "stopping_criteria")

    def _validate_sampling(self) -> None:
        """Validate scalar sampling and stopping options."""
        if self.max_new_tokens < 0:
            raise ValueError("max_new_tokens must be >= 0")
        if self.temperature <= 0:
            raise ValueError("temperature must be > 0")
        if self.top_k < 0:
            raise ValueError("top_k must be >= 0")
        if not 0 < self.top_p <= 1.0:
            raise ValueError("top_p must be in (0, 1]")
        if self.repetition_penalty <= 0:
            raise ValueError("repetition_penalty must be > 0")

    def _validate_prefix(self) -> None:
        """Validate optional prefix cache metadata."""
        if self.prefix_past_key_values is None:
            if self.prefix_attention_mask is not None:
                raise ValueError("prefix_attention_mask requires prefix_past_key_values")
            if self.prefix_sequence_shard_info is not None:
                raise ValueError("prefix_sequence_shard_info requires prefix_past_key_values")
            if self.prefix_cache_length is not None:
                raise ValueError("prefix_cache_length requires prefix_past_key_values")
        else:
            if not self.use_cache:
                raise ValueError("prefix_past_key_values requires use_cache=True")
            if self.prefix_cache_length is not None and self.prefix_cache_length < 0:
                raise ValueError("prefix_cache_length must be >= 0")
            if self.prefix_sequence_shard_info is not None and not self.context_parallel_cache:
                raise ValueError("prefix_sequence_shard_info requires context_parallel_cache=True")

    def _validate_context_parallel(self) -> None:
        """Validate context-parallel cache and logits metadata."""
        if self.context_parallel_cache and not self.use_cache:
            raise ValueError("context_parallel_cache requires use_cache=True")
        if (self.context_parallel_rank is None) != (self.context_parallel_world_size is None):
            raise ValueError(
                "context_parallel_rank and context_parallel_world_size must be set together",
            )
        if self.context_parallel_world_size is not None:
            if self.context_parallel_world_size <= 0:
                raise ValueError("context_parallel_world_size must be > 0")
            if (
                self.context_parallel_rank < 0
                or self.context_parallel_rank >= self.context_parallel_world_size
            ):
                raise ValueError("context_parallel_rank must be in [0, context_parallel_world_size)")

    @staticmethod
    def _validate_callables(values: Optional[List[Callable]], field_name: str) -> None:
        """Validate optional generation extension hooks."""
        if values is None:
            return
        if not isinstance(values, list):
            raise ValueError(f"{field_name} must be a list of callables")
        if not all(callable(item) for item in values):
            raise ValueError(f"{field_name} must contain only callables")


def build_position_ids(
    input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Build left-padding aware position ids."""
    if input_ids.ndim != 2:
        raise ValueError("input_ids must have shape (batch, seq)")
    if attention_mask is None:
        seq_len = input_ids.size(1)
        return torch.arange(
            seq_len, device=input_ids.device, dtype=torch.long,
        ).view(1, -1).expand(input_ids.size(0), -1)
    if attention_mask.shape != input_ids.shape:
        raise ValueError("attention_mask must match input_ids shape")
    position_ids = attention_mask.long().cumsum(dim=-1) - 1
    return position_ids.clamp_min_(0)


def gather_context_parallel_logits(logits: torch.Tensor, config: GenerationConfig) -> torch.Tensor:
    """Select final-token logits from the owning CP rank before sampling."""
    if config.context_logits_rank is None:
        return logits
    if not dist.is_available() or not dist.is_initialized():
        return logits
    world_size = dist.get_world_size(group=config.context_process_group)
    if world_size == 1:
        return logits
    gathered = [torch.empty_like(logits) for _ in range(world_size)]
    dist.all_gather(gathered, logits, group=config.context_process_group)
    stacked = torch.stack(gathered, dim=0)
    owner = torch.as_tensor(
        config.context_logits_rank,
        device=logits.device,
        dtype=torch.long,
    )
    if owner.ndim == 0:
        owner_rank = int(owner.item())
        if owner_rank < 0 or owner_rank >= world_size:
            raise ValueError("context_logits_rank contains an invalid rank")
        return stacked[owner_rank]
    if owner.shape != (logits.shape[0],):
        raise ValueError("context_logits_rank must be a scalar or a batch-sized tensor")
    if torch.any((owner < 0) | (owner >= world_size)):
        raise ValueError("context_logits_rank contains an invalid rank")
    batch_indices = torch.arange(logits.shape[0], device=logits.device)
    return stacked[owner, batch_indices]


def gather_tensor_parallel_logits(logits: torch.Tensor, config: GenerationConfig) -> torch.Tensor:
    """Gather vocab-sharded logits before sampling when TP inference is active."""
    if not config.gather_logits:
        return logits
    if not dist.is_available() or not dist.is_initialized():
        return logits
    world_size = dist.get_world_size(group=config.logits_process_group)
    if world_size == 1:
        return logits
    gather_dim = logits.ndim + config.logits_gather_dim
    if gather_dim < 0 or gather_dim >= logits.ndim:
        raise ValueError("logits_gather_dim is out of range for logits")
    local_shard = torch.tensor(
        [logits.shape[gather_dim]],
        device=logits.device,
        dtype=torch.long,
    )
    shard_sizes = [torch.empty_like(local_shard) for _ in range(world_size)]
    dist.all_gather(shard_sizes, local_shard, group=config.logits_process_group)
    shard_sizes = torch.cat(shard_sizes)
    if torch.any(shard_sizes != shard_sizes[0]):
        raise ValueError(
            "tensor-parallel logits gather requires equal local vocab shard sizes; "
            "pad vocab shards before generation",
        )
    gathered = [torch.empty_like(logits) for _ in range(world_size)]
    dist.all_gather(gathered, logits, group=config.logits_process_group)
    return torch.cat(gathered, dim=config.logits_gather_dim)


def prepare_logits_for_sampling(logits: torch.Tensor, config: GenerationConfig) -> torch.Tensor:
    """Apply distributed logits handoffs before sampling."""
    logits = gather_context_parallel_logits(logits, config)
    return gather_tensor_parallel_logits(logits, config)


def apply_logits_processors(
    input_ids: torch.Tensor,
    logits: torch.Tensor,
    config: GenerationConfig,
) -> torch.Tensor:
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    logits: torch.Tensor,
    config: GenerationConfig,
) -> torch.Tensor:
    """Apply user-supplied logits processors in order."""
    if config.logits_processor is None:
        return logits
    processed = logits
    for processor in config.logits_processor:
        processed = processor(input_ids, processed)
        if not isinstance(processed, torch.Tensor):
            raise ValueError("logits_processor must return a tensor")
    return processed


def should_stop_generation(
    input_ids: torch.Tensor,
    logits: torch.Tensor,
    config: GenerationConfig,
) -> bool:
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    logits: torch.Tensor,
    config: GenerationConfig,
) -> bool:
    """Return whether any configured stopping criterion requests termination."""
    if config.stopping_criteria is None:
        return False
    for criterion in config.stopping_criteria:
        result = criterion(input_ids, logits)
        if isinstance(result, torch.Tensor):
            if result.numel() != 1:
                raise ValueError("stopping_criteria tensor output must be scalar")
            result = bool(result.item())
        if bool(result):
            return True
    return False


def build_causal_mask(
    input_ids: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
    attention_mask: Optional[torch.Tensor] = None,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """Build additive causal + padding mask for prefill."""
    if input_ids.ndim != 2:
        raise ValueError("input_ids must have shape (batch, seq)")
    batch_size, seq_len = input_ids.shape
    device = input_ids.device
    mask = torch.zeros(
        batch_size, 1, seq_len, seq_len, device=device, dtype=dtype,
    )
    causal = torch.triu(
        torch.full((seq_len, seq_len), float("-inf"), device=device, dtype=dtype),
        diagonal=1,
    )
    mask = mask + causal.view(1, 1, seq_len, seq_len)
    if attention_mask is not None:
        if attention_mask.shape != input_ids.shape:
            raise ValueError("attention_mask must match input_ids shape")
        padding = attention_mask.to(device=device) == 0
        mask = mask.masked_fill(padding.view(batch_size, 1, 1, seq_len), float("-inf"))
        mask = mask.masked_fill(padding.view(batch_size, 1, seq_len, 1), 0.0)
    return mask


def append_attention_mask(
    attention_mask: Optional[torch.Tensor],
    next_tokens: torch.Tensor,
) -> Optional[torch.Tensor]:
    """Append valid-token mask entries for generated tokens."""
    if attention_mask is None:
        return None
    ones = torch.ones(
        next_tokens.shape,
        device=attention_mask.device,
        dtype=attention_mask.dtype,
    )
    return torch.cat([attention_mask, ones], dim=-1)