Update custom_generate/generate.py
Browse files- custom_generate/generate.py +58 -114
    	
        custom_generate/generate.py
    CHANGED
    
    | @@ -1,18 +1,22 @@ | |
| 1 | 
            -
             | 
|  | |
|  | |
| 2 | 
             
            import torch
         | 
| 3 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 4 | 
             
            from transformers.generation.utils import (
         | 
| 5 | 
            -
                 | 
| 6 | 
            -
                GenerateNonBeamOutput,
         | 
| 7 | 
             
                GenerateDecoderOnlyOutput,
         | 
|  | |
|  | |
|  | |
| 8 | 
             
            )
         | 
| 9 | 
            -
            from transformers.cache_utils import Cache, EncoderDecoderCache, DynamicCache
         | 
| 10 | 
             
            from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
         | 
| 11 | 
            -
            from transformers.generation.utils import GenerateEncoderDecoderOutput, ALL_CACHE_NAMES
         | 
| 12 | 
             
            from transformers.utils import ModelOutput
         | 
| 13 | 
            -
             | 
| 14 | 
            -
            import torch.nn as nn
         | 
| 15 | 
            -
            import logging
         | 
| 16 |  | 
| 17 | 
             
            if TYPE_CHECKING:
         | 
| 18 | 
             
                from transformers.generation.streamers import BaseStreamer
         | 
| @@ -20,9 +24,7 @@ if TYPE_CHECKING: | |
| 20 | 
             
            logger = logging.getLogger(__name__)
         | 
| 21 |  | 
| 22 |  | 
| 23 | 
            -
            def stack_model_outputs(
         | 
| 24 | 
            -
                model_outputs: list[ModelOutput], config: PretrainedConfig
         | 
| 25 | 
            -
            ) -> ModelOutput:
         | 
| 26 | 
             
                """
         | 
| 27 | 
             
                Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
         | 
| 28 | 
             
                specific ModelOutput subclass from the list provided.
         | 
| @@ -50,17 +52,11 @@ def stack_model_outputs( | |
| 50 | 
             
                        # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
         | 
| 51 | 
             
                        if isinstance(data[0][0], tuple):
         | 
| 52 | 
             
                            return tuple(
         | 
| 53 | 
            -
                                tuple(
         | 
| 54 | 
            -
                                    torch.cat([attr[i][j] for attr in data], dim=0)
         | 
| 55 | 
            -
                                    for j in range(len(data[0][0]))
         | 
| 56 | 
            -
                                )
         | 
| 57 | 
             
                                for i in range(len(data[0]))
         | 
| 58 | 
             
                            )
         | 
| 59 | 
             
                        else:
         | 
| 60 | 
            -
                            return tuple(
         | 
| 61 | 
            -
                                torch.cat([attr[i] for attr in data], dim=0)
         | 
| 62 | 
            -
                                for i in range(len(data[0]))
         | 
| 63 | 
            -
                            )
         | 
| 64 | 
             
                    elif isinstance(data[0], (int, float)):
         | 
| 65 | 
             
                        # If the elements are integers or floats, return a tensor
         | 
| 66 | 
             
                        return torch.tensor(data)
         | 
| @@ -92,9 +88,7 @@ def _ranking_fast( | |
| 92 | 
             
                """
         | 
| 93 | 
             
                norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
         | 
| 94 | 
             
                norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
         | 
| 95 | 
            -
                cosine_matrix = torch.matmul(
         | 
| 96 | 
            -
                    norm_context_hidden, norm_next_hidden.transpose(1, 2)
         | 
| 97 | 
            -
                ).squeeze(-1)  # [B*K, S]
         | 
| 98 |  | 
| 99 | 
             
                # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
         | 
| 100 | 
             
                # Using a large negative value for masked positions
         | 
| @@ -105,9 +99,7 @@ def _ranking_fast( | |
| 105 | 
             
                degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1)  # [B*K]
         | 
| 106 | 
             
                next_top_k_probs = next_top_k_probs.view(-1)  # [B*K]
         | 
| 107 | 
             
                contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
         | 
| 108 | 
            -
                contrastive_score = torch.stack(
         | 
| 109 | 
            -
                    torch.split(contrastive_score, beam_width)
         | 
| 110 | 
            -
                )  # [B, K]
         | 
| 111 | 
             
                _, selected_idx = contrastive_score.max(dim=-1)  # [B]
         | 
| 112 | 
             
                return selected_idx
         | 
| 113 |  | 
| @@ -163,9 +155,7 @@ def _contrastive_search( | |
| 163 | 
             
                        f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
         | 
| 164 | 
             
                    )
         | 
| 165 | 
             
                # init values
         | 
| 166 | 
            -
                has_eos_stopping_criteria = any(
         | 
| 167 | 
            -
                    hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
         | 
| 168 | 
            -
                )
         | 
| 169 | 
             
                top_k = generation_config.top_k
         | 
| 170 | 
             
                penalty_alpha = generation_config.penalty_alpha
         | 
| 171 | 
             
                pad_token_id = generation_config._pad_token_tensor
         | 
| @@ -181,39 +171,22 @@ def _contrastive_search( | |
| 181 | 
             
                scores = () if (return_dict_in_generate and output_scores) else None
         | 
| 182 | 
             
                decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
         | 
| 183 | 
             
                cross_attentions = () if (return_dict_in_generate and output_attentions) else None
         | 
| 184 | 
            -
                decoder_hidden_states = (
         | 
| 185 | 
            -
                    () if (return_dict_in_generate and output_hidden_states) else None
         | 
| 186 | 
            -
                )
         | 
| 187 |  | 
| 188 | 
             
                # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
         | 
| 189 | 
             
                if return_dict_in_generate and model.config.is_encoder_decoder:
         | 
| 190 | 
            -
                    encoder_attentions = (
         | 
| 191 | 
            -
             | 
| 192 | 
            -
                        if output_attentions
         | 
| 193 | 
            -
                        else None
         | 
| 194 | 
            -
                    )
         | 
| 195 | 
            -
                    encoder_hidden_states = (
         | 
| 196 | 
            -
                        model_kwargs["encoder_outputs"].get("hidden_states")
         | 
| 197 | 
            -
                        if output_hidden_states
         | 
| 198 | 
            -
                        else None
         | 
| 199 | 
            -
                    )
         | 
| 200 |  | 
| 201 | 
             
                # keep track of which sequences are already finished
         | 
| 202 | 
             
                batch_size, cur_len = input_ids.shape[:2]
         | 
| 203 | 
            -
                unfinished_sequences = torch.ones(
         | 
| 204 | 
            -
             | 
| 205 | 
            -
                )
         | 
| 206 | 
            -
                model_kwargs = model._get_initial_cache_position(
         | 
| 207 | 
            -
                    cur_len, input_ids.device, model_kwargs
         | 
| 208 | 
            -
                )
         | 
| 209 |  | 
| 210 | 
             
                # Create cosine_matrix_mask based on the attention_mask
         | 
| 211 | 
             
                cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
         | 
| 212 | 
             
                if model.config.is_encoder_decoder:
         | 
| 213 | 
            -
                    if  | 
| 214 | 
            -
                        "decoder_attention_mask" in model_kwargs
         | 
| 215 | 
            -
                        and model_kwargs["decoder_attention_mask"] is not None
         | 
| 216 | 
            -
                    ):
         | 
| 217 | 
             
                        cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
         | 
| 218 | 
             
                else:
         | 
| 219 | 
             
                    cosine_matrix_mask = model_kwargs["attention_mask"]
         | 
| @@ -221,9 +194,7 @@ def _contrastive_search( | |
| 221 |  | 
| 222 | 
             
                this_peer_finished = False
         | 
| 223 |  | 
| 224 | 
            -
                while model._has_unfinished_sequences(
         | 
| 225 | 
            -
                    this_peer_finished, synced_gpus, device=input_ids.device
         | 
| 226 | 
            -
                ):
         | 
| 227 | 
             
                    # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
         | 
| 228 | 
             
                    # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
         | 
| 229 | 
             
                    if model_kwargs.get("past_key_values") is None or (
         | 
| @@ -232,9 +203,7 @@ def _contrastive_search( | |
| 232 | 
             
                    ):
         | 
| 233 | 
             
                        # prepare inputs
         | 
| 234 | 
             
                        model_kwargs["use_cache"] = True
         | 
| 235 | 
            -
                        model_inputs = model.prepare_inputs_for_generation(
         | 
| 236 | 
            -
                            input_ids, **model_kwargs
         | 
| 237 | 
            -
                        )
         | 
| 238 |  | 
| 239 | 
             
                        # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
         | 
| 240 | 
             
                        # the `encoder_outputs`
         | 
| @@ -256,9 +225,7 @@ def _contrastive_search( | |
| 256 | 
             
                        # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
         | 
| 257 | 
             
                        # (the clone itmodel is always small)
         | 
| 258 | 
             
                        # torch.float32 is needed to retain precision for later logits manipulations
         | 
| 259 | 
            -
                        logit_for_next_step = outputs.logits[:, -1, :].to(
         | 
| 260 | 
            -
                            copy=True, dtype=torch.float32, device=input_ids.device
         | 
| 261 | 
            -
                        )
         | 
| 262 |  | 
| 263 | 
             
                        model_kwargs = model._update_model_kwargs_for_generation(
         | 
| 264 | 
             
                            outputs,
         | 
| @@ -282,7 +249,18 @@ def _contrastive_search( | |
| 282 | 
             
                                f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
         | 
| 283 | 
             
                                "for contrastive search."
         | 
| 284 | 
             
                            )
         | 
| 285 | 
            -
                        
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 286 |  | 
| 287 | 
             
                    # contrastive_search main logic start:
         | 
| 288 | 
             
                    # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
         | 
| @@ -300,18 +278,14 @@ def _contrastive_search( | |
| 300 | 
             
                            scores += (processed_logit_for_next_step,)
         | 
| 301 | 
             
                        if output_attentions:
         | 
| 302 | 
             
                            decoder_attentions += (
         | 
| 303 | 
            -
                                (outputs.decoder_attentions,)
         | 
| 304 | 
            -
                                if model.config.is_encoder_decoder
         | 
| 305 | 
            -
                                else (outputs.attentions,)
         | 
| 306 | 
             
                            )
         | 
| 307 | 
             
                            if model.config.is_encoder_decoder:
         | 
| 308 | 
             
                                cross_attentions += (outputs.cross_attentions,)
         | 
| 309 |  | 
| 310 | 
             
                        if output_hidden_states:
         | 
| 311 | 
             
                            decoder_hidden_states += (
         | 
| 312 | 
            -
                                (outputs.decoder_hidden_states,)
         | 
| 313 | 
            -
                                if model.config.is_encoder_decoder
         | 
| 314 | 
            -
                                else (outputs.hidden_states,)
         | 
| 315 | 
             
                            )
         | 
| 316 |  | 
| 317 | 
             
                    # This is needed to properly delete outputs.logits which may be very large for this first iteration
         | 
| @@ -323,8 +297,7 @@ def _contrastive_search( | |
| 323 | 
             
                        past = model_kwargs["past_key_values"]
         | 
| 324 | 
             
                        # If it is a static cache, modify it in-place layer after layer to save memory
         | 
| 325 | 
             
                        if isinstance(past, DynamicCache) or (
         | 
| 326 | 
            -
                            isinstance(past, EncoderDecoderCache)
         | 
| 327 | 
            -
                            and isinstance(past.self_attention_cache, DynamicCache)
         | 
| 328 | 
             
                        ):
         | 
| 329 | 
             
                            past.batch_repeat_interleave(top_k)
         | 
| 330 | 
             
                        else:
         | 
| @@ -344,9 +317,7 @@ def _contrastive_search( | |
| 344 | 
             
                        all_outputs = []
         | 
| 345 | 
             
                        for i in range(top_k):
         | 
| 346 | 
             
                            # compute the candidate tokens by the language model and collect their hidden_states
         | 
| 347 | 
            -
                            next_model_inputs = model.prepare_inputs_for_generation(
         | 
| 348 | 
            -
                                top_k_ids[:, i].view(-1, 1), **model_kwargs
         | 
| 349 | 
            -
                            )
         | 
| 350 |  | 
| 351 | 
             
                            outputs = model(
         | 
| 352 | 
             
                                **next_model_inputs,
         | 
| @@ -356,9 +327,7 @@ def _contrastive_search( | |
| 356 | 
             
                            )
         | 
| 357 | 
             
                            if isinstance(outputs["past_key_values"], DynamicCache) or (
         | 
| 358 | 
             
                                isinstance(outputs["past_key_values"], EncoderDecoderCache)
         | 
| 359 | 
            -
                                and isinstance(
         | 
| 360 | 
            -
                                    outputs["past_key_values"].self_attention_cache, DynamicCache
         | 
| 361 | 
            -
                                )
         | 
| 362 | 
             
                            ):
         | 
| 363 | 
             
                                # Remove past K-V from output since we don't need to stack later
         | 
| 364 | 
             
                                outputs["past_key_values"] = None
         | 
| @@ -376,9 +345,7 @@ def _contrastive_search( | |
| 376 | 
             
                    else:
         | 
| 377 | 
             
                        # compute the candidate tokens by the language model and collect their hidden_states
         | 
| 378 | 
             
                        # assembles top_k_ids into batch of size k
         | 
| 379 | 
            -
                        next_model_inputs = model.prepare_inputs_for_generation(
         | 
| 380 | 
            -
                            top_k_ids.view(-1, 1), **model_kwargs
         | 
| 381 | 
            -
                        )
         | 
| 382 |  | 
| 383 | 
             
                        outputs = model(
         | 
| 384 | 
             
                            **next_model_inputs,
         | 
| @@ -424,9 +391,7 @@ def _contrastive_search( | |
| 424 | 
             
                    selected_idx = selected_idx.to("cpu")
         | 
| 425 |  | 
| 426 | 
             
                    # This will be used instead of the previous inneficient torch.stack(torch.split())
         | 
| 427 | 
            -
                    augmented_idx = torch.tensor(
         | 
| 428 | 
            -
                        [x + i * top_k for i, x in enumerate(selected_idx)]
         | 
| 429 | 
            -
                    )
         | 
| 430 |  | 
| 431 | 
             
                    # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
         | 
| 432 | 
             
                    # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
         | 
| @@ -434,15 +399,11 @@ def _contrastive_search( | |
| 434 | 
             
                    next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
         | 
| 435 | 
             
                    next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
         | 
| 436 | 
             
                    next_hidden = next_hidden[range(batch_size), selected_idx, :]
         | 
| 437 | 
            -
                    last_hidden_states = torch.cat(
         | 
| 438 | 
            -
                        [last_hidden_states, next_hidden.unsqueeze(1)], dim=1
         | 
| 439 | 
            -
                    )
         | 
| 440 |  | 
| 441 | 
             
                    next_decoder_hidden_states = ()
         | 
| 442 | 
             
                    for layer in full_hidden_states:
         | 
| 443 | 
            -
                        layer = torch.stack(torch.split(layer, top_k))[
         | 
| 444 | 
            -
                            range(batch_size), selected_idx, :
         | 
| 445 | 
            -
                        ]
         | 
| 446 | 
             
                        next_decoder_hidden_states += (layer,)
         | 
| 447 |  | 
| 448 | 
             
                    # generate past_key_values cache of only the selected token
         | 
| @@ -462,9 +423,7 @@ def _contrastive_search( | |
| 462 | 
             
                    else:
         | 
| 463 | 
             
                        next_past_key_values = None
         | 
| 464 | 
             
                        for possible_cache_name in ALL_CACHE_NAMES:
         | 
| 465 | 
            -
                            next_past_key_values = next_past_key_values or getattr(
         | 
| 466 | 
            -
                                outputs, possible_cache_name, None
         | 
| 467 | 
            -
                            )
         | 
| 468 | 
             
                        # Do it in-place layer per layer to save memory
         | 
| 469 | 
             
                        if isinstance(next_past_key_values, DynamicCache) or (
         | 
| 470 | 
             
                            isinstance(next_past_key_values, EncoderDecoderCache)
         | 
| @@ -482,9 +441,7 @@ def _contrastive_search( | |
| 482 |  | 
| 483 | 
             
                            next_past_key_values = tuple(new_key_values)
         | 
| 484 |  | 
| 485 | 
            -
                    logit_for_next_step = torch.stack(torch.split(logits, top_k))[
         | 
| 486 | 
            -
                        range(batch_size), selected_idx, :
         | 
| 487 | 
            -
                    ]
         | 
| 488 | 
             
                    logit_for_next_step = logit_for_next_step.to(input_ids.device)
         | 
| 489 |  | 
| 490 | 
             
                    # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
         | 
| @@ -493,14 +450,10 @@ def _contrastive_search( | |
| 493 | 
             
                        next_step_decoder_attentions = ()
         | 
| 494 | 
             
                        if output_attentions:
         | 
| 495 | 
             
                            for layer in outputs.cross_attentions:
         | 
| 496 | 
            -
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[
         | 
| 497 | 
            -
                                    range(batch_size), selected_idx, ...
         | 
| 498 | 
            -
                                ]
         | 
| 499 | 
             
                                next_step_cross_attentions += (layer,)
         | 
| 500 | 
             
                            for layer in outputs.decoder_attentions:
         | 
| 501 | 
            -
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[
         | 
| 502 | 
            -
                                    range(batch_size), selected_idx, ...
         | 
| 503 | 
            -
                                ]
         | 
| 504 | 
             
                                next_step_decoder_attentions += (layer,)
         | 
| 505 | 
             
                        outputs = Seq2SeqLMOutput(
         | 
| 506 | 
             
                            past_key_values=next_past_key_values,
         | 
| @@ -512,9 +465,7 @@ def _contrastive_search( | |
| 512 | 
             
                        next_step_attentions = ()
         | 
| 513 | 
             
                        if output_attentions:
         | 
| 514 | 
             
                            for layer in outputs.attentions:
         | 
| 515 | 
            -
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[
         | 
| 516 | 
            -
                                    range(batch_size), selected_idx, ...
         | 
| 517 | 
            -
                                ]
         | 
| 518 | 
             
                                next_step_attentions += (layer,)
         | 
| 519 | 
             
                        outputs = CausalLMOutputWithPast(
         | 
| 520 | 
             
                            past_key_values=next_past_key_values,
         | 
| @@ -534,9 +485,7 @@ def _contrastive_search( | |
| 534 |  | 
| 535 | 
             
                    # finished sentences should have their next token be a padding token
         | 
| 536 | 
             
                    if has_eos_stopping_criteria:
         | 
| 537 | 
            -
                        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
         | 
| 538 | 
            -
                            1 - unfinished_sequences
         | 
| 539 | 
            -
                        )
         | 
| 540 |  | 
| 541 | 
             
                    # update generated ids, model inputs, and length for next step
         | 
| 542 | 
             
                    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
         | 
| @@ -544,9 +493,7 @@ def _contrastive_search( | |
| 544 | 
             
                        streamer.put(next_tokens.cpu())
         | 
| 545 |  | 
| 546 | 
             
                    # stop when each sentence is finished
         | 
| 547 | 
            -
                    unfinished_sequences = unfinished_sequences & ~stopping_criteria(
         | 
| 548 | 
            -
                        input_ids, scores
         | 
| 549 | 
            -
                    )
         | 
| 550 | 
             
                    this_peer_finished = unfinished_sequences.max() == 0
         | 
| 551 |  | 
| 552 | 
             
                if streamer is not None:
         | 
| @@ -558,9 +505,7 @@ def _contrastive_search( | |
| 558 | 
             
                    if model_kwargs.get("past_key_values") is not None:
         | 
| 559 | 
             
                        if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
         | 
| 560 | 
             
                            isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
         | 
| 561 | 
            -
                            and isinstance(
         | 
| 562 | 
            -
                                model_kwargs["past_key_values"].self_attention_cache, DynamicCache
         | 
| 563 | 
            -
                            )
         | 
| 564 | 
             
                        ):
         | 
| 565 | 
             
                            model_kwargs["past_key_values"].crop(-1)
         | 
| 566 | 
             
                        else:
         | 
| @@ -607,8 +552,7 @@ def generate(model, *args, **kwargs): | |
| 607 | 
             
                """
         | 
| 608 | 
             
                cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
         | 
| 609 | 
             
                if cache_implementation != "dynamic_full" and (
         | 
| 610 | 
            -
                    "sliding_attention"
         | 
| 611 | 
            -
                    in getattr(model.config.get_text_config(), "layer_types", [])
         | 
| 612 | 
             
                    or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
         | 
| 613 | 
             
                ):
         | 
| 614 | 
             
                    logger.warning_once(
         | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            from typing import TYPE_CHECKING, Optional, Union
         | 
| 3 | 
            +
             | 
| 4 | 
             
            import torch
         | 
| 5 | 
            +
            import torch.nn as nn
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
         | 
| 8 | 
            +
            from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
         | 
| 9 | 
            +
            from transformers.configuration_utils import PretrainedConfig
         | 
| 10 | 
             
            from transformers.generation.utils import (
         | 
| 11 | 
            +
                ALL_CACHE_NAMES,
         | 
|  | |
| 12 | 
             
                GenerateDecoderOnlyOutput,
         | 
| 13 | 
            +
                GenerateEncoderDecoderOutput,
         | 
| 14 | 
            +
                GenerateNonBeamOutput,
         | 
| 15 | 
            +
                GenerationMixin,
         | 
| 16 | 
             
            )
         | 
|  | |
| 17 | 
             
            from transformers.modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
         | 
|  | |
| 18 | 
             
            from transformers.utils import ModelOutput
         | 
| 19 | 
            +
             | 
|  | |
|  | |
| 20 |  | 
| 21 | 
             
            if TYPE_CHECKING:
         | 
| 22 | 
             
                from transformers.generation.streamers import BaseStreamer
         | 
|  | |
| 24 | 
             
            logger = logging.getLogger(__name__)
         | 
| 25 |  | 
| 26 |  | 
| 27 | 
            +
            def stack_model_outputs(model_outputs: list[ModelOutput], config: PretrainedConfig) -> ModelOutput:
         | 
|  | |
|  | |
| 28 | 
             
                """
         | 
| 29 | 
             
                Stack a list of ModelOutput objects (or its subclasses) along the batch_size dimension. The function infers the
         | 
| 30 | 
             
                specific ModelOutput subclass from the list provided.
         | 
|  | |
| 52 | 
             
                        # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
         | 
| 53 | 
             
                        if isinstance(data[0][0], tuple):
         | 
| 54 | 
             
                            return tuple(
         | 
| 55 | 
            +
                                tuple(torch.cat([attr[i][j] for attr in data], dim=0) for j in range(len(data[0][0])))
         | 
|  | |
|  | |
|  | |
| 56 | 
             
                                for i in range(len(data[0]))
         | 
| 57 | 
             
                            )
         | 
| 58 | 
             
                        else:
         | 
| 59 | 
            +
                            return tuple(torch.cat([attr[i] for attr in data], dim=0) for i in range(len(data[0])))
         | 
|  | |
|  | |
|  | |
| 60 | 
             
                    elif isinstance(data[0], (int, float)):
         | 
| 61 | 
             
                        # If the elements are integers or floats, return a tensor
         | 
| 62 | 
             
                        return torch.tensor(data)
         | 
|  | |
| 88 | 
             
                """
         | 
| 89 | 
             
                norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
         | 
| 90 | 
             
                norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
         | 
| 91 | 
            +
                cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1, 2)).squeeze(-1)  # [B*K, S]
         | 
|  | |
|  | |
| 92 |  | 
| 93 | 
             
                # Penalize cosine_matrix based on the cosine_matrix_mask (ignore padding positions)
         | 
| 94 | 
             
                # Using a large negative value for masked positions
         | 
|  | |
| 99 | 
             
                degeneration_penalty, _ = torch.max(cosine_matrix, dim=-1)  # [B*K]
         | 
| 100 | 
             
                next_top_k_probs = next_top_k_probs.view(-1)  # [B*K]
         | 
| 101 | 
             
                contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty
         | 
| 102 | 
            +
                contrastive_score = torch.stack(torch.split(contrastive_score, beam_width))  # [B, K]
         | 
|  | |
|  | |
| 103 | 
             
                _, selected_idx = contrastive_score.max(dim=-1)  # [B]
         | 
| 104 | 
             
                return selected_idx
         | 
| 105 |  | 
|  | |
| 155 | 
             
                        f"contrastive search is not supported with stateful models, such as {model.__class__.__name__}"
         | 
| 156 | 
             
                    )
         | 
| 157 | 
             
                # init values
         | 
| 158 | 
            +
                has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
         | 
|  | |
|  | |
| 159 | 
             
                top_k = generation_config.top_k
         | 
| 160 | 
             
                penalty_alpha = generation_config.penalty_alpha
         | 
| 161 | 
             
                pad_token_id = generation_config._pad_token_tensor
         | 
|  | |
| 171 | 
             
                scores = () if (return_dict_in_generate and output_scores) else None
         | 
| 172 | 
             
                decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
         | 
| 173 | 
             
                cross_attentions = () if (return_dict_in_generate and output_attentions) else None
         | 
| 174 | 
            +
                decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
         | 
|  | |
|  | |
| 175 |  | 
| 176 | 
             
                # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
         | 
| 177 | 
             
                if return_dict_in_generate and model.config.is_encoder_decoder:
         | 
| 178 | 
            +
                    encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
         | 
| 179 | 
            +
                    encoder_hidden_states = model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 180 |  | 
| 181 | 
             
                # keep track of which sequences are already finished
         | 
| 182 | 
             
                batch_size, cur_len = input_ids.shape[:2]
         | 
| 183 | 
            +
                unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
         | 
| 184 | 
            +
                model_kwargs = model._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
         | 
|  | |
|  | |
|  | |
|  | |
| 185 |  | 
| 186 | 
             
                # Create cosine_matrix_mask based on the attention_mask
         | 
| 187 | 
             
                cosine_matrix_mask = torch.ones_like(input_ids, dtype=torch.long)
         | 
| 188 | 
             
                if model.config.is_encoder_decoder:
         | 
| 189 | 
            +
                    if "decoder_attention_mask" in model_kwargs and model_kwargs["decoder_attention_mask"] is not None:
         | 
|  | |
|  | |
|  | |
| 190 | 
             
                        cosine_matrix_mask = model_kwargs["decoder_attention_mask"]
         | 
| 191 | 
             
                else:
         | 
| 192 | 
             
                    cosine_matrix_mask = model_kwargs["attention_mask"]
         | 
|  | |
| 194 |  | 
| 195 | 
             
                this_peer_finished = False
         | 
| 196 |  | 
| 197 | 
            +
                while model._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
         | 
|  | |
|  | |
| 198 | 
             
                    # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
         | 
| 199 | 
             
                    # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
         | 
| 200 | 
             
                    if model_kwargs.get("past_key_values") is None or (
         | 
|  | |
| 203 | 
             
                    ):
         | 
| 204 | 
             
                        # prepare inputs
         | 
| 205 | 
             
                        model_kwargs["use_cache"] = True
         | 
| 206 | 
            +
                        model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
         | 
|  | |
|  | |
| 207 |  | 
| 208 | 
             
                        # encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
         | 
| 209 | 
             
                        # the `encoder_outputs`
         | 
|  | |
| 225 | 
             
                        # Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
         | 
| 226 | 
             
                        # (the clone itmodel is always small)
         | 
| 227 | 
             
                        # torch.float32 is needed to retain precision for later logits manipulations
         | 
| 228 | 
            +
                        logit_for_next_step = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
         | 
|  | |
|  | |
| 229 |  | 
| 230 | 
             
                        model_kwargs = model._update_model_kwargs_for_generation(
         | 
| 231 | 
             
                            outputs,
         | 
|  | |
| 249 | 
             
                                f"{model.__class__.__name__} does not support caching and therefore **can't** be used "
         | 
| 250 | 
             
                                "for contrastive search."
         | 
| 251 | 
             
                            )
         | 
| 252 | 
            +
                        # Only those caches have the necesary methods
         | 
| 253 | 
            +
                        elif not (
         | 
| 254 | 
            +
                            isinstance(past_key_values, DynamicCache)
         | 
| 255 | 
            +
                            or (
         | 
| 256 | 
            +
                                isinstance(past_key_values, EncoderDecoderCache)
         | 
| 257 | 
            +
                                and isinstance(past_key_values.self_attention_cache, DynamicCache)
         | 
| 258 | 
            +
                            )
         | 
| 259 | 
            +
                        ):
         | 
| 260 | 
            +
                            raise ValueError(
         | 
| 261 | 
            +
                                f"Unsupported cache type: {type(outputs['past_key_values'])}. Contrastive search requires "
         | 
| 262 | 
            +
                                "dynamic cache, so set `cache_implementation='dynamic'` in the generation config."
         | 
| 263 | 
            +
                            )
         | 
| 264 |  | 
| 265 | 
             
                    # contrastive_search main logic start:
         | 
| 266 | 
             
                    # contrastive search decoding consists of two steps: (1) candidate tokens recall; (2) candidate re-rank by
         | 
|  | |
| 278 | 
             
                            scores += (processed_logit_for_next_step,)
         | 
| 279 | 
             
                        if output_attentions:
         | 
| 280 | 
             
                            decoder_attentions += (
         | 
| 281 | 
            +
                                (outputs.decoder_attentions,) if model.config.is_encoder_decoder else (outputs.attentions,)
         | 
|  | |
|  | |
| 282 | 
             
                            )
         | 
| 283 | 
             
                            if model.config.is_encoder_decoder:
         | 
| 284 | 
             
                                cross_attentions += (outputs.cross_attentions,)
         | 
| 285 |  | 
| 286 | 
             
                        if output_hidden_states:
         | 
| 287 | 
             
                            decoder_hidden_states += (
         | 
| 288 | 
            +
                                (outputs.decoder_hidden_states,) if model.config.is_encoder_decoder else (outputs.hidden_states,)
         | 
|  | |
|  | |
| 289 | 
             
                            )
         | 
| 290 |  | 
| 291 | 
             
                    # This is needed to properly delete outputs.logits which may be very large for this first iteration
         | 
|  | |
| 297 | 
             
                        past = model_kwargs["past_key_values"]
         | 
| 298 | 
             
                        # If it is a static cache, modify it in-place layer after layer to save memory
         | 
| 299 | 
             
                        if isinstance(past, DynamicCache) or (
         | 
| 300 | 
            +
                            isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
         | 
|  | |
| 301 | 
             
                        ):
         | 
| 302 | 
             
                            past.batch_repeat_interleave(top_k)
         | 
| 303 | 
             
                        else:
         | 
|  | |
| 317 | 
             
                        all_outputs = []
         | 
| 318 | 
             
                        for i in range(top_k):
         | 
| 319 | 
             
                            # compute the candidate tokens by the language model and collect their hidden_states
         | 
| 320 | 
            +
                            next_model_inputs = model.prepare_inputs_for_generation(top_k_ids[:, i].view(-1, 1), **model_kwargs)
         | 
|  | |
|  | |
| 321 |  | 
| 322 | 
             
                            outputs = model(
         | 
| 323 | 
             
                                **next_model_inputs,
         | 
|  | |
| 327 | 
             
                            )
         | 
| 328 | 
             
                            if isinstance(outputs["past_key_values"], DynamicCache) or (
         | 
| 329 | 
             
                                isinstance(outputs["past_key_values"], EncoderDecoderCache)
         | 
| 330 | 
            +
                                and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
         | 
|  | |
|  | |
| 331 | 
             
                            ):
         | 
| 332 | 
             
                                # Remove past K-V from output since we don't need to stack later
         | 
| 333 | 
             
                                outputs["past_key_values"] = None
         | 
|  | |
| 345 | 
             
                    else:
         | 
| 346 | 
             
                        # compute the candidate tokens by the language model and collect their hidden_states
         | 
| 347 | 
             
                        # assembles top_k_ids into batch of size k
         | 
| 348 | 
            +
                        next_model_inputs = model.prepare_inputs_for_generation(top_k_ids.view(-1, 1), **model_kwargs)
         | 
|  | |
|  | |
| 349 |  | 
| 350 | 
             
                        outputs = model(
         | 
| 351 | 
             
                            **next_model_inputs,
         | 
|  | |
| 391 | 
             
                    selected_idx = selected_idx.to("cpu")
         | 
| 392 |  | 
| 393 | 
             
                    # This will be used instead of the previous inneficient torch.stack(torch.split())
         | 
| 394 | 
            +
                    augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)])
         | 
|  | |
|  | |
| 395 |  | 
| 396 | 
             
                    # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
         | 
| 397 | 
             
                    # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
         | 
|  | |
| 399 | 
             
                    next_tokens = top_k_ids[range(len(top_k_ids)), selected_idx]
         | 
| 400 | 
             
                    next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), top_k))
         | 
| 401 | 
             
                    next_hidden = next_hidden[range(batch_size), selected_idx, :]
         | 
| 402 | 
            +
                    last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
         | 
|  | |
|  | |
| 403 |  | 
| 404 | 
             
                    next_decoder_hidden_states = ()
         | 
| 405 | 
             
                    for layer in full_hidden_states:
         | 
| 406 | 
            +
                        layer = torch.stack(torch.split(layer, top_k))[range(batch_size), selected_idx, :]
         | 
|  | |
|  | |
| 407 | 
             
                        next_decoder_hidden_states += (layer,)
         | 
| 408 |  | 
| 409 | 
             
                    # generate past_key_values cache of only the selected token
         | 
|  | |
| 423 | 
             
                    else:
         | 
| 424 | 
             
                        next_past_key_values = None
         | 
| 425 | 
             
                        for possible_cache_name in ALL_CACHE_NAMES:
         | 
| 426 | 
            +
                            next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
         | 
|  | |
|  | |
| 427 | 
             
                        # Do it in-place layer per layer to save memory
         | 
| 428 | 
             
                        if isinstance(next_past_key_values, DynamicCache) or (
         | 
| 429 | 
             
                            isinstance(next_past_key_values, EncoderDecoderCache)
         | 
|  | |
| 441 |  | 
| 442 | 
             
                            next_past_key_values = tuple(new_key_values)
         | 
| 443 |  | 
| 444 | 
            +
                    logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
         | 
|  | |
|  | |
| 445 | 
             
                    logit_for_next_step = logit_for_next_step.to(input_ids.device)
         | 
| 446 |  | 
| 447 | 
             
                    # Rebuilds the relevant parts of the model output for the selected token, for use in the next iteration
         | 
|  | |
| 450 | 
             
                        next_step_decoder_attentions = ()
         | 
| 451 | 
             
                        if output_attentions:
         | 
| 452 | 
             
                            for layer in outputs.cross_attentions:
         | 
| 453 | 
            +
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
         | 
|  | |
|  | |
| 454 | 
             
                                next_step_cross_attentions += (layer,)
         | 
| 455 | 
             
                            for layer in outputs.decoder_attentions:
         | 
| 456 | 
            +
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
         | 
|  | |
|  | |
| 457 | 
             
                                next_step_decoder_attentions += (layer,)
         | 
| 458 | 
             
                        outputs = Seq2SeqLMOutput(
         | 
| 459 | 
             
                            past_key_values=next_past_key_values,
         | 
|  | |
| 465 | 
             
                        next_step_attentions = ()
         | 
| 466 | 
             
                        if output_attentions:
         | 
| 467 | 
             
                            for layer in outputs.attentions:
         | 
| 468 | 
            +
                                layer = torch.stack(torch.split(layer, top_k, dim=0))[range(batch_size), selected_idx, ...]
         | 
|  | |
|  | |
| 469 | 
             
                                next_step_attentions += (layer,)
         | 
| 470 | 
             
                        outputs = CausalLMOutputWithPast(
         | 
| 471 | 
             
                            past_key_values=next_past_key_values,
         | 
|  | |
| 485 |  | 
| 486 | 
             
                    # finished sentences should have their next token be a padding token
         | 
| 487 | 
             
                    if has_eos_stopping_criteria:
         | 
| 488 | 
            +
                        next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
         | 
|  | |
|  | |
| 489 |  | 
| 490 | 
             
                    # update generated ids, model inputs, and length for next step
         | 
| 491 | 
             
                    input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
         | 
|  | |
| 493 | 
             
                        streamer.put(next_tokens.cpu())
         | 
| 494 |  | 
| 495 | 
             
                    # stop when each sentence is finished
         | 
| 496 | 
            +
                    unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
         | 
|  | |
|  | |
| 497 | 
             
                    this_peer_finished = unfinished_sequences.max() == 0
         | 
| 498 |  | 
| 499 | 
             
                if streamer is not None:
         | 
|  | |
| 505 | 
             
                    if model_kwargs.get("past_key_values") is not None:
         | 
| 506 | 
             
                        if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
         | 
| 507 | 
             
                            isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
         | 
| 508 | 
            +
                            and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
         | 
|  | |
|  | |
| 509 | 
             
                        ):
         | 
| 510 | 
             
                            model_kwargs["past_key_values"].crop(-1)
         | 
| 511 | 
             
                        else:
         | 
|  | |
| 552 | 
             
                """
         | 
| 553 | 
             
                cache_implementation = kwargs.pop("cache_implementation", "dynamic_full")
         | 
| 554 | 
             
                if cache_implementation != "dynamic_full" and (
         | 
| 555 | 
            +
                    "sliding_attention" in getattr(model.config.get_text_config(), "layer_types", [])
         | 
|  | |
| 556 | 
             
                    or getattr(model.config.get_text_config(), "sliding_window", 0) > 0
         | 
| 557 | 
             
                ):
         | 
| 558 | 
             
                    logger.warning_once(
         | 

