load("//tensorflow:tensorflow.bzl", "if_google", "tf_cc_binary")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")

package_group(
    name = "internal",
    packages = [
        "//tensorflow/core/runtime_fallback/...",
    ] + if_google([
        "//learning/brain/experimental/mlir/tflite/tfmrt/...",
        "//learning/brain/experimental/tfrt/...",
        "//learning/brain/mobile/lite/...",
    ]),
)

package(
    default_visibility = [":internal"],
    licenses = ["notice"],
)

# This build target contains fallback kernels only. Some of the native TFRT
# ops/kernels (e.g. eigen based matmul) can be expensive to build, but they are
# not needed for fallback testing.
tf_cc_binary(
    name = "tf_bef_executor",
    testonly = True,
    tags = ["no_oss"],
    deps = [
        ":bef_executor_lib",
        "//tensorflow/core/platform:stream_executor",
        "@com_google_absl//absl/strings",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_kernels_alwayslink",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_op_handler",
        # copybara:uncomment "//tensorflow/core/runtime_fallback/test:forwarding_test_kernels",
        # copybara:uncomment "//tensorflow/core/runtime_fallback/test:tfrt_forwarding_kernels_alwayslink",
        "//tensorflow/core/runtime_fallback/conversion:conversion_alwayslink",
        "//tensorflow/compiler/mlir/tfrt:tf_cpurt_kernels_alwayslink",
        "@tf_runtime//:basic_kernels_alwayslink",
        "@tf_runtime//:core_runtime_alwayslink",
        "@tf_runtime//:hostcontext_alwayslink",
        "@tf_runtime//:tensor_alwayslink",
        "@tf_runtime//:test_kernels_alwayslink",
        "@tf_runtime//:data_alwayslink",
        # copybara:uncomment "@tf_runtime//backends/cpu:proto_alwayslink",
        # copybara:uncomment "@tf_runtime//backends/cpu:image_alwayslink",
        "@tf_runtime//backends/cpu:core_runtime_alwayslink",
        "@tf_runtime//backends/cpu:test_ops_alwayslink",
        "@tf_runtime//backends/cpu:cpurt_corert_kernels_alwayslink",
    ] + select({
        "//tensorflow:android": [
            "//tensorflow/core:portable_tensorflow_lib_lite",  # TODO(annarev): exclude runtime srcs
        ],
        "//conditions:default": [
            "//tensorflow/core:all_kernels",
        ],
    }) + if_cuda([
        "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_gpu_alwayslink",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_gpu_alwayslink",
    ]) + select({
        "@tf_runtime//:gpu_enabled": [
            "@tf_runtime//backends/gpu:gpu_op_handler_alwayslink",
            "@tf_runtime//backends/gpu:gpu_test_ops_alwayslink",
        ],
        "//conditions:default": [],
    }),
)

cc_library(
    name = "bef_executor_flags",
    testonly = True,
    srcs = ["bef_executor_flags.cc"],
    hdrs = ["bef_executor_flags.h"],
    visibility = ["//third_party/tf_runtime_google:__pkg__"],
    deps = [
        "@com_google_absl//absl/flags:flag",
        "@com_google_absl//absl/strings",
        "@tf_runtime//:bef_executor_driver",
    ],
)

cc_library(
    name = "bef_executor_lib",
    testonly = True,
    srcs = [
        "tf_bef_executor_main.cc",
    ],
    tags = ["no_oss"],
    visibility = [
        ":internal",
        "//tensorflow/core/tfrt/eager:__subpackages__",
    ],
    deps = if_google([
        "//tensorflow/core/runtime_fallback/test:test_tf_opkernels_alwayslink",
        "//tensorflow/core/runtime_fallback/test:test_tfrt_kernels_alwayslink",
        "//third_party/tf_runtime_google:xprof_tracing_sink_alwayslink",
    ]) + [
        ":bef_executor_flags",
        "@com_google_absl//absl/flags:flag",
        "@com_google_absl//absl/strings",
        "@llvm-project//llvm:Support",
        "//tensorflow/core/runtime_fallback/runtime:runtime_fallback_alwayslink",
        "//tensorflow/core/runtime_fallback/util:fallback_test_util",
        "@tf_runtime//:bef_executor_driver",
        "@tf_runtime//:hostcontext_alwayslink",
        "@tf_runtime//:io_alwayslink",
    ] + select({
        "//tensorflow:android": [
            "//tensorflow/core:portable_tensorflow_lib_lite",  # TODO(annarev): exclude runtime srcs
        ],
        "//conditions:default": [
            "//tensorflow/core/platform:platform_port",
        ],
    }),
)

tf_cc_binary(
    name = "tfrt_fallback_opt",
    srcs = [
        "tfrt_fallback_opt.cc",
    ],
    deps = [
        ":tfrt_fallback_async_opdefs",
        ":tfrt_fallback_opdefs",
        ":tfrt_fallback_sync_opdefs",
        "@llvm-project//mlir:MlirOptLib",
        "@tf_runtime//:init_tfrt_dialects",
    ],
)

tf_cc_binary(
    name = "tfrt_fallback_translate",
    srcs = [
        "tfrt_fallback_translate_registration.cc",
    ],
    deps = [
        ":tfrt_fallback_async_opdefs",
        ":tfrt_fallback_opdefs",
        ":tfrt_fallback_registration",
        ":tfrt_fallback_sync_opdefs",
        "@llvm-project//mlir:Translation",
        "//tensorflow/compiler/mlir/tfrt:tf_cpurt_registration",
        "@tf_runtime//:init_tfrt_dialects",
        "@tf_runtime//:mlirtobef_translate",
    ] + if_google(
        ["//third_party/tf_runtime_llvm:tfrt_translate_main"],
        ["@tf_runtime//third_party/llvm_derived:tfrt_translate_main"],
    ),
)

cc_library(
    name = "tfrt_fallback_registration",
    srcs = [
        "tfrt_fallback_registration.cc",
    ],
    hdrs = [
        "tfrt_fallback_registration.h",
    ],
    visibility = if_google([
        "//learning/brain/experimental/tfrt/distributed_runtime:__pkg__",
        "//learning/brain/experimental/tfrt/visualization:__pkg__",
        # Allow visibility from the mlir language server.
        "//learning/brain/mlir/mlir_lsp_server:__pkg__",
    ]),
    deps = [
        ":tfrt_fallback_async_opdefs",
        ":tfrt_fallback_opdefs",
        ":tfrt_fallback_sync_opdefs",
        "@llvm-project//mlir:IR",
    ],
)

alias(
    name = "tfrt_fallback_async_opdefs",
    actual = "//tensorflow/core/runtime_fallback/opdefs:tfrt_fallback_async_opdefs",
    visibility = [
        ":internal",
        "//tensorflow/compiler/mlir/tfrt:__subpackages__",
        "//tensorflow/core/tfrt/saved_model:__pkg__",
    ],
)

alias(
    name = "tfrt_fallback_opdefs",
    actual = "//tensorflow/core/runtime_fallback/opdefs:tfrt_fallback_opdefs",
    visibility = [
        ":internal",
        "//tensorflow/compiler/mlir/tfrt:__subpackages__",
        "//tensorflow/core/tfrt/saved_model:__pkg__",
        "//third_party/tf_runtime_google:__subpackages__",
    ],
)

alias(
    name = "tfrt_fallback_sync_opdefs",
    actual = "//tensorflow/core/runtime_fallback/opdefs:tfrt_fallback_sync_opdefs",
)
