package vulnerability

import (
	"testing"

	"github.com/aquasecurity/trivy-db/pkg/db"

	"github.com/stretchr/testify/assert"

	dbTypes "github.com/aquasecurity/trivy-db/pkg/types"
	"github.com/aquasecurity/trivy/pkg/types"
)

func TestFillAndFilter(t *testing.T) {
	detectedVulns := []types.DetectedVulnerability{
		{
			VulnerabilityID: "foo",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityHigh],
			},
		},
		{
			VulnerabilityID: "piyo",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityCritical],
			},
		},
		{
			VulnerabilityID: "bar",
			PkgName:         "barpkg",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityLow],
			},
		},
		{
			VulnerabilityID: "hoge",
		},
		{
			VulnerabilityID: "baz",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityMedium],
			},
		},
	}

	severities := []dbTypes.Severity{dbTypes.SeverityLow, dbTypes.SeverityCritical,
		dbTypes.SeverityMedium, dbTypes.SeverityHigh, dbTypes.SeverityUnknown}

	mockDBConfig := new(db.MockDBConfig)
	getVulnerability := map[string]dbTypes.Vulnerability{
		"foo": {
			Title:       "footitle",
			Description: "foodesc",
			Severity:    dbTypes.SeverityHigh.String(),
			References:  []string{"fooref"},
		},
		"bar": {
			Title:       "bartitle",
			Description: "bardesc",
			Severity:    dbTypes.SeverityLow.String(),
			References:  []string{"barref"},
		},
		"baz": {
			Title:       "baztitle",
			Description: "bazdesc",
			Severity:    dbTypes.SeverityMedium.String(),
			References:  []string{"bazref"},
		},
		"piyo": {
			Title:       "piyotitle",
			Description: "piyodesc",
			Severity:    dbTypes.SeverityCritical.String(),
			References:  []string{"piyoref"},
		},
		"hoge": {
			Title:       "hogetitle",
			Description: "hogedesc",
			Severity:    dbTypes.SeverityUnknown.String(),
			References:  []string{"hogeref"},
		},
	}

	for pkgName, vulnerability := range getVulnerability {
		mockDBConfig.On("GetVulnerability", pkgName).Return(vulnerability, nil)

	}
	getSeverity := map[string]dbTypes.Severity{
		"foo":  dbTypes.SeverityHigh,
		"bar":  dbTypes.SeverityLow,
		"baz":  dbTypes.SeverityMedium,
		"piyo": dbTypes.SeverityCritical,
		"hoge": dbTypes.SeverityUnknown,
	}

	for pkgName, severity := range getSeverity {
		mockDBConfig.On("GetSeverity", pkgName).Return(severity, nil)
	}

	expected := []types.DetectedVulnerability{
		{
			VulnerabilityID: "piyo",
			Vulnerability: dbTypes.Vulnerability{
				Title:       "piyotitle",
				Description: "piyodesc",
				Severity:    dbTypes.SeverityNames[dbTypes.SeverityCritical],
				References:  []string{"piyoref"},
			},
		},
		{
			VulnerabilityID: "foo",
			Vulnerability: dbTypes.Vulnerability{
				Title:       "footitle",
				Description: "foodesc",
				Severity:    dbTypes.SeverityNames[dbTypes.SeverityHigh],
				References:  []string{"fooref"},
			},
		},
		{
			VulnerabilityID: "baz",
			Vulnerability: dbTypes.Vulnerability{
				Title:       "baztitle",
				Description: "bazdesc",
				Severity:    dbTypes.SeverityNames[dbTypes.SeverityMedium],
				References:  []string{"bazref"},
			},
		},
		{
			VulnerabilityID: "hoge",
			Vulnerability: dbTypes.Vulnerability{
				Title:       "hogetitle",
				Description: "hogedesc",
				Severity:    dbTypes.SeverityNames[dbTypes.SeverityUnknown],
				References:  []string{"hogeref"},
			},
		},
		{
			VulnerabilityID: "bar",
			PkgName:         "barpkg",
			Vulnerability: dbTypes.Vulnerability{
				Title:       "bartitle",
				Description: "bardesc",
				Severity:    dbTypes.SeverityNames[dbTypes.SeverityLow],
				References:  []string{"barref"},
			},
		},
	}

	client := Client{
		dbc: mockDBConfig,
	}
	actual := client.FillAndFilter(detectedVulns, severities, false, ".trivyignore", false)
	assert.Equal(t, expected, actual, "full db")

	expected = []types.DetectedVulnerability{
		{
			VulnerabilityID: "piyo",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityCritical],
			},
		},
		{
			VulnerabilityID: "foo",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityHigh],
			},
		},
		{
			VulnerabilityID: "baz",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityMedium],
			},
		},
		{
			VulnerabilityID: "hoge",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityUnknown],
			},
		},
		{
			VulnerabilityID: "bar",
			PkgName:         "barpkg",
			Vulnerability: dbTypes.Vulnerability{
				Severity: dbTypes.SeverityNames[dbTypes.SeverityLow],
			},
		},
	}

	actual = client.FillAndFilter(detectedVulns, severities, false, ".trivyignore", true)
	assert.Equal(t, expected, actual, "light db")
}
