/*
 * Copyright 2018 LinkedIn, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.linkedin.restli.client;

import com.linkedin.parseq.Task;
import com.linkedin.r2.filter.R2Constants;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.restli.client.config.RequestConfigOverridesBuilder;
import com.linkedin.restli.examples.greetings.api.Greeting;
import com.linkedin.restli.examples.greetings.client.GreetingsBuilders;
import org.testng.annotations.Test;

import static org.testng.Assert.*;


public class TestParSeqRestClientWithD2Timeout extends ParSeqRestClientIntegrationTest {

  private CapturingRestClient _capturingRestClient;

  @Override
  public ParSeqRestliClientConfig getParSeqRestClientConfig() {
    return new ParSeqRestliClientConfigBuilder()
        .addTimeoutMs("withD2Timeout.*/greetings.*", 5000L)
        .addTimeoutMs("*.*/greetings.GET", 9999L)
        .addTimeoutMs("*.*/greetings.*", 10001L)
        .addTimeoutMs("*.*/*.GET", 10002L)
        .addTimeoutMs("foo.*/greetings.GET", 10003L)
        .addTimeoutMs("foo.GET/greetings.GET", 10004L)
        .addTimeoutMs("foo.ACTION-*/greetings.GET", 10005L)
        .addTimeoutMs("foo.ACTION-bar/greetings.GET", 10006L)
        .addBatchingEnabled("withBatching.*/*.*", true)
        .addMaxBatchSize("withBatching.*/*.*", 3)
        .build();
  }

  @Override
  protected RestClient createRestClient() {
    _capturingRestClient = new CapturingRestClient(null,  null, super.createRestClient());
    return _capturingRestClient;
  }

  @Override
  protected void customizeParSeqRestliClient(ParSeqRestliClientBuilder parSeqRestliClientBuilder) {
    parSeqRestliClientBuilder.setD2RequestTimeoutEnabled(true);
  }

  @Test
  public void testConfiguredD2TimeoutOutboundOverride() {
    Task<?> task = greetingGet(1L, new RequestConfigOverridesBuilder().setTimeoutMs(5555L).build());
    runAndWait(getTestClassName() + ".testConfiguredTimeoutOutbound", task);
    assertTrue(hasTask("withTimeout 5555ms", task.getTrace()));
  }

  @Test
  public void testConfiguredD2TimeoutOutboundOp() {
    setInboundRequestContext(new InboundRequestContextBuilder().setName("withD2Timeout").build());
    Task<?> task = greetingDel(9999L).toTry();
    runAndWait(getTestClassName() + ".testConfiguredD2TimeoutOutboundOp", task);
    assertTrue(hasTask("withTimeout 5000ms src: withD2Timeout.*/greetings.*", task.getTrace()));
  }

  @Test
  public void testTimeoutRequest() {
      setInboundRequestContext(new InboundRequestContextBuilder()
          .setName("withD2Timeout")
          .build());
      GetRequest<Greeting> request = new GreetingsBuilders().get().id(1L).build();
      Task<?> task = _parseqClient.createTask(request);
      runAndWait(getTestClassName() + ".testTimeoutRequest", task);
      assertTrue(hasTask("withTimeout 5000ms src: withD2Timeout.*/greetings.*", task.getTrace()));
      verifyRequestContextTimeout(request, 5000, Boolean.TRUE);
  }

  @Test
  public void testTighterTimeoutFromContext() {
    setInboundRequestContext(new InboundRequestContextBuilder()
        .setName("withD2Timeout")
        .build());
    GetRequest<Greeting> request = new GreetingsBuilders().get().id(1L).build();
    RequestContext context = new RequestContext();
    context.putLocalAttr(R2Constants.REQUEST_TIMEOUT, 4000);
    Task<?> task = _parseqClient.createTask(request, context);
    runAndWait(getTestClassName() + ".testTimeoutRequest", task);
    assertFalse(hasTask("withTimeout", task.getTrace()));
    verifyRequestContextTimeout(request, 4000, null);
  }

  @Test
  public void testLongerTimeoutFromContext() {
    setInboundRequestContext(new InboundRequestContextBuilder()
        .setName("withD2Timeout")
        .build());
    GetRequest<Greeting> request = new GreetingsBuilders().get().id(1L).build();
    RequestContext context = new RequestContext();
    context.putLocalAttr(R2Constants.REQUEST_TIMEOUT, 12000);
    Task<?> task = _parseqClient.createTask(request, context);
    runAndWait(getTestClassName() + ".testTimeoutRequest", task);
    assertFalse(hasTask("withTimeout", task.getTrace()));
    verifyRequestContextTimeout(request, 12000, null);
  }

  private void verifyRequestContextTimeout(Request<?> request, int timeout, Boolean ignoreIfHigher) {
    assertTrue(_capturingRestClient.getCapturedRequestContexts().containsKey(request));
    RequestContext context = _capturingRestClient.getCapturedRequestContexts().get(request);
    Number contextTimeout = (Number)context.getLocalAttr(R2Constants.REQUEST_TIMEOUT);
    assertNotNull(contextTimeout);
    assertEquals(contextTimeout.intValue(), timeout);
    if (ignoreIfHigher == null) {
      assertNull(context.getLocalAttr(R2Constants.REQUEST_TIMEOUT_IGNORE_IF_HIGHER_THAN_DEFAULT));
    } else {
      assertEquals(context.getLocalAttr(R2Constants.REQUEST_TIMEOUT_IGNORE_IF_HIGHER_THAN_DEFAULT), ignoreIfHigher);
    }
  }
}