diff --git a/tests/test_dataproc_retry.py b/tests/test_dataproc_retry.py new file mode 100644 index 00000000..117af9c0 --- /dev/null +++ b/tests/test_dataproc_retry.py @@ -0,0 +1,189 @@ +import time +from unittest.mock import Mock, patch + +import grpc +import pytest + +import yandexcloud._operation_waiter as waiter_module +from yandex.cloud.dataproc.v1.cluster_service_pb2_grpc import ClusterServiceStub +from yandex.cloud.operation.operation_pb2 import Operation +from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub +from yandexcloud._backoff import backoff_exponential_jittered_min_interval +from yandexcloud._wrappers.dataproc import ( + Dataproc, + DataprocRetryInterceptor, + InterceptorSettings, +) + + +class MockOperationService: + def __init__(self, fail_attempts=0, fail_code=grpc.StatusCode.CANCELLED): + self.fail_attempts = fail_attempts + self.fail_code = fail_code + self.call_count = 0 + + def Get(self, request): + self.call_count += 1 + if self.call_count <= self.fail_attempts: + error = grpc.RpcError() + error._state = Mock() + error._state.code = self.fail_code + error.code = lambda: self.fail_code + raise error + return Operation(id="test-op", done=True) + + +class MockClusterService: + def Create(self, request): + return Operation(id="test-cluster-op", done=False) + + +class MockSDK: + def __init__(self): + self.client_calls = [] + + def client(self, service_class, interceptor=None): + self.client_calls.append((service_class, interceptor)) + if service_class == ClusterServiceStub: + return MockClusterService() + elif service_class == OperationServiceStub: + return MockOperationService() + + def create_operation_and_get_result(self, request, service, method_name, response_type, meta_type): + operation = Operation(id="test-op", done=False) + waiter = waiter_module.operation_waiter(self, operation.id, None) + for _ in waiter: + time.sleep(0.01) + return MockOperationResult() + + +class MockOperationResult: + def __init__(self): + self.response = Mock() + self.response.id = "test-cluster-id" + + +@pytest.fixture +def mock_sdk(): + return MockSDK() + + +def test_dataproc_custom_interceptor_max_attempts(mock_sdk): + dataproc = Dataproc( + sdk=mock_sdk, + interceptor_settings=InterceptorSettings( + max_retry_count=50, + retriable_codes=( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.INTERNAL, + grpc.StatusCode.CANCELLED, + ), + back_off_func=backoff_exponential_jittered_min_interval(), + ), + ) + + mock_operation_service = MockOperationService(fail_attempts=51, fail_code=grpc.StatusCode.CANCELLED) + + with patch.object(waiter_module, "operation_waiter") as mock_waiter_fn: + mock_waiter = Mock() + mock_waiter.__iter__ = Mock(return_value=iter([])) + mock_waiter.operation = Operation(id="test", done=True) + mock_waiter_fn.return_value = mock_waiter + + with patch.object(mock_sdk, "client") as mock_client: + mock_client.return_value = mock_operation_service + + with pytest.raises(grpc.RpcError) as exc_info: + dataproc.create_cluster( + folder_id="test-folder", + cluster_name="test-cluster", + subnet_id="test-subnet", + service_account_id="test-sa", + ssh_public_keys="test-ssh-key", + ) + + assert exc_info.value.code() == grpc.StatusCode.CANCELLED + assert mock_operation_service.call_count <= 50 + + +def test_dataproc_interceptor_inheritance(): + interceptor = DataprocRetryInterceptor( + max_retry_count=10, retriable_codes=(grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE) + ) + + assert interceptor._RetryInterceptor__is_retriable(grpc.StatusCode.CANCELLED) == True + + assert interceptor._RetryInterceptor__is_retriable(grpc.StatusCode.UNAVAILABLE) == True + + assert interceptor._RetryInterceptor__is_retriable(grpc.StatusCode.PERMISSION_DENIED) == False + + +def test_dataproc_monkey_patch_restoration(): + mock_sdk = Mock() + original_waiter = waiter_module.operation_waiter + + dataproc = Dataproc( + sdk=mock_sdk, + interceptor_settings=InterceptorSettings( + max_retry_count=50, + retriable_codes=( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.INTERNAL, + grpc.StatusCode.CANCELLED, + ), + back_off_func=backoff_exponential_jittered_min_interval(), + ), + ) + + with patch.object(mock_sdk, "create_operation_and_get_result") as mock_create: + mock_create.return_value = MockOperationResult() + + result = dataproc.delete_cluster(cluster_id="test-cluster-id") + + assert result is not None + + assert waiter_module.operation_waiter == original_waiter + + +def test_dataproc_all_methods_use_wrapper(mock_sdk): + dataproc = Dataproc( + sdk=mock_sdk, + interceptor_settings=InterceptorSettings( + max_retry_count=50, + retriable_codes=( + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.RESOURCE_EXHAUSTED, + grpc.StatusCode.INTERNAL, + grpc.StatusCode.CANCELLED, + ), + back_off_func=backoff_exponential_jittered_min_interval(), + ), + ) + + methods_to_test = [ + ( + "create_cluster", + { + "folder_id": "test", + "cluster_name": "test", + "subnet_id": "test", + "service_account_id": "test", + "ssh_public_keys": "test-ssh-key", + }, + ), + ("delete_cluster", {"cluster_id": "test"}), + ("stop_cluster", {"cluster_id": "test"}), + ("start_cluster", {"cluster_id": "test"}), + ] + + with patch.object(dataproc, "_with_dataproc_waiter") as mock_wrapper: + mock_wrapper.return_value = MockOperationResult() + + for method_name, kwargs in methods_to_test: + method = getattr(dataproc, method_name) + method(**kwargs) + + assert mock_wrapper.called + mock_wrapper.reset_mock() diff --git a/yandexcloud/_wrappers/dataproc/__init__.py b/yandexcloud/_wrappers/dataproc/__init__.py index 6b39709e..6bcfd4cc 100644 --- a/yandexcloud/_wrappers/dataproc/__init__.py +++ b/yandexcloud/_wrappers/dataproc/__init__.py @@ -2,8 +2,9 @@ # mypy: ignore-errors import logging import random -from typing import Iterable, NamedTuple +from typing import Callable, Iterable, NamedTuple +import grpc from google.protobuf.field_mask_pb2 import FieldMask import yandex.cloud.dataproc.v1.cluster_pb2 as cluster_pb @@ -16,6 +17,9 @@ import yandex.cloud.dataproc.v1.subcluster_pb2 as subcluster_pb import yandex.cloud.dataproc.v1.subcluster_service_pb2 as subcluster_service_pb import yandex.cloud.dataproc.v1.subcluster_service_pb2_grpc as subcluster_service_grpc_pb +import yandexcloud._operation_waiter as waiter_module +from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub +from yandexcloud._retry_interceptor import RetryInterceptor class InitializationAction(NamedTuple): @@ -31,6 +35,34 @@ def to_grpc(self): ) +class InterceptorSettings(NamedTuple): + max_retry_count: int # Maximum number of retries + retriable_codes: Iterable[grpc.StatusCode] # Retriable error codes + back_off_func: Callable[[int], float] # Backoff function + + +class DataprocRetryInterceptor(RetryInterceptor): + # pylint: disable-next=invalid-name + def _RetryInterceptor__is_retriable(self, error: grpc.StatusCode) -> bool: + if error in self._RetryInterceptor__retriable_codes: + return True + + return False + + +def create_custom_operation_waiter(interceptor_settings: InterceptorSettings): + def custom_operation_waiter(sdk, operation_id, timeout): + retry_interceptor = DataprocRetryInterceptor( + max_retry_count=interceptor_settings.max_retry_count, + retriable_codes=interceptor_settings.retriable_codes, + back_off_func=interceptor_settings.back_off_func, + ) + operation_service = sdk.client(OperationServiceStub, interceptor=retry_interceptor) + return waiter_module.OperationWaiter(operation_id, operation_service, timeout) + + return custom_operation_waiter + + class Dataproc: """ A base hook for Yandex.Cloud Data Proc. @@ -43,9 +75,18 @@ class Dataproc: :type logger: Optional[logging.Logger] :param sdk: SDK object. Normally is being set by Wrappers constructor :type sdk: yandexcloud.SDK + :param interceptor_settings: Settings For Custom Dataproc Interceptor + :type interceptor_settings: Optional[InterceptorSettings] """ - def __init__(self, default_folder_id=None, default_public_ssh_key=None, logger=None, sdk=None): + def __init__( + self, + default_folder_id=None, + default_public_ssh_key=None, + logger=None, + sdk=None, + interceptor_settings=None, + ): self.sdk = sdk or self.sdk self.log = logger if not self.log: @@ -55,6 +96,21 @@ def __init__(self, default_folder_id=None, default_public_ssh_key=None, logger=N self.subnet_id = None self.default_folder_id = default_folder_id self.default_public_ssh_key = default_public_ssh_key + self._custom_operation_waiter = None + if interceptor_settings: + self._custom_operation_waiter = create_custom_operation_waiter(interceptor_settings) + + def _with_dataproc_waiter(self, func, *args, **kwargs): + if not self._custom_operation_waiter: + return func(*args, **kwargs) + + original_waiter = waiter_module.operation_waiter + waiter_module.operation_waiter = self._custom_operation_waiter + + try: + return func(*args, **kwargs) + finally: + waiter_module.operation_waiter = original_waiter def create_cluster( self, @@ -313,7 +369,8 @@ def create_cluster( log_group_id=log_group_id, labels=labels, ) - result = self.sdk.create_operation_and_get_result( + result = self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=cluster_service_grpc_pb.ClusterServiceStub, method_name="Create", @@ -427,7 +484,8 @@ def create_subcluster( hosts_count=hosts_count, autoscaling_config=autoscaling_config, ) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=subcluster_service_grpc_pb.SubclusterServiceStub, method_name="Create", @@ -455,7 +513,8 @@ def update_cluster_description(self, description, cluster_id=None): update_mask=mask, description=description, ) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=cluster_service_grpc_pb.ClusterServiceStub, method_name="Update", @@ -475,7 +534,8 @@ def delete_cluster(self, cluster_id=None): self.log.info("Deleting cluster %s", cluster_id) request = cluster_service_pb.DeleteClusterRequest(cluster_id=cluster_id) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=cluster_service_grpc_pb.ClusterServiceStub, method_name="Delete", @@ -497,7 +557,8 @@ def stop_cluster(self, cluster_id=None, decommission_timeout=0): request = cluster_service_pb.StopClusterRequest( cluster_id=cluster_id, decommission_timeout=decommission_timeout ) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=cluster_service_grpc_pb.ClusterServiceStub, method_name="Stop", @@ -514,7 +575,8 @@ def start_cluster(self, cluster_id=None): if not cluster_id: raise RuntimeError("Cluster id must be specified.") request = cluster_service_pb.StartClusterRequest(cluster_id=cluster_id) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=cluster_service_grpc_pb.ClusterServiceStub, method_name="Start", @@ -575,7 +637,8 @@ def create_hive_job( name=name, hive_job=hive_job, ) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=job_service_grpc_pb.JobServiceStub, method_name="Create", @@ -637,7 +700,8 @@ def create_mapreduce_job( properties=properties, ), ) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=job_service_grpc_pb.JobServiceStub, method_name="Create", @@ -712,7 +776,8 @@ def create_spark_job( exclude_packages=exclude_packages, ), ) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=job_service_grpc_pb.JobServiceStub, method_name="Create", @@ -786,7 +851,8 @@ def create_pyspark_job( exclude_packages=exclude_packages, ), ) - return self.sdk.create_operation_and_get_result( + return self._with_dataproc_waiter( + self.sdk.create_operation_and_get_result, request, service=job_service_grpc_pb.JobServiceStub, method_name="Create",