Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions codegen/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
load("@tflm_pip_deps//:requirements.bzl", "requirement")

package(
default_visibility = ["//:__subpackages__"],
licenses = ["notice"],
)

py_library(
name = "inference_generator",
srcs = [
"inference_generator.py",
],
data = [
"templates/inference.cc.mako",
"templates/inference.h.mako",
],
deps = [
requirement("mako"),
],
)

py_binary(
name = "code_generator",
srcs = [
"code_generator.py",
],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":inference_generator",
"@absl_py//absl:app",
"@absl_py//absl/flags",
],
)
3 changes: 3 additions & 0 deletions codegen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TFLM Code Generator

This is a work in progress experiment. It is not ready for use.
58 changes: 58 additions & 0 deletions codegen/code_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Generates C/C++ source code capable of performing inference for a model. """

import os

from absl import app
from absl import flags
from collections.abc import Sequence

from tflite_micro.codegen import inference_generator

# Usage information:
# Default:
# `bazel run codegen:code_generator -- --model=</path/to/my_model.tflite>`
# Output will be located at: /path/to/my_model.h|cc

_MODEL_PATH = flags.DEFINE_string(name="model",
default=None,
help="Path to the TFLite model file.",
required=True)

_OUTPUT_DIR = flags.DEFINE_string(
name="output_dir",
default=None,
help="Path to write generated source to. Leave blank to use 'model' path.",
required=False)

_OUTPUT_NAME = flags.DEFINE_string(
name="output_name",
default=None,
help=("The output basename for the generated .h/.cc. Leave blank to use "
"'model' basename."),
required=False)


def main(argv: Sequence[str]) -> None:
output_dir = _OUTPUT_DIR.value or os.path.dirname(_MODEL_PATH.value)
output_name = _OUTPUT_NAME.value or os.path.splitext(
os.path.basename(_MODEL_PATH.value))[0]

inference_generator.generate(output_dir, output_name)


if __name__ == "__main__":
app.run(main)
61 changes: 61 additions & 0 deletions codegen/inference_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
""" Generates C/C++ inference source code. """

import pathlib

from mako import template
from typing import TypedDict

_TEMPLATE_DIR = pathlib.Path(__file__).parent / 'templates'
_HEADER_TEMPLATE = _TEMPLATE_DIR / 'inference.h.mako'
_SOURCE_TEMPLATE = _TEMPLATE_DIR / 'inference.cc.mako'


class ModelData(TypedDict):
header_file: str
model_name: str


def _render(output_file: pathlib.Path, template_file: pathlib.Path,
model_data: ModelData) -> None:
print("Generating {}".format(output_file))
t = template.Template(filename=str(template_file))
with output_file.open('w+') as file:
file.write(t.render(**model_data))


def _generate_header(header_path: pathlib.Path, model_data: ModelData) -> None:
_render(header_path, _HEADER_TEMPLATE, model_data)


def _generate_source(source_path: pathlib.Path, model_data: ModelData) -> None:
_render(source_path, _SOURCE_TEMPLATE, model_data)


def generate(output_dir: str, output_name: str) -> None:
""" Generate C/C++ inference code. """
header_file = f"{output_name}.h"
model_data: ModelData = {
'header_file': header_file,
'model_name': output_name
}

# Ensure output directory exists
output_path = pathlib.Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)

_generate_header(output_path / header_file, model_data)
_generate_source(output_path / f"{output_name}.cc", model_data)
24 changes: 24 additions & 0 deletions codegen/templates/inference.cc.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

/* AUTOMATICALLY GENERATED DO NOT MODIFY */

#include "${header_file}"

namespace ${model_name} {

void Invoke() {}

} // ${model_name}
24 changes: 24 additions & 0 deletions codegen/templates/inference.h.mako
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

/* AUTOMATICALLY GENERATED DO NOT MODIFY */

#pragma once

namespace ${model_name} {

void Invoke();

} // ${model_name}