Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ This example load LayoutLMv3 model and confirm its accuracy and speed based on [
```shell
pip install neural-compressor
pip install -r requirements.txt
bash install_layoutlmft.sh
```
> Note: Validated ONNX Runtime [Version](/docs/source/installation_guide.md#validated-software-environment).

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from dataclasses import dataclass, field
from typing import Optional, Union

import torch

from detectron2.structures import ImageList
from detectron2.data.detection_utils import read_image
from detectron2.data.transforms import ResizeTransform, TransformList

from transformers import PreTrainedTokenizerBase
from transformers.file_utils import PaddingStrategy


@dataclass
class DataCollatorForKeyValueExtraction:
"""
Data collator that will dynamically pad the inputs received, as well as the labels.

Args:
tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
The tokenizer used for encoding the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:

* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.

This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (:obj:`int`, `optional`, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
"""

tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100

def __call__(self, features):
label_name = "label" if "label" in features[0].keys() else "labels"
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None

has_image_input = "image" in features[0]
has_bbox_input = "bbox" in features[0]
if has_image_input:
image = ImageList.from_tensors([torch.tensor(feature["image"]) for feature in features], 32)
for feature in features:
del feature["image"]
batch = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
return_tensors="pt" if labels is None else None,
)

if labels is None:
return batch

sequence_length = torch.tensor(batch["input_ids"]).shape[1]
padding_side = self.tokenizer.padding_side
if padding_side == "right":
batch["labels"] = [label + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels]
if has_bbox_input:
batch["bbox"] = [bbox + [[0, 0, 0, 0]] * (sequence_length - len(bbox)) for bbox in batch["bbox"]]
else:
batch["labels"] = [[self.label_pad_token_id] * (sequence_length - len(label)) + label for label in labels]
if has_bbox_input:
batch["bbox"] = [[[0, 0, 0, 0]] * (sequence_length - len(bbox)) + bbox for bbox in batch["bbox"]]

batch = {k: torch.tensor(v, dtype=torch.int64) if isinstance(v[0], list) else v for k, v in batch.items()}
if has_image_input:
batch["image"] = image
return batch

@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""

task_name: Optional[str] = field(default="ner", metadata={"help": "The name of the task (ner, pos...)."})
dataset_name: Optional[str] = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
)
train_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a csv or JSON file)."}
)
validation_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input evaluation data file to evaluate on (a csv or JSON file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to predict on (a csv or JSON file)."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
pad_to_max_length: bool = field(
default=True,
metadata={
"help": "Whether to pad all samples to model maximum sentence length. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
"efficient on GPU but very bad for TPU."
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
},
)
max_val_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)
label_all_tokens: bool = field(
default=False,
metadata={
"help": "Whether to put the label for one word on all tokens of generated by that word or just on the "
"one (in which case the other tokens will have a padding index)."
},
)
return_entity_level_metrics: bool = field(
default=False,
metadata={"help": "Whether to return all the entity levels during evaluation or just the overall ones."},
)


@dataclass
class XFUNDataTrainingArguments(DataTrainingArguments):
lang: Optional[str] = field(default="en")
additional_langs: Optional[str] = field(default=None)

def normalize_bbox(bbox, size):
return [
int(1000 * bbox[0] / size[0]),
int(1000 * bbox[1] / size[1]),
int(1000 * bbox[2] / size[0]),
int(1000 * bbox[3] / size[1]),
]

def load_image(image_path):
image = read_image(image_path, format="BGR")
h = image.shape[0]
w = image.shape[1]
img_trans = TransformList([ResizeTransform(h=h, w=w, new_h=224, new_w=224)])
image = torch.tensor(img_trans.apply_image(image).copy()).permute(2, 0, 1) # copy to make it writeable
return image, (w, h)
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# coding=utf-8

import json
import os

import datasets

from data_utils import load_image, normalize_bbox


logger = datasets.logging.get_logger(__name__)


_CITATION = """\
@article{Jaume2019FUNSDAD,
title={FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents},
author={Guillaume Jaume and H. K. Ekenel and J. Thiran},
journal={2019 International Conference on Document Analysis and Recognition Workshops (ICDARW)},
year={2019},
volume={2},
pages={1-6}
}
"""

_DESCRIPTION = """\
https://guillaumejaume.github.io/FUNSD/
"""


class FunsdConfig(datasets.BuilderConfig):
"""BuilderConfig for FUNSD"""

def __init__(self, **kwargs):
"""BuilderConfig for FUNSD.

Args:
**kwargs: keyword arguments forwarded to super.
"""
super(FunsdConfig, self).__init__(**kwargs)


class Funsd(datasets.GeneratorBasedBuilder):
"""Conll2003 dataset."""

BUILDER_CONFIGS = [
FunsdConfig(name="funsd", version=datasets.Version("1.0.0"), description="FUNSD dataset"),
]

def _info(self):
return datasets.DatasetInfo(
description=_DESCRIPTION,
features=datasets.Features(
{
"id": datasets.Value("string"),
"tokens": datasets.Sequence(datasets.Value("string")),
"bboxes": datasets.Sequence(datasets.Sequence(datasets.Value("int64"))),
"ner_tags": datasets.Sequence(
datasets.features.ClassLabel(
names=["O", "B-HEADER", "I-HEADER", "B-QUESTION", "I-QUESTION", "B-ANSWER", "I-ANSWER"]
)
),
"image": datasets.Array3D(shape=(3, 224, 224), dtype="uint8"),
}
),
supervised_keys=None,
homepage="https://guillaumejaume.github.io/FUNSD/",
citation=_CITATION,
)

def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
downloaded_file = dl_manager.download_and_extract("https://guillaumejaume.github.io/FUNSD/dataset.zip")
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN, gen_kwargs={"filepath": f"{downloaded_file}/dataset/training_data/"}
),
datasets.SplitGenerator(
name=datasets.Split.TEST, gen_kwargs={"filepath": f"{downloaded_file}/dataset/testing_data/"}
),
]

def _generate_examples(self, filepath):
logger.info("⏳ Generating examples from = %s", filepath)
ann_dir = os.path.join(filepath, "annotations")
img_dir = os.path.join(filepath, "images")
for guid, file in enumerate(sorted(os.listdir(ann_dir))):
tokens = []
bboxes = []
ner_tags = []

file_path = os.path.join(ann_dir, file)
with open(file_path, "r", encoding="utf8") as f:
data = json.load(f)
image_path = os.path.join(img_dir, file)
image_path = image_path.replace("json", "png")
image, size = load_image(image_path)
for item in data["form"]:
words, label = item["words"], item["label"]
words = [w for w in words if w["text"].strip() != ""]
if len(words) == 0:
continue
if label == "other":
for w in words:
tokens.append(w["text"])
ner_tags.append("O")
bboxes.append(normalize_bbox(w["box"], size))
else:
tokens.append(words[0]["text"])
ner_tags.append("B-" + label.upper())
bboxes.append(normalize_bbox(words[0]["box"], size))
for w in words[1:]:
tokens.append(w["text"])
ner_tags.append("I-" + label.upper())
bboxes.append(normalize_bbox(w["box"], size))

yield guid, {"id": str(guid), "tokens": tokens, "bboxes": bboxes, "ner_tags": ner_tags, "image": image}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import numpy as np
from datasets import ClassLabel, load_dataset, load_metric

import layoutlmft.data.datasets.funsd
import funsd
import transformers
from layoutlmft.data import DataCollatorForKeyValueExtraction
from layoutlmft.data.data_args import DataTrainingArguments
from layoutlmft.trainers import FunsdTrainer as Trainer
from data_utils import DataCollatorForKeyValueExtraction, DataTrainingArguments
from trainer import FunsdTrainer as Trainer
from transformers import (
AutoConfig,
AutoModelForTokenClassification,
Expand Down Expand Up @@ -184,7 +183,7 @@ def main():
# Set seed before initializing model.
set_seed(training_args.seed)

datasets = load_dataset(os.path.abspath(layoutlmft.data.datasets.funsd.__file__))
datasets = load_dataset(os.path.abspath(funsd.__file__))
if training_args.do_train:
column_names = datasets["train"].column_names
features = datasets["train"].features
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
datasets==1.6.2
transformers==4.6
huggingface-hub==0.0.8
seqeval==1.2.2
tensorboard==2.7.0
accelerate
datasets
transformers
huggingface-hub
seqeval
tensorboard
sentencepiece
timm==0.4.12
timm
Pillow
einops
textdistance
Expand All @@ -17,4 +18,4 @@ onnx
onnxruntime
onnxruntime-extensions; python_version < '3.10'
-f https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.10/index.html
detectron2
detectron2
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Any, Dict, Union

import torch

from transformers import Trainer


class FunsdTrainer(Trainer):
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
"""
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
handling potential state.
"""
for k, v in inputs.items():
if hasattr(v, "to") and hasattr(v, "device"):
inputs[k] = v.to(self.args.device)

if self.args.past_index >= 0 and self._past is not None:
inputs["mems"] = self._past

return inputs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ This example load LayoutLMv3 model and confirm its accuracy and speed based on [
```shell
pip install neural-compressor
pip install -r requirements.txt
bash install_layoutlmft.sh
```
> Note: Validated ONNX Runtime [Version](/docs/source/installation_guide.md#validated-software-environment).

Expand Down
Loading