From 69537f2a4f3286f827088e40cfbf503d70790f21 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Mon, 6 Apr 2026 11:34:35 +0300 Subject: [PATCH] [quantization] Move prefill logic This PR moves prefill logic to `PrefillQModelProcessor` to make the script be ready for prefill-decode pipeline. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../quantize_full_qmodel_with_gptq.py | 282 ++++++++++++++---- 1 file changed, 221 insertions(+), 61 deletions(-) diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 61986d32..4a659bda 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -163,7 +163,7 @@ def quantize_using_PTQ(q_m, calib_inputs, args): # ------------------------------------------------------------------------- # Single-pass activation calibration # ------------------------------------------------------------------------- - print("Calibrating PTQ obeservers…") + print("Calibrating PTQ observers…") # Overwrite weight observers with GPTQ statistics if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict): @@ -190,7 +190,7 @@ def quantize_using_PTQ(q_m, calib_inputs, args): return q_m -def evaluate(q_m, tokenizer, dataset_test, args): +def evaluate(q_m, tokenizer, dataset_test, args, quantized=False): # ------------------------------------------------------------------------- # Evaluate perplexity on Wikitext-2 # ------------------------------------------------------------------------- @@ -200,18 +200,216 @@ def evaluate(q_m, tokenizer, dataset_test, args): q_m, enc, args.device, max_length=args.max_seq_len, stride=args.max_seq_len ) + ppl_info_str = "int16" if quantized is True else "FP32" print("\n┌── Wikitext-2 test perplexity ─────────────") - print(f"│ int16 : {ppl_uint8:8.2f}") + print(f"│ {ppl_info_str} : {ppl_uint8:8.2f}") print("└───────────────────────────────────────────") if args.eval_tasks is not None: results = evaluate_llm_on_tasks( q_m, tokenizer, args.eval_tasks, max_length=args.max_seq_len ) - print("Quantized RESULTS ARE:") + acc_info_str = "Quantized" if quantized is True else "Original" + print(f"{acc_info_str} RESULTS ARE:") print(make_table(results)) +class QModelProcessor: + """Processor for quantization model handling GPTQ and PTQ steps. + + Attributes: + model: The underlying FP model. + tokenizer: Tokenizer associated with the model. + device: Torch device derived from args. + args: Parsed command-line arguments. + """ + + def __init__(self, model, tokenizer, args): + self.model = model + self.tokenizer = tokenizer + self.device = torch.device(args.device) + self.args = args + + def get_tokenized_inputs(self, dataset, shuffle=True): + """Generate tokenized inputs for calibration. + + Concatenates all text from the provided dataset, tokenizes it using the stored + tokenizer, and extracts ``nsamples`` slices of length ``max_position_embeddings``. + If ``shuffle`` is True, slices are selected randomly; otherwise they are taken + sequentially with a fixed stride. + + Args: + dataset: The dataset object containing a ``text`` field. + shuffle (bool): Whether to randomly sample slices. + + Returns: + List[torch.Tensor]: A list of tokenized input tensors. + """ + text = " ".join(dataset["text"]) + ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device) + tokenized_inputs = [] + nsamples = self.args.nsamples_for_qcalibration + seqlen = self.model.config.max_position_embeddings + if shuffle is True: + random.seed(self.args.seed) + else: + stride = (ids.shape[1] - seqlen - 1) // seqlen + for index in range(nsamples): + if shuffle is True: + i = random.randint(0, ids.shape[1] - seqlen - 1) + else: + i = index * stride + j = i + seqlen + inp = ids[:, i:j] + tokenized_inputs.append(inp.cpu()) + return tokenized_inputs + + def run_gptq(self, calib_inputs): + """Run GPTQ weight‑only quantization on the model. + + Args: + calib_inputs (List[torch.Tensor]): Calibration inputs used for + GPTQ weight quantization. + + Returns: + torch.nn.Module: The quantized model with INT‑weight tensors. + """ + print("Applying GPTQ …") + + sens = None + if self.args.gptq_mse is not None and self.args.gptq_mse == "smse": + if self.args.sensitivity_path is not None: + sens = torch.load(self.args.sensitivity_path) + else: + calibrator = SensitivityCalibrator(self.model, calib_inputs) + sens = calibrator.compute_sensitivity_info() + + gptq_config = GPTQConfig( + weight_bits=self.args.linear_weight_bits, + perchannel=True, + mse=self.args.gptq_mse, + sensitivity=sens, + ) + q_m = prepare(self.model, gptq_config, inplace=True) + with torch.no_grad(): + for inp in calib_inputs: + q_m(inp.to(self.device)) + + q_m = convert(q_m, inplace=True) # materialize INT-weight tensors + return q_m + + def evaluate_original(self, dataset_test): + """Evaluate the original (FP) model on the test dataset. + + Args: + dataset_test: The test split of the dataset. + + Returns: + Any: The result of the ``evaluate`` helper (typically prints metrics). + """ + return evaluate( + self.model, + self.tokenizer, + dataset_test, + self.args, + quantized=False, + ) + + def evaluate_quantized(self, dataset_test): + """Placeholder for evaluating the quantized model. + + This method should be overridden in subclasses to implement evaluation of the + quantized model. + + Args: + dataset_test: The test dataset. + + Raises: + NotImplementedError: Indicates the method is not implemented in the base class. + """ + raise NotImplementedError + + def save_quantized(self, model, calib_inputs): + """Placeholder for saving the quantized model. + + Subclasses should implement saving logic for the quantized model or its layers. + + Args: + model: The quantized model instance. + calib_inputs: Calibration inputs used for possible model saving. + + Raises: + NotImplementedError: Indicates the method is not implemented in the base class. + """ + raise NotImplementedError + + +class PrefillQModelProcessor(QModelProcessor): + """ + PrefillQModelProcessor extends QModelProcessor for models that operate in + a simple prefill mode without KV cache. + + It provides implementations for PTQ quantization, evaluation of the + quantized model, and optional saving of layers or the whole model. + """ + + def __init__(self, model, tokenizer, args): + super().__init__(model, tokenizer, args) + + def run_ptq(self, q_m, calib_inputs): + """Run PTQ activation quantization on the model. + + Args: + q_m: The model (potentially already weight‑quantized) to wrap with PTQ. + calib_inputs: Calibration inputs for activation observers. + + Returns: + torch.nn.Module: The PTQ‑wrapped and calibrated model. + """ + return quantize_using_PTQ(q_m, calib_inputs, self.args) + + def evaluate_quantized(self, model, dataset_test): + """Evaluate the quantized model on the test dataset. + + Args: + model: The quantized model instance. + dataset_test: The test split of the dataset. + + Returns: + None. The function prints evaluation metrics. + """ + evaluate(model, self.tokenizer, dataset_test, self.args, quantized=True) + + def save_quantized(self, model, calib_inputs): + """Save the quantized model and its components. + + This method respects the command‑line arguments ``save_layers_to_folder`` + and ``save_circle_to_folder``. If a folder is provided, the corresponding + helper is invoked to persist either individual layers and/or the whole model. + + Args: + model: The quantized model instance. + calib_inputs: Calibration inputs required when saving the full model + (used to construct a dummy batch for conversion to Circle format). + """ + if self.args.save_layers_to_folder is not None: + save_layers_to( + model, self.args.max_seq_len, self.args.save_layers_to_folder + ) + + if self.args.save_circle_to_folder is not None: + calib_inputs = list( + torch.stack(calib_inputs).reshape(-1, 1, self.args.max_seq_len) + ) + save_model_to(model, calib_inputs, self.args.save_circle_to_folder) + + +def get_qmodel_processor(model, tokenizer, args): + # TODO add more processors + + return PrefillQModelProcessor(model, tokenizer, args) + + def main(): parser = argparse.ArgumentParser( description="GPTQ+PTQ pipeline (weight-only + activation)" @@ -388,65 +586,27 @@ def main(): DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT, cache_dir=args.cache_dir ) - print("\nCalculating original perplexities …") - enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") - ppl_fp32 = perplexity( - model, enc, device, max_length=args.max_seq_len, stride=args.max_seq_len - ) - - print("\n┌── Wikitext-2 test perplexity ─────────────") - print(f"│ FP32 : {ppl_fp32:8.2f}") - print("└───────────────────────────────────────────") + # ------------------------------------------------------------------------- + # Create a processor for the model + # ------------------------------------------------------------------------- + qmodel_processor = get_qmodel_processor(model, tokenizer, args) - if args.eval_tasks is not None: - results = evaluate_llm_on_tasks( - model, tokenizer, args.eval_tasks, max_length=args.max_seq_len - ) - print("Original RESULTS ARE:") - print(make_table(results)) + # ------------------------------------------------------------------------- + # Compute original metrics to estimate metrics degradation + # ------------------------------------------------------------------------- + qmodel_processor.evaluate_original(dataset_test) # ------------------------------------------------------------------------- # Prepare calibration dataset # ------------------------------------------------------------------------- dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT) - calib_txt = " ".join(dataset_train["text"]) - train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device) - calib_inputs = [] - nsamples = args.nsamples_for_qcalibration - seqlen = model.config.max_position_embeddings - random.seed(args.seed) - for _ in range(nsamples): - i = random.randint(0, train_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = train_ids[:, i:j] - calib_inputs.append(inp.cpu()) + calib_inputs = qmodel_processor.get_tokenized_inputs(dataset_train, shuffle=True) # ------------------------------------------------------------------------- # Run GPTQ (weight-only) pass # ------------------------------------------------------------------------- if not args.no_GPTQ: - print("Applying GPTQ …") - - sens = None - if args.gptq_mse is not None and args.gptq_mse == "smse": - if args.sensitivity_path is not None: - sens = torch.load(args.sensitivity_path) - else: - calibrator = SensitivityCalibrator(model, calib_inputs) - sens = calibrator.compute_sensitivity_info() - - gptq_config = GPTQConfig( - weight_bits=args.linear_weight_bits, - perchannel=True, - mse=args.gptq_mse, - sensitivity=sens, - ) - q_m = prepare(model, gptq_config, inplace=True) - with torch.no_grad(): - for inp in calib_inputs: - q_m(inp.to(args.device)) - - q_m = convert(q_m, inplace=True) # materialize INT-weight tensors + q_m = qmodel_processor.run_gptq(calib_inputs) else: q_m = model @@ -454,17 +614,17 @@ def main(): # Wrap every layer with PTQWrapper # ------------------------------------------------------------------------- if not args.no_PTQ: - q_m = quantize_using_PTQ(q_m, calib_inputs, args) + q_m = qmodel_processor.run_ptq(q_m, calib_inputs) - # after PTQ quantizer only fixed-length input sequences are valid - evaluate(q_m, tokenizer, dataset_test, args) - - if args.save_layers_to_folder is not None: - save_layers_to(q_m, args.max_seq_len, args.save_layers_to_folder) + # ------------------------------------------------------------------------- + # Compute quantized model metrics to estimate metrics degradation + # ------------------------------------------------------------------------- + qmodel_processor.evaluate_quantized(q_m, dataset_test) - if args.save_circle_to_folder is not None: - calib_inputs = list(torch.stack(calib_inputs).reshape(-1, 1, args.max_seq_len)) - save_model_to(q_m, calib_inputs, args.save_circle_to_folder) + # ------------------------------------------------------------------------- + # Save layers and model + # ------------------------------------------------------------------------- + qmodel_processor.save_quantized(q_m, calib_inputs) if __name__ == "__main__":