[quantization][draft] Prefill-decode logic#570
[quantization][draft] Prefill-decode logic#570stamalakhov wants to merge 1 commit intoSamsung:mainfrom
Conversation
23e4c61 to
192349a
Compare
e20b7aa to
317079e
Compare
724ae23 to
a50716b
Compare
7481bff to
6865880
Compare
94f12ff to
f79f531
Compare
This PR outputs kv-tuples in case `use_cache` was set. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
f79f531 to
ab832f9
Compare
| # print(f"│ FP32 : {prefill_decode_ppl:8.2f}") | ||
| # print("└───────────────────────────────────────────") | ||
|
|
||
| def run_ptq(self, q_m, calib_inputs): |
There was a problem hiding this comment.
I think the decode calibration data should ideally be generated from the float/reference prefill path, not from the already-quantized prefill model.
If decode calibration inputs are produced by the quantized prefill model, then the KV cache and related decode inputs already contain prefill quantization error. In that case, decode PTQ is calibrated on a distorted distribution rather than on the original float decode distribution.
A cleaner approach would be:
- run prefill calibration batches through the float/reference model,
- collect float decode inputs (especially past_key_values),
- use those inputs for decode PTQ calibration.
That way, decode calibration remains aligned with the original model behavior, and prefill quantization error does not leak into the decode calibration set.
There was a problem hiding this comment.
The reason to prefer float-based prefill for decode calibration is that PTQ fundamentally tries to approximate the float model behavior. If decode calibration inputs are generated from a quantized prefill model, the input distribution is already distorted, and the optimization target effectively changes. This mixes prefill quantization error into decode calibration and makes it harder to isolate and properly optimize decode quantization.
That being said, it would be useful to compare both approaches, so I’m opening this issue to track it.
There was a problem hiding this comment.
Due to #618, this PR should be reworked. A far as i understand, we don't need two torch models. Just single prefill-decode torch model, calibrated with use_cache enabled . So i'll be back wirh reworked PR.
| return inputs | ||
|
|
||
|
|
||
| def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"): |
There was a problem hiding this comment.
IMHO, I think the current evaluation API should be simplified.
Right now, we have multiple similarly named functions such as:
evaluate_ppl_of_model_on_datasetevaluate_ppl_of_ref_prefill_model_on_datasetevaluate_ppl_of_prefill_ref_model_on_datasetevaluate_ppl_of_prefill_decode_model_on_dataset
This makes the API hard to follow because the function names encode multiple dimensions at once:
- metric (
ppl) - model role (
ref/quant) - execution path (
prefill/prefill_decode)
A cleaner design would be to expose a single public API such as:
evaluate_ppl(
model,
dataset,
*,
model_role="ref" | "quant",
execution_mode="prefill" | "prefill_decode",
)If it's hard to refactor them once, I'll do this later when the codes are merged.
There was a problem hiding this comment.
@mhs4670go Sorry for a lack of details. We need just evaluate_ppl_of_prefill_decode_model_on_dataset, evaluate_ppl_of_ref_prefill_model_on_dataset/evaluate_ppl_of_prefill_ref_model_on_dataset will be removed. Thank you for your review.
This PR implements prefill_decode logic for Llama compatible LLM's.
--prefill_decode run on Maykeye/TinyLLama-v0
evaluate prefill_decode pipeline of LLama_1B
Note: right padding performs poorly, please see https://huggingface.co/docs/transformers/llm_tutorial#wrong-padding-side for details (vllm-project/vllm#5236 also describes issue about padding), that is why
leftpadding is used in the draft.Related: #586
TICO-DCO-1.0-Signed-off-by: s.malakhov s.malakhov@partner.samsung.com