|
--- |
|
library_name: transformers |
|
license: mit |
|
datasets: |
|
- SpursgoZmy/MMTab |
|
- apoidea/pubtabnet-html |
|
language: |
|
- en |
|
base_model: google/pix2struct-base |
|
--- |
|
|
|
# pix2struct-base-table2html |
|
|
|
*Turn table images into HTML!* |
|
|
|
|
|
## Demo app |
|
|
|
Try the [demo app]() which contains both table detection and recognition! |
|
|
|
|
|
## About |
|
|
|
This model takes an image of a table and outputs HTML - the model parses the image and performs optical character recognition (OCR) and structure recognition to HTML format. |
|
|
|
The model expects an image containing only a table. If the table is embedded in a document, first use a table detection model to extract it. |
|
|
|
The model is finetuned from [Pix2Struct base model](https://huggingface.co/google/pix2struct-base) using a max_patch_length of 1024 and max generation length of 1024. The max_patch_length should likely not be changed for inference but the generation length can be changed. |
|
|
|
The model has been trained using two datasets: [MMTab](https://huggingface.co/datasets/SpursgoZmy/MMTab) and [PubTabNet](https://huggingface.co/datasets/apoidea/pubtabnet-html). |
|
|
|
## Usage |
|
|
|
Below is a complete example of loading the model and performing inference on an example table image (example from the [MMTab dataset](https://huggingface.co/datasets/SpursgoZmy/MMTab)): |
|
|
|
```python |
|
import torch |
|
from transformers import AutoProcessor, Pix2StructForConditionalGeneration |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
|
|
# Load model and processor |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
processor = AutoProcessor.from_pretrained("pix2struct-base-table2html") |
|
model = Pix2StructForConditionalGeneration.from_pretrained("pix2struct-base-table2html") |
|
model.to(device) |
|
model.eval() |
|
|
|
# Load example image from URL |
|
url = "https://example.com/path_to_table_image.jpg" |
|
response = requests.get(url) |
|
image = Image.open(BytesIO(response.content)) |
|
|
|
# Run model inference |
|
encoding = processor(image, return_tensors="pt", max_patches=1024) |
|
with torch.inference_mode(): |
|
flattened_patches = encoding.pop("flattened_patches").to(device) |
|
attention_mask = encoding.pop("attention_mask").to(device) |
|
predictions = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask, max_new_tokens=1024) |
|
|
|
predictions_decoded = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True) |
|
|
|
# Show predictions as text |
|
print(predictions_decoded[0]) |
|
``` |
|
|