DeepSeek-V3 / inference /fp8_cast_bf16.py
FILM6912's picture
Add files using upload-large-folder tool
adda7c1 verified
raw
history blame
3.24 kB
import os
import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import torch
from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)