diff --git a/v2/brokers/sqs/sqs.go b/v2/brokers/sqs/sqs.go index b8cdb045..c09c5e04 100644 --- a/v2/brokers/sqs/sqs.go +++ b/v2/brokers/sqs/sqs.go @@ -200,6 +200,13 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor return errors.New("received empty message, the delivery is " + delivery.GoString()) } + if b.GetConfig().SQS.VisibilityHeartBeat { + notify := make(chan struct{}) + defer close(notify) + + b.visibilityHeartbeat(delivery, notify) + } + sig := new(tasks.Signature) decoder := json.NewDecoder(strings.NewReader(*delivery.Messages[0].Body)) decoder.UseNumber() @@ -219,7 +226,9 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor // and leave the message in the queue if !b.IsTaskRegistered(sig.Name) { if sig.IgnoreWhenTaskNotRegistered { - b.deleteOne(delivery) + if err := b.deleteOne(delivery); err != nil { + log.ERROR.Printf("error when deleting the delivery. delivery is %v, Error=%s", delivery, err) + } } return fmt.Errorf("task %s is not registered", sig.Name) } @@ -227,7 +236,7 @@ func (b *Broker) consumeOne(delivery *awssqs.ReceiveMessageOutput, taskProcessor err := taskProcessor.Process(sig) if err != nil { // stop task deletion in case we want to send messages to dlq in sqs - if err == errs.ErrStopTaskDeletion { + if errors.Is(err, errs.ErrStopTaskDeletion) { return nil } return err @@ -270,9 +279,8 @@ func (b *Broker) receiveMessage(qURL *string) (*awssqs.ReceiveMessageOutput, err if b.GetConfig().SQS != nil { waitTimeSeconds = b.GetConfig().SQS.WaitTimeSeconds visibilityTimeout = b.GetConfig().SQS.VisibilityTimeout - } else { - waitTimeSeconds = 0 } + input := &awssqs.ReceiveMessageInput{ AttributeNames: []*string{ aws.String(awssqs.MessageSystemAttributeNameSentTimestamp), @@ -350,6 +358,40 @@ func (b *Broker) continueReceivingMessages(qURL *string, deliveries chan *awssqs return true, nil } +// visibilityHeartbeat is a method that sends a heartbeat signal to AWS SQS to keep a message invisible to other consumers while being processed. +func (b *Broker) visibilityHeartbeat(delivery *awssqs.ReceiveMessageOutput, notify <-chan struct{}) { + if b.GetConfig().SQS.VisibilityTimeout == nil || *b.GetConfig().SQS.VisibilityTimeout == 0 { + return + } + + ticker := time.NewTicker(time.Duration(*b.GetConfig().SQS.VisibilityTimeout) * 500 * time.Millisecond) + + go func() { + for { + select { + case <-notify: + ticker.Stop() + + return + case <-b.stopReceivingChan: + ticker.Stop() + + return + case <-ticker.C: + // Extend the delivery visibility timeout + _, err := b.service.ChangeMessageVisibility(&awssqs.ChangeMessageVisibilityInput{ + QueueUrl: b.defaultQueueURL(), + ReceiptHandle: delivery.Messages[0].ReceiptHandle, + VisibilityTimeout: aws.Int64(int64(*b.GetConfig().SQS.VisibilityTimeout)), + }) + if err != nil { + log.ERROR.Printf("Error when changing delivery visibility: %v", err) + } + } + } + }() +} + // stopReceiving is a method sending a signal to stopReceivingChan func (b *Broker) stopReceiving() { // Stop the receiving goroutine diff --git a/v2/brokers/sqs/sqs_export_test.go b/v2/brokers/sqs/sqs_export_test.go index 8bcd8d62..57f25efe 100644 --- a/v2/brokers/sqs/sqs_export_test.go +++ b/v2/brokers/sqs/sqs_export_test.go @@ -7,15 +7,13 @@ import ( "os" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sqs/sqsiface" - "github.com/RichardKnop/machinery/v2/brokers/iface" "github.com/RichardKnop/machinery/v2/common" "github.com/RichardKnop/machinery/v2/config" - + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" awssqs "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" ) var ( @@ -107,17 +105,23 @@ func NewTestConfig() *config.Config { DefaultQueue: "test_queue", ResultBackend: fmt.Sprintf("redis://%v", redisURL), Lock: fmt.Sprintf("redis://%v", redisURL), + SQS: &config.SQSConfig{ + VisibilityTimeout: aws.Int(30), + }, } } -func NewTestBroker() *Broker { +func NewTestBroker(cnf *config.Config) *Broker { - cnf := NewTestConfig() sess := session.Must(session.NewSessionWithOptions(session.Options{ SharedConfigState: session.SharedConfigEnable, })) - svc := new(FakeSQS) + var svc sqsiface.SQSAPI = new(FakeSQS) + + if cnf.SQS.Client != nil { + svc = cnf.SQS.Client + } return &Broker{ Broker: common.NewBroker(cnf), sess: sess, diff --git a/v2/brokers/sqs/sqs_test.go b/v2/brokers/sqs/sqs_test.go index 797c5de2..9d17872b 100644 --- a/v2/brokers/sqs/sqs_test.go +++ b/v2/brokers/sqs/sqs_test.go @@ -7,15 +7,19 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/stretchr/testify/assert" - "github.com/RichardKnop/machinery/v2" + eagerbck "github.com/RichardKnop/machinery/v2/backends/eager" + "github.com/RichardKnop/machinery/v2/brokers/eager" "github.com/RichardKnop/machinery/v2/brokers/sqs" "github.com/RichardKnop/machinery/v2/config" + eagerlock "github.com/RichardKnop/machinery/v2/locks/eager" "github.com/RichardKnop/machinery/v2/retry" - + "github.com/aws/aws-sdk-go/aws" awssqs "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) var ( @@ -31,14 +35,14 @@ func init() { func TestNewAWSSQSBroker(t *testing.T) { t.Parallel() - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) assert.IsType(t, broker, sqs.New(cnf)) } func TestPrivateFunc_continueReceivingMessages(t *testing.T) { - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) errorBroker := sqs.NewTestErrorBroker() qURL := broker.DefaultQueueURLForTest() @@ -87,10 +91,7 @@ func TestPrivateFunc_continueReceivingMessages(t *testing.T) { func TestPrivateFunc_consume(t *testing.T) { - server1, err := machinery.NewServer(cnf) - if err != nil { - t.Fatal(err) - } + server1 := machinery.NewServer(cnf, eager.New(), eagerbck.New(), eagerlock.New()) pool := make(chan struct{}) wk := server1.NewWorker("sms_worker", 0) deliveries := make(chan *awssqs.ReceiveMessageOutput) @@ -98,29 +99,25 @@ func TestPrivateFunc_consume(t *testing.T) { outputCopy.Messages = []*awssqs.Message{} go func() { deliveries <- &outputCopy }() - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) // an infinite loop will be executed only when there is no error - err = broker.ConsumeForTest(deliveries, 0, wk, pool) + err := broker.ConsumeForTest(deliveries, 0, wk, pool) assert.NotNil(t, err) } func TestPrivateFunc_consumeOne(t *testing.T) { - - server1, err := machinery.NewServer(cnf) - if err != nil { - t.Fatal(err) - } + server1 := machinery.NewServer(cnf, eager.New(), eagerbck.New(), eagerlock.New()) wk := server1.NewWorker("sms_worker", 0) - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) - err = broker.ConsumeOneForTest(receiveMessageOutput, wk) - assert.NotNil(t, err) + err := broker.ConsumeOneForTest(receiveMessageOutput, wk) + assert.Error(t, err) outputCopy := *receiveMessageOutput outputCopy.Messages = []*awssqs.Message{} err = broker.ConsumeOneForTest(&outputCopy, wk) - assert.NotNil(t, err) + assert.Error(t, err) outputCopy.Messages = []*awssqs.Message{ { @@ -128,12 +125,54 @@ func TestPrivateFunc_consumeOne(t *testing.T) { }, } err = broker.ConsumeOneForTest(&outputCopy, wk) - assert.NotNil(t, err) + assert.Error(t, err) +} + +func TestPrivateFunc_consumeOneWithVisibilityHeartBeat(t *testing.T) { + + cfg := sqs.NewTestConfig() + cfg.SQS.VisibilityHeartBeat = true + cfg.SQS.VisibilityTimeout = aws.Int(1) // seconds + + mockClient := new(MockSQSAPI) + + cfg.SQS.Client = mockClient + + broker := sqs.NewTestBroker(cfg) + + server1 := machinery.NewServer(cfg, broker, eagerbck.New(), eagerlock.New()) + + // Long-running task by two times the visibility timeout. + err := server1.RegisterTask("test-task", func(ctx context.Context) error { + time.Sleep(time.Duration(*cfg.SQS.VisibilityTimeout) * 2 * time.Second) + + return nil + }) + require.NoError(t, err) + + wk := server1.NewWorker("sms_worker", 0) + + receiveMessageOutput.Messages = []*awssqs.Message{ + { + Body: aws.String(`{"Name": "test-task"}`), + }, + } + + mockClient.On("ChangeMessageVisibility", mock.AnythingOfType("*sqs.ChangeMessageVisibilityInput")).Return(&awssqs.ChangeMessageVisibilityOutput{}, nil) + mockClient.On("DeleteMessage", mock.AnythingOfType("*sqs.DeleteMessageInput")).Return(&awssqs.DeleteMessageOutput{}, nil) + + err = broker.ConsumeOneForTest(receiveMessageOutput, wk) + assert.NoError(t, err) + + time.Sleep(time.Duration(*cfg.SQS.VisibilityTimeout) * time.Second) + + // Assert that ChangeMessageVisibility was called. + mockClient.AssertNumberOfCalls(t, "ChangeMessageVisibility", 4) } func TestPrivateFunc_initializePool(t *testing.T) { - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) concurrency := 9 pool := make(chan struct{}, concurrency) @@ -143,13 +182,10 @@ func TestPrivateFunc_initializePool(t *testing.T) { func TestPrivateFunc_startConsuming(t *testing.T) { - server1, err := machinery.NewServer(cnf) - if err != nil { - t.Fatal(err) - } + server1 := machinery.NewServer(cnf, eager.New(), eagerbck.New(), eagerlock.New()) wk := server1.NewWorker("sms_worker", 0) - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) retryFunc := broker.GetRetryFuncForTest() stopChan := broker.GetStopChanForTest() @@ -164,7 +200,7 @@ func TestPrivateFunc_startConsuming(t *testing.T) { func TestPrivateFuncDefaultQueueURL(t *testing.T) { - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) qURL := broker.DefaultQueueURLForTest() @@ -173,7 +209,7 @@ func TestPrivateFuncDefaultQueueURL(t *testing.T) { func TestPrivateFunc_stopReceiving(t *testing.T) { - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) go broker.StopReceivingForTest() @@ -183,7 +219,7 @@ func TestPrivateFunc_stopReceiving(t *testing.T) { func TestPrivateFunc_receiveMessage(t *testing.T) { - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) qURL := broker.DefaultQueueURLForTest() output, err := broker.ReceiveMessageForTest(qURL) @@ -197,13 +233,10 @@ func TestPrivateFunc_consumeDeliveries(t *testing.T) { pool := make(chan struct{}, concurrency) errorsChan := make(chan error) deliveries := make(chan *awssqs.ReceiveMessageOutput) - server1, err := machinery.NewServer(cnf) - if err != nil { - t.Fatal(err) - } + server1 := machinery.NewServer(cnf, eager.New(), eagerbck.New(), eagerlock.New()) wk := server1.NewWorker("sms_worker", 0) - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) go func() { deliveries <- receiveMessageOutput }() whetherContinue, err := broker.ConsumeDeliveriesForTest(deliveries, concurrency, wk, pool, errorsChan) @@ -253,7 +286,7 @@ func TestPrivateFunc_consumeDeliveries(t *testing.T) { func TestPrivateFunc_deleteOne(t *testing.T) { - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) errorBroker := sqs.NewTestErrorBroker() err := broker.DeleteOneForTest(receiveMessageOutput) @@ -265,12 +298,9 @@ func TestPrivateFunc_deleteOne(t *testing.T) { func Test_CustomQueueName(t *testing.T) { - server1, err := machinery.NewServer(cnf) - if err != nil { - t.Fatal(err) - } + server1 := machinery.NewServer(cnf, eager.New(), eagerbck.New(), eagerlock.New()) - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) wk := server1.NewWorker("test-worker", 0) qURL := broker.GetQueueURLForTest(wk) @@ -284,27 +314,25 @@ func Test_CustomQueueName(t *testing.T) { func TestPrivateFunc_consumeWithConcurrency(t *testing.T) { msg := `{ - "UUID": "uuid-dummy-task", - "Name": "test-task", - "RoutingKey": "dummy-routing" - } - ` + "UUID": "uuid-dummy-task", + "Name": "test-task", + "RoutingKey": "dummy-routing" + } + ` testResp := "47f8b355-5115-4b45-b33a-439016400411" output := make(chan string) // The output channel cnf.ResultBackend = "eager" - server1, err := machinery.NewServer(cnf) - if err != nil { - t.Fatal(err) - } - err = server1.RegisterTask("test-task", func(ctx context.Context) error { + server1 := machinery.NewServer(cnf, eager.New(), eagerbck.New(), eagerlock.New()) + + err := server1.RegisterTask("test-task", func(ctx context.Context) error { output <- testResp return nil }) - broker := sqs.NewTestBroker() + broker := sqs.NewTestBroker(cnf) broker.SetRegisteredTaskNames([]string{"test-task"}) assert.NoError(t, err) @@ -338,3 +366,25 @@ func TestPrivateFunc_consumeWithConcurrency(t *testing.T) { t.Fatal("task not processed in 10 seconds") } } + +// MockSQSAPI is a mock implementation of the sqsiface.SQSAPI interface +type MockSQSAPI struct { + mock.Mock + + sqsiface.SQSAPI +} + +func (m *MockSQSAPI) ReceiveMessage(input *awssqs.ReceiveMessageInput) (*awssqs.ReceiveMessageOutput, error) { + args := m.Called(input) + return args.Get(0).(*awssqs.ReceiveMessageOutput), args.Error(1) +} + +func (m *MockSQSAPI) DeleteMessage(input *awssqs.DeleteMessageInput) (*awssqs.DeleteMessageOutput, error) { + args := m.Called(input) + return args.Get(0).(*awssqs.DeleteMessageOutput), args.Error(1) +} + +func (m *MockSQSAPI) ChangeMessageVisibility(input *awssqs.ChangeMessageVisibilityInput) (*awssqs.ChangeMessageVisibilityOutput, error) { + args := m.Called(input) + return args.Get(0).(*awssqs.ChangeMessageVisibilityOutput), args.Error(1) +} diff --git a/v2/config/config.go b/v2/config/config.go index b2b96efe..dbdcc684 100644 --- a/v2/config/config.go +++ b/v2/config/config.go @@ -3,12 +3,12 @@ package config import ( "crypto/tls" "fmt" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" "strings" "time" "cloud.google.com/go/pubsub" "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/sqs" "go.mongodb.org/mongo-driver/mongo" ) @@ -97,11 +97,14 @@ type DynamoDBConfig struct { // SQSConfig wraps SQS related configuration type SQSConfig struct { - Client *sqs.SQS + Client sqsiface.SQSAPI WaitTimeSeconds int `yaml:"receive_wait_time_seconds" envconfig:"SQS_WAIT_TIME_SECONDS"` // https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-visibility-timeout.html // visibility timeout should default to nil to use the overall visibility timeout for the queue VisibilityTimeout *int `yaml:"receive_visibility_timeout" envconfig:"SQS_VISIBILITY_TIMEOUT"` + // https://docs.aws.amazon.com/es_es/AWSSimpleQueueService/latest/SQSDeveloperGuide/best-practices-processing-messages-timely-manner.html + // visibility heartbeat should default to true to ensure that the visibility timeout is extended while the task is being processed. + VisibilityHeartBeat bool `yaml:"visibility_hearth_beat" envconfig:"SQS_VISIBILITY_HEARTBEAT"` } // RedisConfig ... diff --git a/v2/go.mod b/v2/go.mod index 434b9804..98e85d3a 100644 --- a/v2/go.mod +++ b/v2/go.mod @@ -16,7 +16,7 @@ require ( github.com/rabbitmq/amqp091-go v1.9.0 github.com/redis/go-redis/v9 v9.0.5 github.com/robfig/cron/v3 v3.0.1 - github.com/stretchr/testify v1.8.4 + github.com/stretchr/testify v1.10.0 github.com/urfave/cli v1.22.5 go.mongodb.org/mongo-driver v1.17.0 gopkg.in/yaml.v2 v2.4.0 diff --git a/v2/go.sum b/v2/go.sum index dfabe1b7..0a188ece 100644 --- a/v2/go.sum +++ b/v2/go.sum @@ -223,6 +223,8 @@ github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeV github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= @@ -232,6 +234,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203 h1:QVqDTf3h2WHt08YuiTGPZLls0Wq99X9bWd0Q5ZSBesM= github.com/stvp/tempredis v0.0.0-20181119212430-b82af8480203/go.mod h1:oqN97ltKNihBbwlX8dLpwxCl3+HnXKV/R0e+sRLd9C8= github.com/urfave/cli v1.22.5 h1:lNq9sAHXK2qfdI8W+GRItjCEkI+2oR4d+MEHy1CKXoU=