Spaces:
Runtime error
Runtime error
init
Browse files
app.py
CHANGED
@@ -66,7 +66,7 @@ def load(
|
|
66 |
# ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
|
67 |
# ckpt_path = checkpoints[local_rank]
|
68 |
print("Loading")
|
69 |
-
checkpoint = torch.load(ckpt_path, map_location="
|
70 |
instruct_adapter_checkpoint = torch.load(
|
71 |
instruct_adapter_path, map_location="cpu")
|
72 |
caption_adapter_checkpoint = torch.load(
|
@@ -87,12 +87,13 @@ def load(
|
|
87 |
model_args.vocab_size = tokenizer.n_words
|
88 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
89 |
model = Transformer(model_args)
|
90 |
-
vision_model = VisionModel(model_args)
|
91 |
-
|
92 |
-
torch.set_default_tensor_type(torch.FloatTensor)
|
93 |
model.load_state_dict(checkpoint, strict=False)
|
94 |
del checkpoint
|
95 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
96 |
model.load_state_dict(instruct_adapter_checkpoint, strict=False)
|
97 |
model.load_state_dict(caption_adapter_checkpoint, strict=False)
|
98 |
vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
|
|
|
66 |
# ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
|
67 |
# ckpt_path = checkpoints[local_rank]
|
68 |
print("Loading")
|
69 |
+
checkpoint = torch.load(ckpt_path, map_location="cuda")
|
70 |
instruct_adapter_checkpoint = torch.load(
|
71 |
instruct_adapter_path, map_location="cpu")
|
72 |
caption_adapter_checkpoint = torch.load(
|
|
|
87 |
model_args.vocab_size = tokenizer.n_words
|
88 |
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
89 |
model = Transformer(model_args)
|
|
|
|
|
|
|
90 |
model.load_state_dict(checkpoint, strict=False)
|
91 |
del checkpoint
|
92 |
torch.cuda.empty_cache()
|
93 |
+
vision_model = VisionModel(model_args)
|
94 |
+
|
95 |
+
torch.set_default_tensor_type(torch.FloatTensor)
|
96 |
+
|
97 |
model.load_state_dict(instruct_adapter_checkpoint, strict=False)
|
98 |
model.load_state_dict(caption_adapter_checkpoint, strict=False)
|
99 |
vision_model.load_state_dict(caption_adapter_checkpoint, strict=False)
|