diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ebdf079..c3af632 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,6 +1,11 @@ name: Unit tests -on: ["push"] +on: + push: + paths: + - "**.go" + - ".github/workflows/*.yml" + - "example/hasura/docker-compose.yaml" jobs: test-go: @@ -11,12 +16,46 @@ jobs: uses: actions/checkout@v2 - uses: actions/setup-go@v2 with: - go-version: '1.16.4' + go-version: "1.16.4" + - uses: actions/cache@v2 + with: + path: | + ~/go/pkg/mod + ~/.cache/go-build + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- - name: Install dependencies run: go get -t -v ./... - name: Format run: diff -u <(echo -n) <(gofmt -d -s .) - name: Vet run: go vet ./... + - name: Setup integration test infrastructure + run: | + cd ./example/hasura + docker-compose up -d - name: Run Go unit tests - run: go test -v -race ./... \ No newline at end of file + run: go test -v -race -coverprofile=coverage.out ./... + - name: Go coverage format + run: | + go get github.com/boumenot/gocover-cobertura + gocover-cobertura < coverage.out > coverage.xml + - name: Code Coverage Summary Report + uses: irongut/CodeCoverageSummary@v1.3.0 + with: + filename: coverage.xml + badge: true + fail_below_min: true + format: markdown + hide_branch_rate: false + hide_complexity: true + indicators: true + output: both + thresholds: "60 80" + - name: Add Coverage PR Comment + uses: marocchino/sticky-pull-request-comment@v2 + if: github.event_name == 'pull_request' + with: + recreate: true + path: code-coverage-results.md diff --git a/.gitignore b/.gitignore index 9f11b75..137e4f2 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .idea/ +coverage.out \ No newline at end of file diff --git a/README.md b/README.md index 61bb21c..a59eee4 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ For more information, see package [`github.com/shurcooL/githubv4`](https://githu - [Stop the subscription](#stop-the-subscription) - [Authentication](#authentication-1) - [Options](#options) + - [Subscription Protocols](#subscription-protocols) + - [Handle connection error](#handle-connection-error) - [Events](#events) - [Custom HTTP Client](#custom-http-client) - [Custom WebSocket client](#custom-websocket-client) @@ -531,10 +533,18 @@ client := graphql.NewSubscriptionClient("wss://example.com/graphql"). "headers": map[string]string{ "authentication": "...", }, + }). + // or lazy parameters with function + WithConnectionParamsFn(func () map[string]interface{} { + return map[string]interface{} { + "headers": map[string]string{ + "authentication": "...", + }, + } }) - ``` + #### Options ```Go @@ -548,8 +558,35 @@ client. // max size of response message WithReadLimit(10*1024*1024). // these operation event logs won't be printed - WithoutLogTypes(graphql.GQL_DATA, graphql.GQL_CONNECTION_KEEP_ALIVE) + WithoutLogTypes(graphql.GQLData, graphql.GQLConnectionKeepAlive) +``` + +#### Subscription Protocols + +The subscription client supports 2 protocols: +- [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md) (default) +- [graphql-ws](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md) + +The protocol can be switchable by the `WithProtocol` function. +```Go +client.WithProtocol(graphql.GraphQLWS) +``` + +#### Handle connection error + +GraphQL servers can define custom WebSocket error codes in the 3000-4999 range. For example, in the `graphql-ws` protocol, the server sends the invalid message error with status [4400](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#invalid-message). In this case, the subscription client should let the user handle the error through the `OnError` event. + +```go +client := graphql.NewSubscriptionClient(serverEndpoint). + OnError(func(sc *graphql.SubscriptionClient, err error) error { + if strings.Contains(err.Error(), "invalid x-hasura-admin-secret/x-hasura-access-key") { + // exit the subscription client due to unauthorized error + return err + } + // otherwise ignore the error and the client continues to run + return nil + }) ``` #### Events diff --git a/example/hasura/README.md b/example/hasura/README.md new file mode 100644 index 0000000..ca61198 --- /dev/null +++ b/example/hasura/README.md @@ -0,0 +1,27 @@ +# Examples with Hasura graphql server + +## How to run + +### Server + +Requires [Docker](https://www.docker.com/) and [docker-compose](https://docs.docker.com/compose/install/) + +```sh +docker-compose up -d +``` + +Open the console at `http://localhost:8080` with admin secret `hasura`. + +### Client + +#### Subscription with subscriptions-transport-ws protocol + +```sh +go run ./client/subscriptions-transport-ws +``` + +#### Subscription with graphql-ws protocol + +```sh +go run ./client/graphql-ws +``` diff --git a/example/hasura/client/graphql-ws/client.go b/example/hasura/client/graphql-ws/client.go new file mode 100644 index 0000000..f544ccd --- /dev/null +++ b/example/hasura/client/graphql-ws/client.go @@ -0,0 +1,151 @@ +package main + +import ( + "context" + "fmt" + "log" + "math/rand" + "net/http" + "strings" + "time" + + graphql "github.com/hasura/go-graphql-client" +) + +const ( + serverEndpoint = "http://localhost:8080/v1/graphql" + adminSecret = "hasura" + xHasuraAdminSecret = "x-hasura-admin-secret" +) + +func main() { + go insertUsers() + startSubscription() +} + +func startSubscription() error { + + client := graphql.NewSubscriptionClient(serverEndpoint). + WithProtocol(graphql.GraphQLWS). + WithConnectionParams(map[string]interface{}{ + "headers": map[string]string{ + xHasuraAdminSecret: adminSecret, + }, + }).WithLog(log.Println). + OnError(func(sc *graphql.SubscriptionClient, err error) error { + if strings.Contains(err.Error(), "invalid x-hasura-admin-secret/x-hasura-access-key") { + return err + } + return nil + }) + + defer client.Close() + + /* + subscription { + user { + id + name + } + } + */ + var sub struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + _, err := client.Subscribe(sub, nil, func(data []byte, err error) error { + + if err != nil { + log.Println(err) + return nil + } + + if data == nil { + return nil + } + log.Println(string(data)) + return nil + }) + + if err != nil { + panic(err) + } + + // automatically unsubscribe after 10 seconds + // go func() { + // time.Sleep(10 * time.Second) + // client.Unsubscribe(subId) + // }() + + return client.Run() +} + +type user_insert_input map[string]interface{} + +// insertUsers insert users to the graphql server, so the subscription client can receive messages +func insertUsers() { + + client := graphql.NewClient(serverEndpoint, &http.Client{ + Transport: headerRoundTripper{ + setHeaders: func(req *http.Request) { + req.Header.Set(xHasuraAdminSecret, adminSecret) + }, + rt: http.DefaultTransport, + }, + }) + // stop until the subscription client is connected + time.Sleep(time.Second) + for i := 0; i < 10; i++ { + /* + mutation InsertUser($objects: [user_insert_input!]!) { + insert_user(objects: $objects) { + id + name + } + } + */ + var q struct { + InsertUser struct { + Returning []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"returning"` + } `graphql:"insert_user(objects: $objects)"` + } + variables := map[string]interface{}{ + "objects": []user_insert_input{ + { + "name": randomString(), + }, + }, + } + err := client.Mutate(context.Background(), &q, variables, graphql.OperationName("InsertUser")) + if err != nil { + fmt.Println(err) + } + time.Sleep(time.Second) + } +} + +func randomString() string { + var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]rune, 16) + for i := range b { + b[i] = letter[rand.Intn(len(letter))] + } + return string(b) +} + +type headerRoundTripper struct { + setHeaders func(req *http.Request) + rt http.RoundTripper +} + +func (h headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + h.setHeaders(req) + return h.rt.RoundTrip(req) +} diff --git a/example/hasura/client/subscriptions-transport-ws/client.go b/example/hasura/client/subscriptions-transport-ws/client.go new file mode 100644 index 0000000..cba1648 --- /dev/null +++ b/example/hasura/client/subscriptions-transport-ws/client.go @@ -0,0 +1,147 @@ +package main + +import ( + "context" + "fmt" + "log" + "math/rand" + "net/http" + "time" + + graphql "github.com/hasura/go-graphql-client" +) + +const ( + serverEndpoint = "http://localhost:8080/v1/graphql" + adminSecret = "hasura" + xHasuraAdminSecret = "x-hasura-admin-secret" +) + +func main() { + go insertUsers() + startSubscription() +} + +func startSubscription() error { + + client := graphql.NewSubscriptionClient(serverEndpoint). + WithConnectionParams(map[string]interface{}{ + "headers": map[string]string{ + xHasuraAdminSecret: adminSecret, + }, + }).WithLog(log.Println). + OnError(func(sc *graphql.SubscriptionClient, err error) error { + log.Print("err", err) + return err + }) + + defer client.Close() + + /* + subscription { + user { + id + name + } + } + */ + var sub struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(limit: 5, order_by: { id: desc })"` + } + + subId, err := client.Subscribe(sub, nil, func(data []byte, err error) error { + + if err != nil { + log.Println(err) + return nil + } + + if data == nil { + return nil + } + log.Println(string(data)) + return nil + }) + + if err != nil { + panic(err) + } + + // automatically unsubscribe after 10 seconds + go func() { + time.Sleep(10 * time.Second) + client.Unsubscribe(subId) + }() + + return client.Run() +} + +type user_insert_input map[string]interface{} + +// insertUsers insert users to the graphql server, so the subscription client can receive messages +func insertUsers() { + + client := graphql.NewClient(serverEndpoint, &http.Client{ + Transport: headerRoundTripper{ + setHeaders: func(req *http.Request) { + req.Header.Set(xHasuraAdminSecret, adminSecret) + }, + rt: http.DefaultTransport, + }, + }) + // stop until the subscription client is connected + time.Sleep(time.Second) + for i := 0; i < 10; i++ { + /* + mutation InsertUser($objects: [user_insert_input!]!) { + insert_user(objects: $objects) { + id + name + } + } + */ + var q struct { + InsertUser struct { + Returning []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"returning"` + } `graphql:"insert_user(objects: $objects)"` + } + variables := map[string]interface{}{ + "objects": []user_insert_input{ + { + "name": randomString(), + }, + }, + } + err := client.Mutate(context.Background(), &q, variables, graphql.OperationName("InsertUser")) + if err != nil { + fmt.Println(err) + } + time.Sleep(time.Second) + } +} + +func randomString() string { + var letter = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") + + b := make([]rune, 16) + for i := range b { + b[i] = letter[rand.Intn(len(letter))] + } + return string(b) +} + +type headerRoundTripper struct { + setHeaders func(req *http.Request) + rt http.RoundTripper +} + +func (h headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + h.setHeaders(req) + return h.rt.RoundTrip(req) +} diff --git a/example/hasura/docker-compose.yaml b/example/hasura/docker-compose.yaml new file mode 100644 index 0000000..7dea4ad --- /dev/null +++ b/example/hasura/docker-compose.yaml @@ -0,0 +1,32 @@ +version: "3.7" + +services: + postgres: + image: postgres:15 + restart: always + volumes: + - db_data:/var/lib/postgresql/data + environment: + POSTGRES_PASSWORD: postgrespassword + + hasura: + image: hasura/graphql-engine:v2.16.1.cli-migrations-v3 + depends_on: + - "postgres" + ports: + - "8080:8080" + volumes: + - ./server/migrations:/hasura-migrations + - ./server/metadata:/hasura-metadata + restart: always + environment: + HASURA_GRAPHQL_DATABASE_URL: postgres://postgres:postgrespassword@postgres:5432/postgres + ## enable the console served by server + HASURA_GRAPHQL_ENABLE_CONSOLE: "true" # set to "false" to disable console + HASURA_GRAPHQL_ENABLED_LOG_TYPES: startup,http-log,query-log,webhook-log,websocket-log + ## enable debugging mode. It is recommended to disable this in production + HASURA_GRAPHQL_DEV_MODE: "true" + HASURA_GRAPHQL_ADMIN_SECRET: hasura + +volumes: + db_data: diff --git a/example/hasura/server/config.yaml b/example/hasura/server/config.yaml new file mode 100644 index 0000000..725c800 --- /dev/null +++ b/example/hasura/server/config.yaml @@ -0,0 +1,6 @@ +version: 3 +endpoint: http://localhost:8080 +metadata_directory: metadata +actions: + kind: synchronous + handler_webhook_baseurl: http://localhost:3000 diff --git a/example/hasura/server/metadata/actions.graphql b/example/hasura/server/metadata/actions.graphql new file mode 100644 index 0000000..e69de29 diff --git a/example/hasura/server/metadata/actions.yaml b/example/hasura/server/metadata/actions.yaml new file mode 100644 index 0000000..1edb4c2 --- /dev/null +++ b/example/hasura/server/metadata/actions.yaml @@ -0,0 +1,6 @@ +actions: [] +custom_types: + enums: [] + input_objects: [] + objects: [] + scalars: [] diff --git a/example/hasura/server/metadata/allow_list.yaml b/example/hasura/server/metadata/allow_list.yaml new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/example/hasura/server/metadata/allow_list.yaml @@ -0,0 +1 @@ +[] diff --git a/example/hasura/server/metadata/api_limits.yaml b/example/hasura/server/metadata/api_limits.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/example/hasura/server/metadata/api_limits.yaml @@ -0,0 +1 @@ +{} diff --git a/example/hasura/server/metadata/cron_triggers.yaml b/example/hasura/server/metadata/cron_triggers.yaml new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/example/hasura/server/metadata/cron_triggers.yaml @@ -0,0 +1 @@ +[] diff --git a/example/hasura/server/metadata/databases/databases.yaml b/example/hasura/server/metadata/databases/databases.yaml new file mode 100644 index 0000000..65a11b2 --- /dev/null +++ b/example/hasura/server/metadata/databases/databases.yaml @@ -0,0 +1,14 @@ +- name: default + kind: postgres + configuration: + connection_info: + database_url: + from_env: HASURA_GRAPHQL_DATABASE_URL + isolation_level: read-committed + pool_settings: + connection_lifetime: 600 + idle_timeout: 180 + max_connections: 50 + retries: 1 + use_prepared_statements: true + tables: "!include default/tables/tables.yaml" diff --git a/example/hasura/server/metadata/databases/default/tables/public_user.yaml b/example/hasura/server/metadata/databases/default/tables/public_user.yaml new file mode 100644 index 0000000..528cbd1 --- /dev/null +++ b/example/hasura/server/metadata/databases/default/tables/public_user.yaml @@ -0,0 +1,3 @@ +table: + name: user + schema: public diff --git a/example/hasura/server/metadata/databases/default/tables/tables.yaml b/example/hasura/server/metadata/databases/default/tables/tables.yaml new file mode 100644 index 0000000..7a33703 --- /dev/null +++ b/example/hasura/server/metadata/databases/default/tables/tables.yaml @@ -0,0 +1 @@ +- "!include public_user.yaml" diff --git a/example/hasura/server/metadata/graphql_schema_introspection.yaml b/example/hasura/server/metadata/graphql_schema_introspection.yaml new file mode 100644 index 0000000..61a4dca --- /dev/null +++ b/example/hasura/server/metadata/graphql_schema_introspection.yaml @@ -0,0 +1 @@ +disabled_for_roles: [] diff --git a/example/hasura/server/metadata/inherited_roles.yaml b/example/hasura/server/metadata/inherited_roles.yaml new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/example/hasura/server/metadata/inherited_roles.yaml @@ -0,0 +1 @@ +[] diff --git a/example/hasura/server/metadata/metrics_config.yaml b/example/hasura/server/metadata/metrics_config.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/example/hasura/server/metadata/metrics_config.yaml @@ -0,0 +1 @@ +{} diff --git a/example/hasura/server/metadata/network.yaml b/example/hasura/server/metadata/network.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/example/hasura/server/metadata/network.yaml @@ -0,0 +1 @@ +{} diff --git a/example/hasura/server/metadata/query_collections.yaml b/example/hasura/server/metadata/query_collections.yaml new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/example/hasura/server/metadata/query_collections.yaml @@ -0,0 +1 @@ +[] diff --git a/example/hasura/server/metadata/remote_schemas.yaml b/example/hasura/server/metadata/remote_schemas.yaml new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/example/hasura/server/metadata/remote_schemas.yaml @@ -0,0 +1 @@ +[] diff --git a/example/hasura/server/metadata/rest_endpoints.yaml b/example/hasura/server/metadata/rest_endpoints.yaml new file mode 100644 index 0000000..fe51488 --- /dev/null +++ b/example/hasura/server/metadata/rest_endpoints.yaml @@ -0,0 +1 @@ +[] diff --git a/example/hasura/server/metadata/version.yaml b/example/hasura/server/metadata/version.yaml new file mode 100644 index 0000000..0a70aff --- /dev/null +++ b/example/hasura/server/metadata/version.yaml @@ -0,0 +1 @@ +version: 3 diff --git a/example/hasura/server/migrations/default/1673778595604_create_table_public_user/down.sql b/example/hasura/server/migrations/default/1673778595604_create_table_public_user/down.sql new file mode 100644 index 0000000..09bd8ca --- /dev/null +++ b/example/hasura/server/migrations/default/1673778595604_create_table_public_user/down.sql @@ -0,0 +1 @@ +DROP TABLE "public"."user"; diff --git a/example/hasura/server/migrations/default/1673778595604_create_table_public_user/up.sql b/example/hasura/server/migrations/default/1673778595604_create_table_public_user/up.sql new file mode 100644 index 0000000..dabe422 --- /dev/null +++ b/example/hasura/server/migrations/default/1673778595604_create_table_public_user/up.sql @@ -0,0 +1,17 @@ +CREATE TABLE "public"."user" ("id" serial NOT NULL, "name" text NOT NULL, "created_at" timestamptz NOT NULL DEFAULT now(), "updated_at" timestamptz NOT NULL DEFAULT now(), PRIMARY KEY ("id") ); +CREATE OR REPLACE FUNCTION "public"."set_current_timestamp_updated_at"() +RETURNS TRIGGER AS $$ +DECLARE + _new record; +BEGIN + _new := NEW; + _new."updated_at" = NOW(); + RETURN _new; +END; +$$ LANGUAGE plpgsql; +CREATE TRIGGER "set_public_user_updated_at" +BEFORE UPDATE ON "public"."user" +FOR EACH ROW +EXECUTE PROCEDURE "public"."set_current_timestamp_updated_at"(); +COMMENT ON TRIGGER "set_public_user_updated_at" ON "public"."user" +IS 'trigger to set value of column "updated_at" to current timestamp on row update'; diff --git a/example/subscription/client.go b/example/subscription/client.go index 25201db..49e9a23 100644 --- a/example/subscription/client.go +++ b/example/subscription/client.go @@ -22,7 +22,7 @@ func startSubscription() error { "foo": "bar", }, }).WithLog(log.Println). - WithoutLogTypes(graphql.GQL_DATA, graphql.GQL_CONNECTION_KEEP_ALIVE). + WithoutLogTypes(graphql.GQLData, graphql.GQLConnectionKeepAlive). OnError(func(sc *graphql.SubscriptionClient, err error) error { log.Print("err", err) return err diff --git a/example/tibber/README.md b/example/tibber/README.md new file mode 100644 index 0000000..b1d5a8d --- /dev/null +++ b/example/tibber/README.md @@ -0,0 +1,10 @@ +# GraphQL client demo with Tibber + +## How to run + +Go to [https://developer.tibber.com/explorer](https://developer.tibber.com/explorer) and get the demo token. + +```sh +export TIBBER_DEMO_TOKEN= +go run . +``` diff --git a/example/tibber/client.go b/example/tibber/client.go new file mode 100644 index 0000000..716e1d5 --- /dev/null +++ b/example/tibber/client.go @@ -0,0 +1,109 @@ +package main + +import ( + "log" + "net/http" + "os" + "time" + + graphql "github.com/hasura/go-graphql-client" +) + +// https://developer.tibber.com/explorer + +const ( + subscriptionEndpoint = "wss://websocket-api.tibber.com/v1-beta/gql/subscriptions" +) + +func main() { + startSubscription() +} + +// the subscription uses the Real time subscription demo +// +// subscription LiveMeasurement($homeId: ID!) { +// liveMeasurement(homeId: $homeId){ +// timestamp +// power +// accumulatedConsumption +// accumulatedCost +// currency +// minPower +// averagePower +// maxPower +// } +// } +func startSubscription() error { + // get the demo token from the graphiql playground + demoToken := os.Getenv("TIBBER_DEMO_TOKEN") + if demoToken == "" { + panic("TIBBER_DEMO_TOKEN env variable is required") + } + + client := graphql.NewSubscriptionClient(subscriptionEndpoint). + WithProtocol(graphql.GraphQLWS). + WithWebSocketOptions(graphql.WebsocketOptions{ + HTTPClient: &http.Client{ + Transport: headerRoundTripper{ + setHeaders: func(req *http.Request) { + req.Header.Set("User-Agent", "go-graphql-client/0.9.0") + }, + rt: http.DefaultTransport, + }, + }, + }). + WithConnectionParams(map[string]interface{}{ + "token": demoToken, + }).WithLog(log.Println). + OnError(func(sc *graphql.SubscriptionClient, err error) error { + panic(err) + }) + + defer client.Close() + + var sub struct { + LiveMeasurement struct { + Timestamp time.Time `graphql:"timestamp"` + Power int `graphql:"power"` + AccumulatedConsumption float64 `graphql:"accumulatedConsumption"` + AccumulatedCost float64 `graphql:"accumulatedCost"` + Currency string `graphql:"currency"` + MinPower int `graphql:"minPower"` + AveragePower float64 `graphql:"averagePower"` + MaxPower float64 `graphql:"maxPower"` + } `graphql:"liveMeasurement(homeId: $homeId)"` + } + + variables := map[string]interface{}{ + "homeId": graphql.ID("96a14971-525a-4420-aae9-e5aedaa129ff"), + } + _, err := client.Subscribe(sub, variables, func(data []byte, err error) error { + + if err != nil { + log.Println("ERROR: ", err) + return nil + } + + if data == nil { + return nil + } + log.Println(string(data)) + return nil + }) + + if err != nil { + panic(err) + } + + return client.Run() +} + +type headerRoundTripper struct { + setHeaders func(req *http.Request) + rt http.RoundTripper +} + +func (h headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + h.setHeaders(req) + return h.rt.RoundTrip(req) +} diff --git a/graphql.go b/graphql.go index 14b3f83..0f1ed2d 100644 --- a/graphql.go +++ b/graphql.go @@ -115,10 +115,7 @@ func (c *Client) buildAndRequest(ctx context.Context, op operationType, v interf // Request the common method that send graphql request func (c *Client) request(ctx context.Context, query string, variables map[string]interface{}, options ...Option) ([]byte, *http.Response, io.Reader, Errors) { - in := struct { - Query string `json:"query"` - Variables map[string]interface{} `json:"variables,omitempty"` - }{ + in := GraphQLRequestPayload{ Query: query, Variables: variables, } diff --git a/subscription.go b/subscription.go index 994c1bd..c91fc34 100644 --- a/subscription.go +++ b/subscription.go @@ -17,38 +17,39 @@ import ( "nhooyr.io/websocket/wsjson" ) -// Subscription transport follow Apollo's subscriptions-transport-ws protocol specification -// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +// SubscriptionProtocolType represents the protocol specification enum of the subscription +type SubscriptionProtocolType String -// OperationMessageType +const ( + SubscriptionsTransportWS SubscriptionProtocolType = "subscriptions-transport-ws" + GraphQLWS SubscriptionProtocolType = "graphql-ws" + + // Receiving a message of a type or format which is not specified in this document + // The can be vaguely descriptive on why the received message is invalid. + StatusInvalidMessage websocket.StatusCode = 4400 + // if the connection is not acknowledged, the socket will be closed immediately with the event 4401: Unauthorized + StatusUnauthorized websocket.StatusCode = 4401 + // Connection initialisation timeout + StatusConnectionInitialisationTimeout websocket.StatusCode = 4408 + // Subscriber for already exists + StatusSubscriberAlreadyExists websocket.StatusCode = 4409 + // Too many initialisation requests + StatusTooManyInitialisationRequests websocket.StatusCode = 4429 +) + +// OperationMessageType represents a subscription message enum type type OperationMessageType string const ( - // Client sends this message after plain websocket connection to start the communication with the server - GQL_CONNECTION_INIT OperationMessageType = "connection_init" - // The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server rejected the connection. - GQL_CONNECTION_ERROR OperationMessageType = "conn_err" - // Client sends this message to execute GraphQL operation - GQL_START OperationMessageType = "start" - // Client sends this message in order to stop a running GraphQL operation execution (for example: unsubscribe) - GQL_STOP OperationMessageType = "stop" - // Server sends this message upon a failing operation, before the GraphQL execution, usually due to GraphQL validation errors (resolver errors are part of GQL_DATA message, and will be added as errors array) - GQL_ERROR OperationMessageType = "error" - // The server sends this message to transfter the GraphQL execution result from the server to the client, this message is a response for GQL_START message. - GQL_DATA OperationMessageType = "data" - // Server sends this message to indicate that a GraphQL operation is done, and no more data will arrive for the specific operation. - GQL_COMPLETE OperationMessageType = "complete" - // Server message that should be sent right after each GQL_CONNECTION_ACK processed and then periodically to keep the client connection alive. - // The client starts to consider the keep alive message only upon the first received keep alive message from the server. - GQL_CONNECTION_KEEP_ALIVE OperationMessageType = "ka" - // The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server accepted the connection. May optionally include a payload. - GQL_CONNECTION_ACK OperationMessageType = "connection_ack" - // Client sends this message to terminate the connection. - GQL_CONNECTION_TERMINATE OperationMessageType = "connection_terminate" // Unknown operation type, for logging only - GQL_UNKNOWN OperationMessageType = "unknown" + GQLUnknown OperationMessageType = "unknown" // Internal status, for logging only - GQL_INTERNAL OperationMessageType = "internal" + GQLInternal OperationMessageType = "internal" + + // @deprecated: use GQLUnknown instead + GQL_UNKNOWN = GQLUnknown + // @deprecated: use GQLInternal instead + GQL_INTERNAL = GQLInternal ) // ErrSubscriptionStopped a special error which forces the subscription stop @@ -61,6 +62,7 @@ type OperationMessage struct { Payload json.RawMessage `json:"payload,omitempty"` } +// String overrides the default Stringer to return json string for debugging func (om OperationMessage) String() string { bs, _ := json.Marshal(om) @@ -80,46 +82,194 @@ type WebsocketConn interface { SetReadLimit(limit int64) } +// SubscriptionProtocol abstracts the life-cycle of subscription protocol implementation for a specific transport protocol +type SubscriptionProtocol interface { + // GetSubprotocols returns subprotocol names of the subscription transport + // The graphql server depends on the Sec-WebSocket-Protocol header to return the correct message specification + GetSubprotocols() []string + // ConnectionInit sends a initial request to establish a connection within the existing socket + ConnectionInit(ctx *SubscriptionContext, connectionParams map[string]interface{}) error + // Subscribe requests an graphql operation specified in the payload message + Subscribe(ctx *SubscriptionContext, id string, sub *Subscription) error + // Unsubscribe sends a request to stop listening and complete the subscription + Unsubscribe(ctx *SubscriptionContext, id string) error + // OnMessage listens ongoing messages from server + OnMessage(ctx *SubscriptionContext, subscription *Subscription, message OperationMessage) + // Close terminates all subscriptions of the current websocket + Close(ctx *SubscriptionContext) error +} + +// SubscriptionContext represents a shared context for protocol implementations with the websocket connection inside +type SubscriptionContext struct { + context.Context + WebsocketConn + OnConnected func() + onDisconnected func() + cancel context.CancelFunc + subscriptions map[string]*Subscription + disabledLogTypes []OperationMessageType + log func(args ...interface{}) + acknowledged int64 + exitStatusCodes []int + mutex sync.Mutex +} + +// Log prints condition logging with message type filters +func (sc *SubscriptionContext) Log(message interface{}, source string, opType OperationMessageType) { + if sc == nil || sc.log == nil { + return + } + for _, ty := range sc.disabledLogTypes { + if ty == opType { + return + } + } + + sc.log(message, source) +} + +// GetWebsocketConn get the current websocket connection +func (sc *SubscriptionContext) GetWebsocketConn() WebsocketConn { + return sc.WebsocketConn +} + +// SetWebsocketConn set the current websocket connection +func (sc *SubscriptionContext) SetWebsocketConn(conn WebsocketConn) { + sc.mutex.Lock() + defer sc.mutex.Unlock() + sc.WebsocketConn = conn +} + +// GetSubscription get the subscription state by id +func (sc *SubscriptionContext) GetSubscription(id string) *Subscription { + sc.mutex.Lock() + defer sc.mutex.Unlock() + if sc.subscriptions == nil { + return nil + } + sub, _ := sc.subscriptions[id] + return sub +} + +// GetSubscription get all available subscriptions in the context +func (sc *SubscriptionContext) GetSubscriptions() map[string]*Subscription { + newMap := make(map[string]*Subscription) + for k, v := range sc.subscriptions { + newMap[k] = v + } + return newMap +} + +// SetSubscription set the input subscription state into the context +// if subscription is nil, removes the subscription from the map +func (sc *SubscriptionContext) SetSubscription(id string, sub *Subscription) { + sc.mutex.Lock() + if sub == nil { + delete(sc.subscriptions, id) + } else { + sc.subscriptions[id] = sub + } + sc.mutex.Unlock() +} + +// GetAcknowledge get the acknowledge status +func (sc *SubscriptionContext) GetAcknowledge() bool { + return atomic.LoadInt64(&sc.acknowledged) > 0 +} + +// SetAcknowledge set the acknowledge status +func (sc *SubscriptionContext) SetAcknowledge(value bool) { + if value { + atomic.StoreInt64(&sc.acknowledged, 1) + } else { + atomic.StoreInt64(&sc.acknowledged, 0) + } +} + +// Close closes the context and the inner websocket connection if exists +func (sc *SubscriptionContext) Close() error { + if conn := sc.GetWebsocketConn(); conn != nil { + err := conn.Close() + sc.SetWebsocketConn(nil) + if err != nil { + return err + } + } + if sc.cancel != nil { + sc.cancel() + } + + return nil +} + +// Send emits a message to the graphql server +func (sc *SubscriptionContext) Send(message interface{}, opType OperationMessageType) error { + if conn := sc.GetWebsocketConn(); conn != nil { + sc.Log(message, "client", opType) + return conn.WriteJSON(message) + } + return nil +} + type handlerFunc func(data []byte, err error) error -type subscription struct { - query string - variables map[string]interface{} - handler func(data []byte, err error) - started Boolean + +// Subscription stores the subscription declaration and its state +type Subscription struct { + payload GraphQLRequestPayload + handler func(data []byte, err error) + started bool } -// SubscriptionClient is a GraphQL subscription client. -type SubscriptionClient struct { - url string - conn WebsocketConn - connectionParams map[string]interface{} - websocketOptions WebsocketOptions - context context.Context - subscriptions map[string]*subscription - cancel context.CancelFunc - subscribersMu sync.Mutex - timeout time.Duration - isRunning int64 - readLimit int64 // max size of response message. Default 10 MB - log func(args ...interface{}) - createConn func(sc *SubscriptionClient) (WebsocketConn, error) - retryTimeout time.Duration - onConnected func() - onDisconnected func() - onError func(sc *SubscriptionClient, err error) error - errorChan chan error - disabledLogTypes []OperationMessageType +// GetPayload returns the graphql request payload +func (s Subscription) GetPayload() GraphQLRequestPayload { + return s.payload +} + +// GetStarted a public getter for the started status +func (s Subscription) GetStarted() bool { + return s.started +} + +// SetStarted a public getter for the started status +func (s *Subscription) SetStarted(value bool) { + s.started = value } +// GetHandler a public getter for the subscription handler +func (s Subscription) GetHandler() func(data []byte, err error) { + return s.handler +} + +// SubscriptionClient is a GraphQL subscription client. +type SubscriptionClient struct { + url string + context *SubscriptionContext + connectionParams map[string]interface{} + connectionParamsFn func() map[string]interface{} + protocol SubscriptionProtocol + websocketOptions WebsocketOptions + timeout time.Duration + isRunning int64 + readLimit int64 // max size of response message. Default 10 MB + createConn func(sc *SubscriptionClient) (WebsocketConn, error) + retryTimeout time.Duration + onError func(sc *SubscriptionClient, err error) error + errorChan chan error +} + +// NewSubscriptionClient constructs new subscription client func NewSubscriptionClient(url string) *SubscriptionClient { return &SubscriptionClient{ - url: url, - timeout: time.Minute, - readLimit: 10 * 1024 * 1024, // set default limit 10MB - subscriptions: make(map[string]*subscription), - createConn: newWebsocketConn, - retryTimeout: time.Minute, - errorChan: make(chan error), + url: url, + timeout: time.Minute, + readLimit: 10 * 1024 * 1024, // set default limit 10MB + createConn: newWebsocketConn, + retryTimeout: time.Minute, + errorChan: make(chan error), + protocol: &subscriptionsTransportWS{}, + context: &SubscriptionContext{ + subscriptions: make(map[string]*Subscription), + }, } } @@ -128,16 +278,16 @@ func (sc *SubscriptionClient) GetURL() string { return sc.url } -// GetContext returns current context of subscription client -func (sc *SubscriptionClient) GetContext() context.Context { - return sc.context -} - -// GetContext returns write timeout of websocket client +// GetTimeout returns write timeout of websocket client func (sc *SubscriptionClient) GetTimeout() time.Duration { return sc.timeout } +// GetContext returns current context of subscription client +func (sc *SubscriptionClient) GetContext() context.Context { + return sc.context.Context +} + // WithWebSocket replaces customized websocket client constructor // In default, subscription client uses https://github.com/nhooyr/websocket func (sc *SubscriptionClient) WithWebSocket(fn func(sc *SubscriptionClient) (WebsocketConn, error)) *SubscriptionClient { @@ -145,6 +295,21 @@ func (sc *SubscriptionClient) WithWebSocket(fn func(sc *SubscriptionClient) (Web return sc } +// WithProtocol changes the subscription protocol implementation +// By default the subscription client uses the subscriptions-transport-ws protocol +func (sc *SubscriptionClient) WithProtocol(protocol SubscriptionProtocolType) *SubscriptionClient { + + switch protocol { + case GraphQLWS: + sc.protocol = &graphqlWS{} + case SubscriptionsTransportWS: + sc.protocol = &subscriptionsTransportWS{} + default: + panic(fmt.Sprintf("unknown subscription protocol %s", protocol)) + } + return sc +} + // WithWebSocketOptions provides options to the websocket client func (sc *SubscriptionClient) WithWebSocketOptions(options WebsocketOptions) *SubscriptionClient { sc.websocketOptions = options @@ -158,6 +323,13 @@ func (sc *SubscriptionClient) WithConnectionParams(params map[string]interface{} return sc } +// WithConnectionParamsFn set a function that returns connection params for sending to server through GQL_CONNECTION_INIT event +// It's suitable for short-lived access tokens that need to be refreshed frequently +func (sc *SubscriptionClient) WithConnectionParamsFn(fn func() map[string]interface{}) *SubscriptionClient { + sc.connectionParamsFn = fn + return sc +} + // WithTimeout updates write timeout of websocket client func (sc *SubscriptionClient) WithTimeout(timeout time.Duration) *SubscriptionClient { sc.timeout = timeout @@ -165,20 +337,21 @@ func (sc *SubscriptionClient) WithTimeout(timeout time.Duration) *SubscriptionCl } // WithRetryTimeout updates reconnecting timeout. When the websocket server was stopped, the client will retry connecting every second until timeout +// The zero value means unlimited timeout func (sc *SubscriptionClient) WithRetryTimeout(timeout time.Duration) *SubscriptionClient { sc.retryTimeout = timeout return sc } -// WithLog sets loging function to print out received messages. By default, nothing is printed +// WithLog sets logging function to print out received messages. By default, nothing is printed func (sc *SubscriptionClient) WithLog(logger func(args ...interface{})) *SubscriptionClient { - sc.log = logger + sc.context.log = logger return sc } // WithoutLogTypes these operation types won't be printed func (sc *SubscriptionClient) WithoutLogTypes(types ...OperationMessageType) *SubscriptionClient { - sc.disabledLogTypes = types + sc.context.disabledLogTypes = types return sc } @@ -198,17 +371,18 @@ func (sc *SubscriptionClient) OnError(onError func(sc *SubscriptionClient, err e // OnConnected event is triggered when the websocket connected to GraphQL server successfully func (sc *SubscriptionClient) OnConnected(fn func()) *SubscriptionClient { - sc.onConnected = fn + sc.context.OnConnected = fn return sc } // OnDisconnected event is triggered when the websocket client was disconnected func (sc *SubscriptionClient) OnDisconnected(fn func()) *SubscriptionClient { - sc.onDisconnected = fn + sc.context.onDisconnected = fn return sc } -func (sc *SubscriptionClient) setIsRunning(value Boolean) { +// set the running atomic lock status +func (sc *SubscriptionClient) setIsRunning(value bool) { if value { atomic.StoreInt64(&sc.isRunning, 1) } else { @@ -216,85 +390,50 @@ func (sc *SubscriptionClient) setIsRunning(value Boolean) { } } +// initializes the websocket connection func (sc *SubscriptionClient) init() error { now := time.Now() ctx, cancel := context.WithCancel(context.Background()) - sc.context = ctx - sc.cancel = cancel + sc.context.Context = ctx + sc.context.cancel = cancel for { var err error var conn WebsocketConn // allow custom websocket client - if sc.conn == nil { + if sc.context.GetWebsocketConn() == nil { conn, err = sc.createConn(sc) if err == nil { - sc.conn = conn + sc.context.SetWebsocketConn(conn) } } if err == nil { - sc.conn.SetReadLimit(sc.readLimit) + sc.context.SetReadLimit(sc.readLimit) // send connection init event to the server - err = sc.sendConnectionInit() + connectionParams := sc.connectionParams + if sc.connectionParamsFn != nil { + connectionParams = sc.connectionParamsFn() + } + err = sc.protocol.ConnectionInit(sc.context, connectionParams) } if err == nil { return nil } - if now.Add(sc.retryTimeout).Before(time.Now()) { - if sc.onDisconnected != nil { - sc.onDisconnected() + if sc.retryTimeout > 0 && now.Add(sc.retryTimeout).Before(time.Now()) { + if sc.context.onDisconnected != nil { + sc.context.onDisconnected() } return err } - sc.printLog(fmt.Sprintf("%s. retry in second...", err.Error()), "client", GQL_INTERNAL) + sc.context.Log(fmt.Sprintf("%s. retry in second...", err.Error()), "client", GQLInternal) time.Sleep(time.Second) } } -func (sc *SubscriptionClient) writeJSON(v interface{}) error { - if sc.conn != nil { - return sc.conn.WriteJSON(v) - } - return nil -} - -func (sc *SubscriptionClient) printLog(message interface{}, source string, opType OperationMessageType) { - if sc.log == nil { - return - } - for _, ty := range sc.disabledLogTypes { - if ty == opType { - return - } - } - - sc.log(message, source) -} - -func (sc *SubscriptionClient) sendConnectionInit() (err error) { - var bParams []byte = nil - if sc.connectionParams != nil { - - bParams, err = json.Marshal(sc.connectionParams) - if err != nil { - return - } - } - - // send connection_init event to the server - msg := OperationMessage{ - Type: GQL_CONNECTION_INIT, - Payload: bParams, - } - - sc.printLog(msg, "client", GQL_CONNECTION_INIT) - return sc.writeJSON(msg) -} - // Subscribe sends start message to server and open a channel to receive data. // The handler callback function will receive raw message data or error. If the call return error, onError event will be triggered // The function returns subscription ID and error. You can use subscription ID to unsubscribe the subscription @@ -332,61 +471,26 @@ func (sc *SubscriptionClient) do(v interface{}, variables map[string]interface{} func (sc *SubscriptionClient) doRaw(query string, variables map[string]interface{}, handler func(message []byte, err error) error) (string, error) { id := uuid.New().String() - sub := subscription{ - query: query, - variables: variables, - handler: sc.wrapHandler(handler), + sub := Subscription{ + payload: GraphQLRequestPayload{ + Query: query, + Variables: variables, + }, + handler: sc.wrapHandler(handler), } // if the websocket client is running, start subscription immediately if atomic.LoadInt64(&sc.isRunning) > 0 { - if err := sc.startSubscription(id, &sub); err != nil { + if err := sc.protocol.Subscribe(sc.context, id, &sub); err != nil { return "", err } } - sc.subscribersMu.Lock() - sc.subscriptions[id] = &sub - sc.subscribersMu.Unlock() + sc.context.SetSubscription(id, &sub) return id, nil } -// Subscribe sends start message to server and open a channel to receive data -func (sc *SubscriptionClient) startSubscription(id string, sub *subscription) error { - if sub == nil || sub.started { - return nil - } - - in := struct { - Query string `json:"query"` - Variables map[string]interface{} `json:"variables,omitempty"` - }{ - Query: sub.query, - Variables: sub.variables, - } - - payload, err := json.Marshal(in) - if err != nil { - return err - } - - // send stop message to the server - msg := OperationMessage{ - ID: id, - Type: GQL_START, - Payload: payload, - } - - sc.printLog(msg, "client", GQL_START) - if err := sc.writeJSON(msg); err != nil { - return err - } - - sub.started = true - return nil -} - func (sc *SubscriptionClient) wrapHandler(fn handlerFunc) func(data []byte, err error) { return func(data []byte, err error) { if errValue := fn(data, err); errValue != nil { @@ -395,22 +499,18 @@ func (sc *SubscriptionClient) wrapHandler(fn handlerFunc) func(data []byte, err } } +// Unsubscribe sends stop message to server and close subscription channel +// The input parameter is subscription ID that is returned from Subscribe function +func (sc *SubscriptionClient) Unsubscribe(id string) error { + return sc.protocol.Unsubscribe(sc.context, id) +} + // Run start websocket client and subscriptions. If this function is run with goroutine, it can be stopped after closed func (sc *SubscriptionClient) Run() error { if err := sc.init(); err != nil { return fmt.Errorf("retry timeout. exiting...") } - // lazily start subscriptions - sc.subscribersMu.Lock() - for k, v := range sc.subscriptions { - if err := sc.startSubscription(k, v); err != nil { - sc.Unsubscribe(k) - return err - } - } - sc.subscribersMu.Unlock() - sc.setIsRunning(true) go func() { for atomic.LoadInt64(&sc.isRunning) > 0 { @@ -418,12 +518,12 @@ func (sc *SubscriptionClient) Run() error { case <-sc.context.Done(): return default: - if sc.conn == nil { + if sc.context == nil || sc.context.GetWebsocketConn() == nil { return } var message OperationMessage - if err := sc.conn.ReadJSON(&message); err != nil { + if err := sc.context.ReadJSON(&message); err != nil { // manual EOF check if err == io.EOF || strings.Contains(err.Error(), "EOF") { if err = sc.Reset(); err != nil { @@ -432,12 +532,17 @@ func (sc *SubscriptionClient) Run() error { } } closeStatus := websocket.CloseStatus(err) - if closeStatus == websocket.StatusNormalClosure { + switch closeStatus { + case websocket.StatusNormalClosure, websocket.StatusAbnormalClosure: // close event from websocket client, exiting... return + case StatusConnectionInitialisationTimeout, StatusTooManyInitialisationRequests, StatusSubscriberAlreadyExists, StatusUnauthorized: + sc.context.Log(err, "server", GQLError) + return } - if closeStatus != -1 { - sc.printLog(fmt.Sprintf("%s. Retry connecting...", err), "client", GQL_INTERNAL) + + if closeStatus != -1 && closeStatus < 3000 && closeStatus > 4999 { + sc.context.Log(fmt.Sprintf("%s. Retry connecting...", err), "client", GQLInternal) if err = sc.Reset(); err != nil { sc.errorChan <- err return @@ -446,66 +551,16 @@ func (sc *SubscriptionClient) Run() error { if sc.onError != nil { if err = sc.onError(sc, err); err != nil { + // end the subscription if the callback return error + sc.Close() return } } continue } - switch message.Type { - case GQL_ERROR: - sc.printLog(message, "server", GQL_ERROR) - fallthrough - case GQL_DATA: - sc.printLog(message, "server", GQL_DATA) - id, err := uuid.Parse(message.ID) - if err != nil { - continue - } - - sc.subscribersMu.Lock() - sub, ok := sc.subscriptions[id.String()] - sc.subscribersMu.Unlock() - - if !ok { - continue - } - var out struct { - Data *json.RawMessage - Errors Errors - } - - err = json.Unmarshal(message.Payload, &out) - if err != nil { - go sub.handler(nil, err) - continue - } - if len(out.Errors) > 0 { - go sub.handler(nil, out.Errors) - continue - } - - var outData []byte - if out.Data != nil && len(*out.Data) > 0 { - outData = *out.Data - } - - go sub.handler(outData, nil) - case GQL_CONNECTION_ERROR: - sc.printLog(message, "server", GQL_CONNECTION_ERROR) - case GQL_COMPLETE: - sc.printLog(message, "server", GQL_COMPLETE) - sc.Unsubscribe(message.ID) - case GQL_CONNECTION_KEEP_ALIVE: - sc.printLog(message, "server", GQL_CONNECTION_KEEP_ALIVE) - case GQL_CONNECTION_ACK: - sc.printLog(message, "server", GQL_CONNECTION_ACK) - if sc.onConnected != nil { - sc.onConnected() - } - default: - sc.printLog(message, "server", GQL_UNKNOWN) - } + sub := sc.context.GetSubscription(message.ID) + go sc.protocol.OnMessage(sc.context, sub, message) } } }() @@ -535,82 +590,24 @@ func (sc *SubscriptionClient) Run() error { return sc.Reset() } -// Unsubscribe sends stop message to server and close subscription channel -// The input parameter is subscription ID that is returned from Subscribe function -func (sc *SubscriptionClient) Unsubscribe(id string) error { - sc.subscribersMu.Lock() - defer sc.subscribersMu.Unlock() - - _, ok := sc.subscriptions[id] - if !ok { - return fmt.Errorf("subscription id %s doesn't not exist", id) - } - - delete(sc.subscriptions, id) - err := sc.stopSubscription(id) - if err != nil { - return err - } - - // close the client if there is no running subscription - if len(sc.subscriptions) == 0 { - sc.printLog("no running subscription. exiting...", "client", GQL_INTERNAL) - return sc.Close() - } - return nil -} - -func (sc *SubscriptionClient) stopSubscription(id string) error { - if sc.conn != nil { - // send stop message to the server - msg := OperationMessage{ - ID: id, - Type: GQL_STOP, - } - - sc.printLog(msg, "server", GQL_STOP) - if err := sc.writeJSON(msg); err != nil { - return err - } - - } - - return nil -} - -func (sc *SubscriptionClient) terminate() error { - // send terminate message to the server - msg := OperationMessage{ - Type: GQL_CONNECTION_TERMINATE, - } - - if sc.conn != nil { - sc.printLog(msg, "client", GQL_CONNECTION_TERMINATE) - return sc.writeJSON(msg) - } - - return nil -} - // Reset restart websocket connection and subscriptions func (sc *SubscriptionClient) Reset() error { - if atomic.LoadInt64(&sc.isRunning) == 0 { - return nil - } - - sc.subscribersMu.Lock() - for id, sub := range sc.subscriptions { - _ = sc.stopSubscription(id) - sub.started = false + sc.context.SetAcknowledge(false) + isRunning := atomic.LoadInt64(&sc.isRunning) == 0 + + for id, sub := range sc.context.GetSubscriptions() { + sub.SetStarted(false) + if isRunning { + _ = sc.protocol.Unsubscribe(sc.context, id) + sc.context.SetSubscription(id, sub) + } } - sc.subscribersMu.Unlock() - if sc.conn != nil { - _ = sc.terminate() - _ = sc.conn.Close() - sc.conn = nil + if sc.context.GetWebsocketConn() != nil { + _ = sc.protocol.Close(sc.context) + _ = sc.context.Close() + sc.context.SetWebsocketConn(nil) } - sc.cancel() return sc.Run() } @@ -618,26 +615,46 @@ func (sc *SubscriptionClient) Reset() error { // Close closes all subscription channel and websocket as well func (sc *SubscriptionClient) Close() (err error) { sc.setIsRunning(false) - for id := range sc.subscriptions { - if err = sc.Unsubscribe(id); err != nil { - sc.cancel() + for id := range sc.context.GetSubscriptions() { + if err = sc.protocol.Unsubscribe(sc.context, id); err != nil { + sc.context.cancel() return } } - _ = sc.terminate() - if sc.conn != nil { - err = sc.conn.Close() - sc.conn = nil - if sc.onDisconnected != nil { - sc.onDisconnected() + if sc.context != nil { + _ = sc.protocol.Close(sc.context) + err = sc.context.Close() + sc.context.SetWebsocketConn(nil) + if sc.context.onDisconnected != nil { + sc.context.onDisconnected() } } - sc.cancel() return } +// the reusable function for sending connection init message. +// The payload format of both subscriptions-transport-ws and graphql-ws are the same +func connectionInit(conn *SubscriptionContext, connectionParams map[string]interface{}) error { + var bParams []byte = nil + var err error + if connectionParams != nil { + bParams, err = json.Marshal(connectionParams) + if err != nil { + return err + } + } + + // send connection_init event to the server + msg := OperationMessage{ + Type: GQLConnectionInit, + Payload: bParams, + } + + return conn.Send(msg, GQLConnectionInit) +} + // default websocket handler implementation using https://github.com/nhooyr/websocket type WebsocketHandler struct { ctx context.Context @@ -645,6 +662,7 @@ type WebsocketHandler struct { *websocket.Conn } +// WriteJSON implements the function to encode and send message in json format to the server func (wh *WebsocketHandler) WriteJSON(v interface{}) error { ctx, cancel := context.WithTimeout(wh.ctx, wh.timeout) defer cancel() @@ -652,20 +670,24 @@ func (wh *WebsocketHandler) WriteJSON(v interface{}) error { return wsjson.Write(ctx, wh.Conn, v) } +// ReadJSON implements the function to decode the json message from the server func (wh *WebsocketHandler) ReadJSON(v interface{}) error { ctx, cancel := context.WithTimeout(wh.ctx, wh.timeout) defer cancel() return wsjson.Read(ctx, wh.Conn, v) } +// Close implements the function to close the websocket connection func (wh *WebsocketHandler) Close() error { return wh.Conn.Close(websocket.StatusNormalClosure, "close websocket") } +// the default constructor function to create a websocket client +// which uses https://github.com/nhooyr/websocket library func newWebsocketConn(sc *SubscriptionClient) (WebsocketConn, error) { options := &websocket.DialOptions{ - Subprotocols: []string{"graphql-ws"}, + Subprotocols: sc.protocol.GetSubprotocols(), HTTPClient: sc.websocketOptions.HTTPClient, } diff --git a/subscription_graphql_ws.go b/subscription_graphql_ws.go new file mode 100644 index 0000000..adab049 --- /dev/null +++ b/subscription_graphql_ws.go @@ -0,0 +1,171 @@ +package graphql + +import ( + "encoding/json" + "fmt" +) + +// This package implements GraphQL over WebSocket Protocol (graphql-ws) +// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md + +const ( + // Indicates that the client wants to establish a connection within the existing socket. + // This connection is not the actual WebSocket communication channel, but is rather a frame within it asking the server to allow future operation requests. + GQLConnectionInit OperationMessageType = "connection_init" + // Expected response to the ConnectionInit message from the client acknowledging a successful connection with the server. + GQLConnectionAck OperationMessageType = "connection_ack" + // The Ping message can be sent at any time within the established socket. + GQLPing OperationMessageType = "ping" + // The response to the Ping message. Must be sent as soon as the Ping message is received. + GQLPong OperationMessageType = "pong" + // Requests an operation specified in the message payload. This message provides a unique ID field to connect published messages to the operation requested by this message. + GQLSubscribe OperationMessageType = "subscribe" + // Operation execution result(s) from the source stream created by the binding Subscribe message. After all results have been emitted, the Complete message will follow indicating stream completion. + GQLNext OperationMessageType = "next" + // Operation execution error(s) in response to the Subscribe message. + // This can occur before execution starts, usually due to validation errors, or during the execution of the request. + GQLError OperationMessageType = "error" + // indicates that the requested operation execution has completed. If the server dispatched the Error message relative to the original Subscribe message, no Complete message will be emitted. + GQLComplete OperationMessageType = "complete" +) + +type graphqlWS struct { +} + +// GetSubprotocols returns subprotocol names of the subscription transport +func (gws graphqlWS) GetSubprotocols() []string { + return []string{"graphql-transport-ws"} +} + +// ConnectionInit sends a initial request to establish a connection within the existing socket +func (gws *graphqlWS) ConnectionInit(ctx *SubscriptionContext, connectionParams map[string]interface{}) error { + return connectionInit(ctx, connectionParams) +} + +// Subscribe requests an graphql operation specified in the payload message +func (gws *graphqlWS) Subscribe(ctx *SubscriptionContext, id string, sub *Subscription) error { + if sub.GetStarted() { + return nil + } + payload, err := json.Marshal(sub.GetPayload()) + if err != nil { + return err + } + // send start message to the server + msg := OperationMessage{ + ID: id, + Type: GQLSubscribe, + Payload: payload, + } + + if err := ctx.Send(msg, GQLSubscribe); err != nil { + return err + } + + sub.SetStarted(true) + return nil +} + +// Unsubscribe sends stop message to server and close subscription channel +// The input parameter is subscription ID that is returned from Subscribe function +func (gws *graphqlWS) Unsubscribe(ctx *SubscriptionContext, id string) error { + if ctx == nil || ctx.WebsocketConn == nil { + return nil + } + sub := ctx.GetSubscription(id) + + if sub == nil { + return fmt.Errorf("subscription id %s doesn't not exist", id) + } + + ctx.SetSubscription(id, nil) + + // send stop message to the server + msg := OperationMessage{ + ID: id, + Type: GQLComplete, + } + + err := ctx.Send(msg, GQLComplete) + if err != nil { + return err + } + + // close the client if there is no running subscription + if len(ctx.GetSubscriptions()) == 0 { + ctx.Log("no running subscription. exiting...", "client", GQLInternal) + return ctx.Close() + } + + return nil +} + +// OnMessage listens ongoing messages from server +func (gws *graphqlWS) OnMessage(ctx *SubscriptionContext, subscription *Subscription, message OperationMessage) { + + switch message.Type { + case GQLError: + ctx.Log(message, "server", message.Type) + case GQLNext: + ctx.Log(message, "server", message.Type) + if subscription == nil { + return + } + var out struct { + Data *json.RawMessage + Errors Errors + } + + err := json.Unmarshal(message.Payload, &out) + if err != nil { + subscription.handler(nil, err) + return + } + if len(out.Errors) > 0 { + subscription.handler(nil, out.Errors) + return + } + + var outData []byte + if out.Data != nil && len(*out.Data) > 0 { + outData = *out.Data + } + + subscription.handler(outData, nil) + case GQLComplete: + ctx.Log(message, "server", message.Type) + _ = gws.Unsubscribe(ctx, message.ID) + case GQLPing: + ctx.Log(message, "server", GQLPing) + // send pong response message back to the server + msg := OperationMessage{ + Type: GQLPong, + Payload: message.Payload, + } + + if err := ctx.Send(msg, GQLPong); err != nil { + ctx.Log(err, "client", GQLInternal) + } + case GQLConnectionAck: + // Expected response to the ConnectionInit message from the client acknowledging a successful connection with the server. + // The client is now ready to request subscription operations. + ctx.Log(message, "server", GQLConnectionAck) + ctx.SetAcknowledge(true) + for id, sub := range ctx.GetSubscriptions() { + if err := gws.Subscribe(ctx, id, sub); err != nil { + gws.Unsubscribe(ctx, id) + return + } + } + if ctx.OnConnected != nil { + ctx.OnConnected() + } + default: + ctx.Log(message, "server", GQLUnknown) + } +} + +// Close terminates all subscriptions of the current websocket +func (gws *graphqlWS) Close(conn *SubscriptionContext) error { + return nil +} diff --git a/subscription_graphql_ws_test.go b/subscription_graphql_ws_test.go new file mode 100644 index 0000000..b24b20d --- /dev/null +++ b/subscription_graphql_ws_test.go @@ -0,0 +1,176 @@ +package graphql + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "testing" + "time" +) + +const ( + hasuraTestHost = "http://localhost:8080" + hasuraTestAdminSecret = "hasura" +) + +type headerRoundTripper struct { + setHeaders func(req *http.Request) + rt http.RoundTripper +} + +func (h headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + h.setHeaders(req) + return h.rt.RoundTrip(req) +} + +type user_insert_input map[string]interface{} + +func graphqlWS_setupClients() (*Client, *SubscriptionClient) { + endpoint := fmt.Sprintf("%s/v1/graphql", hasuraTestHost) + client := NewClient(endpoint, &http.Client{Transport: headerRoundTripper{ + setHeaders: func(req *http.Request) { + req.Header.Set("x-hasura-admin-secret", hasuraTestAdminSecret) + }, + rt: http.DefaultTransport, + }}) + + subscriptionClient := NewSubscriptionClient(endpoint). + WithProtocol(GraphQLWS). + WithConnectionParams(map[string]interface{}{ + "headers": map[string]string{ + "x-hasura-admin-secret": hasuraTestAdminSecret, + }, + }).WithLog(log.Println) + + return client, subscriptionClient +} + +func waitService(endpoint string, timeoutSecs int) error { + var err error + var res *http.Response + for i := 0; i < timeoutSecs; i++ { + res, err = http.Get(endpoint) + if err == nil && res.StatusCode == 200 { + return nil + } + + time.Sleep(time.Second) + } + + if err != nil { + return err + } + + if res != nil { + body, err := io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf(res.Status) + } + return fmt.Errorf(string(body)) + } + return errors.New("unknown error") +} + +func waitHasuraService(timeoutSecs int) error { + return waitService(fmt.Sprintf("%s/healthz", hasuraTestHost), timeoutSecs) +} + +func TestGraphqlWS_Subscription(t *testing.T) { + stop := make(chan bool) + client, subscriptionClient := graphqlWS_setupClients() + msg := randomID() + + subscriptionClient = subscriptionClient. + OnError(func(sc *SubscriptionClient, err error) error { + return err + }) + + /* + subscription { + user { + id + name + } + } + */ + var sub struct { + Users []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"user(order_by: { id: desc }, limit: 5)"` + } + + _, err := subscriptionClient.Subscribe(sub, nil, func(data []byte, e error) error { + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + log.Println("result", string(data)) + e = json.Unmarshal(data, &sub) + if e != nil { + t.Fatalf("got error: %v, want: nil", e) + return nil + } + + if len(sub.Users) > 0 && sub.Users[0].Name != msg { + t.Fatalf("subscription message does not match. got: %s, want: %s", sub.Users[0].Name, msg) + } + + return errors.New("exit") + }) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + go func() { + if err := subscriptionClient.Run(); err == nil || err.Error() != "exit" { + (*t).Fatalf("got error: %v, want: exit", err) + } + stop <- true + }() + + defer subscriptionClient.Close() + + // wait until the subscription client connects to the server + if err := waitHasuraService(60); err != nil { + t.Fatalf("failed to start hasura service: %s", err) + } + + // call a mutation request to send message to the subscription + /* + mutation InsertUser($objects: [user_insert_input!]!) { + insert_user(objects: $objects) { + id + name + } + } + */ + var q struct { + InsertUser struct { + Returning []struct { + ID int `graphql:"id"` + Name string `graphql:"name"` + } `graphql:"returning"` + } `graphql:"insert_user(objects: $objects)"` + } + variables := map[string]interface{}{ + "objects": []user_insert_input{ + { + "name": msg, + }, + }, + } + err = client.Mutate(context.Background(), &q, variables, OperationName("InsertUser")) + + if err != nil { + t.Fatalf("got error: %v, want: nil", err) + } + + <-stop +} diff --git a/subscription_test.go b/subscription_test.go index 6cfcb1e..64e60da 100644 --- a/subscription_test.go +++ b/subscription_test.go @@ -37,7 +37,7 @@ type HelloSaidEvent { ` func subscription_setupClients() (*Client, *SubscriptionClient) { - endpoint := "http://localhost:8080/graphql" + endpoint := "http://localhost:8081/graphql" client := NewClient(endpoint, &http.Client{Transport: http.DefaultTransport}) @@ -63,7 +63,7 @@ func subscription_setupServer() *http.Server { mux := http.NewServeMux() graphQLHandler := graphqlws.NewHandlerFunc(s, &relay.Handler{Schema: s}) mux.HandleFunc("/graphql", graphQLHandler) - server := &http.Server{Addr: ":8080", Handler: mux} + server := &http.Server{Addr: ":8081", Handler: mux} return server } diff --git a/subscriptions_transport_ws.go b/subscriptions_transport_ws.go new file mode 100644 index 0000000..7b5a0d6 --- /dev/null +++ b/subscriptions_transport_ws.go @@ -0,0 +1,204 @@ +package graphql + +import ( + "encoding/json" + "fmt" +) + +// Subscription transport follow Apollo's subscriptions-transport-ws protocol specification +// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md + +const ( + // The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server rejected the connection. + GQLConnectionError OperationMessageType = "connection_error" + // Client sends this message to execute GraphQL operation + GQLStart OperationMessageType = "start" + // Client sends this message in order to stop a running GraphQL operation execution (for example: unsubscribe) + GQLStop OperationMessageType = "stop" + // Client sends this message in order to stop a running GraphQL operation execution (for example: unsubscribe) + GQLData OperationMessageType = "data" + // Server message that should be sent right after each GQL_CONNECTION_ACK processed and then periodically to keep the client connection alive. + // The client starts to consider the keep alive message only upon the first received keep alive message from the server. + GQLConnectionKeepAlive OperationMessageType = "ka" + // Client sends this message to terminate the connection. + GQLConnectionTerminate OperationMessageType = "connection_terminate" + + // Client sends this message after plain websocket connection to start the communication with the server + // @deprecated: use GQLConnectionInit instead + GQL_CONNECTION_INIT = GQLConnectionInit + // The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server rejected the connection. + // @deprecated: use GQLConnectionError instead + GQL_CONNECTION_ERROR = GQLConnectionError + // Client sends this message to execute GraphQL operation + // @deprecated: use GQLStart instead + GQL_START = GQLStart + // Client sends this message in order to stop a running GraphQL operation execution (for example: unsubscribe) + // @deprecated: use GQLStop instead + GQL_STOP = GQLStop + // Server sends this message upon a failing operation, before the GraphQL execution, usually due to GraphQL validation errors (resolver errors are part of GQL_DATA message, and will be added as errors array) + // @deprecated: use GQLError instead + GQL_ERROR = GQLError + // The server sends this message to transfer the GraphQL execution result from the server to the client, this message is a response for GQL_START message. + // @deprecated: use GQLData instead + GQL_DATA = GQLData + // Server sends this message to indicate that a GraphQL operation is done, and no more data will arrive for the specific operation. + // @deprecated: use GQLComplete instead + GQL_COMPLETE = GQLComplete + // Server message that should be sent right after each GQL_CONNECTION_ACK processed and then periodically to keep the client connection alive. + // The client starts to consider the keep alive message only upon the first received keep alive message from the server. + // @deprecated: use GQLConnectionKeepAlive instead + GQL_CONNECTION_KEEP_ALIVE = GQLConnectionKeepAlive + // The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server accepted the connection. May optionally include a payload. + // @deprecated: use GQLConnectionAck instead + GQL_CONNECTION_ACK = GQLConnectionAck + // Client sends this message to terminate the connection. + // @deprecated: use GQLConnectionTerminate instead + GQL_CONNECTION_TERMINATE = GQLConnectionTerminate +) + +type subscriptionsTransportWS struct { +} + +// GetSubprotocols returns subprotocol names of the subscription transport +func (stw subscriptionsTransportWS) GetSubprotocols() []string { + return []string{"graphql-ws"} +} + +// ConnectionInit sends a initial request to establish a connection within the existing socket +func (stw *subscriptionsTransportWS) ConnectionInit(ctx *SubscriptionContext, connectionParams map[string]interface{}) error { + return connectionInit(ctx, connectionParams) +} + +// Subscribe requests an graphql operation specified in the payload message +func (stw *subscriptionsTransportWS) Subscribe(ctx *SubscriptionContext, id string, sub *Subscription) error { + if sub.GetStarted() { + return nil + } + payload, err := json.Marshal(sub.GetPayload()) + if err != nil { + return err + } + // send start message to the server + msg := OperationMessage{ + ID: id, + Type: GQLStart, + Payload: payload, + } + + if err := ctx.Send(msg, GQLStart); err != nil { + return err + } + + sub.SetStarted(true) + return nil +} + +// Unsubscribe sends stop message to server and close subscription channel +// The input parameter is subscription ID that is returned from Subscribe function +func (stw *subscriptionsTransportWS) Unsubscribe(ctx *SubscriptionContext, id string) error { + if ctx == nil || ctx.WebsocketConn == nil { + return nil + } + sub := ctx.GetSubscription(id) + + if sub == nil { + return fmt.Errorf("subscription id %s doesn't not exist", id) + } + + ctx.SetSubscription(id, nil) + + // send stop message to the server + msg := OperationMessage{ + ID: id, + Type: GQLStop, + } + + err := ctx.Send(msg, GQLStop) + if err != nil { + return err + } + + // close the client if there is no running subscription + if len(ctx.GetSubscriptions()) == 0 { + ctx.Log("no running subscription. exiting...", "client", GQLInternal) + return ctx.Close() + } + + return nil +} + +// OnMessage listens ongoing messages from server +func (stw *subscriptionsTransportWS) OnMessage(ctx *SubscriptionContext, subscription *Subscription, message OperationMessage) { + + switch message.Type { + case GQLError: + ctx.Log(message, "server", GQLError) + case GQLData: + ctx.Log(message, "server", GQLData) + if subscription == nil { + return + } + var out struct { + Data *json.RawMessage + Errors Errors + } + + err := json.Unmarshal(message.Payload, &out) + if err != nil { + subscription.handler(nil, err) + return + } + if len(out.Errors) > 0 { + subscription.handler(nil, out.Errors) + return + } + + var outData []byte + if out.Data != nil && len(*out.Data) > 0 { + outData = *out.Data + } + + subscription.handler(outData, nil) + case GQLConnectionError, "conn_err": + ctx.Log(message, "server", GQLConnectionError) + _ = stw.Close(ctx) + _ = ctx.Close() + ctx.cancel() + case GQLComplete: + ctx.Log(message, "server", GQLComplete) + _ = stw.Unsubscribe(ctx, message.ID) + case GQLConnectionKeepAlive: + ctx.Log(message, "server", GQLConnectionKeepAlive) + case GQLConnectionAck: + // Expected response to the ConnectionInit message from the client acknowledging a successful connection with the server. + // The client is now ready to request subscription operations. + ctx.Log(message, "server", GQLConnectionAck) + ctx.SetAcknowledge(true) + subscriptions := ctx.GetSubscriptions() + for id, sub := range subscriptions { + if err := stw.Subscribe(ctx, id, sub); err != nil { + stw.Unsubscribe(ctx, id) + return + } + } + if ctx.OnConnected != nil { + ctx.OnConnected() + } + default: + ctx.Log(message, "server", GQLUnknown) + } +} + +// Close terminates all subscriptions of the current websocket +func (stw *subscriptionsTransportWS) Close(ctx *SubscriptionContext) error { + // send terminate message to the server + msg := OperationMessage{ + Type: GQLConnectionTerminate, + } + + if ctx.WebsocketConn != nil { + return ctx.Send(msg, GQLConnectionTerminate) + } + + return nil +} diff --git a/type.go b/type.go index c5e9138..a28f85c 100644 --- a/type.go +++ b/type.go @@ -13,3 +13,11 @@ package graphql type GraphQLType interface { GetGraphQLType() string } + +// GraphQLRequestPayload represents the graphql JSON-encoded request body +// https://graphql.org/learn/serving-over-http/#post-request +type GraphQLRequestPayload struct { + Query string `json:"query"` + Variables map[string]interface{} `json:"variables,omitempty"` + OperationName string `json:"operationName,omitempty"` +}