""" aggregate.py - module for 'reducing' multiple 'summary chunks' into one an overly complicated class for legacy compatibility reasons, for usage of the 2024 map-reduce models see hf.co/pszemraj/bart-large-summary-map-reduce#usage """ import logging import pprint as pp import time import torch from transformers import GenerationConfig, pipeline # Setting up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) class BatchAggregator: """ BatchAggregator is a class for aggregating text from multiple sources. Usage: from aggregate import BatchAggregator aggregator = BatchAggregator() agg = aggregator.infer_aggregate(["This is a test", "This is another test"]) print(agg) """ GENERIC_CONFIG = GenerationConfig( max_new_tokens=512, num_beams=4, early_stopping=True, do_sample=False, truncation=True, ) def __init__( self, model_name: str = "pszemraj/bart-large-summary-map-reduce", force_cpu: bool = False, **kwargs, ): """ __init__ initializes the BatchAggregator class. :param str model_name: model name to use, default: "pszemraj/bart-large-summary-map-reduce" :param bool force_cpu: force the model to run on CPU, default: False """ self.device = None self.is_compiled = False self.model_name = None self.aggregator = None self.force_cpu = force_cpu self.logger = logging.getLogger(__name__) self.init_model(model_name) def init_model(self, model_name: str) -> None: """ Initialize the model. :param model_name: The name of the model to use. """ # Free up memory if torch.cuda.is_available(): torch.cuda.empty_cache() self.logger.info(f"Setting model to {model_name}") self.model_name = model_name self.aggregator = self._create_pipeline(model_name) self._configure_model() def _create_pipeline( self, model_name: str = "pszemraj/bart-large-summary-map-reduce" ) -> pipeline: """ _create_pipeline creates a pipeline for the model. :param str model_name: model name to use :return pipeline: the pipeline for the model :raises Exception: if the pipeline cannot be created """ device_map = ( "auto" if torch.cuda.is_available() and not self.force_cpu else "cpu" ) try: self.logger.info( f"Creating pipeline with model {model_name} on device {device_map}" ) return pipeline( "text2text-generation", model=model_name, device_map=device_map, torch_dtype=torch.float32, ) except Exception as e: self.logger.error(f"Failed to create pipeline: {e}") raise def _configure_model(self): """ Configure the model for generation. """ try: self.aggregator.model = torch.compile(self.aggregator.model) self.is_compiled = True except Exception as e: self.logger.warning(f"Could not compile model with Torch 2.0: {e}") self._set_default_generation_config() self.logger.info(self.aggregator.model.generation_config.to_json_string()) def _set_default_generation_config(self): """ Set the default generation configuration for the model. """ self.aggregator.model.generation_config.update( **self.GENERIC_CONFIG.to_diff_dict() ) def update_generation_config(self, **kwargs): """ Update the generation configuration with the specified parameters. Args: **kwargs: The parameters to update in the generation configuration. """ self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}") self.aggregator.model.generation_config.update(**kwargs) def get_generation_config(self) -> dict: """ Get the current generation configuration. Returns: dict: The current generation configuration. """ return self.aggregator.model.generation_config.to_dict() def update_loglevel(self, level: str = "INFO"): """ Update the log level. Args: level (str): The log level to set. Defaults to "INFO". """ self.logger.setLevel(level) def infer_aggregate( self, text_list: list, instruction: str = None, # Kept for backward compatibility but not used **kwargs, ) -> str: """ infer_aggregate - infers a consolidated summary from a list of texts. Args: text_list (list): The texts to summarize. instruction (str): Not used by this model, kept for compatibility. **kwargs: Additional parameters to update in the generation configuration. Returns: The generated summary. """ joined_text = "\n\n".join(text_list) if kwargs: self.update_generation_config(**kwargs) st = time.perf_counter() self.logger.info(f"inference on {len(text_list)} texts ...") result = self.aggregator( joined_text, generation_config=self.aggregator.model.generation_config, )[0]["generated_text"] self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") self.logger.info( f"Input tokens:\t{self.count_tokens(joined_text)}. Output tokens:\t{self.count_tokens(result)}" ) self.logger.debug(f"Generated text:\n{result}") return result def count_tokens(self, text: str) -> int: """count the number of tokens in a text""" return ( len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) if text else 0 )