package openid

import (
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"net/http"
	"strings"
	"sync"
	"time"

	"github.com/golang-jwt/jwt/v4"
	grpcauth "github.com/grpc-ecosystem/go-grpc-middleware/auth"
	"github.com/hashicorp/go-retryablehttp"
	"github.com/lestrrat-go/jwx/jwk"

	"github.com/Permify/permify/internal/config"
	base "github.com/Permify/permify/pkg/pb/base/v1"
)

type Authn struct {
	// URL of the issuer. This is typically the base URL of the identity provider.
	IssuerURL string
	// Audience for which the token is intended. It must match the audience in the JWT.
	Audience string
	// URL of the JSON Web Key Set (JWKS). This URL hosts public keys used to verify JWT signatures.
	JwksURI string
	// Pointer to an AutoRefresh object from the JWKS library. It helps in automatically refreshing the JWKS at predefined intervals.
	jwksSet *jwk.AutoRefresh
	// List of valid signing methods. Specifies which signing algorithms are considered valid for the JWTs.
	validMethods []string
	// Pointer to a JWT parser object. This is used to parse and validate the JWT tokens.
	jwtParser *jwt.Parser
	// Duration of the interval between retries for the backoff policy.
	backoffInterval time.Duration
	// Maximum number of retries for the backoff policy.
	backoffMaxRetries int

	backoffFrequency time.Duration

	// Global backoff state for tracking retry attempts across concurrent requests
	globalRetryCount int
	globalFirstSeen  time.Time
	retriedKeys      map[string]bool
	mutex            sync.Mutex // protects concurrent access to retry state
}

// NewOidcAuthn creates a new OIDC authenticator.
func NewOidcAuthn(ctx context.Context, conf config.Oidc) (*Authn, error) {
	// Create a new HTTP client with retry capabilities. This client is used for making HTTP requests, particularly for fetching OIDC configuration.
	client := retryablehttp.NewClient()
	client.Logger = SlogAdapter{Logger: slog.Default()}

	// Fetch the OIDC configuration from the issuer's well-known configuration endpoint.
	oidcConf, err := fetchOIDCConfiguration(client.StandardClient(), strings.TrimSuffix(conf.Issuer, "/")+"/.well-known/openid-configuration")
	if err != nil {
		return nil, fmt.Errorf("failed to fetch OIDC configuration: %w", err)
	}

	// Set up automatic refresh of the JSON Web Key Set (JWKS) to ensure the public keys are always up-to-date.
	autoRefresh := jwk.NewAutoRefresh(ctx)
	autoRefresh.Configure(oidcConf.JWKsURI, jwk.WithHTTPClient(client.StandardClient()), jwk.WithRefreshInterval(conf.RefreshInterval))

	// Validate and set backoffInterval, backoffMaxRetries, and backoffFrequency
	backoffInterval := conf.BackoffInterval
	if backoffInterval <= 0 {
		return nil, errors.New("invalid or missing backoffInterval")
	}

	backoffMaxRetries := conf.BackoffMaxRetries
	if backoffMaxRetries <= 0 {
		return nil, errors.New("invalid or missing backoffMaxRetries")
	}

	backoffFrequency := conf.BackoffFrequency
	if backoffFrequency <= 0 {
		return nil, errors.New("invalid or missing backoffFrequency")
	}

	// Initialize the Authn struct with the OIDC configuration details and other relevant settings.
	oidc := &Authn{
		IssuerURL:         conf.Issuer,
		Audience:          conf.Audience,
		JwksURI:           oidcConf.JWKsURI,
		validMethods:      conf.ValidMethods,
		jwtParser:         jwt.NewParser(jwt.WithValidMethods(conf.ValidMethods)),
		jwksSet:           autoRefresh,
		backoffInterval:   backoffInterval,
		backoffMaxRetries: backoffMaxRetries,
		backoffFrequency:  backoffFrequency,
		globalRetryCount:  0,
		retriedKeys:       make(map[string]bool),
		globalFirstSeen:   time.Time{},
		mutex:             sync.Mutex{},
	}

	// Attempt to fetch the JWKS immediately to ensure it's available and valid.
	if _, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI); err != nil {
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
	}

	return oidc, nil
}

// Authenticate validates the JWT token found in the authorization header of the incoming request.
func (oidc *Authn) Authenticate(ctx context.Context) error {
	// Extract the authorization header from the metadata of the incoming gRPC request.
	authHeader, err := grpcauth.AuthFromMD(ctx, "Bearer")
	if err != nil { // Check for authentication errors
		slog.Error("failed to extract authorization header from gRPC request", "error", err)
		return errors.New(base.ErrorCode_ERROR_CODE_MISSING_BEARER_TOKEN.String())
	}
	slog.Debug("Successfully extracted authorization header from gRPC request")

	// Parse and validate the JWT token extracted from the authorization header.
	parsedToken, err := oidc.jwtParser.Parse(authHeader, func(token *jwt.Token) (interface{}, error) {
		slog.Info("starting JWT parsing and validation.")

		// Retrieve the key ID from the JWT header and find the corresponding key in the JWKS.
		keyID, ok := token.Header["kid"].(string)
		if ok { // Key ID found in token header
			return oidc.getKeyWithRetry(ctx, keyID)
		}
		slog.Error("jwt does not contain a key ID")
		return nil, errors.New("kid must be specified in the token header")
	})
	if err != nil {
		slog.Error("token parsing or validation failed", "error", err)
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
	}

	// Ensure the token is valid.
	if !parsedToken.Valid {
		slog.Warn("parsed token is invalid")
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_BEARER_TOKEN.String())
	}

	// Extract the claims from the token.
	claims, ok := parsedToken.Claims.(jwt.MapClaims)
	if !ok {
		slog.Warn("token claims are in an incorrect format")
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_CLAIMS.String())
	}

	slog.Debug("extracted token claims", "claims", claims)

	// Verify the issuer of the token matches the expected issuer.
	if ok := claims.VerifyIssuer(oidc.IssuerURL, true); !ok {
		slog.Warn("token issuer is invalid", "expected", oidc.IssuerURL, "actual", claims["iss"])
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_ISSUER.String())
	}
	// Verify the audience of the token matches the expected audience.

	if ok := claims.VerifyAudience(oidc.Audience, true); !ok {
		slog.Warn("token audience is invalid", "expected", oidc.Audience, "actual", claims["aud"])
		return errors.New(base.ErrorCode_ERROR_CODE_INVALID_AUDIENCE.String())
	}

	slog.Info("token validation succeeded")

	// If all validations pass, return nil indicating the token is valid.
	return nil
}

// getKeyWithRetry attempts to retrieve the key for the given keyID with retries using a custom backoff strategy.
func (oidc *Authn) getKeyWithRetry(
	ctx context.Context,
	keyID string,
) (interface{}, error) {
	var raw interface{}
	var err error

	oidc.mutex.Lock()
	now := time.Now()

	// Reset global state if the interval has passed
	if oidc.globalFirstSeen.IsZero() || time.Since(oidc.globalFirstSeen) >= oidc.backoffInterval {
		slog.Info("resetting state as interval has passed or first seen is zero", "keyID", keyID)
		oidc.globalFirstSeen = now
		oidc.globalRetryCount = 0
		oidc.retriedKeys = make(map[string]bool)
	} else if oidc.globalRetryCount >= oidc.backoffMaxRetries {
		// If max retries reached within the interval, unlock and check keyID once
		slog.Warn("max retries reached within interval, will check keyID once", "keyID", keyID)
		oidc.mutex.Unlock()

		// Try to fetch the keyID once
		raw, err = oidc.fetchKey(ctx, keyID)
		if err == nil { // Successfully fetched the key
			oidc.mutex.Lock()
			if _, wasRetried := oidc.retriedKeys[keyID]; wasRetried {
				// Reset global backoff state if a valid key is found and that key had been previously retried
				// Use case: prevents malicious keyIDs from blocking valid keyIDs
				// The valid KeyID should not reset counters for invalid keys
				slog.Info("valid key found in backoff period, resetting global state", "keyID", keyID)
				oidc.globalRetryCount = 0                // Reset retry counter
				oidc.globalFirstSeen = time.Time{}       // Reset timestamp
				oidc.retriedKeys = make(map[string]bool) // Clear retried keys
			}
			oidc.mutex.Unlock() // Release the lock
			return raw, nil
		}

		// Log the failure and return an error if keyID is not found
		slog.Error("failed to fetch key during backoff period", "keyID", keyID, "error", err)
		return nil, errors.New("too many attempts, backoff in effect")
	}
	oidc.mutex.Unlock()

	// Retry mechanism
	retries := 0
	for retries <= oidc.backoffMaxRetries {
		raw, err = oidc.fetchKey(ctx, keyID)
		if err == nil { // Key successfully retrieved
			if retries != 0 { // Reset state if retry was successful
				oidc.mutex.Lock()
				oidc.globalRetryCount = 0
				oidc.globalFirstSeen = time.Time{}
				oidc.retriedKeys = make(map[string]bool)
				oidc.mutex.Unlock()
			}
			return raw, nil
		}
		oidc.mutex.Lock()
		snapshotCount := oidc.globalRetryCount
		oidc.retriedKeys[keyID] = true
		if oidc.globalRetryCount > oidc.backoffMaxRetries {
			slog.Error("key ID not found in JWKS due to exceeding global retries", "keyID", keyID, "globalRetryCount", oidc.globalRetryCount)
			oidc.mutex.Unlock() // Unlock before returning
			return nil, errors.New("too many retry attempts, backoff policy active due to global retry limit")
		}
		oidc.mutex.Unlock() // Release mutex
		if retries > 0 {
			select {
			case <-time.After(oidc.backoffFrequency):
				slog.Info("waiting before retrying", "keyID", keyID, "retries", retries)
			case <-ctx.Done():
				slog.Error("context cancelled during retry", "keyID", keyID)
				return nil, ctx.Err()
			}
		}

		oidc.mutex.Lock()
		if oidc.globalRetryCount > snapshotCount { // Another goroutine already refreshed
			// Another concurrent request in retry loop has already refreshed the JWKS
			retries++
			slog.Warn("concurrent request has already refreshed JWKS, skipping refresh")
			oidc.mutex.Unlock() // Unlock and continue
			continue            // Skip to next iteration
		}

		oidc.globalRetryCount++ // Increment the global retry counter
		slog.Warn("retrying to fetch JWKS due to error", "keyID", keyID, "retries", retries, "error", err)
		retries++ // Increment retry counter

		if _, err := oidc.jwksSet.Refresh(ctx, oidc.JwksURI); err != nil {
			oidc.mutex.Unlock()
			slog.Error("failed to refresh JWKS", "error", err)
			return nil, err
		}
		// Unlock after Refresh to prevent concurrent duplicate refresh calls
		oidc.mutex.Unlock() // Release lock after successful refresh
	}

	// Mark the global state to prevent further retries for the backoff interval
	oidc.mutex.Lock()
	if time.Since(oidc.globalFirstSeen) < oidc.backoffInterval {
		slog.Warn("marking state to prevent further retries", "keyID", keyID)
		oidc.globalRetryCount = oidc.backoffMaxRetries
	}
	oidc.mutex.Unlock()

	slog.Error("key ID not found in JWKS after retries", "keyID", keyID)
	return nil, errors.New("key ID not found in JWKS after retries")
}

// fetchKey attempts to fetch the JWKS and retrieve the key for the given keyID.
// It fetches from the configured JWKS URI and looks up the key by its ID.
func (oidc *Authn) fetchKey(
	ctx context.Context,
	keyID string,
) (interface{}, error) {
	// Log the attempt to find the key in JWKS
	slog.DebugContext(ctx, "attempting to find key in JWKS", "kid", keyID)

	// Fetch the JWKS from the configured URI
	jwks, err := oidc.jwksSet.Fetch(ctx, oidc.JwksURI)
	if err != nil { // Check for fetch errors
		// Log an error and return if fetching fails
		slog.Error("failed to fetch JWKS", "uri", oidc.JwksURI, "error", err)
		return nil, fmt.Errorf("failed to fetch JWKS: %w", err)
	}

	// Log a successful fetch of the JWKS
	slog.InfoContext(ctx, "successfully fetched JWKS")

	// Attempt to find the key in the fetched JWKS using the key ID
	if key, found := jwks.LookupKeyID(keyID); found {
		var k interface{} // Variable to hold the raw key
		// Convert the key to a usable format
		if err := key.Raw(&k); err != nil {
			slog.ErrorContext(ctx, "failed to get raw public key", "kid", keyID, "error", err)
			return nil, fmt.Errorf("failed to get raw public key: %w", err)
		}
		// Log a successful retrieval of the raw public key
		slog.DebugContext(ctx, "successfully obtained raw public key", "key", k)
		return k, nil // Return the public key for JWT signature verification
	}
	// Log an error if the key ID is not found in the JWKS
	slog.ErrorContext(ctx, "key ID not found in JWKS", "kid", keyID)
	return nil, fmt.Errorf("kid %s not found", keyID)
}

// Config holds OpenID Connect (OIDC) configuration details.
type Config struct {
	// Issuer is the OIDC provider's unique identifier URL.
	Issuer string `json:"issuer"`
	// JWKsURI is the URL to the JSON Web Key Set (JWKS) provided by the OIDC issuer.
	JWKsURI string `json:"jwks_uri"`
}

// fetchOIDCConfiguration sends an HTTP request to the given URL to fetch the OpenID Connect (OIDC) configuration.
// It requires an HTTP client and the URL from which to fetch the configuration.
func fetchOIDCConfiguration(client *http.Client, url string) (*Config, error) {
	// Send an HTTP GET request to the provided URL to fetch the OIDC configuration.
	// This typically points to the well-known configuration endpoint of the OIDC provider.
	body, err := doHTTPRequest(client, url)
	if err != nil {
		// If there is an error in fetching the configuration (network error, bad response, etc.), return nil and the error.
		return nil, err
	}

	// Parse the JSON response body into an OIDC Config struct.
	// This involves unmarshalling the JSON into a struct that matches the expected fields of the OIDC configuration.
	oidcConfig, err := parseOIDCConfiguration(body)
	if err != nil {
		return nil, err
	}

	// Return the parsed OIDC configuration and nil as the error (indicating success).
	return oidcConfig, nil
}

// doHTTPRequest makes an HTTP GET request to the specified URL and returns the response body.
// It handles HTTP errors and logs the request execution process.
func doHTTPRequest(client *http.Client, url string) ([]byte, error) {
	// Log the attempt to create a new HTTP GET request
	slog.Debug("creating new HTTP GET request", "url", url)

	// Create a new HTTP GET request.
	req, err := http.NewRequest("GET", url, nil)
	if err != nil {
		slog.Error("failed to create HTTP request", "url", url, "error", err)
		return nil, fmt.Errorf("failed to create HTTP request for OIDC configuration: %w", err)
	}

	// Log the execution of the HTTP request
	slog.Debug("executing HTTP request", "url", url)

	// Send the request using the configured HTTP client.
	res, err := client.Do(req)
	if err != nil {
		// Log the error if executing the HTTP request fails
		slog.Error("failed to execute HTTP request", "url", url, "error", err)
		return nil, fmt.Errorf("failed to execute HTTP request for OIDC configuration: %w", err)
	}

	// Log the HTTP status code of the response
	slog.Debug("received HTTP response", "status_code", res.StatusCode, "url", url)

	// Ensure the response body is closed after reading.
	defer res.Body.Close()

	// Check if the HTTP status code indicates success.
	if res.StatusCode != http.StatusOK {
		slog.Warn("received unexpected status code", "status_code", res.StatusCode, "url", url)
		return nil, fmt.Errorf("received unexpected status code (%d) while fetching OIDC configuration", res.StatusCode)
	}

	// Log the attempt to read the response body
	slog.Debug("reading response body", "url", url)

	// Read the response body.
	body, err := io.ReadAll(res.Body)
	if err != nil {
		slog.Error("failed to read response body", "url", url, "error", err)
		return nil, fmt.Errorf("failed to read response body from OIDC configuration request: %w", err)
	}

	// Log the successful retrieval of the response body
	slog.Debug("successfully read response body", "url", url, "response_length", len(body))

	// Return the response body.
	return body, nil
}

// parseOIDCConfiguration decodes the OIDC configuration from the given JSON body.
// It validates that required fields like Issuer and JWKsURI are present.
func parseOIDCConfiguration(body []byte) (*Config, error) {
	var oidcConfig Config
	// Attempt to unmarshal the JSON body into the oidcConfig struct.
	if err := json.Unmarshal(body, &oidcConfig); err != nil {
		slog.Error("failed to unmarshal OIDC configuration", "error", err)
		return nil, fmt.Errorf("failed to decode OIDC configuration: %w", err)
	}
	// Log the successful decoding of OIDC configuration
	slog.Debug("successfully decoded OIDC configuration")

	if oidcConfig.Issuer == "" {
		slog.Warn("missing issuer value in OIDC configuration")
		return nil, errors.New("issuer value is required but missing in OIDC configuration")
	}

	if oidcConfig.JWKsURI == "" {
		slog.Warn("missing JWKsURI value in OIDC configuration")
		return nil, errors.New("JWKsURI value is required but missing in OIDC configuration")
	}

	// Log the successful parsing of the OIDC configuration
	slog.Info("successfully parsed OIDC configuration", "issuer", oidcConfig.Issuer, "jwks_uri", oidcConfig.JWKsURI)

	// Return the successfully parsed configuration.
	return &oidcConfig, nil
}
