Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Rafael Stahl authored and PhilippvK committed Jan 23, 2024
1 parent 33562e7 commit 1284f25
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 40 deletions.
49 changes: 25 additions & 24 deletions python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,9 @@
def add_compile_parser(subparsers, main_parser, json_params, argv):
"""Include parser for 'compile' subcommand"""

parser = subparsers.add_parser("compile", help="compile a model.", add_help=False)
parser = subparsers.add_parser("compile", help="compile a model.")
parser.set_defaults(func=drive_compile)

parser.add_argument(
"--experimental-tvm-extension",
default=[],
action="append",
help="path from which to load packages named tvm_extension which implement the "
"TVMExtension interface.",
)
disposable_parser = TVMCSuppressedArgumentParser(main_parser)
try:
known_args, _ = disposable_parser.parse_known_args(argv)
except TVMCException:
pass
try:
ext_dirs = known_args.experimental_tvm_extension
except AttributeError:
ext_dirs = []
_handle_extensions(ext_dirs)

parser.add_argument(
"--cross-compiler",
default="",
Expand Down Expand Up @@ -135,16 +117,13 @@ def add_compile_parser(subparsers, main_parser, json_params, argv):
"e.g. '--pass-config tir.add_lower_pass=opt_level1,pass1,opt_level2,pass2'.",
)

generate_target_args(parser)
parser.add_argument(
"--tuning-records",
metavar="PATH",
default="",
help="path to an auto-tuning log file by AutoTVM. If not presented, "
"the fallback/tophub configs will be used.",
)
generate_registry_args(parser, Executor, "graph")
generate_registry_args(parser, Runtime, "cpp")

parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity.")
# TODO (@leandron) This is a path to a physical file, but
Expand Down Expand Up @@ -198,13 +177,35 @@ def add_compile_parser(subparsers, main_parser, json_params, argv):
for one_entry in json_params:
parser.set_defaults(**one_entry)

parser.add_argument(
"--experimental-tvmc-extension",
default=[],
action="append",
help="path from which to load packages named tvmc_extension which implement the "
"TVMCExtension interface.",
)
disposable_parser = TVMCSuppressedArgumentParser(main_parser)
try:
known_args, _ = disposable_parser.parse_known_args(argv)
except TVMCException:
known_args = None
try:
ext_dirs = known_args.experimental_tvmc_extension
except AttributeError:
ext_dirs = []
_handle_extensions(ext_dirs)

generate_target_args(parser)
generate_registry_args(parser, Executor, "graph")
generate_registry_args(parser, Runtime, "cpp")

generate_workspace_pools_args(parser)


def _handle_extensions(extra_paths):
extension_paths = extra_paths
if os.environ.get("TVM_EXTENSION_DIR", None):
extension_paths.append(os.environ["TVM_EXTENSION_DIR"])
if os.environ.get("TVMC_EXTENSION_DIR", None):
extension_paths.append(os.environ["TVMC_EXTENSION_DIR"])

load_extensions(extension_paths)
for ext in get_extensions():
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/driver/tvmc/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
_EXTENSIONS = []


class TVMExtension(object):
class TVMCExtension(object):
@abstractmethod
def uma_backends(self):
return []
Expand All @@ -46,16 +46,16 @@ def load_extensions(paths):
"""
Loads extensions from the given locations.
Extensions must implement the `TVMExtension` interface and be stored in a directory called
`tvm_extension`.
Extensions must implement the `TVMCExtension` interface and be stored in a directory called
`tvmc_extension`.
"""

path_backup = copy.copy(sys.path)
sys.path.extend(paths)

top_modules = []
try:
mod = importlib.import_module("tvm_extension")
mod = importlib.import_module("tvmc_extension")
top_modules.append(mod)
except ImportError:
pass
Expand Down Expand Up @@ -125,4 +125,4 @@ def _scan_all(top_level):


def _is_concrete_extension_type(obj):
return inspect.isclass(obj) and issubclass(obj, TVMExtension) and not inspect.isabstract(obj)
return inspect.isclass(obj) and issubclass(obj, TVMCExtension) and not inspect.isabstract(obj)
2 changes: 1 addition & 1 deletion python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,4 +666,4 @@ def export_model_library_format(

_make_tar(tempdir.path, file_name, modules)

return file_name
return str(file_name)
4 changes: 3 additions & 1 deletion python/tvm/relay/backend/contrib/uma/api/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def _get_tensors(te_cached_func):
)

compiler_attr = relay_prim_func.attrs["Compiler"]
target = tvm.target.Target(compiler_attr)
target = tvm.target.Target.current()
if target is None or target.kind.name != compiler_attr:
target = tvm.target.Target(compiler_attr)

tir_prim_func = tir_prim_func.with_attr("target", target)
tir_prim_func = tir_prim_func.with_attr("relay_attrs", relay_prim_func.attrs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import nonexistingmodule

nonexistingmodule.value = 1
35 changes: 28 additions & 7 deletions tests/python/contrib/test_uma/test_tvmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
pytest.importorskip("tensorflow")

import os
from unittest import mock
import sys
import tvm
from tensorflow import keras
from tvm.relay.backend.contrib.uma import uma_available
Expand All @@ -29,7 +29,16 @@
pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available")


def test_conv2d(tmpdir_factory):
def run_test(tmpdir_factory, ext_dir_name, check_relay=True):
if "tvmc_extension" in sys.modules:
del sys.modules["tvmc_extension"]
from tvm.driver.tvmc.extensions import _EXTENSIONS

_EXTENSIONS.clear()
from tvm.driver.tvmc.composite_target import REGISTERED_CODEGEN

REGISTERED_CODEGEN.clear()

tmpdir = tmpdir_factory.mktemp("data")
model_path = os.path.join(tmpdir, "model.h5")
package_path = os.path.join(tmpdir, "out.tar")
Expand All @@ -38,20 +47,32 @@ def test_conv2d(tmpdir_factory):
[
keras.layers.InputLayer(input_shape=[10, 10, 3], batch_size=1),
keras.layers.Conv2D(5, kernel_size=(3, 3)),
keras.layers.Activation("relu"),
]
)
model.save(model_path)

extension_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "vanilla_ext")
extension_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), ext_dir_name)
compile_str = (
f"tvmc compile --target vanilla_accelerator,c -f mlf "
f"--experimental-tvm-extension={extension_dir} "
f"--desired-layout NCHW "
f"--output={package_path} {model_path}"
f"--experimental-tvmc-extension {extension_dir} "
f"--desired-layout NCHW --dump-code relay "
f"--output {package_path} {model_path}"
)
compile_args = compile_str.split(" ")[1:]
assert _main(compile_args) == 0
if check_relay:
with open(package_path + ".relay") as f:
assert 'Compiler="vanilla_accelerator"' in f.read()


def test_conv2d(tmpdir_factory):
run_test(tmpdir_factory, "vanilla_ext")


def test_invalid_ext(tmpdir_factory):
with pytest.warns(UserWarning):
with pytest.raises(RuntimeError):
run_test(tmpdir_factory, "invalid_ext", check_relay=False)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.

from tvm.driver.tvmc.extensions import TVMExtension
from tvm.driver.tvmc.extensions import TVMCExtension
from tests.python.contrib.test_uma.test_uma_vanilla_accelerator import VanillaAcceleratorBackend


class VanillaExtension(TVMExtension):
class VanillaExtension(TVMCExtension):
def __init__(self):
self.backend = VanillaAcceleratorBackend()

Expand Down

0 comments on commit 1284f25

Please sign in to comment.