load("//tensorflow:tensorflow.bzl", "if_google", "if_oss", "tf_cc_test")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")

package(
    default_visibility = [":friends"],
    licenses = ["notice"],
)

package_group(
    name = "friends",
    packages = [
        # copybara:uncomment "//learning/brain/experimental/tfrt/...",
        # copybara:uncomment "//learning/brain/tfrt/...",
        "//tensorflow/c/eager/...",
        "//tensorflow/compiler/mlir/tfrt/...",
        "//tensorflow/core/runtime_fallback/...",
        "//tensorflow/core/tfrt/...",
        "//tensorflow/python/...",
        # copybara:uncomment "//tensorflow_serving/batching/google/...",
        # copybara:uncomment "//tensorflow_serving/servables/tensorflow/google/...",
        # copybara:uncomment "//third_party/tf_runtime_google/...",
    ],
)

cc_library(
    name = "utils",
    srcs = [
        "utils.cc",
    ],
    hdrs = [
        "dtype.def",
        "utils.h",
    ],
    deps = [
        "//tensorflow/core:framework",
        "//tensorflow/core/common_runtime/eager:context",
        "//tensorflow/core/framework:types_proto_cc",
        "//tensorflow/core/lib/gtl:array_slice",
        "//tensorflow/core/platform:status",
        "//tensorflow/core/tfrt/eager:virtual_device",
        "//tensorflow/core/tpu:virtual_device",
        "@tf_runtime//:befexecutor",
        "@tf_runtime//:dtype",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:support",
    ],
)

tf_cc_test(
    name = "utils_test",
    srcs = ["utils_test.cc"],
    deps = [
        ":utils",
        "//tensorflow/core:framework",
        "//tensorflow/core/common_runtime/eager:context",
        "//tensorflow/core/common_runtime/eager:core_no_xla",
        "@com_google_googletest//:gtest_main",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//cpp_tests:common",
    ],
)

cc_library(
    name = "tensor_util",
    srcs = ["tensor_util.cc"],
    hdrs = [
        "dtype.def",
        "tensor_util.h",
    ],
    deps = [
        ":statusor",
        "@com_google_absl//absl/container:inlined_vector",
        "@com_google_absl//absl/strings",
        "//third_party/eigen3",
        "//tensorflow/core/framework:types_proto_cc",
        "//tensorflow/core/framework:tensor_shape",
        "//tensorflow/core/platform:errors",
        "//tensorflow/core/platform:tstring",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_tensor",
        "//tensorflow/core/runtime_fallback/util:tensor_util",
        "//tensorflow/core/runtime_fallback/util:type_util",
        "//tensorflow/stream_executor/lib",
        "@tf_runtime//:core_runtime",
        "@tf_runtime//:dtype",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:support",
        "@tf_runtime//:tensor",
    ] + select({
        "//tensorflow:android": [
            "//tensorflow/core:portable_tensorflow_lib_lite",  # TODO(annarev): exclude runtime srcs
        ],
        "//conditions:default": [
            "//tensorflow/core/framework:tensor",
            "//tensorflow/core/platform:status",
        ],
    }),
)

tf_cc_test(
    name = "tensor_util_test",
    srcs = ["tensor_util_test.cc"],
    tags = ["no_oss"],
    deps = [
        ":tensor_util",
        "//tensorflow/core:framework",
        "//tensorflow/core/runtime_fallback/kernel:kernel_fallback_tensor",
        "@com_google_googletest//:gtest_main",
        "@llvm-project//llvm:Support",
        "@tf_runtime//:core_runtime",
        "@tf_runtime//:dtype",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:support",
        "@tf_runtime//:tensor",
        "@tf_runtime//cpp_tests:common",
    ],
)

cc_library(
    name = "error_util",
    srcs = [
        "error_util.cc",
    ],
    hdrs = [
        "error_type.def",
        "error_util.h",
    ],
    deps = [
        "//tensorflow/core/platform:status",
        "@tf_runtime//:hostcontext",
        "@tf_runtime//:support",
    ],
)

tf_cc_test(
    name = "error_util_test",
    srcs = ["error_util_test.cc"],
    tags = ["no_oss"],
    deps = [
        ":error_util",
        "//tensorflow/core/platform:status",
        "@com_google_googletest//:gtest_main",
        "@tf_runtime//:support",
        "@tf_runtime//cpp_tests:common",
    ],
)

cc_library(
    name = "statusor",
    hdrs = ["statusor.h"],
    deps = ["//tensorflow/stream_executor/lib"],
)

cc_library(
    name = "tfrt_graph_execution_state",
    srcs = ["tfrt_graph_execution_state.cc"],
    hdrs = ["tfrt_graph_execution_state.h"],
    deps = [
        ":statusor",
        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
        "//tensorflow/compiler/mlir/tensorflow:upgrade_graph",
        "//tensorflow/core:core_cpu_base",
        "//tensorflow/core:framework",
        "//tensorflow/core/common_runtime:core_cpu_internal",
        "//tensorflow/core/framework:graph_proto_cc",
        "//tensorflow/core/framework:op_def_proto_cc",
        "//tensorflow/core/framework:versions_proto_cc",
        "//tensorflow/core/grappler:utils",
        "//tensorflow/core/platform:errors",
        "//tensorflow/core/platform:status",
        "//tensorflow/core/protobuf:for_core_protos_cc",
        "//tensorflow/core/tfrt/fallback:fallback_state",
        "@com_google_absl//absl/container:flat_hash_map",
        "@com_google_absl//absl/container:flat_hash_set",
        "@com_google_absl//absl/time",
    ],
)

tf_cc_test(
    name = "tfrt_graph_execution_state_test",
    srcs = ["tfrt_graph_execution_state_test.cc"],
    tags = if_oss([
        "manual",
        "no_oss",
    ]),  # b/169705709, no protobuf matchers in OSS.
    deps = [
        ":tfrt_graph_execution_state",
        "//tensorflow/cc:array_ops",
        "//tensorflow/cc:cc_ops",
        "//tensorflow/cc:const_op",
        "//tensorflow/cc:while_loop",
        "//tensorflow/core:test",
        "//tensorflow/core/grappler/utils:grappler_test",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "model_metadata",
    hdrs = ["model_metadata.h"],
    deps = ["@com_google_absl//absl/base:core_headers"],
)

cc_library(
    name = "fallback_tensor",
    srcs = ["fallback_tensor.cc"],
    hdrs = ["fallback_tensor.h"],
    deps = [
        "//tensorflow/core/common_runtime:dma_helper",
        "//tensorflow/core/framework:tensor",
        "@com_google_absl//absl/types:variant",
    ],
)

tf_cc_test(
    name = "fallback_tensor_test",
    srcs = ["fallback_tensor_test.cc"],
    tags = ["no_oss"],
    deps = [
        ":fallback_tensor",
        "@com_google_googletest//:gtest_main",
    ],
)

cc_library(
    name = "test_util",
    testonly = 1,
    hdrs = ["test_util.h"],
    visibility = ["//visibility:public"],
    deps = [
        "//tensorflow/core/platform:env",
        "//tensorflow/core/platform:threadpool_interface",
    ],
    alwayslink = 1,
)

cc_library(
    name = "bridge_graph_analysis",
    hdrs = ["bridge_graph_analysis.h"],
    visibility = [
        "//tensorflow/core/tfrt/saved_model:__pkg__",
    ],
    deps = if_google([
        "//learning/brain/mlir/bridge:graph_analysis",
        "//tensorflow/core/platform:enable_tf2_utils",
    ]) + [
        "//tensorflow/core/platform:status",
        "//tensorflow/core/protobuf:for_core_protos_cc",
        "//tensorflow/core:core_cpu_base",
    ],
)
