Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 221 additions & 61 deletions tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
# -------------------------------------------------------------------------
# Single-pass activation calibration
# -------------------------------------------------------------------------
print("Calibrating PTQ obeservers…")
print("Calibrating PTQ observers…")

# Overwrite weight observers with GPTQ statistics
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
Expand All @@ -190,7 +190,7 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
return q_m


def evaluate(q_m, tokenizer, dataset_test, args):
def evaluate(q_m, tokenizer, dataset_test, args, quantized=False):
# -------------------------------------------------------------------------
# Evaluate perplexity on Wikitext-2
# -------------------------------------------------------------------------
Expand All @@ -200,18 +200,216 @@ def evaluate(q_m, tokenizer, dataset_test, args):
q_m, enc, args.device, max_length=args.max_seq_len, stride=args.max_seq_len
)

ppl_info_str = "int16" if quantized is True else "FP32"
print("\n┌── Wikitext-2 test perplexity ─────────────")
print(f"│ int16 : {ppl_uint8:8.2f}")
print(f"│ {ppl_info_str} : {ppl_uint8:8.2f}")
print("└───────────────────────────────────────────")

if args.eval_tasks is not None:
results = evaluate_llm_on_tasks(
q_m, tokenizer, args.eval_tasks, max_length=args.max_seq_len
)
print("Quantized RESULTS ARE:")
acc_info_str = "Quantized" if quantized is True else "Original"
print(f"{acc_info_str} RESULTS ARE:")
print(make_table(results))


class QModelProcessor:
"""Processor for quantization model handling GPTQ and PTQ steps.

Attributes:
model: The underlying FP model.
tokenizer: Tokenizer associated with the model.
device: Torch device derived from args.
args: Parsed command-line arguments.
"""

def __init__(self, model, tokenizer, args):
self.model = model
self.tokenizer = tokenizer
self.device = torch.device(args.device)
self.args = args

def get_tokenized_inputs(self, dataset, shuffle=True):
"""Generate tokenized inputs for calibration.

Concatenates all text from the provided dataset, tokenizes it using the stored
tokenizer, and extracts ``nsamples`` slices of length ``max_position_embeddings``.
If ``shuffle`` is True, slices are selected randomly; otherwise they are taken
sequentially with a fixed stride.

Args:
dataset: The dataset object containing a ``text`` field.
shuffle (bool): Whether to randomly sample slices.

Returns:
List[torch.Tensor]: A list of tokenized input tensors.
"""
text = " ".join(dataset["text"])
ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device)
tokenized_inputs = []
nsamples = self.args.nsamples_for_qcalibration
seqlen = self.model.config.max_position_embeddings
if shuffle is True:
random.seed(self.args.seed)
else:
stride = (ids.shape[1] - seqlen - 1) // seqlen
for index in range(nsamples):
if shuffle is True:
i = random.randint(0, ids.shape[1] - seqlen - 1)
else:
i = index * stride
j = i + seqlen
inp = ids[:, i:j]
tokenized_inputs.append(inp.cpu())
return tokenized_inputs

def run_gptq(self, calib_inputs):
"""Run GPTQ weight‑only quantization on the model.

Args:
calib_inputs (List[torch.Tensor]): Calibration inputs used for
GPTQ weight quantization.

Returns:
torch.nn.Module: The quantized model with INT‑weight tensors.
"""
print("Applying GPTQ …")

sens = None
if self.args.gptq_mse is not None and self.args.gptq_mse == "smse":
if self.args.sensitivity_path is not None:
sens = torch.load(self.args.sensitivity_path)
else:
calibrator = SensitivityCalibrator(self.model, calib_inputs)
sens = calibrator.compute_sensitivity_info()

gptq_config = GPTQConfig(
weight_bits=self.args.linear_weight_bits,
perchannel=True,
mse=self.args.gptq_mse,
sensitivity=sens,
)
q_m = prepare(self.model, gptq_config, inplace=True)
with torch.no_grad():
for inp in calib_inputs:
q_m(inp.to(self.device))

q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
return q_m

def evaluate_original(self, dataset_test):
"""Evaluate the original (FP) model on the test dataset.

Args:
dataset_test: The test split of the dataset.

Returns:
Any: The result of the ``evaluate`` helper (typically prints metrics).
"""
return evaluate(
self.model,
self.tokenizer,
dataset_test,
self.args,
quantized=False,
)

def evaluate_quantized(self, dataset_test):
"""Placeholder for evaluating the quantized model.

This method should be overridden in subclasses to implement evaluation of the
quantized model.

Args:
dataset_test: The test dataset.

Raises:
NotImplementedError: Indicates the method is not implemented in the base class.
"""
raise NotImplementedError

def save_quantized(self, model, calib_inputs):
"""Placeholder for saving the quantized model.

Subclasses should implement saving logic for the quantized model or its layers.

Args:
model: The quantized model instance.
calib_inputs: Calibration inputs used for possible model saving.

Raises:
NotImplementedError: Indicates the method is not implemented in the base class.
"""
raise NotImplementedError


class PrefillQModelProcessor(QModelProcessor):
"""
PrefillQModelProcessor extends QModelProcessor for models that operate in
a simple prefill mode without KV cache.

It provides implementations for PTQ quantization, evaluation of the
quantized model, and optional saving of layers or the whole model.
"""

def __init__(self, model, tokenizer, args):
super().__init__(model, tokenizer, args)

def run_ptq(self, q_m, calib_inputs):
"""Run PTQ activation quantization on the model.

Args:
q_m: The model (potentially already weight‑quantized) to wrap with PTQ.
calib_inputs: Calibration inputs for activation observers.

Returns:
torch.nn.Module: The PTQ‑wrapped and calibrated model.
"""
return quantize_using_PTQ(q_m, calib_inputs, self.args)

def evaluate_quantized(self, model, dataset_test):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This api spec doesn't match with QModelProcessor. Is it intentional?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahhh. Sorry. I'll fix it.

"""Evaluate the quantized model on the test dataset.

Args:
model: The quantized model instance.
dataset_test: The test split of the dataset.

Returns:
None. The function prints evaluation metrics.
"""
evaluate(model, self.tokenizer, dataset_test, self.args, quantized=True)

def save_quantized(self, model, calib_inputs):
"""Save the quantized model and its components.

This method respects the command‑line arguments ``save_layers_to_folder``
and ``save_circle_to_folder``. If a folder is provided, the corresponding
helper is invoked to persist either individual layers and/or the whole model.

Args:
model: The quantized model instance.
calib_inputs: Calibration inputs required when saving the full model
(used to construct a dummy batch for conversion to Circle format).
"""
if self.args.save_layers_to_folder is not None:
save_layers_to(
model, self.args.max_seq_len, self.args.save_layers_to_folder
)

if self.args.save_circle_to_folder is not None:
calib_inputs = list(
torch.stack(calib_inputs).reshape(-1, 1, self.args.max_seq_len)
)
save_model_to(model, calib_inputs, self.args.save_circle_to_folder)


def get_qmodel_processor(model, tokenizer, args):
# TODO add more processors

return PrefillQModelProcessor(model, tokenizer, args)


def main():
parser = argparse.ArgumentParser(
description="GPTQ+PTQ pipeline (weight-only + activation)"
Expand Down Expand Up @@ -388,83 +586,45 @@ def main():
DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT, cache_dir=args.cache_dir
)

print("\nCalculating original perplexities …")
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
ppl_fp32 = perplexity(
model, enc, device, max_length=args.max_seq_len, stride=args.max_seq_len
)

print("\n┌── Wikitext-2 test perplexity ─────────────")
print(f"│ FP32 : {ppl_fp32:8.2f}")
print("└───────────────────────────────────────────")
# -------------------------------------------------------------------------
# Create a processor for the model
# -------------------------------------------------------------------------
qmodel_processor = get_qmodel_processor(model, tokenizer, args)

if args.eval_tasks is not None:
results = evaluate_llm_on_tasks(
model, tokenizer, args.eval_tasks, max_length=args.max_seq_len
)
print("Original RESULTS ARE:")
print(make_table(results))
# -------------------------------------------------------------------------
# Compute original metrics to estimate metrics degradation
# -------------------------------------------------------------------------
qmodel_processor.evaluate_original(dataset_test)

# -------------------------------------------------------------------------
# Prepare calibration dataset
# -------------------------------------------------------------------------
dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT)
calib_txt = " ".join(dataset_train["text"])
train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device)
calib_inputs = []
nsamples = args.nsamples_for_qcalibration
seqlen = model.config.max_position_embeddings
random.seed(args.seed)
for _ in range(nsamples):
i = random.randint(0, train_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = train_ids[:, i:j]
calib_inputs.append(inp.cpu())
calib_inputs = qmodel_processor.get_tokenized_inputs(dataset_train, shuffle=True)

# -------------------------------------------------------------------------
# Run GPTQ (weight-only) pass
# -------------------------------------------------------------------------
if not args.no_GPTQ:
print("Applying GPTQ …")

sens = None
if args.gptq_mse is not None and args.gptq_mse == "smse":
if args.sensitivity_path is not None:
sens = torch.load(args.sensitivity_path)
else:
calibrator = SensitivityCalibrator(model, calib_inputs)
sens = calibrator.compute_sensitivity_info()

gptq_config = GPTQConfig(
weight_bits=args.linear_weight_bits,
perchannel=True,
mse=args.gptq_mse,
sensitivity=sens,
)
q_m = prepare(model, gptq_config, inplace=True)
with torch.no_grad():
for inp in calib_inputs:
q_m(inp.to(args.device))

q_m = convert(q_m, inplace=True) # materialize INT-weight tensors
q_m = qmodel_processor.run_gptq(calib_inputs)
else:
q_m = model

# -------------------------------------------------------------------------
# Wrap every layer with PTQWrapper
# -------------------------------------------------------------------------
if not args.no_PTQ:
q_m = quantize_using_PTQ(q_m, calib_inputs, args)
q_m = qmodel_processor.run_ptq(q_m, calib_inputs)

# after PTQ quantizer only fixed-length input sequences are valid
evaluate(q_m, tokenizer, dataset_test, args)

if args.save_layers_to_folder is not None:
save_layers_to(q_m, args.max_seq_len, args.save_layers_to_folder)
# -------------------------------------------------------------------------
# Compute quantized model metrics to estimate metrics degradation
# -------------------------------------------------------------------------
qmodel_processor.evaluate_quantized(q_m, dataset_test)

if args.save_circle_to_folder is not None:
calib_inputs = list(torch.stack(calib_inputs).reshape(-1, 1, args.max_seq_len))
save_model_to(q_m, calib_inputs, args.save_circle_to_folder)
# -------------------------------------------------------------------------
# Save layers and model
# -------------------------------------------------------------------------
qmodel_processor.save_quantized(q_m, calib_inputs)


if __name__ == "__main__":
Expand Down
Loading