1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
|
import argparse import json from awq import AutoAWQForCausalLM from transformers import AutoTokenizer import prompt_utils
parser = argparse.ArgumentParser(description="LLM_CLS") parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--quant_path", type=str, required=True) parser.add_argument("--data_file", type=str, required=True)
args = parser.parse_args()
system_prompt = prompt_utils.system_prompt_v2.system_prompt label_map = prompt_utils.system_prompt_v2.label_map
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} tokenizer = AutoTokenizer.from_pretrained(args.model_path) model = AutoAWQForCausalLM.from_pretrained( args.model_path, device_map="auto", use_cache=False )
data = [] fi = open(args.data_file, "r", encoding="utf-8") for line in fi: line_js = json.loads(line) text = line_js["text"] label = line_js["label"] if label not in label_map: raise ValueError(f"Unknown label: {label}") label = label_map[label] messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": text}, {"role": "assistant", "content": label}, ] text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False, enable_thinking=False ) data.append(text.strip()) model.quantize( tokenizer, quant_config=quant_config, calib_data=data, n_parallel_calib_samples=1, max_calib_samples=256, max_calib_seq_len=10240, ) model.save_quantized(args.quant_path, safetensors=True, shard_size="4GB") tokenizer.save_pretrained(args.quant_path) print("Quantization completed.")
|