csuhan commited on
Commit
bafac31
·
1 Parent(s): 361bc82
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -64,7 +64,7 @@ def load(
64
  # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
65
  # ckpt_path = checkpoints[local_rank]
66
  print("Loading")
67
- # checkpoint = torch.load(ckpt_path, map_location="cpu")
68
  instruct_adapter_checkpoint = torch.load(
69
  instruct_adapter_path, map_location="cpu")
70
  caption_adapter_checkpoint = torch.load(
@@ -88,8 +88,8 @@ def load(
88
  vision_model = VisionModel(model_args)
89
 
90
  torch.set_default_tensor_type(torch.FloatTensor)
91
- # model.load_state_dict(checkpoint, strict=False)
92
- # del checkpoint
93
  model.load_state_dict(instruct_adapter_checkpoint, strict=False)
94
  model.load_state_dict(caption_adapter_checkpoint, strict=False)
95
  vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
@@ -169,8 +169,8 @@ def download_llama_adapter(instruct_adapter_path, caption_adapter_path):
169
  # ckpt_path = "/data1/llma/7B/consolidated.00.pth"
170
  # param_path = "/data1/llma/7B/params.json"
171
  # tokenizer_path = "/data1/llma/tokenizer.model"
172
- # ckpt_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="consolidated.00.pth")
173
- ckpt_path = None
174
  param_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="params.json")
175
  tokenizer_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
176
  instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"
 
64
  # ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
65
  # ckpt_path = checkpoints[local_rank]
66
  print("Loading")
67
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
68
  instruct_adapter_checkpoint = torch.load(
69
  instruct_adapter_path, map_location="cpu")
70
  caption_adapter_checkpoint = torch.load(
 
88
  vision_model = VisionModel(model_args)
89
 
90
  torch.set_default_tensor_type(torch.FloatTensor)
91
+ model.load_state_dict(checkpoint, strict=False)
92
+ del checkpoint
93
  model.load_state_dict(instruct_adapter_checkpoint, strict=False)
94
  model.load_state_dict(caption_adapter_checkpoint, strict=False)
95
  vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
 
169
  # ckpt_path = "/data1/llma/7B/consolidated.00.pth"
170
  # param_path = "/data1/llma/7B/params.json"
171
  # tokenizer_path = "/data1/llma/tokenizer.model"
172
+ ckpt_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="consolidated.00.pth")
173
+ # ckpt_path = None
174
  param_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="params.json")
175
  tokenizer_path = hf_hub_download(repo_id="nyanko7/LLaMA-7B", filename="tokenizer.model")
176
  instruct_adapter_path = "llama_adapter_len10_layer30_release.pth"