from main import * def get_satclip(ckpt_path, device, return_all=False): ckpt = torch.load(ckpt_path,map_location=device) ckpt['hyper_parameters'].pop('eval_downstream') ckpt['hyper_parameters'].pop('air_temp_data_path') ckpt['hyper_parameters'].pop('election_data_path') lightning_model = SatCLIPLightningModule(**ckpt['hyper_parameters']).to(device) lightning_model.load_state_dict(ckpt['state_dict']) lightning_model.eval() geo_model = lightning_model.model if return_all: return geo_model else: return geo_model.location