Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
""" | |
aggregate.py - module for 'reducing' multiple 'summary chunks' into one | |
an overly complicated class for legacy compatibility reasons, for usage of the | |
2024 map-reduce models see hf.co/pszemraj/bart-large-summary-map-reduce#usage | |
""" | |
import logging | |
import pprint as pp | |
import time | |
import torch | |
from transformers import GenerationConfig, pipeline | |
# Setting up logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
class BatchAggregator: | |
""" | |
BatchAggregator is a class for aggregating text from multiple sources. | |
Usage: | |
from aggregate import BatchAggregator | |
aggregator = BatchAggregator() | |
agg = aggregator.infer_aggregate(["This is a test", "This is another test"]) | |
print(agg) | |
""" | |
GENERIC_CONFIG = GenerationConfig( | |
max_new_tokens=512, | |
num_beams=4, | |
early_stopping=True, | |
do_sample=False, | |
truncation=True, | |
) | |
def __init__( | |
self, | |
model_name: str = "pszemraj/bart-large-summary-map-reduce", | |
force_cpu: bool = False, | |
**kwargs, | |
): | |
""" | |
__init__ initializes the BatchAggregator class. | |
:param str model_name: model name to use, default: "pszemraj/bart-large-summary-map-reduce" | |
:param bool force_cpu: force the model to run on CPU, default: False | |
""" | |
self.device = None | |
self.is_compiled = False | |
self.model_name = None | |
self.aggregator = None | |
self.force_cpu = force_cpu | |
self.logger = logging.getLogger(__name__) | |
self.init_model(model_name) | |
def init_model(self, model_name: str) -> None: | |
""" | |
Initialize the model. | |
:param model_name: The name of the model to use. | |
""" | |
# Free up memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
self.logger.info(f"Setting model to {model_name}") | |
self.model_name = model_name | |
self.aggregator = self._create_pipeline(model_name) | |
self._configure_model() | |
def _create_pipeline( | |
self, model_name: str = "pszemraj/bart-large-summary-map-reduce" | |
) -> pipeline: | |
""" | |
_create_pipeline creates a pipeline for the model. | |
:param str model_name: model name to use | |
:return pipeline: the pipeline for the model | |
:raises Exception: if the pipeline cannot be created | |
""" | |
device_map = ( | |
"auto" if torch.cuda.is_available() and not self.force_cpu else "cpu" | |
) | |
try: | |
self.logger.info( | |
f"Creating pipeline with model {model_name} on device {device_map}" | |
) | |
return pipeline( | |
"text2text-generation", | |
model=model_name, | |
device_map=device_map, | |
torch_dtype=torch.float32, | |
) | |
except Exception as e: | |
self.logger.error(f"Failed to create pipeline: {e}") | |
raise | |
def _configure_model(self): | |
""" | |
Configure the model for generation. | |
""" | |
try: | |
self.aggregator.model = torch.compile(self.aggregator.model) | |
self.is_compiled = True | |
except Exception as e: | |
self.logger.warning(f"Could not compile model with Torch 2.0: {e}") | |
self._set_default_generation_config() | |
self.logger.info(self.aggregator.model.generation_config.to_json_string()) | |
def _set_default_generation_config(self): | |
""" | |
Set the default generation configuration for the model. | |
""" | |
self.aggregator.model.generation_config.update( | |
**self.GENERIC_CONFIG.to_diff_dict() | |
) | |
def update_generation_config(self, **kwargs): | |
""" | |
Update the generation configuration with the specified parameters. | |
Args: | |
**kwargs: The parameters to update in the generation configuration. | |
""" | |
self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}") | |
self.aggregator.model.generation_config.update(**kwargs) | |
def get_generation_config(self) -> dict: | |
""" | |
Get the current generation configuration. | |
Returns: | |
dict: The current generation configuration. | |
""" | |
return self.aggregator.model.generation_config.to_dict() | |
def update_loglevel(self, level: str = "INFO"): | |
""" | |
Update the log level. | |
Args: | |
level (str): The log level to set. Defaults to "INFO". | |
""" | |
self.logger.setLevel(level) | |
def infer_aggregate( | |
self, | |
text_list: list, | |
instruction: str = None, # Kept for backward compatibility but not used | |
**kwargs, | |
) -> str: | |
""" | |
infer_aggregate - infers a consolidated summary from a list of texts. | |
Args: | |
text_list (list): The texts to summarize. | |
instruction (str): Not used by this model, kept for compatibility. | |
**kwargs: Additional parameters to update in the generation configuration. | |
Returns: | |
The generated summary. | |
""" | |
joined_text = "\n\n".join(text_list) | |
if kwargs: | |
self.update_generation_config(**kwargs) | |
st = time.perf_counter() | |
self.logger.info(f"inference on {len(text_list)} texts ...") | |
result = self.aggregator( | |
joined_text, | |
generation_config=self.aggregator.model.generation_config, | |
)[0]["generated_text"] | |
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") | |
self.logger.info( | |
f"Input tokens:\t{self.count_tokens(joined_text)}. Output tokens:\t{self.count_tokens(result)}" | |
) | |
self.logger.debug(f"Generated text:\n{result}") | |
return result | |
def count_tokens(self, text: str) -> int: | |
"""count the number of tokens in a text""" | |
return ( | |
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) | |
if text | |
else 0 | |
) | |