document-summarization / aggregate.py
pszemraj's picture
✨ upgrade aggregation model
f578dba
"""
aggregate.py - module for 'reducing' multiple 'summary chunks' into one
an overly complicated class for legacy compatibility reasons, for usage of the
2024 map-reduce models see hf.co/pszemraj/bart-large-summary-map-reduce#usage
"""
import logging
import pprint as pp
import time
import torch
from transformers import GenerationConfig, pipeline
# Setting up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class BatchAggregator:
"""
BatchAggregator is a class for aggregating text from multiple sources.
Usage:
from aggregate import BatchAggregator
aggregator = BatchAggregator()
agg = aggregator.infer_aggregate(["This is a test", "This is another test"])
print(agg)
"""
GENERIC_CONFIG = GenerationConfig(
max_new_tokens=512,
num_beams=4,
early_stopping=True,
do_sample=False,
truncation=True,
)
def __init__(
self,
model_name: str = "pszemraj/bart-large-summary-map-reduce",
force_cpu: bool = False,
**kwargs,
):
"""
__init__ initializes the BatchAggregator class.
:param str model_name: model name to use, default: "pszemraj/bart-large-summary-map-reduce"
:param bool force_cpu: force the model to run on CPU, default: False
"""
self.device = None
self.is_compiled = False
self.model_name = None
self.aggregator = None
self.force_cpu = force_cpu
self.logger = logging.getLogger(__name__)
self.init_model(model_name)
def init_model(self, model_name: str) -> None:
"""
Initialize the model.
:param model_name: The name of the model to use.
"""
# Free up memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.logger.info(f"Setting model to {model_name}")
self.model_name = model_name
self.aggregator = self._create_pipeline(model_name)
self._configure_model()
def _create_pipeline(
self, model_name: str = "pszemraj/bart-large-summary-map-reduce"
) -> pipeline:
"""
_create_pipeline creates a pipeline for the model.
:param str model_name: model name to use
:return pipeline: the pipeline for the model
:raises Exception: if the pipeline cannot be created
"""
device_map = (
"auto" if torch.cuda.is_available() and not self.force_cpu else "cpu"
)
try:
self.logger.info(
f"Creating pipeline with model {model_name} on device {device_map}"
)
return pipeline(
"text2text-generation",
model=model_name,
device_map=device_map,
torch_dtype=torch.float32,
)
except Exception as e:
self.logger.error(f"Failed to create pipeline: {e}")
raise
def _configure_model(self):
"""
Configure the model for generation.
"""
try:
self.aggregator.model = torch.compile(self.aggregator.model)
self.is_compiled = True
except Exception as e:
self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
self._set_default_generation_config()
self.logger.info(self.aggregator.model.generation_config.to_json_string())
def _set_default_generation_config(self):
"""
Set the default generation configuration for the model.
"""
self.aggregator.model.generation_config.update(
**self.GENERIC_CONFIG.to_diff_dict()
)
def update_generation_config(self, **kwargs):
"""
Update the generation configuration with the specified parameters.
Args:
**kwargs: The parameters to update in the generation configuration.
"""
self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
self.aggregator.model.generation_config.update(**kwargs)
def get_generation_config(self) -> dict:
"""
Get the current generation configuration.
Returns:
dict: The current generation configuration.
"""
return self.aggregator.model.generation_config.to_dict()
def update_loglevel(self, level: str = "INFO"):
"""
Update the log level.
Args:
level (str): The log level to set. Defaults to "INFO".
"""
self.logger.setLevel(level)
def infer_aggregate(
self,
text_list: list,
instruction: str = None, # Kept for backward compatibility but not used
**kwargs,
) -> str:
"""
infer_aggregate - infers a consolidated summary from a list of texts.
Args:
text_list (list): The texts to summarize.
instruction (str): Not used by this model, kept for compatibility.
**kwargs: Additional parameters to update in the generation configuration.
Returns:
The generated summary.
"""
joined_text = "\n\n".join(text_list)
if kwargs:
self.update_generation_config(**kwargs)
st = time.perf_counter()
self.logger.info(f"inference on {len(text_list)} texts ...")
result = self.aggregator(
joined_text,
generation_config=self.aggregator.model.generation_config,
)[0]["generated_text"]
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
self.logger.info(
f"Input tokens:\t{self.count_tokens(joined_text)}. Output tokens:\t{self.count_tokens(result)}"
)
self.logger.debug(f"Generated text:\n{result}")
return result
def count_tokens(self, text: str) -> int:
"""count the number of tokens in a text"""
return (
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
if text
else 0
)