gistillery / src /gistillery /preprocessing.py
Benjamin Bossan
Use transformers agents where applicable
01ae0bb
raw
history blame
2.99 kB
import abc
import io
import logging
import re
from typing import Optional
import trafilatura
from httpx import Client
from PIL import Image
from gistillery.base import JobInput
from gistillery.tools import get_agent
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
RE_URL = re.compile(r"(https?://[^\s]+)")
def get_url(text: str) -> str | None:
urls: list[str] = list(RE_URL.findall(text))
if len(urls) == 1:
url = urls[0]
return url
return None
class Processor(abc.ABC):
def get_name(self) -> str:
return self.__class__.__name__
def __call__(self, job: JobInput) -> str:
_id = job.id
logger.info(f"Processing {input} with {self.__class__.__name__} (id={_id[:8]})")
result = self.process(job)
logger.info(f"Finished processing input (id={_id[:8]})")
return result
@abc.abstractmethod
def process(self, input: JobInput) -> str:
raise NotImplementedError
@abc.abstractmethod
def match(self, input: JobInput) -> bool:
raise NotImplementedError
class RawTextProcessor(Processor):
def match(self, input: JobInput) -> bool:
return True
def process(self, input: JobInput) -> str:
return input.content.strip()
class DefaultUrlProcessor(Processor):
def __init__(self) -> None:
self.client = Client()
self.url = Optional[str]
self.template = "{url}\n\n{content}"
def match(self, input: JobInput) -> bool:
url = get_url(input.content.strip())
if url is None:
return False
self.url = url
return True
def process(self, input: JobInput) -> str:
"""Get content of website and return it as string"""
if not isinstance(self.url, str):
raise TypeError("self.url must be a string")
text = self.client.get(self.url).text
assert isinstance(text, str)
extracted = trafilatura.extract(text)
text = self.template.format(url=self.url, content=extracted)
return str(text)
class ImageUrlProcessor(Processor):
def __init__(self) -> None:
self.client = Client()
self.url = Optional[str]
self.template = "{url}\n\n{content}"
self.image_suffixes = {'jpg', 'jpeg', 'png', 'gif'}
def match(self, input: JobInput) -> bool:
url = get_url(input.content.strip())
if url is None:
return False
suffix = url.rsplit(".", 1)[-1].lower()
if suffix not in self.image_suffixes:
return False
self.url = url
return True
def process(self, input: JobInput) -> str:
if not isinstance(self.url, str):
raise TypeError("self.url must be a string")
response = self.client.get(self.url)
image = Image.open(io.BytesIO(response.content)).convert('RGB')
caption = get_agent().run("Caption the following image", image=image)
return str(caption)