Skip to content

[quantization][draft] Prefill-decode logic#570

Draft
stamalakhov wants to merge 1 commit intoSamsung:mainfrom
stamalakhov:model_cache_br
Draft

[quantization][draft] Prefill-decode logic#570
stamalakhov wants to merge 1 commit intoSamsung:mainfrom
stamalakhov:model_cache_br

Conversation

@stamalakhov
Copy link
Copy Markdown
Contributor

@stamalakhov stamalakhov commented Mar 23, 2026

This PR implements prefill_decode logic for Llama compatible LLM's.

--prefill_decode run on Maykeye/TinyLLama-v0
Namespace(model='Maykeye/TinyLLama-v0', device='cuda', dtype='float32', seed=42, trust_remote_code=False, hf_token=None, no_tqdm=False, no_GPTQ=True, no_PTQ=False, save_circle_to_folder='.', save_layers_to_folder='.', cache_dir='/mnt/storage/transformers_cache', nsamples_for_qcalibration=128, linear_weight_bits=4, gptq_mse='mse', max_seq_len=2048, calibrate_seq_len=2048, embedding_weight_bits=8, lm_head_weight_bits=4, eval_tasks=None, sensitivity_path=None, prefill_decode=True)
=== Config ===
Model            : Maykeye/TinyLLama-v0
Device           : cuda
DType            : float32

Loading FP model …
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
Warning: tokenizer doesn't have pad_token. Prefill-decoding scheme may fail.

Calculating perplexities …
Token indices sequence length is longer than the specified maximum sequence length for this model (324381 > 2048). Running this sequence through the model will result in indexing errors
PPL:  99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 158/159 [00:08<00:00, 18.21it/s]

┌── Wikitext-2 test perplexity ─────────────
│ FP32 :  7584.31
└───────────────────────────────────────────
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:16<00:00,  7.89it/s]

┌── Wikitext-2 prefill_prefill original test perplexity ─────────────
│ FP32 :     7.27
└───────────────────────────────────────────
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:15<00:00,  8.11it/s]

┌── Wikitext-2 prefill_decode initial train perplexity──
│ FP32 :     8.15
└───────────────────────────────────────────
Wrapping layers with PTQWrapper …
Calibrating PTQ observers…
[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [01:15<00:00,  1.69it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [01:53<00:00,  1.13it/s]

┌── Wikitext-2 prefill_prefill train calibration perplexity ─────────────
│ int16 :     8.40
└───────────────────────────────────────────
Computing calibration set for decode-model
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [00:41<00:00,  3.05it/s]
Wrapping layers with PTQWrapper …
Calibrating PTQ observers…
[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection.
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [01:09<00:00,  1.83it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [01:51<00:00,  1.15it/s]

┌── Wikitext-2 prefill_decode train calibration perplexity ─────────────
│ int16 :     8.40
└───────────────────────────────────────────

Calculating perplexities …
PPL:  99%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 158/159 [00:50<00:00,  3.10it/s]

┌── Wikitext-2 test perplexity ─────────────
│ int16 :  7227.65
└───────────────────────────────────────────
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [01:54<00:00,  1.12it/s]

┌── Wikitext-2 prefill_decode quantized test perplexity ─────────────
│ FP32 :     7.76
└───────────────────────────────────────────
Saving prefill-model layer_0 to /mnt/storage/slow_repos/TICO/decoder_layer_prefill_0.q.circle
Saving prefill-model layer_1 to /mnt/storage/slow_repos/TICO/decoder_layer_prefill_1.q.circle
Saving prefill-model layer_2 to /mnt/storage/slow_repos/TICO/decoder_layer_prefill_2.q.circle
Saving prefill-model layer_3 to /mnt/storage/slow_repos/TICO/decoder_layer_prefill_3.q.circle
Saving prefill-model layer_4 to /mnt/storage/slow_repos/TICO/decoder_layer_prefill_4.q.circle
Saving prefill-model layer_5 to /mnt/storage/slow_repos/TICO/decoder_layer_prefill_5.q.circle
Saving prefill-model layer_6 to /mnt/storage/slow_repos/TICO/decoder_layer_prefill_6.q.circle
Saving prefill-model layer_7 to /mnt/storage/slow_repos/TICO/decoder_layer_prefill_7.q.circle
Saving decode-model layer_0 to /mnt/storage/slow_repos/TICO/decoder_layer_decode_0.q.circle
Saving decode-model layer_1 to /mnt/storage/slow_repos/TICO/decoder_layer_decode_1.q.circle
Saving decode-model layer_2 to /mnt/storage/slow_repos/TICO/decoder_layer_decode_2.q.circle
Saving decode-model layer_3 to /mnt/storage/slow_repos/TICO/decoder_layer_decode_3.q.circle
Saving decode-model layer_4 to /mnt/storage/slow_repos/TICO/decoder_layer_decode_4.q.circle
Saving decode-model layer_5 to /mnt/storage/slow_repos/TICO/decoder_layer_decode_5.q.circle
Saving decode-model layer_6 to /mnt/storage/slow_repos/TICO/decoder_layer_decode_6.q.circle
Saving decode-model layer_7 to /mnt/storage/slow_repos/TICO/decoder_layer_decode_7.q.circle
Saving torch fake_quantized prefill-decode models to /mnt/storage/slow_repos/TICO/prefill_decode_models.q.pt
saving the whole model to /mnt/storage/slow_repos/TICO/model_prefill.q.circle
saving the whole model to /mnt/storage/slow_repos/TICO/model_decode.q.circle
evaluate prefill_decode pipeline of LLama_1B

Namespace(model='unsloth/Llama-3.2-1B-Instruct', device='cuda', dtype='float32', hf_token=None, trust_remote_code=False, cache_dir=None, fk_model_path='./prefill_decode_models_LLama_1B.q.pt', prompt='The capital of France is', eval_tasks=None, max_new_tokens=100)
=== Config ===
Model            : unsloth/Llama-3.2-1B-Instruct
Device           : cuda
DType            : float32
Prompt           : The capital of France is

Loading FP model …
Loading weights: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 146/146 [00:00<00:00, 615.13it/s]
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Original model prompt: The capital of France is Paris. The Eiffel Tower is located in Paris. The Louvre Museum is also located in Paris. The famous French artist Claude Monet was born in Paris. The famous French writer Victor Hugo was born in Paris. The famous French composer Claude Debussy was born in Paris. The famous French painter Pierre-Auguste Renoir was born in Paris. The famous French sculptor Auguste Rodin was born in Paris. The famous French architect Antoni Gaudí was born in Reims, but
Fake quantized model prompt: The capital of France is Paris. The capital of Germany is Berlin. The capital of Sweden is Stockholm. The capital of Norway is Oslo. The capital of Denmark is Copenhagen. The capital of Belgium is Brussels. The capital of the Netherlands is Amsterdam. The capital of Luxembourg is Brussels. The capital of Austria is Vienna. The capital of Switzerland is Bern. The capital of Italy is Rome. The capital of Spain is Madrid. The capital of Portugal is Lisbon. The capital of Greece is Athens. The capital of Turkey is Istanbul

Note: right padding performs poorly, please see https://huggingface.co/docs/transformers/llm_tutorial#wrong-padding-side for details (vllm-project/vllm#5236 also describes issue about padding), that is why left padding is used in the draft.

Related: #586

TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com

@stamalakhov stamalakhov self-assigned this Mar 23, 2026
@stamalakhov stamalakhov force-pushed the model_cache_br branch 3 times, most recently from 23e4c61 to 192349a Compare March 23, 2026 14:31
@stamalakhov stamalakhov force-pushed the model_cache_br branch 4 times, most recently from e20b7aa to 317079e Compare March 31, 2026 07:35
@stamalakhov stamalakhov changed the title [quantization][draft] Ouput kv-tuples [quantization][draft] Prefill-decode logic Mar 31, 2026
@stamalakhov stamalakhov force-pushed the model_cache_br branch 8 times, most recently from 724ae23 to a50716b Compare April 3, 2026 05:55
@stamalakhov stamalakhov force-pushed the model_cache_br branch 8 times, most recently from 94f12ff to f79f531 Compare April 9, 2026 13:56
This PR outputs kv-tuples in case `use_cache` was set.

TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
# print(f"│ FP32 : {prefill_decode_ppl:8.2f}")
# print("└───────────────────────────────────────────")

def run_ptq(self, q_m, calib_inputs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the decode calibration data should ideally be generated from the float/reference prefill path, not from the already-quantized prefill model.

If decode calibration inputs are produced by the quantized prefill model, then the KV cache and related decode inputs already contain prefill quantization error. In that case, decode PTQ is calibrated on a distorted distribution rather than on the original float decode distribution.

A cleaner approach would be:

  1. run prefill calibration batches through the float/reference model,
  2. collect float decode inputs (especially past_key_values),
  3. use those inputs for decode PTQ calibration.

That way, decode calibration remains aligned with the original model behavior, and prefill quantization error does not leak into the decode calibration set.

Copy link
Copy Markdown
Contributor

@mhs4670go mhs4670go Apr 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason to prefer float-based prefill for decode calibration is that PTQ fundamentally tries to approximate the float model behavior. If decode calibration inputs are generated from a quantized prefill model, the input distribution is already distorted, and the optimization target effectively changes. This mixes prefill quantization error into decode calibration and makes it harder to isolate and properly optimize decode quantization.

That being said, it would be useful to compare both approaches, so I’m opening this issue to track it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to #618, this PR should be reworked. A far as i understand, we don't need two torch models. Just single prefill-decode torch model, calibrated with use_cache enabled . So i'll be back wirh reworked PR.

return inputs


def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO, I think the current evaluation API should be simplified.

Right now, we have multiple similarly named functions such as:

  • evaluate_ppl_of_model_on_dataset
  • evaluate_ppl_of_ref_prefill_model_on_dataset
  • evaluate_ppl_of_prefill_ref_model_on_dataset
  • evaluate_ppl_of_prefill_decode_model_on_dataset

This makes the API hard to follow because the function names encode multiple dimensions at once:

  • metric (ppl)
  • model role (ref / quant)
  • execution path (prefill / prefill_decode)

A cleaner design would be to expose a single public API such as:

evaluate_ppl(
    model,
    dataset,
    *,
    model_role="ref" | "quant",
    execution_mode="prefill" | "prefill_decode",
)

If it's hard to refactor them once, I'll do this later when the codes are merged.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mhs4670go Sorry for a lack of details. We need just evaluate_ppl_of_prefill_decode_model_on_dataset, evaluate_ppl_of_ref_prefill_model_on_dataset/evaluate_ppl_of_prefill_ref_model_on_dataset will be removed. Thank you for your review.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants