Spaces:
Runtime error
Runtime error
import inspect | |
import traceback | |
import torch | |
import vision_models | |
consumers = dict() | |
def load_models(): | |
global consumers | |
list_models = [m[1] for m in inspect.getmembers(vision_models, inspect.isclass) | |
if issubclass(m[1], vision_models.BaseModel) and m[1] != vision_models.BaseModel] | |
list_models.sort(key=lambda x: x.load_order) | |
print("-" * 10, "List models", list_models) | |
counter_ = 0 | |
for model_class_ in list_models: | |
print("-" * 10, "Now loading {}:".format(model_class_)) | |
for process_name_ in model_class_.list_processes(): | |
consumers[process_name_] = make_fn(model_class_, process_name_, counter_) | |
counter_ += 1 | |
print("-" * 10, "Loading {} finished. Current gpu:".format(model_class_)) | |
print(torch.cuda.memory_summary()) | |
print("-" * 10, "Model loading finished. Final gpu:") | |
print(torch.cuda.memory_summary()) | |
def make_fn(model_class, process_name, counter): | |
""" | |
model_class.name and process_name will be the same unless the same model is used in multiple processes, for | |
different tasks | |
""" | |
# We initialize each one on a separate GPU, to make sure there are no out of memory errors | |
num_gpus = torch.cuda.device_count() | |
gpu_number = counter % num_gpus | |
model_instance = model_class(gpu_number=gpu_number) | |
def _function(*args, **kwargs): | |
if process_name != model_class.name: | |
kwargs['process_name'] = process_name | |
if model_class.to_batch: | |
# Batchify the input. Model expects a batch. And later un-batchify the output. | |
args = [[arg] for arg in args] | |
kwargs = {k: [v] for k, v in kwargs.items()} | |
# The defaults that are not in args or kwargs, also need to listify | |
full_arg_spec = inspect.getfullargspec(model_instance.forward) | |
if full_arg_spec.defaults is None: | |
default_dict = {} | |
else: | |
default_dict = dict(zip(full_arg_spec.args[-len(full_arg_spec.defaults):], full_arg_spec.defaults)) | |
non_given_args = full_arg_spec.args[1:][len(args):] | |
non_given_args = set(non_given_args) - set(kwargs.keys()) | |
for arg_name in non_given_args: | |
kwargs[arg_name] = [default_dict[arg_name]] | |
try: | |
out = model_instance.forward(*args, **kwargs) | |
if model_class.to_batch: | |
out = out[0] | |
except Exception as e: | |
print(f'Error in {process_name} model:', e) | |
traceback.print_exc() | |
out = None | |
return out | |
return _function | |
def forward(model_name, *args, **kwargs): | |
return consumers[model_name](*args, **kwargs) | |