From 4f160b961251cd4afade2d0faa86808a6ef4cc04 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Thu, 16 Apr 2026 08:06:27 +0300 Subject: [PATCH 1/2] [quantization] Ouput kv-tuples This PR outputs kv-tuples in case `use_cache` was set. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../wrapq/wrappers/llama/test_quant_model.py | 65 +++ .../examples/evaluate_fake_quantized_model.py | 431 ++++++++++++++++++ .../examples/static_llama_layer_runtime.py | 3 +- .../wrapq/wrappers/llama/quant_model.py | 84 +++- .../llama/quant_model_for_causal_lm.py | 2 + 5 files changed, 560 insertions(+), 25 deletions(-) create mode 100644 tico/quantization/wrapq/examples/evaluate_fake_quantized_model.py diff --git a/test/quantization/wrapq/wrappers/llama/test_quant_model.py b/test/quantization/wrapq/wrappers/llama/test_quant_model.py index 06437878..f1068011 100644 --- a/test/quantization/wrapq/wrappers/llama/test_quant_model.py +++ b/test/quantization/wrapq/wrappers/llama/test_quant_model.py @@ -104,3 +104,68 @@ def test_forward_diff(self): self.assertGreater(diff, 0.0) self.assertLess(diff, 0.4) self.assertEqual(fp_out.shape, q_out.shape) + + +@unittest.skipUnless(has_transformers_for("llama"), skip_msg) +class TestQuantLlamaModelWithCache(unittest.TestCase): + seq_len: int + vocab_size: int + hid_layers: int + fp_model: torch.nn.Module + + @classmethod + def setUpClass(cls): + torch.manual_seed(0) + + from transformers.models.llama.configuration_llama import LlamaConfig + from transformers.models.llama.modeling_llama import LlamaModel + + cls.seq_len = 16 + cls.vocab_size = 10000 + cls.hid_layers = 3 + cfg = LlamaConfig( + hidden_size=8, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=4, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + num_hidden_layers=cls.hid_layers, + max_position_embeddings=cls.seq_len, + use_cache=True, + return_dict=False, + ) + cls.fp_model = LlamaModel(cfg) + + def test_model_output(self): + qmodel = QuantLlamaModel( + self.fp_model, qcfg=PTQConfig(wrapper_variant="prefill") + ) + self.assertIs(qmodel._mode, Mode.NO_QUANT) + + qmodel.enable_calibration() + self.assertIs(qmodel._mode, Mode.CALIB) + + x = torch.randint( + 0, + self.vocab_size, + ( + 1, + self.seq_len, + ), + ) + output = qmodel(x) + + assert len(output) == 2 # last_hidden_states + past_key_values + past_key_values = output[1] + assert len(past_key_values) == self.hid_layers + for index in range(self.hid_layers): + past_key_value = past_key_values[index] + assert isinstance(past_key_value, tuple) + + past_key = past_key_value[0] + assert past_key.shape[-2] == self.seq_len + + past_value = past_key_value[1] + assert past_value.shape[-2] == self.seq_len diff --git a/tico/quantization/wrapq/examples/evaluate_fake_quantized_model.py b/tico/quantization/wrapq/examples/evaluate_fake_quantized_model.py new file mode 100644 index 00000000..51f92a6a --- /dev/null +++ b/tico/quantization/wrapq/examples/evaluate_fake_quantized_model.py @@ -0,0 +1,431 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import torch + +from lm_eval.utils import make_table + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from tico.quantization.evaluation.script.llm_tasks_eval import evaluate_llm_on_tasks +from tico.quantization.wrapq.examples.static_llama_layer_runtime import ( + _build_decode_attention_mask, + _build_rope_templates_from_config, + _slice_rope, +) + +from tico.quantization.wrapq.examples.quantize_full_qmodel_with_gptq import pad_input + +DTYPE_MAP = { + "float32": torch.float32, + # TODO Support more dtypes + # "bfloat16": torch.bfloat16, + # "float16": torch.float16, +} + +#import os +#os.environ["CUDA_VISIBLE_DEVICES"]= "0" + +@torch.no_grad() +class GreedyDecoder: + def __init__(self, model, tokenizer, device): + self.model = model + self.tokenizer = tokenizer + self.device = device + + def generate(self, prompt, max_length): + inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + + eos_token_id = self.tokenizer.eos_token_id + + with torch.no_grad(): + while inputs.shape[-1] < max_length: + logits = self.model(inputs).logits + next_token = torch.tensor([[torch.argmax(logits[..., -1, :])]], device=inputs.device) + if eos_token_id is not None and torch.all(next_token == eos_token_id): + break + inputs = torch.cat([inputs, next_token], dim=1) + + return inputs + +def pad_input_to_left(input, pad_token, max_seq_len): + """Pad a tensor to a maximum sequence length using the specified pad token.""" + pads = torch.full( + (input.shape[0], max_seq_len - input.shape[1]), + fill_value=pad_token, + device=input.device, + ) + return torch.cat((pads, input), dim=1) + +class PrefillDecodeGreedyDecoder: + def __init__(self, model, orig_model, tokenizer, max_seq_len, config, device): + prefill_model, decode_model = model + self.prefill_model = prefill_model + self.decode_model = decode_model + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + self.device = device + self.orig_model = orig_model + self.rope_cos, self.rope_sin = _build_rope_templates_from_config( + config, max_seq=max_seq_len, device=device, dtype=torch.float32 + ) + + + def generate_left_padding(self, prompt, max_new_tokens): + inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + assert isinstance(inputs, torch.Tensor) + + eos_token_id = self.tokenizer.eos_token_id + + generated = inputs.clone() + cur_seq_len = inputs.shape[-1] + prefill_max_seq_len = self.max_seq_len - 1 + prefill_input = pad_input_to_left(inputs, self.tokenizer.pad_token_id, prefill_max_seq_len) + attn_mask = self.build_prefill_padded_attention_mask(cur_seq_len, prefill_max_seq_len, self.device, right_padding=False) + position_embeddings = self.build_prefill_position_embeddings(cur_seq_len, prefill_max_seq_len, self.device, right_padding=False) + + with torch.no_grad(): + outputs = self.prefill_model(prefill_input, attention_mask = attn_mask, position_embeddings=position_embeddings, use_cache = True) + + # orig_inputs = self.tokenizer(prompt, return_tensors="pt", max_length=prefill_max_seq_len, padding='max_length', padding_side="left").to(self.device) + # orig_attn_mask = orig_inputs["attention_mask"] + # orig_position_ids = orig_attn_mask.long().cumsum(-1) - 1 + # orig_position_ids.masked_fill_(orig_attn_mask == 0, 0) + # orig_inputs["position_ids"] = orig_position_ids + # #orig_outs = self.orig_model.to(self.device)(**orig_inputs) + + logits = outputs.logits + past_key_values = outputs.past_key_values + + self.prefill_model = self.prefill_model.cpu() + self.decode_model = self.decode_model.to(self.device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + produced_tokens = 0 + with torch.no_grad(): + while produced_tokens < max_new_tokens: + next_token = torch.tensor([[torch.argmax(logits[..., -1, :])]], device=self.device) + if eos_token_id is not None and torch.all(next_token == eos_token_id): + break + generated = torch.cat([generated, next_token], dim=1) + cur_seq_len += 1 + produced_tokens += 1 + + dec_inputs = self.get_input_for_decode_model(next_token, past_key_values=past_key_values, cur_seq_len = cur_seq_len, right_padding=False) + outputs = self.decode_model(**dec_inputs) + logits = outputs.logits + new_key_values = outputs.past_key_values + # shift past_key_values + for i in range(prefill_max_seq_len - 1): + #cur_seq_idx = prefill_max_seq_len - cur_seq_len + i - 1 + for idx in range(len(new_key_values)): + past_key_values[idx][0][:, :, i : i + 1, :] =\ + past_key_values[idx][0][:, :, i + 1: i + 2, :] + past_key_values[idx][1][:, :, i : i + 1, :] =\ + past_key_values[idx][1][:, :, i + 1 : i + 2, :] + + # update past_key_values + for idx in range(len(new_key_values)): + past_key_values[idx][0][:, :, -1 :, :] = new_key_values[idx][0] + past_key_values[idx][1][:, :, -1 :, :] = new_key_values[idx][1] + + + return generated + + def generate_right_padding(self, prompt, max_new_tokens): + + inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + assert isinstance(inputs, torch.Tensor) + + eos_token_id = self.tokenizer.eos_token_id + + generated = inputs.clone() + cur_seq_len = inputs.shape[-1] + prefill_max_seq_len = self.max_seq_len - 1 + prefill_input = pad_input(inputs, self.tokenizer.pad_token_id, prefill_max_seq_len) + attn_mask = self.build_prefill_padded_attention_mask(cur_seq_len, prefill_max_seq_len, self.device, right_padding=True) + position_embeddings = self.build_prefill_position_embeddings(cur_seq_len, prefill_max_seq_len, self.device, right_padding=True) + + with torch.no_grad(): + outputs = self.prefill_model(prefill_input, attention_mask = attn_mask, position_embeddings=position_embeddings, use_cache = True) + + # orig_inputs = self.tokenizer(prompt, return_tensors="pt", max_length=prefill_max_seq_len, padding='max_length', padding_side="right").to(self.device) + # orig_attn_mask = orig_inputs["attention_mask"] + # orig_position_ids = orig_attn_mask.long().cumsum(-1) - 1 + # orig_position_ids.masked_fill_(orig_attn_mask == 0, 0) + # orig_inputs["position_ids"] = orig_position_ids + # orig_outs = self.orig_model.to(self.device)(**orig_inputs) + + + logits = outputs.logits + past_key_values = outputs.past_key_values + + self.prefill_model = self.prefill_model.cpu() + self.decode_model = self.decode_model.to(self.device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + produced_tokens = 0 + with torch.no_grad(): + while produced_tokens < max_new_tokens: + next_token = torch.tensor([[torch.argmax(logits[..., -1, :])]], device=self.device) + if eos_token_id is not None and torch.all(next_token == eos_token_id): + break + generated = torch.cat([generated, next_token], dim=1) + cur_seq_len += 1 + produced_tokens += 1 + + dec_inputs = self.get_input_for_decode_model(next_token, past_key_values=past_key_values, cur_seq_len = cur_seq_len-1, right_padding=True) + outputs = self.decode_model(**dec_inputs) + logits = outputs.logits + new_key_values = outputs.past_key_values + # shift past_key_values + for i in range(prefill_max_seq_len - 1): + #cur_seq_idx = prefill_max_seq_len - cur_seq_len + i - 1 + for idx in range(len(new_key_values)): + past_key_values[idx][0][:, :, i : i + 1, :] =\ + past_key_values[idx][0][:, :, i + 1: i + 2, :] + past_key_values[idx][1][:, :, i : i + 1, :] =\ + past_key_values[idx][1][:, :, i + 1 : i + 2, :] + + # update past_key_values to be the last added + for idx in range(len(new_key_values)): + past_key_values[idx][0][:, :, -1 :, :] = new_key_values[idx][0] + past_key_values[idx][1][:, :, -1 :, :] = new_key_values[idx][1] + + + + return generated + + def generate(self, prompt, max_new_tokens): + return self.generate_left_padding(prompt, max_new_tokens) + #return self.generate_right_padding(prompt, max_new_tokens) + + def build_prefill_padded_attention_mask(self, cur_seq_len, max_seq_len, device, right_padding = False, mask_value:float = -120.0): + dtype = torch.float32 + mask = torch.full((1, max_seq_len, max_seq_len), mask_value, device=device, dtype=dtype) + + if right_padding: + for i in range(max_seq_len): + for j in range(max_seq_len): + if i >= j and j < cur_seq_len: + mask[..., i, j] = 0 + else: + for i in range(max_seq_len): + for j in range(max_seq_len): + if i >= j and j >= max_seq_len - cur_seq_len: + mask[..., i, j] = 0 + + return mask + + def build_prefill_position_embeddings(self, cur_seq_len, max_seq_len, device, right_padding = False): + dtype = torch.float32 + position_embeddings = self.prefill_model.wrapped.model.wrapped.get_position_embeddings_for(dtype, device) + cos = torch.ones_like(position_embeddings[0][:, : max_seq_len, :]) + sin = torch.zeros_like(position_embeddings[1][:, : max_seq_len, :]) + + sl_cos, sl_sin = position_embeddings + sl_cos = sl_cos[:, : cur_seq_len, :] + sl_sin = sl_sin[:, : cur_seq_len, :] + if right_padding is True: + cos[..., :cur_seq_len, :] = sl_cos + sin[..., :cur_seq_len, :] = sl_sin + else: + # left padding + cos[..., max_seq_len - cur_seq_len : max_seq_len, :] = sl_cos + sin[..., max_seq_len - cur_seq_len : max_seq_len, :] = sl_sin + + return (cos, sin) + + def get_input_for_decode_model(self, next_token, past_key_values, cur_seq_len, right_padding=False): + dtype = torch.float32 + if right_padding: + attention_mask = _build_decode_attention_mask( + batch_size=1, + past_len=cur_seq_len, + max_seq=self.max_seq_len, + device=self.device, + dtype=dtype, + ) + + + position_embeddings = _slice_rope( + self.rope_cos, + self.rope_sin, + position=cur_seq_len, + batch_size=1, + device=self.device, + dtype=dtype, + ) + else: + attention_mask = self.build_prefill_padded_attention_mask(cur_seq_len, self.max_seq_len, self.device) + attention_mask = attention_mask[..., -1, :].unsqueeze(0) + position_embeddings = self.build_prefill_position_embeddings(cur_seq_len, self.max_seq_len, self.device) + (cos, sin) = position_embeddings + cos = cos[..., -1, :].unsqueeze(0) + sin = sin[..., -1, :].unsqueeze(0) + position_embeddings = (cos, sin) + + # fill in input + inputs = {} + inputs["input_ids"] = next_token + inputs["attention_mask"] = attention_mask + inputs["position_embeddings"] = position_embeddings + inputs["past_key_values"] = past_key_values + return inputs + +def main(): + parser = argparse.ArgumentParser( + description="Try a fake-quantized models" + ) + parser.add_argument( + "--model", type=str, required=True, help="HF repo name or local path." + ) + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run on (cuda|cpu|mps).", + ) + parser.add_argument( + "--dtype", + choices=list(DTYPE_MAP.keys()), + default="float32", + help="Model dtype for load.", + ) + parser.add_argument( + "--hf-token", + type=str, + default=None, + help="Optional HF token for gated/private repos.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Enable only if you trust the model repo code.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="cache_dir for using model/datasets loading", + ) + parser.add_argument( + "--fk_model_path", type=str, required=True, help="Path to fake_quantized model" + ) + parser.add_argument( + "--prompt", type=str, default="The capital of France is", help="Prompt to decode" + ) + parser.add_argument( + "--eval_tasks", + type=str, + default=None, + help="tasks to be evaluated using lm_eval, e.g. `winogrande,arc_easy,arc_challenge,openbookqa,mmlu_pro,ifeval,bbh`", + ) + parser.add_argument( + "--max_new_tokens", type=int, default=128, help="Maximum new okens to produce" + ) + + args = parser.parse_args() + print(args) + + # Basic setup + + device = torch.device(args.device) + dtype = DTYPE_MAP[args.dtype] + + print("=== Config ===") + print(f"Model : {args.model}") + print(f"Device : {device.type}") + print(f"DType : {args.dtype}") + print(f"Prompt : {args.prompt}") + print() + + # ------------------------------------------------------------------------- + # 2. Load the FP backbone and tokenizer + # ------------------------------------------------------------------------- + print("Loading FP model …") + tokenizer = AutoTokenizer.from_pretrained( + args.model, + trust_remote_code=args.trust_remote_code, + token=args.hf_token, + cache_dir=args.cache_dir, + ) + model = AutoModelForCausalLM.from_pretrained( + args.model, + dtype=dtype, + trust_remote_code=args.trust_remote_code, + token=args.hf_token, + cache_dir=args.cache_dir, + ).cpu().eval() + + if tokenizer.pad_token is None: + print( + "Warning: tokenizer doesn't have pad_token. Prefill-decoding scheme may fail." + ) + tokenizer.pad_token = tokenizer.eos_token + + fk_model = torch.load(args.fk_model_path, weights_only=False) + + if isinstance(fk_model, tuple): + fk_model = (fk_model[0].eval().cpu(), fk_model[1].eval().cpu()) + config = fk_model[0].wrapped.config + else: + fk_model.eval() + fk_model = fk_model.cpu() + config = fk_model.wrapped.config + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + max_seq_len = config.max_position_embeddings + inputs = tokenizer(args.prompt, return_tensors="pt", max_length=max_seq_len - 1, padding='max_length', padding_side="left").to(device) #just try with right padding below + #inputs = tokenizer(args.prompt, return_tensors="pt", max_length=max_seq_len - 1, padding='max_length', padding_side="right", device=args.device).to(device) + + model.config.use_cache = True + if args.eval_tasks is not None: + results = evaluate_llm_on_tasks( + model, tokenizer, args.eval_tasks, max_length=args.max_seq_len + ) + print("Original RESULTS ARE:") + print(make_table(results)) + + out_ids = model.to(args.device).generate(**inputs, max_length=max_seq_len + args.max_new_tokens, do_sample = False) + output = tokenizer.decode(out_ids.squeeze(), skip_special_tokens=True) + print(f"Original model prompt: {output}") + model = model.cpu() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + fk_model = fk_model.to(args.device) if not isinstance(fk_model, tuple) else (fk_model[0].to(args.device), fk_model[1].cpu()) + if args.eval_tasks is not None: + results = evaluate_llm_on_tasks( + fk_model, tokenizer, args.eval_tasks, max_length=max_seq_len + ) + print("Quantized RESULTS ARE:") + print(make_table(results)) + + fk_decoder = GreedyDecoder(fk_model, tokenizer, args.device) if not isinstance(fk_model, tuple) else PrefillDecodeGreedyDecoder(fk_model, model, tokenizer, max_seq_len, config, args.device) + + out_ids = fk_decoder.generate(args.prompt, max_new_tokens=args.max_new_tokens) + output = tokenizer.decode(out_ids.squeeze(), skip_special_tokens=True) + print(f"Fake quantized model prompt: {output}") + fk_model = fk_model.cpu() if not isinstance(fk_model, tuple) else (fk_model[0].cpu(), fk_model[1].cpu()) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/examples/static_llama_layer_runtime.py b/tico/quantization/wrapq/examples/static_llama_layer_runtime.py index 865388b6..613db166 100644 --- a/tico/quantization/wrapq/examples/static_llama_layer_runtime.py +++ b/tico/quantization/wrapq/examples/static_llama_layer_runtime.py @@ -556,8 +556,9 @@ def main(): model = AutoModelForCausalLM.from_pretrained( args.model, dtype=torch.float32, + cache_dir = "/mnt/storage/transformers_cache" ).to(args.device) - tokenizer = AutoTokenizer.from_pretrained(args.model, legacy=False) + tokenizer = AutoTokenizer.from_pretrained(args.model, legacy=False, cache_dir = "/mnt/storage/transformers_cache") if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model.py b/tico/quantization/wrapq/wrappers/llama/quant_model.py index 7a1696ff..32e0db97 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_model.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_model.py @@ -27,6 +27,7 @@ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase from tico.quantization.wrapq.wrappers.registry import try_register +Q_INF = float(-120) #quantization friendly negative infinity @try_register( "transformers.models.llama.modeling_llama.LlamaModel", @@ -96,7 +97,7 @@ def __init__( # Static causal mask template --------------------------------------- assert isinstance(self.config.max_position_embeddings, int) max_seq = self.config.max_position_embeddings - mask = torch.full((1, 1, max_seq, max_seq), float("-120")) + mask = torch.full((1, 1, max_seq, max_seq), Q_INF) mask.triu_(1) self.register_buffer("causal_mask_template", mask, persistent=False) @@ -143,23 +144,23 @@ def __init__( self.register_buffer("rope_cos_template", cos_t, persistent=False) self.register_buffer("rope_sin_template", sin_t, persistent=False) - def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor: + def _slice_causal(self, seq_len: int, device: torch.device, offset: int = 0) -> torch.Tensor: """Return `[1,1,L,L]` causal mask slice on *device*.""" assert isinstance(self.causal_mask_template, torch.Tensor) - return self.causal_mask_template[..., :seq_len, :seq_len].to(device) + return self.causal_mask_template[..., offset : offset + seq_len, : offset + seq_len].to(device) - def get_attention_mask_for(self, x): + def get_attention_mask_for(self, x, offset: int = 0): L = x.size(1) - attention_mask = self._slice_causal(L, x.device) + attention_mask = self._slice_causal(L, x.device, offset) return attention_mask - def get_position_embeddings_for(self, hidden_states): + def get_position_embeddings_for(self, dtype, device): return ( self.rope_cos_template.to( - dtype=hidden_states.dtype, device=hidden_states.device + dtype=dtype, device=device ), self.rope_sin_template.to( - dtype=hidden_states.dtype, device=hidden_states.device + dtype=dtype, device=device ), ) @@ -175,6 +176,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -202,12 +204,16 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = DynamicCache() - + past_key_values = [] + + present_key_values = [] if cache_position is None: past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 + 0 + if (past_key_values is None or len(past_key_values) == 0) + else past_key_values[0][0].shape[-2] ) + cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -223,23 +229,45 @@ def forward( if self.rotate_embedding is not None: hidden_states = self.rotate_embedding(hidden_states) - # create position_embeddings and causal_mask to be shared across all the decoder layers - causal_mask = self.get_attention_mask_for(hidden_states) - causal_mask = causal_mask.squeeze(0) + + offset = past_key_values[0][0].shape[-2] if past_key_values is not None and len(past_key_values) > 0 else 0 + + if attention_mask is not None and len(attention_mask.shape) >= 3: + causal_mask = attention_mask # set externally + else: + if attention_mask is not None: # assuming it's boolean matrix 0 - False, 1- True (e.g. padding) + # convert it to float, so that True(1) maps to 0, False(0) maps to Q_INF + attention_mask = (torch.ones_like(attention_mask) - attention_mask) * Q_INF + + # create causal_mask to be shared across all the decoder layers + causal_mask = self.get_attention_mask_for(hidden_states, offset) + if attention_mask is not None: + # in case external mask was set just `and` it with causal_mask + causal_mask = torch.max(Q_INF, causal_mask + attention_mask) + causal_mask = causal_mask.squeeze(0) causal_mask = self._fq(causal_mask, self.obs_causal_mask) - position_embeddings = self.get_position_embeddings_for(hidden_states) - cos, sin = position_embeddings - position_embeddings = ( - self._fq(cos[:, : hidden_states.size(1), :], self.obs_cos), - self._fq(sin[:, : hidden_states.size(1), :], self.obs_sin), - ) - + if position_embeddings is None: + position_embeddings = self.get_position_embeddings_for(hidden_states.dtype, hidden_states.device) + cos, sin = position_embeddings + + position_embeddings = ( + self._fq(cos[:, offset : offset + hidden_states.size(1), :], self.obs_cos), + self._fq(sin[:, offset : offset + hidden_states.size(1), :], self.obs_sin), + ) + else: + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos, self.obs_cos), + self._fq(sin, self.obs_sin), + ) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for idx, decoder_layer in enumerate( + self.layers[: self.config.num_hidden_layers] + ): if output_hidden_states: all_hidden_states += (hidden_states,) # type: ignore[operator] @@ -247,7 +275,11 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=( + past_key_values[idx] + if past_key_values is not None and idx < len(past_key_values) + else None + ), output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -257,6 +289,10 @@ def forward( if decoder_layer.wrapped.return_type == "tuple": hidden_states = layer_outputs[0] + elif use_cache: + hidden_states = layer_outputs[0] + assert isinstance(layer_outputs[1], tuple) + present_key_values.append(layer_outputs[1]) else: hidden_states = layer_outputs @@ -271,7 +307,7 @@ def forward( output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, + past_key_values=present_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py b/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py index 1553e42b..4389777c 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py @@ -93,6 +93,7 @@ def forward( use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, + position_embeddings: Optional[torch.Tensor] = None, **kwargs, ) -> CausalLMOutputWithPast: @@ -112,6 +113,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, ) From fbd5ad9f8459d7c002b6c11409bc6b0a2f70092e Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Thu, 16 Apr 2026 17:07:28 +0300 Subject: [PATCH 2/2] Rebasing. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../examples/evaluate_fake_quantized_model.py | 16 +- .../wrapq/examples/evaluate_fk_llama_model.py | 211 ++++- .../quantize_full_qmodel_with_gptq.py | 767 +++++++++++++++--- .../wrappers/llama/quant_decoder_layer.py | 2 +- 4 files changed, 884 insertions(+), 112 deletions(-) diff --git a/tico/quantization/wrapq/examples/evaluate_fake_quantized_model.py b/tico/quantization/wrapq/examples/evaluate_fake_quantized_model.py index 51f92a6a..44956f7f 100644 --- a/tico/quantization/wrapq/examples/evaluate_fake_quantized_model.py +++ b/tico/quantization/wrapq/examples/evaluate_fake_quantized_model.py @@ -26,8 +26,6 @@ _slice_rope, ) -from tico.quantization.wrapq.examples.quantize_full_qmodel_with_gptq import pad_input - DTYPE_MAP = { "float32": torch.float32, # TODO Support more dtypes @@ -38,6 +36,20 @@ #import os #os.environ["CUDA_VISIBLE_DEVICES"]= "0" +def pad_input(input, pad_token, max_seq_len, right: bool = True): + """Pad a tensor to a maximum sequence length using the specified pad token.""" + pads = torch.full( + (input.shape[0], max_seq_len - input.shape[1]), + fill_value=pad_token, + device=input.device, + ) + if right is True: + res = torch.cat((input, pads), dim=1) + else: + res = torch.cat((pads, input), dim=1) + + return res + @torch.no_grad() class GreedyDecoder: def __init__(self, model, tokenizer, device): diff --git a/tico/quantization/wrapq/examples/evaluate_fk_llama_model.py b/tico/quantization/wrapq/examples/evaluate_fk_llama_model.py index 082877a9..bc8e5c13 100644 --- a/tico/quantization/wrapq/examples/evaluate_fk_llama_model.py +++ b/tico/quantization/wrapq/examples/evaluate_fk_llama_model.py @@ -21,6 +21,13 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from tico.quantization.evaluation.script.llm_tasks_eval import evaluate_llm_on_tasks +from tico.quantization.wrapq.examples.static_llama_layer_runtime import ( + _build_decode_attention_mask, + _build_rope_templates_from_config, + _slice_rope, +) + +from tico.quantization.wrapq.examples.quantize_full_qmodel_with_gptq import pad_input, PrefillDecodeUtils, left_pad DTYPE_MAP = { "float32": torch.float32, @@ -29,6 +36,173 @@ # "float16": torch.float16, } +@torch.no_grad() +class GreedyDecoder: + def __init__(self, model, tokenizer, device): + self.model = model + self.tokenizer = tokenizer + self.device = device + + def generate(self, prompt, max_length): + inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + + eos_token_id = self.tokenizer.eos_token_id + + with torch.no_grad(): + while inputs.shape[-1] < max_length: + logits = self.model(inputs).logits + next_token = torch.tensor([[torch.argmax(logits[..., -1, :])]], device=inputs.device) + if eos_token_id is not None and torch.all(next_token == eos_token_id): + break + inputs = torch.cat([inputs, next_token], dim=1) + + return inputs + +class PrefillDecodeGreedyDecoder: + def __init__(self, model, orig_model, tokenizer, max_seq_len, config, device): + self.prefill_model = model + self.decode_model = model + self.tokenizer = tokenizer + self.max_seq_len = max_seq_len + self.device = device + self.helper = PrefillDecodeUtils(max_seq_len, config, self.device) + self.orig_model = orig_model + self.pos_embeds = _build_rope_templates_from_config( + config, max_seq=max_seq_len, device=device, dtype=torch.float32 + ) + + + def generate_left_padding(self, prompt, max_new_tokens): + inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + assert isinstance(inputs, torch.Tensor) + + eos_token_id = self.tokenizer.eos_token_id + + generated = inputs.clone() + cur_seq_len = inputs.shape[-1] + prefill_max_seq_len = self.max_seq_len - 1 + prefill_input = pad_input(inputs, self.tokenizer.pad_token_id, prefill_max_seq_len, right = False) + attn_mask = self.helper.build_attention_mask_for_padded_input(cur_seq_len, right_padding=False) + position_embeddings = self.helper.build_position_embeddings_for_padded_input(self.pos_embeds, cur_seq_len, right_padding=False) + + with torch.no_grad(): + outputs = self.prefill_model(prefill_input, attention_mask = attn_mask, position_embeddings=position_embeddings, use_cache = True) + + # orig_inputs = self.tokenizer(prompt, return_tensors="pt", max_length=prefill_max_seq_len, padding='max_length', padding_side="left").to(self.device) + # orig_attn_mask = orig_inputs["attention_mask"] + # orig_position_ids = orig_attn_mask.long().cumsum(-1) - 1 + # orig_position_ids.masked_fill_(orig_attn_mask == 0, 0) + # orig_inputs["position_ids"] = orig_position_ids + # #orig_outs = self.orig_model.to(self.device)(**orig_inputs) + + logits = outputs.logits + past_key_values = outputs.past_key_values + + self.prefill_model = self.prefill_model.cpu() + self.decode_model = self.decode_model.to(self.device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + produced_tokens = 0 + with torch.no_grad(): + while produced_tokens < max_new_tokens: + next_token = torch.tensor([[torch.argmax(logits[..., -1, :])]], device=self.device) + if eos_token_id is not None and torch.all(next_token == eos_token_id): + break + generated = torch.cat([generated, next_token], dim=1) + cur_seq_len += 1 + produced_tokens += 1 + + dec_inputs = self.helper.get_input_for_decode_model(next_token, past_key_values=past_key_values, cur_seq_len = cur_seq_len, right_padding=False) + outputs = self.decode_model(**dec_inputs) + logits = outputs.logits + new_key_values = outputs.past_key_values + # shift past_key_values + for i in range(prefill_max_seq_len - 1): + #cur_seq_idx = prefill_max_seq_len - cur_seq_len + i - 1 + for idx in range(len(new_key_values)): + past_key_values[idx][0][:, :, i : i + 1, :] =\ + past_key_values[idx][0][:, :, i + 1: i + 2, :] + past_key_values[idx][1][:, :, i : i + 1, :] =\ + past_key_values[idx][1][:, :, i + 1 : i + 2, :] + + # update past_key_values + for idx in range(len(new_key_values)): + past_key_values[idx][0][:, :, -1 :, :] = new_key_values[idx][0] + past_key_values[idx][1][:, :, -1 :, :] = new_key_values[idx][1] + + + return generated + + def generate_right_padding(self, prompt, max_new_tokens): + + inputs = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device) + assert isinstance(inputs, torch.Tensor) + + eos_token_id = self.tokenizer.eos_token_id + + generated = inputs.clone() + cur_seq_len = inputs.shape[-1] + prefill_max_seq_len = self.max_seq_len - 1 + prefill_input = pad_input(inputs, self.tokenizer.pad_token_id, prefill_max_seq_len, right=True) + attn_mask = self.helper.build_attention_mask_for_padded_input(cur_seq_len, right_padding=True) + position_embeddings = self.helper.build_position_embeddings_for_padded_input(self.pos_embeds, cur_seq_len, right_padding=True) + + with torch.no_grad(): + outputs = self.prefill_model(prefill_input, attention_mask = attn_mask, position_embeddings=position_embeddings, use_cache = True) + + # orig_inputs = self.tokenizer(prompt, return_tensors="pt", max_length=prefill_max_seq_len, padding='max_length', padding_side="right").to(self.device) + # orig_attn_mask = orig_inputs["attention_mask"] + # orig_position_ids = orig_attn_mask.long().cumsum(-1) - 1 + # orig_position_ids.masked_fill_(orig_attn_mask == 0, 0) + # orig_inputs["position_ids"] = orig_position_ids + # orig_outs = self.orig_model.to(self.device)(**orig_inputs) + + + logits = outputs.logits + past_key_values = outputs.past_key_values + + self.prefill_model = self.prefill_model.cpu() + self.decode_model = self.decode_model.to(self.device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + produced_tokens = 0 + with torch.no_grad(): + while produced_tokens < max_new_tokens: + next_token = torch.tensor([[torch.argmax(logits[..., -1, :])]], device=self.device) + if eos_token_id is not None and torch.all(next_token == eos_token_id): + break + generated = torch.cat([generated, next_token], dim=1) + cur_seq_len += 1 + produced_tokens += 1 + + dec_inputs = self.helper.get_input_for_decode_model(next_token, past_key_values=past_key_values, cur_seq_len = cur_seq_len-1, right_padding=True) + outputs = self.decode_model(**dec_inputs) + logits = outputs.logits + new_key_values = outputs.past_key_values + # shift past_key_values + for i in range(prefill_max_seq_len - 1): + #cur_seq_idx = prefill_max_seq_len - cur_seq_len + i - 1 + for idx in range(len(new_key_values)): + past_key_values[idx][0][:, :, i : i + 1, :] =\ + past_key_values[idx][0][:, :, i + 1: i + 2, :] + past_key_values[idx][1][:, :, i : i + 1, :] =\ + past_key_values[idx][1][:, :, i + 1 : i + 2, :] + + # update past_key_values to be the last added + for idx in range(len(new_key_values)): + past_key_values[idx][0][:, :, -1 :, :] = new_key_values[idx][0] + past_key_values[idx][1][:, :, -1 :, :] = new_key_values[idx][1] + + + return generated + + def generate(self, prompt, max_new_tokens): + if left_pad: + return self.generate_left_padding(prompt, max_new_tokens) + + return self.generate_right_padding(prompt, max_new_tokens) def main(): parser = argparse.ArgumentParser( @@ -80,6 +254,17 @@ def main(): action="store_true", help="Skip original model evaluation.", ) + parser.add_argument( + "--prefill_decode", + action="store_true", + help="Model is calibrated for prefill_decode pipeline.", + ) + parser.add_argument( + "--prompt", type=str, default="The capital of France is", help="Prompt to decode" + ) + parser.add_argument( + "--max_new_tokens", type=int, default=128, help="Maximum new tokens to produce" + ) args = parser.parse_args() print(args) @@ -105,7 +290,6 @@ def main(): ) if not args.skip_fp_eval: - # ------------------------------------------------------------------------- # FP model evaluation # ------------------------------------------------------------------------- @@ -131,6 +315,14 @@ def main(): print("Original RESULTS ARE:") print(make_table(results)) + if args.prompt is not None: + max_seq_len = model.config.max_position_embeddings + inputs = tokenizer(args.prompt, return_tensors="pt", max_length=max_seq_len - 1, padding='max_length', padding_side="left").to(device) #just try with right padding below + #inputs = tokenizer(args.prompt, return_tensors="pt", max_length=max_seq_len - 1, padding='max_length', padding_side="right", device=args.device).to(device) + out_ids = model.to(args.device).generate(**inputs, max_length=max_seq_len + args.max_new_tokens, do_sample = False) + output = tokenizer.decode(out_ids.squeeze(), skip_special_tokens=True) + print(f"Original model prompt: {output}") + model = model.cpu() if device.type == "cuda" and torch.cuda.is_available(): torch.cuda.empty_cache() @@ -140,17 +332,26 @@ def main(): # ------------------------------------------------------------------------- print("Loading fake quantized model …") fk_model = torch.load(args.fk_model_path, weights_only=False).eval().to(args.device) - + config = fk_model.wrapped.config + max_seq_len = config.max_position_embeddings + if args.eval_tasks is not None: - config = fk_model.wrapped.config - max_seq_len = config.max_position_embeddings - results = evaluate_llm_on_tasks( fk_model, tokenizer, args.eval_tasks, max_length=max_seq_len ) print("Quantized RESULTS ARE:") print(make_table(results)) + fk_decoder = GreedyDecoder(fk_model, tokenizer, args.device) if not args.prefill_decode else PrefillDecodeGreedyDecoder(fk_model, model, tokenizer, max_seq_len, config, args.device) + max_seq_len = model.config.max_position_embeddings + inputs = tokenizer(args.prompt, return_tensors="pt", max_length=max_seq_len - 1, padding='max_length', padding_side="left").to(device) #just try with right padding below + #inputs = tokenizer(args.prompt, return_tensors="pt", max_length=max_seq_len - 1, padding='max_length', padding_side="right", device=args.device).to(device) + out_ids = fk_decoder.generate(args.prompt, max_new_tokens=args.max_new_tokens) + output = tokenizer.decode(out_ids.squeeze(), skip_special_tokens=True) + print(f"Fake quantized model prompt: {output}") + fk_model = fk_model.cpu() if not isinstance(fk_model, tuple) else (fk_model[0].cpu(), fk_model[1].cpu()) + if torch.cuda.is_available(): + torch.cuda.empty_cache() if __name__ == "__main__": main() diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 9736292c..9bff3978 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -26,10 +26,16 @@ # ============================================================================= import argparse +import copy import pathlib import random from typing import Any +import types + +from typing import Any, List, Optional, Tuple, Union + +import numpy as np import torch import tqdm from datasets import load_dataset @@ -44,12 +50,17 @@ from tico.quantization.config.spinquant import SpinQuantConfig from tico.quantization.evaluation.script.llm_tasks_eval import evaluate_llm_on_tasks from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.examples.static_llama_layer_runtime import ( + _build_decode_attention_mask, + _build_rope_templates_from_config, + _slice_rope, +) from tico.quantization.wrapq.observers.affine_base import AffineObserverBase from tico.quantization.wrapq.qscheme import QScheme from tico.quantization.wrapq.utils.metrics import perplexity from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase -from tico.utils.utils import SuppressWarning +from tico.utils.utils import SuppressWarning, move_to_device DTYPE_MAP = { "float32": torch.float32, @@ -64,6 +75,7 @@ TRAIN_SPLIT = "train" TEST_SPLIT = "test" +left_pad = True # ------------------------------------------------------------------------- # Helper — copy GPTQ (scale, zp) into PTQ observers @@ -93,24 +105,322 @@ def inject_gptq_qparams( # GPTQ quantizer attributes obs.load_qparams(quantizer.scale, quantizer.zero, lock=True) +def pad_input(input, pad_token, max_seq_len, right: bool = True): + """Pad a tensor to a maximum sequence length using the specified pad token.""" + pads = torch.full( + (input.shape[0], max_seq_len - input.shape[1]), + fill_value=pad_token, + device=input.device, + ) + if right is True: + res = torch.cat((input, pads), dim=1) + else: + res = torch.cat((pads, input), dim=1) -# ------------------------------------------------------------------------- -# Save model in circle format -# ------------------------------------------------------------------------- -def save_model_to(q_m, calib_inputs, save_circle_to_folder): + return res +class PrefillDecodeUtils: + def __init__(self, max_seq_len, config, device): + self.max_seq_len = max_seq_len + self.device = device + self.pos_embeds = _build_rope_templates_from_config( + config, max_seq=max_seq_len, device=device, dtype=torch.float32 + ) + + def build_attention_mask_for_padded_input(self, cur_seq_len, right_padding = False, mask_value:float = -120.0, prefill: bool = True): + dtype = torch.float32 + # mask = torch.full((1, max_seq_len, max_seq_len), mask_value, device=device, dtype=dtype) + # + # #return mask + # if right_padding: + # for i in range(max_seq_len): + # for j in range(max_seq_len): + # if i >= j and j < cur_seq_len: + # mask[..., i, j] = 0 + # else: + # for i in range(max_seq_len): + # for j in range(max_seq_len): + # if i >= j and j >= max_seq_len - cur_seq_len: + # mask[..., i, j] = 0 + + max_seq_len = self.max_seq_len - 1 if prefill is True else self.max_seq_len + pad_mask = [i < cur_seq_len for i in range(max_seq_len)] if right_padding else [i >= max_seq_len - cur_seq_len for i in range(max_seq_len)] + pad_mask = torch.tensor(pad_mask, dtype = torch.bool, device = self.device) + causal_mask = torch.full((1, max_seq_len, max_seq_len), 1, device=self.device, dtype=dtype) + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask = causal_mask == 0 # negating + res_mask = torch.logical_and(pad_mask, causal_mask) + mask_res = torch.zeros((1, max_seq_len, max_seq_len), device=self.device, dtype=dtype) + mask_res = mask_res.masked_fill(~res_mask, mask_value) + return mask_res + #assert torch.equal(mask_res, mask) + + #return mask + + def build_position_embeddings_for_padded_input(self, position_embeddings, cur_seq_len, right_padding = False, prefill: bool = True): + dtype = torch.float32 + #position_embeddings = self.prefill_model.wrapped.model.wrapped.get_position_embeddings_for(dtype, self.device) + max_seq_len = self.max_seq_len - 1 if prefill is True else self.max_seq_len + cos = torch.ones_like(position_embeddings[0][:, : max_seq_len, :]) + sin = torch.zeros_like(position_embeddings[1][:, : max_seq_len, :]) + + sl_cos, sl_sin = position_embeddings + sl_cos = sl_cos[:, : cur_seq_len, :] + sl_sin = sl_sin[:, : cur_seq_len, :] + if right_padding is True: + cos[..., :cur_seq_len, :] = sl_cos + sin[..., :cur_seq_len, :] = sl_sin + else: + # left padding + cos[..., max_seq_len - cur_seq_len : max_seq_len, :] = sl_cos + sin[..., max_seq_len - cur_seq_len : max_seq_len, :] = sl_sin + + return (cos, sin) + + def get_input_for_decode_model(self, next_token, past_key_values, cur_seq_len, right_padding=False): + dtype = torch.float32 + if right_padding: + # attention_mask = _build_decode_attention_mask( + # batch_size=1, + # past_len=cur_seq_len, + # max_seq=self.max_seq_len, + # device=self.device, + # dtype=dtype, + # ) +# + # position_embeddings = _slice_rope( + # self.pos_embeds[0], # cos + # self.pos_embeds[1], # sin + # position=cur_seq_len, + # batch_size=1, + # device=self.device, + # dtype=dtype, + # ) + attention_mask = self.build_attention_mask_for_padded_input(cur_seq_len, right_padding=True, prefill=False) + attention_mask = attention_mask[..., -1, :].unsqueeze(0) # last row + position_embeddings = self.build_position_embeddings_for_padded_input(self.pos_embeds, cur_seq_len, right_padding=True, prefill=False) + (cos, sin) = position_embeddings + cos = cos[..., -1, :].unsqueeze(0) + sin = sin[..., -1, :].unsqueeze(0) + position_embeddings = (cos, sin) + else: + attention_mask = self.build_attention_mask_for_padded_input(cur_seq_len, right_padding=False, prefill=False) + attention_mask = attention_mask[..., -1, :].unsqueeze(0) # last row + position_embeddings = self.build_position_embeddings_for_padded_input(self.pos_embeds, cur_seq_len, right_padding=False, prefill=False ) + (cos, sin) = position_embeddings + cos = cos[..., -1, :].unsqueeze(0) + sin = sin[..., -1, :].unsqueeze(0) + position_embeddings = (cos, sin) + + # fill in input + inputs = {} + inputs["input_ids"] = next_token + inputs["attention_mask"] = attention_mask + inputs["position_embeddings"] = position_embeddings + inputs["past_key_values"] = past_key_values + inputs["use_cache"] = True + + return inputs + +def get_decode_input( + prefill_model, + calib_input, + pad_token_id, + ropes, + max_seq_len, + device, + helper, +): + """Prepare inputs for the decode model using prefill KV‑cache and rotary embeddings.""" + prefill_input = calib_input[..., :-1] + prefill_seq_len = calib_input.shape[-1] + + prefill_max_seq_len = max_seq_len - 1 + prefill_input = pad_input(prefill_input, pad_token_id, prefill_max_seq_len, right = not left_pad) + attn_mask = helper.build_attention_mask_for_padded_input(prefill_seq_len, right_padding=not left_pad) + position_embeddings = helper.build_position_embeddings_for_padded_input(ropes, prefill_seq_len, right_padding=not left_pad) + + with torch.no_grad(): + # run prefill model to get kv-cache + outputs = prefill_model(prefill_input.to(device), attention_mask = attn_mask, position_embeddings=position_embeddings, use_cache = True) + + # fill inputs for decode model + next_token = calib_input[..., -1:] + dec_inputs = helper.get_input_for_decode_model(next_token, past_key_values=outputs.past_key_values, cur_seq_len = prefill_seq_len + 1, right_padding=not left_pad) + return dec_inputs + #dtype=torch.float32 + #attention_mask = _build_decode_attention_mask( + # batch_size=1, + # past_len=prefill_seq_len, + # max_seq=max_seq_len, + # device=device, + # dtype=dtype, + #) +# + #rope_cos, rope_sin = ropes + #position_embeddings = _slice_rope( + # rope_cos, + # rope_sin, + # position=prefill_seq_len - 1, + # batch_size=1, + # device=device, + # dtype=dtype, + #) + + # fill in input + inputs = {} + inputs["input_ids"] = torch.tensor([[next_token]]) + inputs["attention_mask"] = attention_mask + inputs["position_embeddings"] = position_embeddings + inputs["past_key_values"] = outputs.past_key_values + return inputs + +def evaluate_ppl_of_prefill_decode_model_on_dataset( + model, + dataset, + pad_token_id, + max_seq_len, + seed=0, + device: str = "cuda", +): + """Compute perplexity for the prefill-decode logic.""" + + config = ( + model.config + if hasattr(model, "config") + else model.wrapped.config + ) + rope_cos, rope_sin = _build_rope_templates_from_config( + config, max_seq=max_seq_len, device=device, dtype=torch.float32 + ) + + + if hasattr(model, "device") and model.device.type != device.type: + if hasattr(model, "to"): + model.to(device) + + torch.manual_seed(seed) + nlls = [] + helper = PrefillDecodeUtils(max_seq_len, config, device) + + with torch.no_grad(): + for batch in tqdm.tqdm(dataset): + if isinstance(batch, torch.Tensor): + + prefill_seq_len = ( + torch.randint(3, max_seq_len - 1, (1,)).cpu().item() + ) # max_seq_len - 1# cropped input length + prefill_input = batch[..., :prefill_seq_len] # cropped input + #ref_output = model(prefill_input.to(device)) + + last_token = batch[..., prefill_seq_len].cpu().unsqueeze(0) + #torch.tensor( + # [[torch.argmax(ref_output.logits[:, -1, :], dim=-1).cpu()]] + #) + if hasattr(model, "wrapped"): + inputs = get_decode_input( + model, + prefill_input, + pad_token_id, + (rope_cos, rope_sin), + max_seq_len, + device, + helper + ) + inputs = move_to_device(inputs, device) + output = model(**inputs) + else: + input = pad_input( + prefill_input[..., :-1], pad_token_id, max_seq_len - 1, right = not left_pad + ) + prefill_attn_mask = input != pad_token_id + prefill_position_ids = prefill_attn_mask.long().cumsum(-1) - 1 + prefill_position_ids.masked_fill_(prefill_attn_mask == 0, 1) + prefill_ouput = model( + input.to(model.device), + attention_mask = prefill_attn_mask.to(model.device), + position_ids = prefill_position_ids.to(model.device), + use_cache=True + ) + next_token = prefill_input[..., -1:] + decode_attention_mask = torch.ones((1, max_seq_len)) + decode_attention_mask[..., :input.shape[-1]] = input != pad_token_id + decode_position_ids = torch.tensor([[prefill_seq_len-1]]) + + output = model( + next_token.to(model.device), + past_key_values=prefill_ouput.past_key_values, + attention_mask = decode_attention_mask.to(model.device), + position_ids = decode_position_ids.to(model.device), + use_cache=True, + ) + + lm_logits = output.logits + + if torch.isfinite(lm_logits).all(): + labels = last_token[0].to(device) + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + lm_logits.reshape(-1, lm_logits.size(-1)), + labels.view(-1), + ) + nlls.append(loss) + + torch.cuda.empty_cache() + ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) + return ppl + + +def save_model_to(q_m, calib_inputs, args, prefill = True, kwargs=None): + """ Save the whole model in circle format """ q_m.eval() q_m.cpu() + save_circle_to_folder = args.output_dir + suffix = "" if args.prefill_decode is False else "_prefill" if prefill is True else "_decode" - save_path = pathlib.Path(save_circle_to_folder, "model.q.circle") - print(f"saving the whole model to {save_path.resolve()}") + save_path = pathlib.Path(save_circle_to_folder, f"model{suffix}.q.circle") + print(f"saving the whole {'decode-' if prefill is False else 'prefill-' if args.prefill_decode else ''}model to {save_path.resolve()}") with torch.no_grad(): with SuppressWarning(UserWarning, ".*"): - cm = tico.convert(q_m, (calib_inputs[0],), strict=False) + cm = tico.convert(q_m, (calib_inputs[0],), kwargs=kwargs, strict=False) cm.save(save_path) +# ----------------------------------------------------------------------------- +# copied from tico/quantization/wrapq/examples/llama/quantize_decoder_layer_decode.py +# TODO reduce code duplication +# ----------------------------------------------------------------------------- +def make_random_decode_batch(config, device, max_seq): + D = config.hidden_size + B = 1 + head_dim = getattr(config, "head_dim", D // config.num_attention_heads) + n_kv = config.num_key_value_heads + + # Single-token hidden state. + x = torch.randn(B, 1, D, device=device) + + # RoPE tables for the *current token* only. + cos = torch.randn(B, 1, head_dim, device=device) + sin = torch.randn(B, 1, head_dim, device=device) + pos = (cos, sin) -def save_layers_to(q_m, max_seq_len, save_layers_to_folder): + # Additive mask of final static width: (B, 1, MAX_SEQ) + # Simulate that only the first L_eff positions are valid and the rest are padding. + L_eff = torch.randint(low=1, high=max_seq + 1, size=(1,)).item() + mask = torch.zeros(B, 1, max_seq, device=device, dtype=torch.float32) + if L_eff < max_seq: + mask[:, :, L_eff:] = float("-120") + + # Static-sized past KV (already RoPE-applied for past tokens). + past_k = torch.randn(B, n_kv, max_seq - 1, head_dim, device=device) + past_v = torch.randn(B, n_kv, max_seq - 1, head_dim, device=device) + past = (past_k, past_v) + + return x, pos, mask, past + +def save_layers_to(q_m, args, prefill = True): + """ Save all layers of the model in circle format """ + max_seq_len = args.max_seq_len + save_layers_to_folder = args.output_dir q_m.eval() q_m.cpu() @@ -120,35 +430,48 @@ def save_layers_to(q_m, max_seq_len, save_layers_to_folder): layers = q_m.wrapped.model.wrapped.layers config = q_m.wrapped.config + suffix = "" if args.prefill_decode is False else "prefill_" if prefill is True else "decode_" for i, qlayer in enumerate(layers): - save_path = pathlib.Path(save_layers_to_folder, f"decoder_layer_{i}.q.circle") - B, S, D = 1, max_seq_len, config.hidden_size - example_hidden = torch.randn(B, S, D) - - attention_mask = ( - qlayer.wrapped.causal_mask_template[..., :S, :S].squeeze(0).to("cpu") - ) - dtype = example_hidden.dtype - pos_embeds = qlayer.wrapped._slice_rope( - start=0, seq_len=S, device="cpu", dtype=dtype - ) + save_path = pathlib.Path(save_layers_to_folder, f"decoder_layer_{suffix}{i}.q.circle") + B, D = 1, config.hidden_size + if args.prefill_decode is False: + S = max_seq_len + variant = "prefill" + elif prefill is True: + S = max_seq_len - 1 + variant = "prefill" + else: + # decode + S = 1 + variant = "decode" + + if prefill: + example_hidden = torch.randn(B, S, D) + attention_mask = ( + qlayer.wrapped.causal_mask_template[..., :S, :S].squeeze(0).to("cpu") + ) + dtype = example_hidden.dtype + pos_embeds = qlayer.wrapped._slice_rope( + start=0, seq_len=S, device="cpu", dtype=dtype + ) + kwargs={"attention_mask": attention_mask, "position_embeddings": pos_embeds } + else: + example_hidden, pos_embeds, attention_mask, past_kv = make_random_decode_batch(config, "cpu", args.max_seq_len) + kwargs={"attention_mask": attention_mask, "position_embeddings": pos_embeds, "past_key_value": past_kv} - print(f"Saving model layer_{i} to {save_path.resolve()}") + print(f"Saving {suffix}model layer_{i} to {save_path.resolve()}") with torch.no_grad(): with SuppressWarning(UserWarning, ".*"): # Pass attention_mask and position_embeddings as inputs to avoid # storing them per layer and increasing model size. cm = tico.convert( - qlayer.wrapped.as_export_module("prefill").eval(), + qlayer.wrapped.as_export_module(variant, return_kv=args.prefill_decode).eval(), (example_hidden,), - kwargs={ - "attention_mask": attention_mask, - "position_embeddings": pos_embeds, - }, + kwargs=kwargs ) + cm.save(save_path) - def quantize_using_PTQ(q_m, calib_inputs, args): print("Wrapping layers with PTQWrapper …") @@ -168,7 +491,7 @@ def quantize_using_PTQ(q_m, calib_inputs, args): # ------------------------------------------------------------------------- # Single-pass activation calibration # ------------------------------------------------------------------------- - print("Calibrating PTQ obeservers…") + print("Calibrating PTQ observers…") # Overwrite weight observers with GPTQ statistics if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict): @@ -187,26 +510,35 @@ def quantize_using_PTQ(q_m, calib_inputs, args): device = torch.device(args.device) with torch.no_grad(): for inp in tqdm.tqdm(calib_inputs): - q_m(inp.to(device)) - + if args.prefill_decode: + outputs = q_m(inp[..., :-1].to(device), use_cache=True) + # TODO add padding? + q_m( + inp[..., -1:].to(device), + past_key_values=outputs.past_key_values, + use_cache=True, + ) + else: + q_m(inp.to(device)) # Freeze all Q-params (scale, zero-point) q_m = convert(q_m) return q_m -def evaluate(q_m, tokenizer, dataset_test, args): +def evaluate(q_m, tokenizer, dataset_test, args, quantized: bool): # ------------------------------------------------------------------------- # Evaluate perplexity on Wikitext-2 # ------------------------------------------------------------------------- print("\nCalculating perplexities …") enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") - ppl_uint8 = perplexity( + ppl = perplexity( q_m, enc, args.device, max_length=args.max_seq_len, stride=args.max_seq_len ) + help_str = "int16" if quantized is True else "FP32" print("\n┌── Wikitext-2 test perplexity ─────────────") - print(f"│ int16 : {ppl_uint8:8.2f}") + print(f"│ {help_str} : {ppl:8.2f}") print("└───────────────────────────────────────────") if args.eval_tasks is not None: @@ -250,6 +582,272 @@ def get_ptq_model_name(model, args): ) return name +class QModelProcessor: + """Base processor handling tokenization, GPTQ, and evaluation logic.""" + + def __init__(self, model, tokenizer, args): + """Initialize the processor with model, tokenizer, and arguments.""" + self.model = model + self.tokenizer = tokenizer + self.device = torch.device(args.device) + self.args = args + + def get_tokenized_inputs(self, dataset, shuffle=True): + """Tokenize the dataset into fixed‑length chunks for calibration.""" + text = " ".join(dataset["text"]) + ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device) + tokenized_inputs = [] + nsamples = self.args.nsamples_for_qcalibration + seqlen = self.model.config.max_position_embeddings + if shuffle is True: + random.seed(self.args.seed) + else: + stride = min((ids.shape[1] - seqlen - 1) // nsamples, seqlen) + for index in range(nsamples): + if shuffle is True: + i = random.randint(0, ids.shape[1] - seqlen - 1) + else: + i = index * stride + j = i + seqlen + inp = ids[:, i:j] + tokenized_inputs.append(inp.cpu()) + return tokenized_inputs + + def run_gptq(self, calib_inputs): + """Run GPTQ weight‑only quantization on the model using calibration inputs.""" + print("Applying GPTQ …") + + sens = None + if self.args.gptq_mse is not None and self.args.gptq_mse == "smse": + if self.args.sensitivity_path is not None: + sens = torch.load(self.args.sensitivity_path) + else: + calibrator = SensitivityCalibrator(self.model, calib_inputs) + sens = calibrator.compute_sensitivity_info() + if self.args.output_dir is not None and "sensitivity" in self.args.save: + save_name = get_sensitivities_info_name( + self.model, "wikitext", self.args.seed, len(calib_inputs) + ) + save_path = pathlib.Path(self.args.output_dir, save_name) + print(f"Saving calibrated_sensitivities to {save_path}") + torch.save(sens, save_path) + + gptq_config = GPTQConfig( + weight_bits=self.args.linear_weight_bits, + perchannel=True, + mse=self.args.gptq_mse, + sensitivity=sens, + ) + q_m = prepare(self.model, gptq_config, inplace=True) + with torch.no_grad(): + for inp in calib_inputs: + q_m(inp.to(self.device)) + + q_m = convert(q_m, inplace=True) # materialize INT-weight tensors + return q_m + + def run_ptq(self, q_m, calib_inputs): + assert(False) + + def _run_ptq(self, q_m, calib_inputs): + q_m = quantize_using_PTQ(q_m, calib_inputs, self.args) + if self.args.output_dir is not None and "ptq_checkpoint" in self.args.save: + save_name = get_ptq_model_name(self.model, self.args) + save_path = pathlib.Path(self.args.output_dir, save_name) + print(f"Saving PTQ model to {save_path}") + torch.save(q_m, save_path) + return q_m + + def evaluate_original(self, dataset_test): + """Evaluate the original ( model on the test dataset.""" + return evaluate( + self.model, self.tokenizer, dataset_test, self.args, quantized=False + ) + + def evaluate_quantized(self, dataset_test): + """Placeholder for evaluating the quantized model (implementation elsewhere).""" + assert False + + def save_quantized(self, model, calib_inputs): + """Placeholder for saving quantgization artifacts (implementation elsewhere).""" + assert False + + +class PrefillQModelProcessor(QModelProcessor): + """ + Processor for simple model (just-prefill-model) which doesn't use kv cache. + """ + + def __init__(self, model, tokenizer, args): + """Initialize the prefill‑decode processor, setting up rope embeddings and handling tokenizer pad token.""" + super().__init__(model, tokenizer, args) + + def run_ptq(self, q_m, calib_inputs): + return super()._run_ptq(q_m, calib_inputs) + + def evaluate_quantized(self, model, dataset_test): + evaluate(model, self.tokenizer, dataset_test, self.args, quantized=True) + + def save_quantized(self, model, calib_inputs): + if self.args.output_dir is not None and "circle_per_layer" in self.args.save: + save_layers_to(model, self.args) + + if self.args.output_dir is not None and "circle_full" in self.args.save: + calib_inputs = list( + torch.stack(calib_inputs).reshape(-1, 1, self.args.max_seq_len) + ) + save_model_to(model, calib_inputs, self.args) + + +class PrefillDecodeQModelProcessor(QModelProcessor): + """ + Processor for Prefill-Decode models. + Prefill-model computes kv-cache for the user input then each new token is produced by decode-model wit upadted kv-cache. + """ + + def __init__(self, model, tokenizer, args): + """Initialize the prefill‑decode processor, handling tokenizer pad token and preparing rotary embeddings.""" + super().__init__(model, tokenizer, args) + if tokenizer.pad_token is None: + print( + "Warning: tokenizer doesn't have pad_token. Prefill-decoding scheme may fail." + ) + tokenizer.pad_token = tokenizer.eos_token + + rope_cos, rope_sin = _build_rope_templates_from_config( + self.model.config, + max_seq=self.args.max_seq_len, + device=self.device, + dtype=torch.float32, + ) + self.rope_cos = rope_cos + self.rope_sin = rope_sin + + # debug padding + + # inputs = tokenizer("Hello!", return_tensors="pt", max_length=args.max_seq_len - 1, padding='max_length', padding_side="right").input_ids.to(device) + # #inputs = tokenizer("Hello! How are you?", return_tensors="pt").input_ids.to(device) + # model.config.use_cache = True + # model.config._attn_implementation = "eager" + # out_ids = model.generate(inputs) + # output = tokenizer.decode(out_ids.squeeze(), skip_special_tokens=True) + + # prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset(q_m, q_m, calib_inputs, pad_token_id=tokenizer.pad_token_id, max_seq_len = args.max_seq_len, seed = args.seed, device = device) + # + # print("\n┌── Wikitext-2 prefill_decode initial calibration perplexity──") + # print(f"│ FP32 : {prefill_decode_ppl:8.2f}") + # print("└───────────────────────────────────────────") + + def run_ptq(self, q_m, calib_inputs): + """Run PTQ for the prefill‑decode pipeline, calibrating for padding and kv-cache.""" + + # ?? pre_ptq_model = copy.deepcopy(q_m).to("cpu") # to be used in decode quntizing + + # get prefill_model + q_m = super()._run_ptq(q_m, calib_inputs) + + prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset( + q_m, + calib_inputs, + pad_token_id=self.tokenizer.pad_token_id, + max_seq_len=self.args.max_seq_len, + seed=self.args.seed, + device=self.device, + ) + + print( + "\n┌── Wikitext-2 prefill_prefill train calibration perplexity ─────────────" + ) + print(f"│ int16 : {prefill_decode_ppl:8.2f}") + print("└───────────────────────────────────────────") + + torch.manual_seed(self.args.seed) + + return q_m + + def evaluate_original(self, dataset_test): + """Evaluate the original (FP) model using the prefill‑decode pipeline.""" + super().evaluate_original(dataset_test) + + test_inputs = self.get_tokenized_inputs(dataset_test, shuffle=False) + prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset( + self.model, + test_inputs, + pad_token_id=self.tokenizer.pad_token_id, + max_seq_len=self.args.max_seq_len, + seed=self.args.seed, + device=self.device, + ) + + print("\n┌── Wikitext-2 prefill_prefill original test perplexity ─────────────") + print(f"│ FP32 : {prefill_decode_ppl:8.2f}") + print("└───────────────────────────────────────────") + + def evaluate_quantized(self, model, dataset_test): + """Evaluate the quantized prefill‑decode model on the test dataset.""" + + evaluate(model, self.tokenizer, dataset_test, self.args, quantized=True) + + test_inputs = self.get_tokenized_inputs(dataset_test, shuffle=False) + prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset( + model, + test_inputs, + pad_token_id=self.tokenizer.pad_token_id, + max_seq_len=self.args.max_seq_len, + seed=self.args.seed, + device=self.device, + ) + print("\n┌── Wikitext-2 prefill_decode quantized test perplexity ─────────────") + print(f"│ int16 : {prefill_decode_ppl:8.2f}") + print("└───────────────────────────────────────────") + + def save_quantized(self, model, calib_inputs): + """Save the quantized prefill and decode models (and optionally their layers) to disk.""" + + model.wrapped.config.use_cache = True + if self.args.output_dir is not None and "circle_per_layer" in self.args.save: + save_layers_to(model, self.args, prefill=True) + save_layers_to(model, self.args, prefill=False) + + if self.args.output_dir is not None and "circle_full" in self.args.save: + model = model.to("cpu") + calib_inputs = list( + torch.stack(calib_inputs).reshape(-1, 1, self.args.max_seq_len) + ) + # save prefill model + save_model_to(model, [calib_inputs[0][..., :-1]], self.args, prefill=True) # seq_len = max_seq_len - 1 + + # compute example input + prefill_seq_len = ( + torch.randint(3, self.args.max_seq_len - 1, (1,)).cpu().item() + ) # cropped input length + prefill_input = calib_inputs[0][..., :prefill_seq_len].to( + "cpu" + ) # cropped input + + inputs = get_decode_input( + model, + prefill_input, + self.tokenizer.pad_token_id, + (self.rope_cos.cpu(), self.rope_sin.cpu()), + self.args.max_seq_len, + "cpu", + helper=PrefillDecodeUtils(self.args.max_seq_len, model.wrapped.config, "cpu") + ) + + assert "input_ids" in inputs + input = inputs.pop("input_ids") + + # save decode model + save_model_to(model, [input], self.args, prefill=False, kwargs = inputs) # seq_len = 1 + + +def get_qmodel_processor(model, tokenizer, args): + if args.prefill_decode: + return PrefillDecodeQModelProcessor(model, tokenizer, args) + + return PrefillQModelProcessor(model, tokenizer, args) + def main(): parser = argparse.ArgumentParser( @@ -376,6 +974,13 @@ def main(): type=str, default=None, ) + parser.add_argument( + "--prefill_decode", + action="store_true", + default=False, + help="Wether to use cache", + ) + args = parser.parse_args() print(args) @@ -418,7 +1023,7 @@ def main(): else: print("Skipping SpinQuant preprocessing …") - model.config.use_cache = False # TODO use args for it + model.config.use_cache = False if args.calibrate_seq_len is not None: model.config.max_position_embeddings = min( model.config.max_position_embeddings, args.calibrate_seq_len @@ -428,72 +1033,32 @@ def main(): DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT, cache_dir=args.cache_dir ) - print("\nCalculating original perplexities …") - enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") - ppl_fp32 = perplexity( - model, enc, device, max_length=args.max_seq_len, stride=args.max_seq_len - ) - - print("\n┌── Wikitext-2 test perplexity ─────────────") - print(f"│ FP32 : {ppl_fp32:8.2f}") - print("└───────────────────────────────────────────") + # ------------------------------------------------------------------------- + # Create a processor for the model + # ------------------------------------------------------------------------- + qmodel_processor = get_qmodel_processor(model, tokenizer, args) - if args.eval_tasks is not None: - results = evaluate_llm_on_tasks( - model, tokenizer, args.eval_tasks, max_length=args.max_seq_len - ) - print("Original RESULTS ARE:") - print(make_table(results)) + # ------------------------------------------------------------------------- + # Compute original metrics to estimate metrics degradation + # ------------------------------------------------------------------------- + qmodel_processor.evaluate_original(dataset_test) # ------------------------------------------------------------------------- # Prepare calibration dataset # ------------------------------------------------------------------------- dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT) - calib_txt = " ".join(dataset_train["text"]) - train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device) - calib_inputs = [] - nsamples = args.nsamples_for_qcalibration - seqlen = model.config.max_position_embeddings - random.seed(args.seed) - for _ in range(nsamples): - i = random.randint(0, train_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = train_ids[:, i:j] - calib_inputs.append(inp.cpu()) + calib_inputs = qmodel_processor.get_tokenized_inputs(dataset_train, shuffle=True) + + # original_prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset(model, model, calib_inputs, pad_token_id=tokenizer.pad_token_id, max_seq_len = args.max_seq_len, seed = args.seed, device = device) + # print("\n┌── Wikitext-2 prefill_decode original calibration perplexity ─────────────") + # print(f"│ fp32 : {original_prefill_decode_ppl:8.2f}") + # print("└───────────────────────────────────────────") # ------------------------------------------------------------------------- # Run GPTQ (weight-only) pass # ------------------------------------------------------------------------- if not args.no_GPTQ: - print("Applying GPTQ …") - - sens = None - if args.gptq_mse is not None and args.gptq_mse == "smse": - if args.sensitivity_path is not None: - sens = torch.load(args.sensitivity_path) - else: - calibrator = SensitivityCalibrator(model, calib_inputs) - sens = calibrator.compute_sensitivity_info() - if args.output_dir is not None and "sensitivity" in args.save: - save_name = get_sensitivities_info_name( - model, "wikitext", args.seed, len(calib_inputs) - ) - save_path = pathlib.Path(args.output_dir, save_name) - print(f"Saving calibrated_sensitivities to {save_path}") - torch.save(sens, save_path) - - gptq_config = GPTQConfig( - weight_bits=args.linear_weight_bits, - perchannel=True, - mse=args.gptq_mse, - sensitivity=sens, - ) - q_m = prepare(model, gptq_config, inplace=True) - with torch.no_grad(): - for inp in calib_inputs: - q_m(inp.to(args.device)) - - q_m = convert(q_m, inplace=True) # materialize INT-weight tensors + q_m = qmodel_processor.run_gptq(calib_inputs) else: q_m = model @@ -501,23 +1066,17 @@ def main(): # Wrap every layer with PTQWrapper # ------------------------------------------------------------------------- if not args.no_PTQ: - q_m = quantize_using_PTQ(q_m, calib_inputs, args) - - if args.output_dir is not None and "ptq_checkpoint" in args.save: - save_name = get_ptq_model_name(model, args) - save_path = pathlib.Path(args.output_dir, save_name) - print(f"Saving PTQ model to {save_path}") - torch.save(q_m, save_path) + q_m = qmodel_processor.run_ptq(q_m, calib_inputs) - # after PTQ quantizer only fixed-length input sequences are valid - evaluate(q_m, tokenizer, dataset_test, args) - - if args.output_dir is not None and "circle_per_layer" in args.save: - save_layers_to(q_m, args.max_seq_len, args.output_dir) + # ------------------------------------------------------------------------- + # Compute quantized model metrics to estimate metrics degradation + # ------------------------------------------------------------------------- + qmodel_processor.evaluate_quantized(q_m, dataset_test) - if args.output_dir is not None and "circle_full" in args.save: - calib_inputs = list(torch.stack(calib_inputs).reshape(-1, 1, args.max_seq_len)) - save_model_to(q_m, calib_inputs, args.output_dir) + # ------------------------------------------------------------------------- + # Save layers and model + # ------------------------------------------------------------------------- + qmodel_processor.save_quantized(q_m, calib_inputs) if __name__ == "__main__": diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py index 3cd50744..a2a1ca25 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer.py @@ -305,7 +305,7 @@ def forward( if use_cache: outputs += (present_key_value,) # type: ignore[assignment] - if self.return_type == "tuple": + if self.return_type == "tuple" or use_cache is True: return outputs if self.return_type == "tensor": return hidden_states