/*
Copyright 2024 The Ceph-CSI Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package rbd

import (
	"context"
	"errors"
	"fmt"

	"github.com/ceph/ceph-csi/internal/journal"
	rbderrors "github.com/ceph/ceph-csi/internal/rbd/errors"
	rbd_group "github.com/ceph/ceph-csi/internal/rbd/group"
	"github.com/ceph/ceph-csi/internal/rbd/types"
	"github.com/ceph/ceph-csi/internal/util"
	"github.com/ceph/ceph-csi/internal/util/log"
)

var _ types.Manager = &rbdManager{}

type rbdManager struct {
	// driverInstance is the instance id of the CSI-driver (driver name).
	driverInstance string
	// parameters can contain the parameters of a create request.
	parameters map[string]string
	// secrets contain the credentials to connect to the Ceph cluster.
	secrets map[string]string

	// creds are the cached credentials, will be freed on Destroy()
	creds *util.Credentials
	// vgJournal is the journal that is used during opetations, it will be freed on Destroy().
	vgJournal journal.VolumeGroupJournal
}

// NewManager returns a new manager for handling Volume and Volume Group
// operations, combining the requests for RBD and the journalling in RADOS.
func NewManager(driverInstance string, parameters, secrets map[string]string) types.Manager {
	return &rbdManager{
		driverInstance: driverInstance,
		parameters:     parameters,
		secrets:        secrets,
	}
}

func (mgr *rbdManager) Destroy(ctx context.Context) {
	if mgr.creds != nil {
		mgr.creds.DeleteCredentials()
		mgr.creds = nil
	}

	if mgr.vgJournal != nil {
		mgr.vgJournal.Destroy()
		mgr.vgJournal = nil
	}
}

func (mgr *rbdManager) GetVolumeByID(ctx context.Context, id string) (types.Volume, error) {
	creds, err := mgr.getCredentials()
	if err != nil {
		return nil, err
	}

	volume, err := GenVolFromVolID(ctx, id, creds, mgr.secrets)
	if err != nil {
		switch {
		case errors.Is(err, rbderrors.ErrImageNotFound):
			err = fmt.Errorf("volume %s not found: %w", id, err)

			return nil, err
		case errors.Is(err, util.ErrPoolNotFound):
			err = fmt.Errorf("pool %s not found for %s: %w", volume.Pool, id, err)

			return nil, err
		default:
			return nil, fmt.Errorf("failed to get volume from id %q: %w", id, err)
		}
	}

	return volume, nil
}

func (mgr *rbdManager) GetSnapshotByID(ctx context.Context, id string) (types.Snapshot, error) {
	creds, err := mgr.getCredentials()
	if err != nil {
		return nil, err
	}

	snapshot, err := genSnapFromSnapID(ctx, id, creds, mgr.secrets)
	if err != nil {
		switch {
		case errors.Is(err, rbderrors.ErrImageNotFound):
			err = fmt.Errorf("volume %s not found: %w", id, err)

			return nil, err
		case errors.Is(err, util.ErrPoolNotFound):
			err = fmt.Errorf("pool %s not found for %s: %w", snapshot.Pool, id, err)

			return nil, err
		default:
			return nil, fmt.Errorf("failed to get volume from id %q: %w", id, err)
		}
	}

	return snapshot, nil
}

func (mgr *rbdManager) GetVolumeGroupByID(ctx context.Context, id string) (types.VolumeGroup, error) {
	creds, err := mgr.getCredentials()
	if err != nil {
		return nil, err
	}

	vg, err := rbd_group.GetVolumeGroup(ctx, id, mgr.driverInstance, creds, mgr)
	if err != nil {
		return nil, fmt.Errorf("failed to get volume group with id %q: %w", id, err)
	}

	return vg, nil
}

func (mgr *rbdManager) MakeVolumeGroupID(ctx context.Context, poolID int64, name string) (string, error) {
	clusterID, err := util.GetClusterID(mgr.parameters)
	if err != nil {
		return "", fmt.Errorf("failed to get cluster-id: %w", err)
	}

	// convert the clusterid, poolid and name to an id/handle
	id, err := journal.MakeVolumeGroupID(clusterID, poolID, name, mgr.getVolumeGroupNamePrefix())
	if err != nil {
		return "", fmt.Errorf("failed to convert name %q to a CSI-handle: %w", name, err)
	}

	return id, nil
}

func (mgr *rbdManager) CreateVolumeGroup(ctx context.Context, name string) (types.VolumeGroup, error) {
	creds, err := mgr.getCredentials()
	if err != nil {
		return nil, err
	}

	clusterID, err := util.GetClusterID(mgr.parameters)
	if err != nil {
		return nil, fmt.Errorf("failed to get cluster-id: %w", err)
	}

	vgJournal, err := mgr.getVolumeGroupJournal(clusterID)
	if err != nil {
		return nil, err
	}

	// pool is a required parameter
	pool, ok := mgr.parameters["pool"]
	if !ok || pool == "" {
		return nil, errors.New("required 'pool' option missing in volume group parameters")
	}

	// journalPool is an optional parameter, use pool if it is not set
	journalPool, ok := mgr.parameters["journalPool"]
	if !ok || journalPool == "" {
		journalPool = pool
	}

	// volumeGroupNamePrefix is an optional parameter, can be an empty string
	prefix := mgr.getVolumeGroupNamePrefix()

	// check if the journal contains a generated name for the group already
	vgData, err := vgJournal.CheckReservation(ctx, journalPool, name, prefix)
	if err != nil {
		return nil, fmt.Errorf("failed to reserve volume group for name %q: %w", name, err)
	}

	var uuid string
	if vgData != nil && vgData.GroupUUID != "" {
		uuid = vgData.GroupUUID
	} else {
		log.DebugLog(ctx, "the journal does not contain a reservation for a volume group with name %q yet", name)

		var vgName string
		uuid, vgName, err = vgJournal.ReserveName(ctx, journalPool, name, uuid, prefix)
		if err != nil {
			return nil, fmt.Errorf("failed to reserve volume group for name %q: %w", name, err)
		}
		defer func() {
			if err != nil {
				err = vgJournal.UndoReservation(ctx, pool, vgName, name)
				if err != nil {
					log.ErrorLog(ctx, "failed to undo the reservation for volume group %q: %w", name, err)
				}
			}
		}()
	}

	monitors, err := util.Mons(util.CsiConfigFile, clusterID)
	if err != nil {
		return nil, fmt.Errorf("failed to find MONs for cluster %q: %w", clusterID, err)
	}

	_ /*journalPoolID*/, poolID, err := util.GetPoolIDs(ctx, monitors, journalPool, pool, creds)
	if err != nil {
		return nil, fmt.Errorf("failed to generate a unique CSI volume group with uuid for %q: %w", uuid, err)
	}

	csiID, err := util.GenerateVolID(ctx, monitors, creds, poolID, pool, clusterID, uuid)
	if err != nil {
		return nil, fmt.Errorf("failed to generate a unique CSI volume group with uuid for %q: %w", uuid, err)
	}

	vg, err := rbd_group.GetVolumeGroup(ctx, csiID, mgr.driverInstance, creds, mgr)
	if err != nil {
		return nil, fmt.Errorf("failed to get volume group %q at cluster %q: %w", name, clusterID, err)
	}
	defer func() {
		if err != nil {
			vg.Destroy(ctx)
		}
	}()

	err = vg.Create(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to create volume group %q: %w", name, err)
	}

	return vg, nil
}

func (mgr *rbdManager) GetVolumeGroupSnapshotByID(
	ctx context.Context,
	id string,
) (types.VolumeGroupSnapshot, error) {
	creds, err := mgr.getCredentials()
	if err != nil {
		return nil, err
	}

	vgs, err := rbd_group.GetVolumeGroupSnapshot(ctx, id, mgr.driverInstance, creds, mgr)
	if err != nil {
		return nil, fmt.Errorf("failed to get volume group with id %q: %w", id, err)
	}

	return vgs, nil
}

func (mgr *rbdManager) GetVolumeGroupSnapshotByName(
	ctx context.Context,
	name string,
) (types.VolumeGroupSnapshot, error) {
	pool, ok := mgr.parameters["pool"]
	if !ok || pool == "" {
		return nil, errors.New("required 'pool' option missing in volume group parameters")
	}

	clusterID, err := util.GetClusterID(mgr.parameters)
	if err != nil {
		return nil, fmt.Errorf("failed to get cluster-id: %w", err)
	}

	uuid, freeUUID, err := mgr.getGroupUUID(ctx, clusterID, pool, name)
	if err != nil {
		return nil, fmt.Errorf("failed to get a UUID for volume group snapshot %q: %w", name, err)
	}
	defer func() {
		// no error, no need to undo the reservation
		if err == nil {
			return
		}

		freeUUID()
	}()

	monitors, err := util.Mons(util.CsiConfigFile, clusterID)
	if err != nil {
		return nil, fmt.Errorf("failed to find MONs for cluster %q: %w", clusterID, err)
	}

	_ /*journalPoolID*/, poolID, err := util.GetPoolIDs(ctx, monitors, pool, pool, mgr.creds)
	if err != nil {
		return nil, fmt.Errorf("failed to get the pool for volume group snapshot with uuid for %q: %w", uuid, err)
	}

	csiID, err := util.GenerateVolID(ctx, monitors, mgr.creds, poolID, pool, clusterID, uuid)
	if err != nil {
		return nil, fmt.Errorf("failed to generate a unique CSI volume group with uuid %q: %w", uuid, err)
	}

	vgs, err := rbd_group.GetVolumeGroupSnapshot(ctx, csiID, mgr.driverInstance, mgr.creds, mgr)
	if err != nil {
		return nil, fmt.Errorf("failed to get existing volume group snapshot with uuid %q: %w", uuid, err)
	}

	snapshots, err := vgs.ListSnapshots(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to get snapshots for volume group snapshot %q: %w", vgs, err)
	}

	if len(snapshots) == 0 {
		return nil, fmt.Errorf("volume group snapshot %q is incomplete, it has no snapshots", vgs)
	}

	return vgs, nil
}

func (mgr *rbdManager) CreateVolumeGroupSnapshot(
	ctx context.Context,
	vg types.VolumeGroup,
	name string,
) (types.VolumeGroupSnapshot, error) {
	pool, err := vg.GetPool(ctx)
	if err != nil {
		return nil, err
	}

	clusterID, err := vg.GetClusterID(ctx)
	if err != nil {
		return nil, fmt.Errorf("failed to get cluster id for volume group snapshot %q: %w", vg, err)
	}

	uuid, freeUUID, err := mgr.getGroupUUID(ctx, clusterID, pool, name)
	if err != nil {
		return nil, fmt.Errorf("failed to get a UUID for volume group snapshot %q: %w", vg, err)
	}
	defer func() {
		// no error, no need to undo the reservation
		if err == nil {
			return
		}

		freeUUID()
	}()

	monitors, err := util.Mons(util.CsiConfigFile, clusterID)
	if err != nil {
		return nil, fmt.Errorf("failed to find MONs for cluster %q: %w", clusterID, err)
	}

	_ /*journalPoolID*/, poolID, err := util.GetPoolIDs(ctx, monitors, pool, pool, mgr.creds)
	if err != nil {
		return nil, fmt.Errorf("failed to get PoolID for %q: %w", pool, err)
	}

	groupID, err := util.GenerateVolID(ctx, monitors, mgr.creds, poolID, pool, clusterID, uuid)
	if err != nil {
		return nil, fmt.Errorf("failed to generate a unique CSI volume group with uuid for %q: %w", uuid, err)
	}

	vgs, err := rbd_group.GetVolumeGroupSnapshot(ctx, groupID, mgr.driverInstance, mgr.creds, mgr)
	if vgs != nil {
		log.DebugLog(ctx, "found existing volume group snapshot %q for id %q", vgs, groupID)

		// validate the contents of the vgs
		snapshots, vgsErr := vgs.ListSnapshots(ctx)
		if vgsErr != nil {
			return nil, fmt.Errorf("failed to list snapshots of existing volume group snapshot %q: %w", vgs, vgsErr)
		}

		volumes, vgErr := vg.ListVolumes(ctx)
		if vgErr != nil {
			return nil, fmt.Errorf("failed to list volumes of volume group %q: %w", vg, vgErr)
		}

		// return the existing vgs if the contents matches
		// TODO: improve contents verification, length is a very minimal check
		if len(snapshots) == len(volumes) {
			log.DebugLog(ctx, "existing volume group snapshot %q contains %d snapshots", vgs, len(snapshots))

			return vgs, nil
		}
	} else if err != nil && !errors.Is(err, rbderrors.ErrImageNotFound) {
		// ErrImageNotFound can be returned if the VolumeGroupSnapshot
		// could not be found. It is expected that it does not exist
		// yet, in which case it will be created below.
		return nil, fmt.Errorf("failed to check for existing volume group snapshot with id %q: %w", groupID, err)
	}

	snapshots, err := vg.CreateSnapshots(ctx, mgr.creds, groupID)
	if err != nil {
		return nil, fmt.Errorf("failed to create volume group snapshot %q: %w", name, err)
	}
	defer func() {
		// cleanup created snapshots in case there was a failure
		if err == nil {
			return
		}

		for _, snap := range snapshots {
			delErr := snap.Delete(ctx)
			if delErr != nil {
				log.ErrorLog(ctx, "failed to delete snapshot %q: %v", snap, delErr)
			}
		}
	}()

	log.DebugLog(ctx, "volume group snapshot %q contains %d snapshots: %v", name, len(snapshots), snapshots)

	vgs, err = rbd_group.NewVolumeGroupSnapshot(ctx, groupID, mgr.driverInstance, mgr.creds, snapshots)
	if err != nil {
		return nil, fmt.Errorf("failed to create new volume group snapshot %q: %w", name, err)
	}

	log.DebugLog(ctx, "volume group snapshot %q has been created", vgs)

	return vgs, nil
}

// RegenerateVolumeGroupJournal regenerate the omap data for the volume group.
// This performs the following operations:
//   - extracts clusterID and Mons from the cluster mapping
//   - Retrieves pool and journalPool parameters from the VolumeGroupReplicationClass
//   - Reserves omap data
//   - Add volumeIDs mapping to the reserved volume group omap object
//   - Generate new volume group handle
//
// Returns the generated volume group handle.
//
// Note: The new volume group handle will differ from the original as it includes
// poolID and clusterID, which vary between clusters.
func (mgr *rbdManager) RegenerateVolumeGroupJournal(
	ctx context.Context,
	groupID, requestName string,
	volumeIds []string,
) (string, error) {
	var (
		clusterID   string
		monitors    string
		pool        string
		journalPool string
		namePrefix  string
		groupUUID   string
		vgName      string

		gi  util.CSIIdentifier
		ok  bool
		err error
	)

	err = gi.DecomposeCSIID(groupID)
	if err != nil {
		return "", fmt.Errorf("%w: error decoding volume group ID (%w) (%s)", rbderrors.ErrInvalidVolID, err, groupID)
	}

	monitors, clusterID, err = util.FetchMappedClusterIDAndMons(ctx, gi.ClusterID)
	if err != nil {
		return "", err
	}

	pool, ok = mgr.parameters["pool"]
	if !ok {
		return "", errors.New("required 'pool' parameter missing in parameters")
	}

	journalPool, ok = mgr.parameters["journalPool"]
	if !ok || journalPool == "" {
		journalPool = pool
	}

	vgJournal, err := mgr.getVolumeGroupJournal(clusterID)
	if err != nil {
		return "", err
	}
	defer vgJournal.Destroy()

	namePrefix = mgr.parameters["volumeGroupNamePrefix"]
	vgData, err := vgJournal.CheckReservation(ctx, journalPool, requestName, namePrefix)
	if err != nil {
		return "", err
	}

	if vgData != nil {
		groupUUID = vgData.GroupUUID
		vgName = vgData.GroupName
	} else {
		log.DebugLog(ctx, "the journal does not contain a reservation for a volume group with name %q yet", requestName)
		groupUUID, vgName, err = vgJournal.ReserveName(ctx, journalPool, requestName, gi.ObjectUUID, namePrefix)
		if err != nil {
			return "", fmt.Errorf("failed to reserve volume group for name %q: %w", requestName, err)
		}
		defer func() {
			if err != nil {
				undoError := vgJournal.UndoReservation(ctx, journalPool, vgName, requestName)
				if undoError != nil {
					log.ErrorLog(ctx, "failed to undo the reservation for volume group %q: %w", requestName, undoError)
				}
			}
		}()
	}

	volumes := make([]types.Volume, len(volumeIds))
	defer func() {
		for _, v := range volumes {
			v.Destroy(ctx)
		}
	}()
	var volume types.Volume
	for i, id := range volumeIds {
		volume, err = mgr.GetVolumeByID(ctx, id)
		if err != nil {
			return "", fmt.Errorf("failed to find required volume %q for volume group id %q: %w", id, vgName, err)
		}

		volumes[i] = volume
	}

	var volID string
	for _, vol := range volumes {
		volID, err = vol.GetID(ctx)
		if err != nil {
			return "", fmt.Errorf("failed to get VolumeID for %q: %w", vol, err)
		}

		toAdd := map[string]string{
			volID: "",
		}
		log.DebugLog(ctx, "adding volume mapping for volume %q to volume group %q", volID, vgName)
		err = mgr.vgJournal.AddVolumesMapping(ctx, pool, gi.ObjectUUID, toAdd)
		if err != nil {
			return "", fmt.Errorf("failed to add mapping for volume %q to volume group %q: %w", volID, vgName, err)
		}
	}

	_, poolID, err := util.GetPoolIDs(ctx, monitors, journalPool, pool, mgr.creds)
	if err != nil {
		return "", fmt.Errorf("failed to get poolID for %q: %w", groupUUID, err)
	}

	groupHandle, err := util.GenerateVolID(ctx, monitors, mgr.creds, poolID, pool, clusterID, groupUUID)
	if err != nil {
		return "", fmt.Errorf("failed to generate a unique CSI volume group with uuid for %q: %w", groupUUID, err)
	}

	log.DebugLog(ctx, "re-generated Group ID (%s) and Group Name (%s) for request name (%s)",
		groupHandle, vgName, requestName)

	return groupHandle, nil
}

// CompareVolumesInGroup returns 'true' when the list of volumes matches the
// volumes in the group. In case a volume belongs to no group, or an other
// group than the VolumeGroup, 'false' is returned.
func (mgr *rbdManager) CompareVolumesInGroup(
	ctx context.Context,
	volumes []types.Volume,
	vg types.VolumeGroup,
) (bool, error) {
	vgVols, err := vg.ListVolumes(ctx)
	if err != nil {
		return false, fmt.Errorf("failed to list volumes in group %q: %w", vg, err)
	}

	// the vg is allowed to be empty, or have the exact number of volumes
	if !(len(vgVols) == 0 || len(vgVols) == len(volumes)) {
		return false, fmt.Errorf(
			"volume group %q has more or less volumes (%d) than expected (%d)",
			vg,
			len(vgVols),
			len(volumes))
	}

	vgID, err := vg.GetID(ctx)
	if err != nil {
		return false, fmt.Errorf("failed to get name for volume group %q: %w", vg, err)
	}

	// verify that all volumes are part of the vg, or do not have a group at all
	matchingGroup, err := mgr.VolumesInSameGroup(ctx, volumes)
	if err != nil {
		return false, err
	} else if !matchingGroup {
		return false, nil
	}

	// all volumes are in the same group
	groupID, err := volumes[0].GetVolumeGroupID(ctx, mgr)
	if err != nil && !errors.Is(err, rbderrors.ErrGroupNotFound) {
		return false, fmt.Errorf("failed to get group for volume %q: %w", volumes[0], err)
	}

	// if none of the volumes is in a group, groupID will be ""
	if groupID != "" && vgID != groupID {
		log.DebugLog(ctx, "expecting group %q but volume %q has group %q", vgID, volumes[0], groupID)

		return false, nil
	}

	return true, nil
}

// VolumesInSameGroup returns 'true' when all volumes are in the same group, or
// in no group at all.
func (mgr *rbdManager) VolumesInSameGroup(ctx context.Context, volumes []types.Volume) (bool, error) {
	var lastID *string
	for _, v := range volumes {
		id, err := v.GetVolumeGroupID(ctx, mgr)
		if err != nil && !errors.Is(err, rbderrors.ErrGroupNotFound) {
			return false, fmt.Errorf("failed to get group name for volume %q: %w", v, err)
		}

		// all volumes should be part of the same group
		// lastID == nil in the 1st loop
		if lastID != nil && *lastID != id {
			return false, fmt.Errorf("volume %q belongs to group %q, but expected %q", v, id, *lastID)
		}

		lastID = &id
	}

	return true, nil
}

// getCredentials sets up credentials and connects to the journal.
func (mgr *rbdManager) getCredentials() (*util.Credentials, error) {
	if mgr.creds != nil {
		return mgr.creds, nil
	}

	creds, err := util.NewUserCredentials(mgr.secrets)
	if err != nil {
		return nil, fmt.Errorf("failed to get credentials: %w", err)
	}

	mgr.creds = creds

	return creds, nil
}

// getVolumeGroupNamePrefix returns the prefix for the volume group if set, or
// an empty string if none is configured.
func (mgr *rbdManager) getVolumeGroupNamePrefix() string {
	return mgr.parameters["volumeGroupNamePrefix"]
}

func (mgr *rbdManager) getVolumeGroupJournal(clusterID string) (journal.VolumeGroupJournal, error) {
	if mgr.vgJournal != nil {
		return mgr.vgJournal, nil
	}

	creds, err := mgr.getCredentials()
	if err != nil {
		return nil, err
	}

	monitors, err := util.Mons(util.CsiConfigFile, clusterID)
	if err != nil {
		return nil, fmt.Errorf("failed to find MONs for cluster %q: %w", clusterID, err)
	}

	ns, err := util.GetRBDRadosNamespace(util.CsiConfigFile, clusterID)
	if err != nil {
		return nil, fmt.Errorf("failed to find the RADOS namespace for cluster %q: %w", clusterID, err)
	}

	vgJournalConfig := journal.NewCSIVolumeGroupJournalWithNamespace(mgr.driverInstance, ns)

	vgJournal, err := vgJournalConfig.Connect(monitors, ns, creds)
	if err != nil {
		return nil, fmt.Errorf("failed to connect to journal: %w", err)
	}

	mgr.vgJournal = vgJournal

	return vgJournal, nil
}

// getGroupUUID checks if a UUID in the volume group journal is already
// reserved. If none is reserved, a new reservation is made. Upon exit of
// getGroupUUID, the function returns:
// 1. the UUID that was reserved
// 2. an undo() function that reverts the reservation (if that succeeded), should be called in a defer
// 3. an error or nil.
func (mgr *rbdManager) getGroupUUID(
	ctx context.Context,
	clusterID, journalPool, name string,
) (string, func(), error) {
	nothingToUndo := func() {
		// the reservation was not done, no need to undo the reservation
	}

	prefix := mgr.getVolumeGroupNamePrefix()

	vgJournal, err := mgr.getVolumeGroupJournal(clusterID)
	if err != nil {
		return "", nothingToUndo, err
	}

	vgsData, err := vgJournal.CheckReservation(ctx, journalPool, name, prefix)
	if err != nil {
		return "", nothingToUndo, fmt.Errorf("failed to check reservation for group %q: %w", name, err)
	}

	var uuid string
	if vgsData != nil && vgsData.GroupUUID != "" {
		uuid = vgsData.GroupUUID
	} else {
		log.DebugLog(ctx, "the journal does not contain a reservation for group %q yet", name)

		uuid, _ /*vgsName*/, err = vgJournal.ReserveName(ctx, journalPool, name, uuid, prefix)
		if err != nil {
			return "", nothingToUndo, fmt.Errorf("failed to reserve a UUID for group %q: %w", name, err)
		}
	}

	log.DebugLog(ctx, "got UUID %q for group %q", uuid, name)

	// undo contains the cleanup that should be done by the caller when the
	// reservation was made, and further actions fulfilling the final
	// request failed
	undo := func() {
		err = vgJournal.UndoReservation(ctx, journalPool, uuid, name)
		if err != nil {
			log.ErrorLog(ctx, "failed to undo the reservation for group %q: %w", name, err)
		}
	}

	return uuid, undo, nil
}
