这是indexloc提供的服务,不要输入任何密码
Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions tests/test_dataproc_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import time
from unittest.mock import Mock, patch

import grpc
import pytest

import yandexcloud._operation_waiter as waiter_module
from yandex.cloud.dataproc.v1.cluster_service_pb2_grpc import ClusterServiceStub
from yandex.cloud.operation.operation_pb2 import Operation
from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub
from yandexcloud._backoff import backoff_exponential_jittered_min_interval
from yandexcloud._wrappers.dataproc import (
Dataproc,
DataprocRetryInterceptor,
InterceptorSettings,
)


class MockOperationService:
def __init__(self, fail_attempts=0, fail_code=grpc.StatusCode.CANCELLED):
self.fail_attempts = fail_attempts
self.fail_code = fail_code
self.call_count = 0

def Get(self, request):
self.call_count += 1
if self.call_count <= self.fail_attempts:
error = grpc.RpcError()
error._state = Mock()
error._state.code = self.fail_code
error.code = lambda: self.fail_code
raise error
return Operation(id="test-op", done=True)


class MockClusterService:
def Create(self, request):
return Operation(id="test-cluster-op", done=False)


class MockSDK:
def __init__(self):
self.client_calls = []

def client(self, service_class, interceptor=None):
self.client_calls.append((service_class, interceptor))
if service_class == ClusterServiceStub:
return MockClusterService()
elif service_class == OperationServiceStub:
return MockOperationService()

def create_operation_and_get_result(self, request, service, method_name, response_type, meta_type):
operation = Operation(id="test-op", done=False)
waiter = waiter_module.operation_waiter(self, operation.id, None)
for _ in waiter:
time.sleep(0.01)
return MockOperationResult()


class MockOperationResult:
def __init__(self):
self.response = Mock()
self.response.id = "test-cluster-id"


@pytest.fixture
def mock_sdk():
return MockSDK()


def test_dataproc_custom_interceptor_max_attempts(mock_sdk):
dataproc = Dataproc(
sdk=mock_sdk,
interceptor_settings=InterceptorSettings(
max_retry_count=50,
retriable_codes=(
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.RESOURCE_EXHAUSTED,
grpc.StatusCode.INTERNAL,
grpc.StatusCode.CANCELLED,
),
back_off_func=backoff_exponential_jittered_min_interval(),
),
)

mock_operation_service = MockOperationService(fail_attempts=51, fail_code=grpc.StatusCode.CANCELLED)

with patch.object(waiter_module, "operation_waiter") as mock_waiter_fn:
mock_waiter = Mock()
mock_waiter.__iter__ = Mock(return_value=iter([]))
mock_waiter.operation = Operation(id="test", done=True)
mock_waiter_fn.return_value = mock_waiter

with patch.object(mock_sdk, "client") as mock_client:
mock_client.return_value = mock_operation_service

with pytest.raises(grpc.RpcError) as exc_info:
dataproc.create_cluster(
folder_id="test-folder",
cluster_name="test-cluster",
subnet_id="test-subnet",
service_account_id="test-sa",
ssh_public_keys="test-ssh-key",
)

assert exc_info.value.code() == grpc.StatusCode.CANCELLED
assert mock_operation_service.call_count <= 50


def test_dataproc_interceptor_inheritance():
interceptor = DataprocRetryInterceptor(
max_retry_count=10, retriable_codes=(grpc.StatusCode.CANCELLED, grpc.StatusCode.UNAVAILABLE)
)

assert interceptor._RetryInterceptor__is_retriable(grpc.StatusCode.CANCELLED) == True

assert interceptor._RetryInterceptor__is_retriable(grpc.StatusCode.UNAVAILABLE) == True

assert interceptor._RetryInterceptor__is_retriable(grpc.StatusCode.PERMISSION_DENIED) == False


def test_dataproc_monkey_patch_restoration():
mock_sdk = Mock()
original_waiter = waiter_module.operation_waiter

dataproc = Dataproc(
sdk=mock_sdk,
interceptor_settings=InterceptorSettings(
max_retry_count=50,
retriable_codes=(
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.RESOURCE_EXHAUSTED,
grpc.StatusCode.INTERNAL,
grpc.StatusCode.CANCELLED,
),
back_off_func=backoff_exponential_jittered_min_interval(),
),
)

with patch.object(mock_sdk, "create_operation_and_get_result") as mock_create:
mock_create.return_value = MockOperationResult()

result = dataproc.delete_cluster(cluster_id="test-cluster-id")

assert result is not None

assert waiter_module.operation_waiter == original_waiter


def test_dataproc_all_methods_use_wrapper(mock_sdk):
dataproc = Dataproc(
sdk=mock_sdk,
interceptor_settings=InterceptorSettings(
max_retry_count=50,
retriable_codes=(
grpc.StatusCode.UNAVAILABLE,
grpc.StatusCode.RESOURCE_EXHAUSTED,
grpc.StatusCode.INTERNAL,
grpc.StatusCode.CANCELLED,
),
back_off_func=backoff_exponential_jittered_min_interval(),
),
)

methods_to_test = [
(
"create_cluster",
{
"folder_id": "test",
"cluster_name": "test",
"subnet_id": "test",
"service_account_id": "test",
"ssh_public_keys": "test-ssh-key",
},
),
("delete_cluster", {"cluster_id": "test"}),
("stop_cluster", {"cluster_id": "test"}),
("start_cluster", {"cluster_id": "test"}),
]

with patch.object(dataproc, "_with_dataproc_waiter") as mock_wrapper:
mock_wrapper.return_value = MockOperationResult()

for method_name, kwargs in methods_to_test:
method = getattr(dataproc, method_name)
method(**kwargs)

assert mock_wrapper.called
mock_wrapper.reset_mock()
90 changes: 78 additions & 12 deletions yandexcloud/_wrappers/dataproc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# mypy: ignore-errors
import logging
import random
from typing import Iterable, NamedTuple
from typing import Callable, Iterable, NamedTuple

import grpc
from google.protobuf.field_mask_pb2 import FieldMask

import yandex.cloud.dataproc.v1.cluster_pb2 as cluster_pb
Expand All @@ -16,6 +17,9 @@
import yandex.cloud.dataproc.v1.subcluster_pb2 as subcluster_pb
import yandex.cloud.dataproc.v1.subcluster_service_pb2 as subcluster_service_pb
import yandex.cloud.dataproc.v1.subcluster_service_pb2_grpc as subcluster_service_grpc_pb
import yandexcloud._operation_waiter as waiter_module
from yandex.cloud.operation.operation_service_pb2_grpc import OperationServiceStub
from yandexcloud._retry_interceptor import RetryInterceptor


class InitializationAction(NamedTuple):
Expand All @@ -31,6 +35,34 @@ def to_grpc(self):
)


class InterceptorSettings(NamedTuple):
max_retry_count: int # Maximum number of retries
retriable_codes: Iterable[grpc.StatusCode] # Retriable error codes
back_off_func: Callable[[int], float] # Backoff function


class DataprocRetryInterceptor(RetryInterceptor):
# pylint: disable-next=invalid-name
def _RetryInterceptor__is_retriable(self, error: grpc.StatusCode) -> bool:
if error in self._RetryInterceptor__retriable_codes:
return True

return False


def create_custom_operation_waiter(interceptor_settings: InterceptorSettings):
def custom_operation_waiter(sdk, operation_id, timeout):
retry_interceptor = DataprocRetryInterceptor(
max_retry_count=interceptor_settings.max_retry_count,
retriable_codes=interceptor_settings.retriable_codes,
back_off_func=interceptor_settings.back_off_func,
)
operation_service = sdk.client(OperationServiceStub, interceptor=retry_interceptor)
return waiter_module.OperationWaiter(operation_id, operation_service, timeout)

return custom_operation_waiter


class Dataproc:
"""
A base hook for Yandex.Cloud Data Proc.
Expand All @@ -43,9 +75,18 @@ class Dataproc:
:type logger: Optional[logging.Logger]
:param sdk: SDK object. Normally is being set by Wrappers constructor
:type sdk: yandexcloud.SDK
:param interceptor_settings: Settings For Custom Dataproc Interceptor
:type interceptor_settings: Optional[InterceptorSettings]
"""

def __init__(self, default_folder_id=None, default_public_ssh_key=None, logger=None, sdk=None):
def __init__(
self,
default_folder_id=None,
default_public_ssh_key=None,
logger=None,
sdk=None,
interceptor_settings=None,
):
self.sdk = sdk or self.sdk
self.log = logger
if not self.log:
Expand All @@ -55,6 +96,21 @@ def __init__(self, default_folder_id=None, default_public_ssh_key=None, logger=N
self.subnet_id = None
self.default_folder_id = default_folder_id
self.default_public_ssh_key = default_public_ssh_key
self._custom_operation_waiter = None
if interceptor_settings:
self._custom_operation_waiter = create_custom_operation_waiter(interceptor_settings)

def _with_dataproc_waiter(self, func, *args, **kwargs):
if not self._custom_operation_waiter:
return func(*args, **kwargs)

original_waiter = waiter_module.operation_waiter
waiter_module.operation_waiter = self._custom_operation_waiter

try:
return func(*args, **kwargs)
finally:
waiter_module.operation_waiter = original_waiter

def create_cluster(
self,
Expand Down Expand Up @@ -313,7 +369,8 @@ def create_cluster(
log_group_id=log_group_id,
labels=labels,
)
result = self.sdk.create_operation_and_get_result(
result = self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=cluster_service_grpc_pb.ClusterServiceStub,
method_name="Create",
Expand Down Expand Up @@ -427,7 +484,8 @@ def create_subcluster(
hosts_count=hosts_count,
autoscaling_config=autoscaling_config,
)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=subcluster_service_grpc_pb.SubclusterServiceStub,
method_name="Create",
Expand Down Expand Up @@ -455,7 +513,8 @@ def update_cluster_description(self, description, cluster_id=None):
update_mask=mask,
description=description,
)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=cluster_service_grpc_pb.ClusterServiceStub,
method_name="Update",
Expand All @@ -475,7 +534,8 @@ def delete_cluster(self, cluster_id=None):

self.log.info("Deleting cluster %s", cluster_id)
request = cluster_service_pb.DeleteClusterRequest(cluster_id=cluster_id)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=cluster_service_grpc_pb.ClusterServiceStub,
method_name="Delete",
Expand All @@ -497,7 +557,8 @@ def stop_cluster(self, cluster_id=None, decommission_timeout=0):
request = cluster_service_pb.StopClusterRequest(
cluster_id=cluster_id, decommission_timeout=decommission_timeout
)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=cluster_service_grpc_pb.ClusterServiceStub,
method_name="Stop",
Expand All @@ -514,7 +575,8 @@ def start_cluster(self, cluster_id=None):
if not cluster_id:
raise RuntimeError("Cluster id must be specified.")
request = cluster_service_pb.StartClusterRequest(cluster_id=cluster_id)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=cluster_service_grpc_pb.ClusterServiceStub,
method_name="Start",
Expand Down Expand Up @@ -575,7 +637,8 @@ def create_hive_job(
name=name,
hive_job=hive_job,
)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=job_service_grpc_pb.JobServiceStub,
method_name="Create",
Expand Down Expand Up @@ -637,7 +700,8 @@ def create_mapreduce_job(
properties=properties,
),
)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=job_service_grpc_pb.JobServiceStub,
method_name="Create",
Expand Down Expand Up @@ -712,7 +776,8 @@ def create_spark_job(
exclude_packages=exclude_packages,
),
)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=job_service_grpc_pb.JobServiceStub,
method_name="Create",
Expand Down Expand Up @@ -786,7 +851,8 @@ def create_pyspark_job(
exclude_packages=exclude_packages,
),
)
return self.sdk.create_operation_and_get_result(
return self._with_dataproc_waiter(
self.sdk.create_operation_and_get_result,
request,
service=job_service_grpc_pb.JobServiceStub,
method_name="Create",
Expand Down
Loading