diff --git a/jsonrpc/error.go b/jsonrpc/error.go new file mode 100644 index 00000000..c72070d5 --- /dev/null +++ b/jsonrpc/error.go @@ -0,0 +1,110 @@ +package jsonrpc + +import ( + "reflect" + + "github.com/danielgtaylor/huma/v2" +) + +type JSONRPCErrorCode int + +const ( + // ParseError defines invalid JSON was received by the server. + // An error occurred on the server while parsing the JSON text. + ParseError JSONRPCErrorCode = -32700 + + // InvalidRequestError defines the JSON sent is not a valid Request object. + InvalidRequestError JSONRPCErrorCode = -32600 + + // MethodNotFoundError defines the method does not exist / is not available. + MethodNotFoundError JSONRPCErrorCode = -32601 + + // InvalidParamsError defines invalid method parameter(s). + InvalidParamsError JSONRPCErrorCode = -32602 + + // InternalError defines a server error + InternalError JSONRPCErrorCode = -32603 +) + +var errorMessage = map[JSONRPCErrorCode]string{ + ParseError: "An error occurred on the server while parsing JSON object.", + InvalidRequestError: "The JSON sent is not a valid Request object.", + MethodNotFoundError: "The method does not exist / is not available.", + InvalidParamsError: "Invalid method parameter(s).", + InternalError: "Internal JSON-RPC error.", +} + +func GetDefaultErrorMessage(code JSONRPCErrorCode) string { + return errorMessage[code] +} + +// Error defines a JSON RPC error that can be returned in a Response from the spec +// http://www.jsonrpc.org/specification#error_object +type JSONRPCError struct { + // The error type that occurred. + Code JSONRPCErrorCode `json:"code"` + + // A short description of the error. The message SHOULD be limited to a concise + // single sentence. + Message string `json:"message"` + + // Additional information about the error. The value of this member is defined by + // the sender (e.g. detailed error information, nested errors etc.). + Data interface{} `json:"data,omitempty"` +} + +// Error implements error. +func (e JSONRPCError) Error() string { + if e.Message != "" { + return e.Message + } + return errorMessage[e.Code] +} + +// ErrorCode returns the JSON RPC error code associated with the error. +func (e JSONRPCError) ErrorCode() JSONRPCErrorCode { + return e.Code +} + +type ResponseStatusError struct { + Response[any] + status int `json:"-"` +} + +func (e *ResponseStatusError) Error() string { + if e.Response.Error != nil { + return e.Response.Error.Message + } + return "" +} + +func (e *ResponseStatusError) GetStatus() int { + return e.status +} + +func (e ResponseStatusError) Schema(r huma.Registry) *huma.Schema { + + errorObjectSchema := r.Schema(reflect.TypeOf(e.Response.Error), true, "") + + responseObjectSchema := &huma.Schema{ + Type: huma.TypeObject, + Required: []string{"jsonrpc"}, + Properties: map[string]*huma.Schema{ + "jsonrpc": { + Type: huma.TypeString, + Enum: []any{"2.0"}, + Description: "JSON-RPC version, must be '2.0'", + }, + "id": { + Description: "Request identifier. Compulsory for method responses. This MUST be null to the client in case of parse errors etc.", + OneOf: []*huma.Schema{ + {Type: huma.TypeInteger}, + {Type: huma.TypeString}, + }, + }, + "error": errorObjectSchema, + }, + } + + return responseObjectSchema +} diff --git a/jsonrpc/example/endpoint.go b/jsonrpc/example/endpoint.go new file mode 100644 index 00000000..3055929a --- /dev/null +++ b/jsonrpc/example/endpoint.go @@ -0,0 +1,115 @@ +package example + +import ( + "context" + + "github.com/danielgtaylor/huma/v2/jsonrpc" +) + +// /////////////// Handlers ///////////////// + +// AddParams defines the parameters for the "add" method +type AddParams struct { + A int `json:"a"` + B int `json:"b"` +} + +type AddResult struct { + Sum int `json:"sum"` +} + +type NotifyParams struct { + Message string `json:"message"` +} + +// ConcatParams defines the parameters for the "concat" method +type ConcatParams struct { + S1 string `json:"s1"` + S2 string `json:"s2"` +} + +// PingParams defines the parameters for the "ping" notification +type PingParams struct { + Message string `json:"message"` +} + +// AddEndpoint is the handler for the "add" method +func AddEndpoint(ctx context.Context, params AddParams) (AddResult, error) { + res := params.A + params.B + return AddResult{Sum: res}, nil +} + +// ConcatEndpoint is the handler for the "concat" method +func ConcatEndpoint(ctx context.Context, params ConcatParams) (string, error) { + return params.S1 + params.S2, nil +} + +// PingEndpoint is the handler for the "ping" notification +func PingEndpoint(ctx context.Context, params PingParams) error { + return nil +} + +func NotifyEndpoint(ctx context.Context, params NotifyParams) error { + // Process notification + return nil +} + +func GetMethodHandlers() map[string]jsonrpc.IMethodHandler { + // Define method maps + methodMap := map[string]jsonrpc.IMethodHandler{ + "add": &jsonrpc.MethodHandler[AddParams, AddResult]{Endpoint: AddEndpoint}, + "addpositional": &jsonrpc.MethodHandler[[]int, AddResult]{ + Endpoint: func(ctx context.Context, params []int) (AddResult, error) { + res := 0 + for _, v := range params { + res += v + } + return AddResult{Sum: res}, nil + }, + }, + "concat": &jsonrpc.MethodHandler[ConcatParams, string]{Endpoint: ConcatEndpoint}, + "concatOptionalIn": &jsonrpc.MethodHandler[*ConcatParams, string]{ + Endpoint: func(ctx context.Context, params *ConcatParams) (string, error) { + if params != nil { + return params.S1 + params.S2, nil + } + return "", nil + }, + }, + "concatOptionalInOut": &jsonrpc.MethodHandler[*ConcatParams, *string]{ + Endpoint: func(ctx context.Context, params *ConcatParams) (*string, error) { + if params != nil { + r := params.S1 + params.S2 + return &r, nil + } + return nil, nil + }, + }, + "echo": &jsonrpc.MethodHandler[any, any]{ + Endpoint: func(ctx context.Context, _ any) (any, error) { + return nil, nil + }, + }, + "echooptional": &jsonrpc.MethodHandler[*string, *string]{ + Endpoint: func(ctx context.Context, e *string) (*string, error) { + return e, nil + }, + }, + } + + return methodMap + +} + +func GetNotificationHandlers() map[string]jsonrpc.INotificationHandler { + + notificationMap := map[string]jsonrpc.INotificationHandler{ + "ping": &jsonrpc.NotificationHandler[PingParams]{Endpoint: PingEndpoint}, + "notify": &jsonrpc.NotificationHandler[NotifyParams]{ + Endpoint: NotifyEndpoint, + }, + } + + return notificationMap + +} diff --git a/jsonrpc/example/httpsse_cli.go b/jsonrpc/example/httpsse_cli.go new file mode 100644 index 00000000..f073ec0e --- /dev/null +++ b/jsonrpc/example/httpsse_cli.go @@ -0,0 +1,103 @@ +package example + +import ( + "context" + "fmt" + "log" + "net/http" + "runtime/debug" + "time" + + "github.com/danielgtaylor/huma/v2" + "github.com/danielgtaylor/huma/v2/adapters/humago" + "github.com/danielgtaylor/huma/v2/humacli" + "github.com/danielgtaylor/huma/v2/jsonrpc" +) + +// CLI options can be added as needed +type Options struct { + Host string `doc:"Host to listen on" default:"localhost"` + Port int `doc:"Port to listen on" default:"8080"` + Debug bool `doc:"Enable debug logs" default:"false"` +} + +// This is a huma middleware. +// Either a huma middleware can be added or a http handler middleware can be added +func loggingMiddleware(ctx huma.Context, next func(huma.Context)) { + // log.Printf("Received request: %v %v", ctx.URL().RawPath, ctx.Operation().Path) + next(ctx) + // log.Printf("Responded to request: %v %v", ctx.URL().RawPath, ctx.Operation().Path) +} + +// This is a http handler middleware. +// PanicRecoveryMiddleware recovers from panics in handlers +func PanicRecoveryMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + // Log the panic to stderr + log.Printf("Recovered from panic: %+v", err) + + // Optionally, log the stack trace + log.Printf("%s", debug.Stack()) + + // Return a 500 Internal Server Error + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} + +func SetupSSETransport() http.Handler { + // Use default go router + router := http.NewServeMux() + + api := humago.New(router, huma.DefaultConfig("Example JSONRPC API", "1.0.0")) + // Add any middlewares + api.UseMiddleware(loggingMiddleware) + handler := PanicRecoveryMiddleware(router) + + // Init the servers method and notifications handlers + methodMap := GetMethodHandlers() + notificationMap := GetNotificationHandlers() + op := jsonrpc.GetDefaultOperation() + // Register the methods + jsonrpc.Register(api, op, methodMap, notificationMap) + + return handler +} + +func GetHTTPServerCLI() humacli.CLI { + + cli := humacli.New(func(hooks humacli.Hooks, opts *Options) { + log.Printf("Options are %+v\n", opts) + handler := SetupSSETransport() + // Initialize the http server + server := http.Server{ + Addr: fmt.Sprintf("%s:%d", opts.Host, opts.Port), + Handler: handler, + } + + // Hook the HTTP server. + hooks.OnStart(func() { + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("listen: %s\n", err) + } + }) + + hooks.OnStop(func() { + // Gracefully shutdown your server here + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = server.Shutdown(ctx) + }) + }) + + return cli +} + +func StartHTTPServer() { + cli := GetHTTPServerCLI() + cli.Run() +} diff --git a/jsonrpc/example/httpsse_test.go b/jsonrpc/example/httpsse_test.go new file mode 100644 index 00000000..536935f5 --- /dev/null +++ b/jsonrpc/example/httpsse_test.go @@ -0,0 +1,651 @@ +package example + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + + "github.com/danielgtaylor/huma/v2/jsonrpc" +) + +type JSONRPCClient interface { + Send(reqBytes []byte) ([]byte, error) +} + +type HTTPJSONRPCClient struct { + client *http.Client + url string +} + +func NewHTTPClient(t *testing.T) *HTTPJSONRPCClient { + handler := SetupSSETransport() + server := httptest.NewUnstartedServer(handler) + server.Start() + t.Cleanup(server.Close) // Ensure server closes after test + client := server.Client() + url := server.URL + "/jsonrpc" + return &HTTPJSONRPCClient{ + client: client, + url: url, + } +} + +func (c *HTTPJSONRPCClient) Send(reqBytes []byte) ([]byte, error) { + resp, err := c.client.Post(c.url, "application/json", bytes.NewReader(reqBytes)) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return io.ReadAll(resp.Body) +} + +func getClient(t *testing.T) JSONRPCClient { + return NewHTTPClient(t) +} + +func sendJSONRPCRequest(t *testing.T, client JSONRPCClient, request interface{}) []byte { + var reqBytes []byte + var err error + if b, ok := request.([]byte); ok { + reqBytes = b + } else { + reqBytes, err = json.Marshal(request) + if err != nil { + t.Fatalf("Error marshaling request: %v", err) + } + } + t.Logf("Sending req %s", string(reqBytes)) + respBody, err := client.Send(reqBytes) + if err != nil { + t.Fatalf("Error sending request: %v", err) + } + if len(respBody) == 0 { + t.Log("Got Empty response") + return nil + } + var o interface{} + err = json.Unmarshal(respBody, &o) + if err == nil { + r, err := json.Marshal(o) + if err == nil { + t.Logf("Json resp %s", string(r)) + } + } + return respBody +} + +func TestValidSingleRequests(t *testing.T) { + client := getClient(t) + + tests := []struct { + name string + request interface{} + expectedResult interface{} + }{ + { + name: "Add method with named parameters", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "add", + "params": map[string]interface{}{"a": 2, "b": 3}, + "id": 1, + }, + expectedResult: map[string]float64{"sum": 5}, + }, + { + name: "Add method with positional parameters", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "addpositional", + "params": []interface{}{2, 3}, + "id": 2, + }, + expectedResult: map[string]float64{"sum": 5}, + }, + { + name: "Echo method with no parameters", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "echo", + "id": 3, + }, + expectedResult: nil, + }, + { + name: "Echo method with optional parameters", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "echooptional", + "id": "1", + "params": "foo", + }, + expectedResult: "foo", + }, + { + name: "Echo method with optional parameters nil input", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "echooptional", + "id": "2", + }, + expectedResult: nil, + }, + { + name: "Concat method", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "concat", + "params": map[string]interface{}{"s1": "Hello, ", "s2": "World!"}, + "id": 2, + }, + expectedResult: "Hello, World!", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + respBody := sendJSONRPCRequest(t, client, tc.request) + + var response struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result"` + JSONRPCError *jsonrpc.JSONRPCError `json:"error"` + ID interface{} `json:"id"` + } + + err := json.Unmarshal(respBody, &response) + if err != nil { + t.Fatalf("Error unmarshaling response: %v", err) + } + + if response.JSONRPCError != nil { + t.Errorf("Expected no error, but got: %+v", response.JSONRPCError) + } else { + eq, err := jsonStructEqual(response.Result, tc.expectedResult) + if err != nil || !eq { + t.Errorf("Expected result %#v, got %#v", tc.expectedResult, response.Result) + } + } + }) + } +} + +func TestInvalidSingleRequests(t *testing.T) { + client := getClient(t) + + tests := []struct { + name string + request interface{} + rawRequest []byte + expectedError *jsonrpc.JSONRPCError + }{ + { + name: "Invalid JSON request", + rawRequest: []byte(`{ this is invalid json }`), + expectedError: &jsonrpc.JSONRPCError{ + Code: jsonrpc.ParseError, + Message: jsonrpc.GetDefaultErrorMessage(jsonrpc.ParseError), + }, + }, + { + name: "Method not found", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "unknown_method", + "id": 1, + }, + expectedError: &jsonrpc.JSONRPCError{ + Code: jsonrpc.MethodNotFoundError, + Message: "Method 'unknown_method' not found", + }, + }, + { + name: "Invalid parameters", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "add", + "params": map[string]interface{}{"a": "two", "b": 3}, + "id": 2, + }, + expectedError: &jsonrpc.JSONRPCError{ + Code: jsonrpc.InvalidRequestError, + Message: jsonrpc.GetDefaultErrorMessage(jsonrpc.InvalidRequestError), + }, + }, + { + name: "Missing jsonrpc field", + request: map[string]interface{}{ + "method": "add", + "params": map[string]interface{}{"a": 2, "b": 3}, + "id": 3, + }, + expectedError: &jsonrpc.JSONRPCError{ + Code: jsonrpc.InvalidRequestError, + Message: jsonrpc.GetDefaultErrorMessage(jsonrpc.InvalidRequestError), + }, + }, + { + name: "Invalid jsonrpc version", + request: map[string]interface{}{ + "jsonrpc": "1.0", + "method": "add", + "params": map[string]interface{}{"a": 2, "b": 3}, + "id": 4, + }, + expectedError: &jsonrpc.JSONRPCError{ + Code: jsonrpc.InvalidRequestError, + Message: jsonrpc.GetDefaultErrorMessage(jsonrpc.InvalidRequestError), + }, + }, + { + name: "Missing method field", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "params": map[string]interface{}{"a": 2, "b": 3}, + "id": 5, + }, + expectedError: &jsonrpc.JSONRPCError{ + Code: jsonrpc.InvalidRequestError, + Message: jsonrpc.GetDefaultErrorMessage(jsonrpc.InvalidRequestError), + }, + }, + { + name: "Invalid id field (array)", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "add", + "params": map[string]interface{}{"a": 2, "b": 3}, + "id": []int{1, 2, 3}, + }, + expectedError: &jsonrpc.JSONRPCError{ + Code: jsonrpc.InvalidRequestError, + Message: jsonrpc.GetDefaultErrorMessage(jsonrpc.InvalidRequestError), + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := tc.request + if tc.rawRequest != nil { + req = tc.rawRequest + } + + respBody := sendJSONRPCRequest(t, client, req) + + var response struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result"` + JSONRPCError *jsonrpc.JSONRPCError `json:"error"` + ID interface{} `json:"id"` + } + + err := json.Unmarshal(respBody, &response) + if err != nil { + t.Fatalf("Error unmarshaling response: %v", err) + } + + if response.JSONRPCError == nil { + t.Errorf("Expected error but got none") + } else { + if response.JSONRPCError.Code != tc.expectedError.Code { + t.Errorf("Expected error code %d, got %d", tc.expectedError.Code, response.JSONRPCError.Code) + } + if !strings.Contains(response.JSONRPCError.Message, tc.expectedError.Message) { + t.Errorf("Expected error message '%s', got '%s'", tc.expectedError.Message, response.JSONRPCError.Message) + } + } + }) + } +} + +func TestNotifications(t *testing.T) { + client := getClient(t) + + tests := []struct { + name string + request interface{} + expectedError *jsonrpc.JSONRPCError + }{ + { + name: "Valid notification", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "notify", + "params": map[string]interface{}{"message": "Hello"}, + }, + }, + { + name: "Notification with invalid method", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "unknown_method", + "params": map[string]interface{}{"message": "Hello"}, + }, + expectedError: &jsonrpc.JSONRPCError{ + Code: jsonrpc.MethodNotFoundError, + Message: "Method 'unknown_method' not found", + }, + }, + { + name: "Ping notification", + request: map[string]interface{}{ + "jsonrpc": "2.0", + "method": "ping", + "params": map[string]interface{}{"message": "Test Ping"}, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + respBody := sendJSONRPCRequest(t, client, tc.request) + + if len(respBody) != 0 { + if tc.expectedError != nil { + var response struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result"` + JSONRPCError *jsonrpc.JSONRPCError `json:"error"` + ID interface{} `json:"id"` + } + err := json.Unmarshal(respBody, &response) + if err != nil { + t.Fatalf("Error unmarshaling response: %v", err) + } + if response.JSONRPCError.Code != tc.expectedError.Code { + t.Errorf( + "Expected error code %d, got %d", + tc.expectedError.Code, + response.JSONRPCError.Code, + ) + } + if !strings.Contains(response.JSONRPCError.Message, tc.expectedError.Message) { + t.Errorf( + "Expected error message '%s', got '%s'", + tc.expectedError.Message, + response.JSONRPCError.Message, + ) + } + } else { + t.Errorf("Expected no response, but got: %s", string(respBody)) + } + } + + }) + } +} + +func TestBatchRequests(t *testing.T) { + client := getClient(t) + + tests := []struct { + name string + batchRequest []interface{} + expectedResponses int + expectedErrorCodes []int + expectedResults map[interface{}]interface{} + }{ + { + name: "Valid batch with multiple requests", + batchRequest: []interface{}{ + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "add", + "params": map[string]interface{}{"a": 1, "b": 2}, + "id": 1, + }, + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "echooptional", + "params": "foo", + "id": 2, + }, + }, + expectedResponses: 2, + expectedErrorCodes: []int{}, + }, + { + name: "Batch with mixed valid requests and notifications", + batchRequest: []interface{}{ + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "add", + "params": map[string]interface{}{"a": 1, "b": 2}, + "id": 1, + }, + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "notify", + "params": map[string]interface{}{"message": "Hello"}, + }, + }, + expectedResponses: 1, + expectedErrorCodes: []int{}, + }, + // // This wont work as framer will not allow sending a request for stdio transport + // { + // name: "Batch with invalid JSON in one request", + // batchRequest: []interface{}{[]byte(`[{ + // "jsonrpc": "2.0", + // "method": "add", + // "params": {"a":1,"b":2}, + // "id":1 + // }, { + // "jsonrpc": "2.0", + // "method": "invalid_method", + // "params": {}, + // "id":2 + // }`)}, // Incomplete closing square bracket + // expectedResponses: 1, + // expectedErrorCodes: []int{-32700}, + // }, + { + name: "Batch of notifications", + batchRequest: []interface{}{ + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "notify", + "params": map[string]interface{}{"message": "Hello"}, + }, + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "notify", + "params": map[string]interface{}{"message": "World"}, + }, + }, + expectedResponses: 0, + expectedErrorCodes: []int{}, + }, + { + name: "Empty batch array", + batchRequest: []interface{}{}, + expectedResponses: 1, + expectedErrorCodes: []int{-32600}, + }, + { + name: "Batch with valid and invalid methods", + batchRequest: []interface{}{ + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "add", + "params": map[string]interface{}{"a": 1, "b": 2}, + "id": 1, + }, + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "concat", + "params": map[string]interface{}{"s1": "foo", "s2": "bar"}, + "id": 2, + }, + map[string]interface{}{ + "jsonrpc": "2.0", + "method": "unknownMethod", + "id": 3, + }, + }, + expectedResponses: 1, + expectedErrorCodes: []int{int(jsonrpc.InvalidRequestError)}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + var batch interface{} + if len(tc.batchRequest) > 0 { + if b, ok := tc.batchRequest[0].([]byte); ok { + batch = b + } else { + batch = tc.batchRequest + } + } + respBody := sendJSONRPCRequest(t, client, batch) + + if tc.expectedResponses == 0 { + if len(respBody) != 0 { + t.Errorf("Expected no response, but got: %s", string(respBody)) + return + } else { + return + } + } + + var responses []struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result"` + JSONRPCError *jsonrpc.JSONRPCError `json:"error"` + ID interface{} `json:"id"` + } + + if err := json.Unmarshal(respBody, &responses); err != nil { + var singleResponse struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result"` + JSONRPCError *jsonrpc.JSONRPCError `json:"error"` + ID interface{} `json:"id"` + } + if err := json.Unmarshal(respBody, &singleResponse); err != nil { + t.Fatalf("Error unmarshaling response: %v", err) + } + responses = []struct { + JSONRPC string `json:"jsonrpc"` + Result interface{} `json:"result"` + JSONRPCError *jsonrpc.JSONRPCError `json:"error"` + ID interface{} `json:"id"` + }{singleResponse} + } + + if len(responses) != tc.expectedResponses { + t.Errorf("Expected %d responses, got %d", tc.expectedResponses, len(responses)) + } + + var gotErrorCodes []int + for _, response := range responses { + if response.JSONRPCError != nil { + gotErrorCodes = append(gotErrorCodes, int(response.JSONRPCError.Code)) + } + } + if !arraysAreSimilar(gotErrorCodes, tc.expectedErrorCodes) { + t.Errorf( + "Mismatched error codes. Got: %#v, Expected: %#v", + gotErrorCodes, + tc.expectedErrorCodes, + ) + } + + if tc.expectedResults != nil { + for _, response := range responses { + id := response.ID + expectedResult, ok := tc.expectedResults[id] + if ok { + if response.JSONRPCError != nil { + t.Errorf( + "Expected result for id %v, but got error: %+v", + id, + response.JSONRPCError, + ) + } else { + eq, err := jsonStructEqual(response.Result, expectedResult) + if err != nil { + t.Errorf("Error comparing result for id %v: %v", id, err) + } else if !eq { + t.Errorf("Mismatched result for id %v. Got: %+v, Expected: %+v", id, response.Result, expectedResult) + } + } + } + } + } + }) + } +} + +func jsonEqual(a, b json.RawMessage) bool { + var o1 interface{} + var o2 interface{} + + if err := json.Unmarshal(a, &o1); err != nil { + return false + } + if err := json.Unmarshal(b, &o2); err != nil { + return false + } + // Direct reflect Deepequal would have issues when there are pointers, keyorders etc. + // unmarshalling into a interface and then doing deepequal removes those issues + return reflect.DeepEqual(o1, o2) +} + +func jsonStringsEqual(a, b string) bool { + return jsonEqual([]byte(a), []byte(b)) +} + +func getJSONStrings(args ...interface{}) ([]string, error) { + ret := make([]string, 0, len(args)) + for _, a := range args { + jsonBytes, err := json.Marshal(a) + if err != nil { + return nil, err + } + ret = append(ret, string(jsonBytes)) + } + return ret, nil +} + +func jsonStructEqual(arg1 interface{}, arg2 interface{}) (bool, error) { + vals, err := getJSONStrings(arg1, arg2) + if err != nil { + return false, errors.New("Could not encode struct to json") + } + return jsonStringsEqual(vals[0], vals[1]), nil +} + +func arraysAreSimilar(arr1, arr2 []int) bool { + if len(arr1) != len(arr2) { + return false + } + if len(arr1) != 0 { + counts1 := make(map[int]int) + counts2 := make(map[int]int) + + for _, num := range arr1 { + counts1[num]++ + } + + for _, num := range arr2 { + counts2[num]++ + } + + for key, count1 := range counts1 { + if count2, exists := counts2[key]; !exists || count1 != count2 { + return false + } + } + } + + return true +} diff --git a/jsonrpc/example/run_httpsse_test.go b/jsonrpc/example/run_httpsse_test.go new file mode 100644 index 00000000..0eb0d814 --- /dev/null +++ b/jsonrpc/example/run_httpsse_test.go @@ -0,0 +1,11 @@ +//go:build integration + +package example + +import "testing" + +// go test -v -tags=integration -run TestRunServer -count=1 ./jsonrpc/example & +func TestRunServer(t *testing.T) { + // Start the server + StartHTTPServer() +} diff --git a/jsonrpc/handler_meta.go b/jsonrpc/handler_meta.go new file mode 100644 index 00000000..f2f6d6c4 --- /dev/null +++ b/jsonrpc/handler_meta.go @@ -0,0 +1,157 @@ +package jsonrpc + +import ( + "context" + "encoding/json" + "fmt" + "reflect" +) + +// IMethodHandler is an interface for handlers that process requests expecting a response. +type IMethodHandler interface { + Handle(ctx context.Context, req Request[json.RawMessage]) (Response[json.RawMessage], error) + GetTypes() (reflect.Type, reflect.Type) +} + +// INotificationHandler is an interface for handlers that process notifications (no response expected). +type INotificationHandler interface { + // Even though there is a error return allowed this is mainly present for any debugging logs etc in the server + // The client will never receive any error for a notification + Handle(ctx context.Context, req Request[json.RawMessage]) error + GetTypes() reflect.Type +} + +// GetMetaRequestHandler creates a handler function that processes MetaRequests. +func GetMetaRequestHandler( + methodMap map[string]IMethodHandler, + notificationMap map[string]INotificationHandler, +) func(context.Context, *MetaRequest) (*MetaResponse, error) { + return func(ctx context.Context, metaReq *MetaRequest) (*MetaResponse, error) { + if metaReq == nil || metaReq.Body == nil || len(metaReq.Body.Items) == 0 { + item := Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: nil, + Error: &JSONRPCError{ + Code: ParseError, + Message: "No input received for", + }, + } + // Return single error if invalid batch or even a single item cannot be found. + ret := MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{item}, + }, + } + return &ret, nil + } + + resp := MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: metaReq.Body.IsBatch, + Items: []Response[json.RawMessage]{}, + }, + } + + for _, request := range metaReq.Body.Items { + // Need a valid JSONRPC version and method + if request.JSONRPC != JSONRPCVersion || request.Method == "" { + msg := fmt.Sprintf( + "Invalid JSON-RPC version: '%s'", + request.JSONRPC, + ) + if request.Method == "" { + msg = "Method name missing" + } + resp.Body.Items = append(resp.Body.Items, Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: request.ID, + Error: &JSONRPCError{ + Code: InvalidRequestError, + Message: msg, + }, + }) + continue + } + + absentRequestID := request.ID == nil + + if absentRequestID { + // Handle notification + handler, ok := notificationMap[request.Method] + if ok { + // Create context with request info + subCtx := contextWithRequestInfo(ctx, request.Method, true, nil) + + // Call the notification handler + // Cannot return error; possibly log internally + _ = handler.Handle(subCtx, request) + // Notifications do not produce a response + continue + } + + // Notification not found, but requestid was nil + // If it was a method, send a invalid request error. Else dont send anything. + if _, ok = methodMap[request.Method]; ok { + resp.Body.Items = append(resp.Body.Items, Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: nil, + Error: &JSONRPCError{ + Code: InvalidRequestError, + Message: fmt.Sprintf( + "Received no requestID for method: '%s'", + request.Method, + ), + }, + }) + } + + continue + } + + // Handle request expecting a response + handler, ok := methodMap[request.Method] + if !ok { + // Method not found + resp.Body.Items = append(resp.Body.Items, Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: request.ID, + Error: &JSONRPCError{ + Code: MethodNotFoundError, + Message: fmt.Sprintf("Method '%s' not found", request.Method), + }, + }) + continue + } + + // Create context with request info + subCtx := contextWithRequestInfo(ctx, request.Method, false, request.ID) + + // Call the method handler + response, err := handler.Handle(subCtx, request) + if err != nil { + // Handler returned an error. + // This should generally not happen as handler is expected to convert any errors into a jsonrpc response with error object + resp.Body.Items = append(resp.Body.Items, Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: request.ID, + Error: &JSONRPCError{ + Code: InternalError, + Message: fmt.Sprintf("Handler error: %v", err), + }, + }) + continue + } + + // Append the response + resp.Body.Items = append(resp.Body.Items, response) + } + + // If there are no responses to return, return nil response. + if len(resp.Body.Items) == 0 { + return nil, nil + } + + return &resp, nil + } +} diff --git a/jsonrpc/handler_method.go b/jsonrpc/handler_method.go new file mode 100644 index 00000000..449d1ae4 --- /dev/null +++ b/jsonrpc/handler_method.go @@ -0,0 +1,99 @@ +package jsonrpc + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "reflect" +) + +// MethodHandler represents a generic handler with customizable input and output types. +// It generally expects a response to be returned to the client. +// +// Usage Scenarios: +// +// 1. Compulsory Parameters: +// Use concrete types for both I and O when both input and output are required. +// +// 2. Optional Input or Output Parameters: +// Use a pointer type for I or O to allow passing nil when no input or output is provided. +// +// 3. No Input or Output Parameters: +// Use struct{} for I or O when the handler does not require any input or output. +// +// Example: +// +// // Handler with no input and output +// handler := MethodHandler[struct{}, struct{}]{ +// Endpoint: func(ctx context.Context, _ struct{}) (struct{}, error) { +// // Implementation +// return struct{}{}, nil +// }, +// } +type MethodHandler[I any, O any] struct { + Endpoint func(ctx context.Context, params I) (O, error) +} + +// Handle processes a request expecting a response. +func (m *MethodHandler[I, O]) Handle( + ctx context.Context, + req Request[json.RawMessage], +) (Response[json.RawMessage], error) { + params, err := unmarshalParams[I](req) + if err != nil { + // Return InvalidParamsError + return invalidParamsResponse(req, err), nil + } + + // Call the handler + result, err := m.Endpoint(ctx, params) + if err != nil { + // Check if err is a *jsonrpc.Error (JSON-RPC error) + var jsonrpcErr *JSONRPCError + if errors.As(err, &jsonrpcErr) { + // Handler returned a JSON-RPC error + return Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Error: jsonrpcErr, + }, nil + } + // Handler returned a standard error + return Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Error: &JSONRPCError{ + Code: InternalError, + Message: err.Error(), + }, + }, nil + } + + // Marshal the result. + resultData, err := json.Marshal(result) + if err != nil { + return Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Error: &JSONRPCError{ + Code: InternalError, + Message: fmt.Sprintf("Error marshaling result: %v", err), + }, + }, nil + } + + // Return the response with the marshaled result + return Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Result: json.RawMessage(resultData), + }, nil +} + +// GetTypes returns the reflect.Type of the input and output types. +func (m *MethodHandler[I, O]) GetTypes() (reflect.Type, reflect.Type) { + iType := reflect.TypeOf((*I)(nil)).Elem() + oType := reflect.TypeOf((*O)(nil)).Elem() + return iType, oType +} diff --git a/jsonrpc/handler_notification.go b/jsonrpc/handler_notification.go new file mode 100644 index 00000000..5b034d85 --- /dev/null +++ b/jsonrpc/handler_notification.go @@ -0,0 +1,51 @@ +package jsonrpc + +import ( + "context" + "encoding/json" + "reflect" +) + +// NotificationHandler is a RPC handler for methods that do not expect a response. +// +// Usage Scenarios: +// +// 1. Compulsory Parameters: +// Use concrete types for I when input is required. +// +// 2. Optional Input Parameters: +// Use a pointer type for I to allow passing nil when no input is provided. +// +// 3. No Input Parameters: +// Use struct{} for I when the handler does not require any input. +// +// Example: +// +// // Handler with no input +// handler := NotificationHandler[struct{}]{ +// Endpoint: func(ctx context.Context, _ struct{}) error { +// // Implementation +// return nil +// }, +// } +type NotificationHandler[I any] struct { + Endpoint func(ctx context.Context, params I) error +} + +// Handle processes a notification (no response expected). +func (n *NotificationHandler[I]) Handle(ctx context.Context, req Request[json.RawMessage]) error { + params, err := unmarshalParams[I](req) + if err != nil { + // Cannot send error to client in notification; possibly log internally + return err + } + + // Call the endpoint + return n.Endpoint(ctx, params) +} + +// GetTypes returns the reflect.Type of the input +func (m *NotificationHandler[I]) GetTypes() reflect.Type { + return reflect.TypeOf((*I)(nil)).Elem() + +} diff --git a/jsonrpc/handler_test.go b/jsonrpc/handler_test.go new file mode 100644 index 00000000..b04f0e12 --- /dev/null +++ b/jsonrpc/handler_test.go @@ -0,0 +1,589 @@ +package jsonrpc + +import ( + "context" + "encoding/json" + "errors" + "testing" +) + +// AddParams defines the parameters for the "add" method +type AddParams struct { + A int `json:"a"` + B int `json:"b"` +} + +type AddResult struct { + Sum int `json:"sum"` +} + +type NotifyParams struct { + Message string `json:"message"` +} + +// ConcatParams defines the parameters for the "concat" method +type ConcatParams struct { + S1 string `json:"s1"` + S2 string `json:"s2"` +} + +// PingParams defines the parameters for the "ping" notification +type PingParams struct { + Message string `json:"message"` +} + +// AddEndpoint is the handler for the "add" method +func AddEndpoint(ctx context.Context, params AddParams) (AddResult, error) { + res := params.A + params.B + return AddResult{Sum: res}, nil +} + +// ConcatEndpoint is the handler for the "concat" method +func ConcatEndpoint(ctx context.Context, params ConcatParams) (string, error) { + return params.S1 + params.S2, nil +} + +// PingEndpoint is the handler for the "ping" notification +func PingEndpoint(ctx context.Context, params PingParams) error { + return nil +} + +func NotifyEndpoint(ctx context.Context, params NotifyParams) error { + // Process notification + return nil +} + +func TestGetMetaRequestHandler(t *testing.T) { + // Define method maps + methodMap := map[string]IMethodHandler{ + "add": &MethodHandler[AddParams, AddResult]{Endpoint: AddEndpoint}, + "addErrorSimple": &MethodHandler[AddParams, AddResult]{ + Endpoint: func(ctx context.Context, params AddParams) (AddResult, error) { + return AddResult{}, errors.New("intentional error") + }, + }, + "addErrorJSONRPC": &MethodHandler[AddParams, AddResult]{ + Endpoint: func(ctx context.Context, params AddParams) (AddResult, error) { + return AddResult{}, &JSONRPCError{ + Code: 1234, + Message: "Custom error", + } + }, + }, + "concat": &MethodHandler[ConcatParams, string]{Endpoint: ConcatEndpoint}, + } + + notificationMap := map[string]INotificationHandler{ + "ping": &NotificationHandler[PingParams]{Endpoint: PingEndpoint}, + "notify": &NotificationHandler[NotifyParams]{ + Endpoint: NotifyEndpoint, + }, + "errornotify": &NotificationHandler[NotifyParams]{ + Endpoint: func(ctx context.Context, params NotifyParams) error { + return errors.New("processing error") + }, + }, + } + + // Define test cases + tests := []struct { + name string + metaReq *MetaRequest + expectedResp *MetaResponse + }{ + { + name: "Nil MetaRequest", + metaReq: nil, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + ID: nil, + Error: &JSONRPCError{ + Code: ParseError, + Message: "No input received for", + }, + }}, + }, + }, + }, + { + name: "Empty Body Items", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{}, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + ID: nil, + Error: &JSONRPCError{ + Code: ParseError, + Message: "No input received for", + }, + }}, + }, + }, + }, + { + name: "Invalid JSON-RPC version", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: "1.0", + Method: "add", + Params: json.RawMessage(`{"a":1,"b":2}`), + ID: &RequestID{Value: 1}, + }, + }, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 1}, + Error: &JSONRPCError{ + Code: InvalidRequestError, + Message: "Invalid JSON-RPC version: '1.0'", + }, + }}, + }, + }, + }, + { + name: "Invalid notification method", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + Method: "unknown_notification", + Params: json.RawMessage(`{}`), + ID: nil, + }}, + }, + }, + expectedResp: nil, + }, + { + name: "Valid notification", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + Method: "ping", + Params: json.RawMessage(`{"message":"hello"}`), + ID: nil, + }}, + }, + }, + expectedResp: nil, // Notifications do not produce a response + }, + { + name: "Processing single notification", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "notify", + Params: json.RawMessage(`{"message":"Hello"}`), + ID: nil, // Notification + }, + }, + }, + }, + expectedResp: nil, + }, + { + name: "Invalid parameters in notification (unmarshaling fails)", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "notify", + Params: json.RawMessage(`{"message":123}`), + ID: nil, // Notification + }, + }, + }, + }, + expectedResp: nil, + }, + { + name: "Notify Endpoint returns an error", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "errornotify", + Params: json.RawMessage(`{"message":"Hello"}`), + ID: nil, // Notification + }, + }, + }, + }, + expectedResp: nil, + }, + { + name: "Processing batch of requests and notifications", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: true, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "add", + Params: json.RawMessage(`{"a":1,"b":2}`), + ID: &RequestID{Value: 1}, + }, + { + JSONRPC: JSONRPCVersion, + Method: "notify", + Params: json.RawMessage(`{"message":"Hello"}`), + ID: nil, + }, + }, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: true, + Items: []Response[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 1}, + Result: json.RawMessage(`{"sum":3}`), + }, + // No response for notification + }, + }, + }, + }, + { + name: "Valid request to 'add' method", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + Method: "add", + Params: json.RawMessage(`{"a":2,"b":3}`), + ID: &RequestID{Value: 1}, + }}, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 1}, + Result: json.RawMessage(`{"sum":5}`), + }}, + }, + }, + }, + { + name: "Method with missing method name", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "", + Params: json.RawMessage(`{"a":1,"b":2}`), + ID: &RequestID{Value: 1}, + }, + }, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 1}, + Error: &JSONRPCError{ + Code: InvalidRequestError, + Message: "Method name missing", + }, + }}, + }, + }, + }, + { + name: "Method not found", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + Method: "subtract", + Params: json.RawMessage(`{"a":5,"b":2}`), + ID: &RequestID{Value: 2}, + }}, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 2}, + Error: &JSONRPCError{ + Code: MethodNotFoundError, + Message: "Method 'subtract' not found", + }, + }}, + }, + }, + }, + { + name: "Method with invalid ID", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "add", + Params: json.RawMessage(`{"a":1,"b":2}`), + ID: nil, + }, + }, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + ID: nil, + Error: &JSONRPCError{ + Code: InvalidRequestError, + Message: "Received no requestID for method: 'add'", + }, + }}, + }, + }, + }, + { + name: "Batch request with mixed valid and invalid methods", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: true, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "add", + Params: json.RawMessage(`{"a":1,"b":2}`), + ID: &RequestID{Value: 1}, + }, + { + JSONRPC: JSONRPCVersion, + Method: "concat", + Params: json.RawMessage(`{"s1":"hello","s2":"world"}`), + ID: &RequestID{Value: 2}, + }, + { + JSONRPC: JSONRPCVersion, + Method: "subtract", + Params: json.RawMessage(`{"a":5,"b":3}`), + ID: &RequestID{Value: 3}, + }, + { + JSONRPC: JSONRPCVersion, + Method: "ping", + Params: json.RawMessage(`{"message":"ping"}`), + ID: nil, + }, + }, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: true, + Items: []Response[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 1}, + Result: json.RawMessage(`{"sum":3}`), + }, + { + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 2}, + Result: json.RawMessage(`"helloworld"`), + }, + { + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 3}, + Error: &JSONRPCError{ + Code: MethodNotFoundError, + Message: "Method 'subtract' not found", + }, + }, + }, + }, + }, + }, + { + name: "Method request with invalid parameters", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "add", + Params: json.RawMessage(`{"a":"one","b":2}`), + ID: &RequestID{Value: 1}, + }, + }, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 1}, + Error: &JSONRPCError{ + Code: InvalidParamsError, + Message: "Invalid parameters: json: cannot unmarshal string into Go struct field AddParams.a of type int", + }, + }, + }, + }, + }, + }, + { + name: "Method endpoint returns simple error", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "addErrorSimple", + Params: json.RawMessage(`{"a":1,"b":2}`), + ID: &RequestID{Value: 1}, + }, + }, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 1}, + Error: &JSONRPCError{ + Code: InternalError, + Message: "intentional error", + }, + }, + }, + }, + }, + }, + { + name: "Method Endpoint returns a *jsonrpc.Error", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + Method: "addErrorJSONRPC", + Params: json.RawMessage(`{"a":1,"b":2}`), + ID: &RequestID{Value: 1}, + }, + }, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{ + { + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 1}, + Error: &JSONRPCError{ + Code: 1234, + Message: "Custom error", + }, + }, + }, + }, + }, + }, + { + name: "Handler returns an error", + metaReq: &MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + Method: "add", + Params: json.RawMessage(`invalid`), + ID: &RequestID{Value: 4}, + }}, + }, + }, + expectedResp: &MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{{ + JSONRPC: JSONRPCVersion, + ID: &RequestID{Value: 4}, + Error: &JSONRPCError{ + Code: InvalidParamsError, + Message: "Invalid parameters: invalid character 'i' looking for beginning of value", + }, + }}, + }, + }, + }, + } + + handlerFunc := GetMetaRequestHandler(methodMap, notificationMap) + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp, err := handlerFunc(ctx, tt.metaReq) + if err != nil { + t.Errorf("handlerFunc returned error: %v", err) + } + eq, err := jsonStructEqual(tt.expectedResp, resp) + if err != nil { + t.Fatalf("Could not compare struct") + } + if !eq { + vals, err := getJSONStrings(tt.expectedResp, resp) + if err != nil { + t.Fatalf("Could not encode json") + } + t.Errorf("Expected response %#v, got %#v", vals[0], vals[1]) + } + }) + } +} diff --git a/jsonrpc/helpers.go b/jsonrpc/helpers.go new file mode 100644 index 00000000..0fa9133b --- /dev/null +++ b/jsonrpc/helpers.go @@ -0,0 +1,72 @@ +package jsonrpc + +import ( + "context" + "encoding/json" + "fmt" +) + +type contextKey string + +const ( + ctxKeyRequestID contextKey = "jsonrpcRequestID" + ctxKeyMethodName contextKey = "jsonrpcMethodName" + ctxKeyIsNotification contextKey = "jsonrpcIsNotification" +) + +// GetRequestID retrieves the RequestID from the context. +func GetRequestID(ctx context.Context) (RequestID, bool) { + id, ok := ctx.Value(ctxKeyRequestID).(RequestID) + return id, ok +} + +// GetMethodName retrieves the MethodName from the context. +func GetMethodName(ctx context.Context) (string, bool) { + method, ok := ctx.Value(ctxKeyMethodName).(string) + return method, ok +} + +// IsNotification checks if the request is a notification. +func IsNotification(ctx context.Context) bool { + isNotification, ok := ctx.Value(ctxKeyIsNotification).(bool) + return ok && isNotification +} + +// Helper function to create context with request information. +func contextWithRequestInfo( + parentCtx context.Context, + methodName string, + isNotification bool, + requestID *RequestID, +) context.Context { + ctx := context.WithValue(parentCtx, ctxKeyMethodName, methodName) + ctx = context.WithValue(ctx, ctxKeyIsNotification, isNotification) + if !isNotification && requestID != nil { + ctx = context.WithValue(ctx, ctxKeyRequestID, *requestID) + } + return ctx +} + +// Helper function to unmarshal parameters from the request. +func unmarshalParams[I any](req Request[json.RawMessage]) (I, error) { + var params I + if req.Params == nil { + return params, nil + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return params, err + } + return params, nil +} + +// Helper function to create an InvalidParamsError response +func invalidParamsResponse(req Request[json.RawMessage], err error) Response[json.RawMessage] { + return Response[json.RawMessage]{ + JSONRPC: JSONRPCVersion, + ID: req.ID, + Error: &JSONRPCError{ + Code: InvalidParamsError, + Message: fmt.Sprintf("Invalid parameters: %v", err), + }, + } +} diff --git a/jsonrpc/openapi.go b/jsonrpc/openapi.go new file mode 100644 index 00000000..553eddff --- /dev/null +++ b/jsonrpc/openapi.go @@ -0,0 +1,254 @@ +package jsonrpc + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + + "github.com/danielgtaylor/huma/v2" +) + +type RequestAny = Request[any] +type NotificationAny = Notification[any] + +func getTypeSchema( + api huma.API, + methodName string, + mtype reflect.Type, + suffix string, +) *huma.Schema { + hint := methodName + suffix + inputSchema := api.OpenAPI().Components.Schemas.Schema(mtype, true, hint) + if inputSchema.Ref != "" { + inputSubSchema := api.OpenAPI().Components.Schemas.SchemaFromRef(inputSchema.Ref) + inputSubSchema.Title = inputSchema.Ref[strings.LastIndex(inputSchema.Ref, "/")+1:] + } else if hint != "" { + // Base types + // E.g: For string param huma name will be String and hint will be the above. + // For Array + humaName := huma.DefaultSchemaNamer(mtype, hint) + titlecaseHint := strings.ToUpper(string(hint[0])) + hint[1:] + if titlecaseHint != humaName { + inputSchema.Title = titlecaseHint + " - " + humaName + } else { + inputSchema.Title = titlecaseHint + } + } + + return inputSchema +} + +func isNillableType(t reflect.Type) bool { + switch t.Kind() { + case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: + return true + default: + return false + } +} + +// Function to dynamically create the Request type with Params of type iType +func getRequestType(iType reflect.Type, isNotification bool) reflect.Type { + // Get the reflect.Type of Request[any] + requestAnyType := reflect.TypeOf(RequestAny{}) + if isNotification { + requestAnyType = reflect.TypeOf(NotificationAny{}) + } + + // Get the number of fields in the Request struct + numFields := requestAnyType.NumField() + + // Create a slice to hold the StructField definitions + fields := make([]reflect.StructField, numFields) + + // Iterate over each field in the Request struct + for i := 0; i < numFields; i++ { + field := requestAnyType.Field(i) // Get the field + + // If the field is 'Params', replace its type with iType + if field.Name == "Params" { + field.Type = iType + jsonTag := field.Tag.Get("json") + + // If iType is pointer type, add omitempty to json tag + if isNillableType(iType) { + if !strings.Contains(jsonTag, "omitempty") { + jsonTag += ",omitempty" + } + // Update the field's tag with the modified JSON and required tags + field.Tag = reflect.StructTag( + fmt.Sprintf(`json:"%s"`, jsonTag), + ) + } else { + // If iType is not a pointer, add required:true to required tag + // Update the field's tag with the modified JSON and required tags + field.Tag = reflect.StructTag( + fmt.Sprintf(`json:"%s" required:"true"`, jsonTag), + ) + } + } + + // Add the field to the fields slice + fields[i] = field + } + // Create a new struct type with the updated fields + reqType := reflect.StructOf(fields) + return reqType +} + +func getRequestSchema( + api huma.API, + methodName string, + paramType reflect.Type, + isNotification bool, +) *huma.Schema { + newReqType := getRequestType(paramType, isNotification) + reqSchema := getTypeSchema(api, methodName, newReqType, "Request") + if reqSchema.Properties == nil { + reqSchema.Properties = make(map[string]*huma.Schema) + } + if reqSchema.Ref != "" { + reqSubSchema := api.OpenAPI().Components.Schemas.SchemaFromRef(reqSchema.Ref) + // Set method name as a constant in the schema + reqSubSchema.Properties["method"] = &huma.Schema{ + Type: "string", + Enum: []interface{}{methodName}, + } + if !isNotification { + reqSubSchema.Required = append(reqSubSchema.Required, "id") + } + } + return reqSchema +} + +func getResponseSchema( + api huma.API, + methodName string, + paramType reflect.Type, +) *huma.Schema { + // Get the error type used in your application + errorType := reflect.TypeOf(JSONRPCError{}) + + // Create dynamic types for success and error responses + successResponseType := getSuccessResponseType(paramType) + errorResponseType := getErrorResponseType(errorType) + + // Generate schemas for these dynamic types + successSchema := getTypeSchema( + api, + methodName, + successResponseType, + "SuccessResponse", + ) + errorSchema := getTypeSchema(api, methodName, errorResponseType, "ErrorResponse") + + // Build the response schema with OneOf combining the two schemas + responseSchema := &huma.Schema{ + Title: strings.ToUpper(string(methodName[0])) + methodName[1:] + "Response", + OneOf: []*huma.Schema{ + successSchema, + errorSchema, + }, + } + + return responseSchema +} + +// Function to create the success response type dynamically +func getSuccessResponseType(resultType reflect.Type) reflect.Type { + fields := []reflect.StructField{ + { + Name: "Jsonrpc", + Type: reflect.TypeOf(""), + Tag: `json:"jsonrpc"`, + }, + { + Name: "Id", + Type: reflect.TypeOf((*IntString)(nil)).Elem(), + Tag: `json:"id"`, + }, + } + var resultField reflect.StructField + resultField.Name = "Result" + resultField.Type = resultType + + if isNillableType(resultType) { + // If resultType is a pointer, add omitempty to json tag + resultField.Tag = reflect.StructTag(`json:"result,omitempty"`) + } else { + resultField.Tag = reflect.StructTag(`json:"result" required:"true"`) + } + + fields = append(fields, resultField) + + return reflect.StructOf(fields) +} + +// Function to create the error response type dynamically +func getErrorResponseType(errorType reflect.Type) reflect.Type { + fields := []reflect.StructField{ + { + Name: "Jsonrpc", + Type: reflect.TypeOf(""), + Tag: `json:"jsonrpc"`, + }, + { + Name: "Id", + Type: reflect.TypeOf((*IntString)(nil)).Elem(), + Tag: `json:"id"`, + }, + { + Name: "Error", + Type: errorType, + Tag: `json:"error"`, + }, + } + return reflect.StructOf(fields) +} + +func AddSchemasToAPI( + api huma.API, + methodMap map[string]IMethodHandler, + notificationMap map[string]INotificationHandler, +) { + reqSchemas := make([]*huma.Schema, 0, len(methodMap)+len(notificationMap)) + resSchemas := make([]*huma.Schema, 0, len(methodMap)) + + // Process method handlers + for methodName, handler := range methodMap { + inputType, outputType := handler.GetTypes() + + reqSchema := getRequestSchema(api, methodName, inputType, false) + reqSchemas = append(reqSchemas, reqSchema) + + respSchema := getResponseSchema(api, methodName, outputType) + resSchemas = append(resSchemas, respSchema) + } + + // Process notification handlers + for methodName, handler := range notificationMap { + inputType := handler.GetTypes() + reqSchema := getRequestSchema(api, methodName, inputType, true) + reqSchemas = append(reqSchemas, reqSchema) + } + + // Get base Request[json.RawMessage] and Response[json.RawMessage] schemas + reqType := reflect.TypeOf((*Request[json.RawMessage])(nil)).Elem() + baseReqSchema := api.OpenAPI().Components.Schemas.Schema(reqType, false, "") + baseReqSchema.OneOf = reqSchemas + // Delete properties + baseReqSchema.Properties = make(map[string]*huma.Schema) + baseReqSchema.Required = []string{} + baseReqSchema.AdditionalProperties = true + baseReqSchema.Type = "" + + respType := reflect.TypeOf((*Response[json.RawMessage])(nil)).Elem() + baseRespSchema := api.OpenAPI().Components.Schemas.Schema(respType, false, "") + baseRespSchema.OneOf = resSchemas + // Delete properties + baseRespSchema.Properties = make(map[string]*huma.Schema) + baseRespSchema.Required = []string{} + baseRespSchema.AdditionalProperties = true + baseRespSchema.Type = "" +} diff --git a/jsonrpc/register.go b/jsonrpc/register.go new file mode 100644 index 00000000..b18fc303 --- /dev/null +++ b/jsonrpc/register.go @@ -0,0 +1,149 @@ +package jsonrpc + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/danielgtaylor/huma/v2" +) + +// GetDefaultOperation gets the conventional values for jsonrpc as a single operation +func GetDefaultOperation() huma.Operation { + + return huma.Operation{ + Method: http.MethodPost, + Path: "/jsonrpc", + DefaultStatus: 200, + + Tags: []string{"JSONRPC"}, + Summary: "JSONRPC endpoint", + Description: "Serve all jsonrpc methods", + OperationID: "jsonrpc", + } +} + +// GetErrorHandler is a closure returning a function that converts any errors returned into a JSONRPC error +// response object. It implements the huma StatusError interface. +// IF the JSONRPC handler is invoked, it should never throw an error, but should return a error response object. +// JSONRPC requires a error case to be covered via the specifications error response object +func GetErrorHandler( + methodMap map[string]IMethodHandler, + notificationMap map[string]INotificationHandler, +) func(status int, message string, errs ...error) huma.StatusError { + return func(gotStatus int, gotMessage string, errs ...error) huma.StatusError { + var foundJSONRPCError *JSONRPCError + message := gotMessage + details := make([]string, 0) + details = append(details, "Message: "+gotMessage) + // Add the HTTP status to details and set status sent back as 200 + details = append(details, fmt.Sprintf("HTTP Status:%d", gotStatus)) + status := 200 + + code := InternalError + if gotStatus >= 400 && gotStatus < 500 { + code = InvalidRequestError + message = errorMessage[InvalidRequestError] + } + + for _, err := range errs { + if converted, ok := err.(huma.ErrorDetailer); ok { + d := converted.ErrorDetail() + // See if this is parse error + if strings.Contains(d.Message, "unmarshal") || + strings.Contains(d.Message, "invalid character") || + strings.Contains(d.Message, "unexpected end") { + code = ParseError + message = errorMessage[ParseError] + } + } else if jsonRPCError, ok := err.(JSONRPCError); ok { + // Check if the error is of type JSONRPCError + foundJSONRPCError = &jsonRPCError + } + details = append(details, err.Error()) + } + + // If a JSONRPCError was found, update the message and append JSON-encoded details + if foundJSONRPCError != nil { + message = foundJSONRPCError.Message + code = foundJSONRPCError.Code + + // JSON encode the Data field of the found JSONRPCError + if jsonData, err := json.Marshal(foundJSONRPCError.Data); err == nil { + details = append(details, string(jsonData)) + } + } + + // Check for method not found + if gotMessage == "validation failed" { + // Assume that the method name is in one of the error messages + // Look for "method:" + var methodName string + for _, errMsg := range details { + idx := strings.Index(errMsg, "method:") + if idx != -1 { + // Extract method name up to the next space or bracket or end of string + rest := errMsg[idx+len("method:"):] + endIdx := strings.IndexFunc(rest, func(r rune) bool { + return r == ' ' || r == ']' || r == ')' + }) + if endIdx == -1 { + methodName = rest + } else { + methodName = rest[:endIdx] + } + break + } + } + // Check if method exists in methodMap or notificationMap + if methodName != "" { + if _, exists := methodMap[methodName]; !exists { + if _, exists := notificationMap[methodName]; !exists { + // Method not found + code = MethodNotFoundError // You need to define this constant + message = fmt.Sprintf("Method '%s' not found", methodName) + } + } + } + } + + return &ResponseStatusError{ + status: status, + Response: Response[any]{ + JSONRPC: JSONRPCVersion, + ID: nil, + Error: &JSONRPCError{ + Code: code, + Message: message, + Data: details, + }, + }, + } + } +} + +// Register a new JSONRPC operation. +// The `methodMap` maps from method name to request handlers. Request clients expect a response object +// The `notificationMap` maps from method name to notification handlers. Notification clients do not expect a response +// +// These maps can be instantiated as +// +// methodMap := map[string]jsonrpc.IMethodHandler{ +// "add": &jsonrpc.MethodHandler[AddParams, int]{Endpoint: AddEndpoint}, +// } +// +// notificationMap := map[string]jsonrpc.INotificationHandler{ +// "log": &jsonrpc.NotificationHandler[LogParams]{Endpoint: LogEndpoint}, +// } +func Register( + api huma.API, + op huma.Operation, + methodMap map[string]IMethodHandler, + notificationMap map[string]INotificationHandler, +) { + AddSchemasToAPI(api, methodMap, notificationMap) + huma.NewError = GetErrorHandler(methodMap, notificationMap) + reqHandler := GetMetaRequestHandler(methodMap, notificationMap) + huma.Register(api, op, reqHandler) +} diff --git a/jsonrpc/type_intstring.go b/jsonrpc/type_intstring.go new file mode 100644 index 00000000..badafed8 --- /dev/null +++ b/jsonrpc/type_intstring.go @@ -0,0 +1,102 @@ +package jsonrpc + +import ( + "encoding/json" + + "errors" + + "github.com/danielgtaylor/huma/v2" +) + +type IntString struct { + Value interface{} +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (is *IntString) UnmarshalJSON(data []byte) error { + if string(data) == "null" { + // If the input is "null", return an error for non-pointer types + // (UnmarshalJSON is called only for non-pointer types in this case) + return errors.New("IntString cannot be null") + } + + // Try to unmarshal data into an int + var intValue int + if err := json.Unmarshal(data, &intValue); err == nil { + is.Value = intValue + return nil + } + + // Try to unmarshal data into a string + var strValue string + if err := json.Unmarshal(data, &strValue); err == nil { + is.Value = strValue + return nil + } + + // If neither int nor string, return an error + return errors.New("IntString must be a string or an integer") +} + +// MarshalJSON implements the json.Marshaler interface. +func (is IntString) MarshalJSON() ([]byte, error) { + switch v := is.Value.(type) { + case int: + return json.Marshal(v) + case string: + return json.Marshal(v) + default: + return nil, errors.New("IntString contains unsupported type") + } +} + +func (is IntString) Schema(r huma.Registry) *huma.Schema { + return &huma.Schema{ + OneOf: []*huma.Schema{ + {Type: huma.TypeInteger}, + {Type: huma.TypeString}, + }, + } +} + +// Helper methods +func (is IntString) IsInt() bool { + _, ok := is.Value.(int) + return ok +} + +func (is IntString) IsString() bool { + _, ok := is.Value.(string) + return ok +} + +func (is IntString) IntValue() (int, bool) { + v, ok := is.Value.(int) + return v, ok +} + +func (is IntString) StringValue() (string, bool) { + v, ok := is.Value.(string) + return v, ok +} + +func (is *IntString) Equal(other *IntString) bool { + // Handle nil cases + if is == nil && other == nil { + return true + } + if is == nil || other == nil { + return false + } + // Compare the underlying values based on their types + switch v := is.Value.(type) { + case int: + ov, ok := other.Value.(int) + return ok && v == ov + case string: + ov, ok := other.Value.(string) + return ok && v == ov + default: + return false + } +} diff --git a/jsonrpc/type_intstring_test.go b/jsonrpc/type_intstring_test.go new file mode 100644 index 00000000..5fe522d2 --- /dev/null +++ b/jsonrpc/type_intstring_test.go @@ -0,0 +1,410 @@ +package jsonrpc + +import ( + "encoding/json" + "reflect" + "strings" + "testing" +) + +func TestIntString_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + wantValue interface{} + wantIsInt bool + wantErr bool + wantErrMsg string + }{ + { + name: "Valid integer", + input: `123`, + wantValue: 123, + wantIsInt: true, + wantErr: false, + }, + { + name: "Valid string", + input: `"hello"`, + wantValue: "hello", + wantIsInt: false, + wantErr: false, + }, + { + name: "Invalid type (float)", + input: `123.45`, + wantErr: true, + wantErrMsg: "IntString must be a string or an integer", + }, + { + name: "Invalid type (boolean)", + input: `true`, + wantErr: true, + wantErrMsg: "IntString must be a string or an integer", + }, + { + name: "Null value", + input: `null`, + wantErr: true, + wantErrMsg: "IntString cannot be null", + }, + { + name: "Invalid JSON", + input: `{}`, + wantErr: true, + wantErrMsg: "IntString must be a string or an integer", + }, + { + name: "Empty string", + input: `""`, + wantValue: "", + wantIsInt: false, + wantErr: false, + }, + { + name: "Negative integer", + input: `-42`, + wantValue: -42, + wantIsInt: true, + wantErr: false, + }, + { + name: "Zero integer", + input: `0`, + wantValue: 0, + wantIsInt: true, + wantErr: false, + }, + { + name: "String containing number", + input: `"123"`, + wantValue: "123", + wantIsInt: false, + wantErr: false, + }, + { + name: "String containing special characters", + input: `"special_chars!@#$%^&*()"`, + wantValue: "special_chars!@#$%^&*()", + wantIsInt: false, + wantErr: false, + }, + { + name: "String containing special characters html escaped", + input: `"special_chars!@#$%^\u0026*()"`, + wantValue: "special_chars!@#$%^&*()", + wantIsInt: false, + wantErr: false, + }, + { + name: "Whitespace string", + input: `" "`, + wantValue: " ", + wantIsInt: false, + wantErr: false, + }, + { + name: "Unicode string", + input: `"こんにちは"`, + wantValue: "こんにちは", + wantIsInt: false, + wantErr: false, + }, + { + name: "Invalid JSON (missing quotes)", + input: `hello`, + wantErr: true, + wantErrMsg: "invalid character", + }, + { + name: "Array input", + input: `["hello", 123]`, + wantErr: true, + wantErrMsg: "IntString must be a string or an integer", + }, + { + name: "Object input", + input: `{"key": "value"}`, + wantErr: true, + wantErrMsg: "IntString must be a string or an integer", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var is IntString + err := json.Unmarshal([]byte(tt.input), &is) + if (err != nil) != tt.wantErr { + t.Logf("Got is: %v", is) + t.Fatalf("IntString.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if tt.wantErrMsg != "" && !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "IntString.UnmarshalJSON() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + + if !reflect.DeepEqual(is.Value, tt.wantValue) { + t.Errorf("IntString.Value = %v, want %v", is.Value, tt.wantValue) + } + + if is.IsInt() != tt.wantIsInt { + t.Errorf("IntString.IsInt() = %v, want %v", is.IsInt(), tt.wantIsInt) + } + + if is.IsString() != !tt.wantIsInt { + t.Errorf("IntString.IsString() = %v, want %v", is.IsString(), !tt.wantIsInt) + } + }) + } +} + +func TestIntString_MarshalJSON(t *testing.T) { + tests := []struct { + name string + value interface{} + wantOutput string + wantErr bool + wantErrMsg string + }{ + { + name: "Integer value", + value: 123, + wantOutput: `123`, + wantErr: false, + }, + { + name: "String value", + value: "hello", + wantOutput: `"hello"`, + wantErr: false, + }, + { + name: "Unsupported type (float)", + value: 123.45, + wantErr: true, + wantErrMsg: "IntString contains unsupported type", + }, + { + name: "Unsupported type (boolean)", + value: true, + wantErr: true, + wantErrMsg: "IntString contains unsupported type", + }, + { + name: "Nil value", + value: nil, + wantErr: true, + wantErrMsg: "IntString contains unsupported type", + }, + { + name: "Empty string", + value: "", + wantOutput: `""`, + wantErr: false, + }, + { + name: "Negative integer", + value: -42, + wantOutput: `-42`, + wantErr: false, + }, + { + name: "Zero integer", + value: 0, + wantOutput: `0`, + wantErr: false, + }, + { + name: "String containing number", + value: "123", + wantOutput: `"123"`, + wantErr: false, + }, + { + name: "String containing special characters", + value: "special_chars!@#$%^&*()", + // Need html escaped output + wantOutput: `"special_chars!@#$%^\u0026*()"`, + wantErr: false, + }, + { + name: "Whitespace string", + value: " ", + wantOutput: `" "`, + wantErr: false, + }, + { + name: "Unicode string", + value: "こんにちは", + wantOutput: `"こんにちは"`, + wantErr: false, + }, + { + name: "Unsupported type (slice)", + value: []int{1, 2, 3}, + wantErr: true, + wantErrMsg: "IntString contains unsupported type", + }, + { + name: "Unsupported type (map)", + value: map[string]string{"key": "value"}, + wantErr: true, + wantErrMsg: "IntString contains unsupported type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + is := IntString{Value: tt.value} + data, err := json.Marshal(is) + if (err != nil) != tt.wantErr { + t.Fatalf("IntString.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if tt.wantErrMsg != "" && !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "IntString.MarshalJSON() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + + if string(data) != tt.wantOutput { + t.Errorf( + "IntString.MarshalJSON() output = %s, want %s", + string(data), + tt.wantOutput, + ) + } + }) + } +} + +func TestIntString_HelperMethods(t *testing.T) { + tests := []struct { + name string + value interface{} + wantIsInt bool + wantIsString bool + wantIntValue int + wantStrValue string + }{ + { + name: "Integer value", + value: 123, + wantIsInt: true, + wantIsString: false, + wantIntValue: 123, + }, + { + name: "String value", + value: "hello", + wantIsInt: false, + wantIsString: true, + wantStrValue: "hello", + }, + { + name: "Nil value", + value: nil, + wantIsInt: false, + wantIsString: false, + }, + { + name: "Unsupported type (float)", + value: 123.45, + wantIsInt: false, + wantIsString: false, + }, + { + name: "Unsupported type (boolean)", + value: true, + wantIsInt: false, + wantIsString: false, + }, + { + name: "Negative integer value", + value: -42, + wantIsInt: true, + wantIsString: false, + wantIntValue: -42, + }, + { + name: "Empty string value", + value: "", + wantIsInt: false, + wantIsString: true, + wantStrValue: "", + }, + { + name: "String containing number", + value: "123", + wantIsInt: false, + wantIsString: true, + wantStrValue: "123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + is := IntString{Value: tt.value} + + if is.IsInt() != tt.wantIsInt { + t.Errorf("IntString.IsInt() = %v, want %v", is.IsInt(), tt.wantIsInt) + } + + if is.IsString() != tt.wantIsString { + t.Errorf("IntString.IsString() = %v, want %v", is.IsString(), tt.wantIsString) + } + + intVal, ok := is.IntValue() + if tt.wantIsInt != ok { + t.Errorf("IntString.IntValue() ok = %v, want %v", ok, tt.wantIsInt) + } + if ok && intVal != tt.wantIntValue { + t.Errorf("IntString.IntValue() = %v, want %v", intVal, tt.wantIntValue) + } + + strVal, ok := is.StringValue() + if tt.wantIsString != ok { + t.Errorf("IntString.StringValue() ok = %v, want %v", ok, tt.wantIsString) + } + if ok && strVal != tt.wantStrValue { + t.Errorf("IntString.StringValue() = %v, want %v", strVal, tt.wantStrValue) + } + }) + } +} + +// Additional test to ensure proper error messages +func TestIntString_ErrorMessages(t *testing.T) { + // Unmarshaling an array should return a specific error message + data := `["hello", 123]` + var is IntString + err := json.Unmarshal([]byte(data), &is) + if err == nil { + t.Fatalf("Expected error when unmarshaling array, but got none") + } + expectedErrMsg := "IntString must be a string or an integer" + if !strings.Contains(err.Error(), expectedErrMsg) { + t.Errorf("Error message = %v, want %v", err.Error(), expectedErrMsg) + } + + // Marshaling an unsupported type should return a specific error message + is = IntString{Value: []int{1, 2, 3}} + _, err = json.Marshal(is) + if err == nil { + t.Fatalf("Expected error when marshaling unsupported type, but got none") + } + expectedErrMsg = "IntString contains unsupported type" + if !strings.Contains(err.Error(), expectedErrMsg) { + t.Errorf("Error message = %v, want %v", err.Error(), expectedErrMsg) + } +} diff --git a/jsonrpc/type_meta.go b/jsonrpc/type_meta.go new file mode 100644 index 00000000..a5165bbf --- /dev/null +++ b/jsonrpc/type_meta.go @@ -0,0 +1,131 @@ +package jsonrpc + +import ( + "bytes" + "encoding/json" + "reflect" + + "github.com/danielgtaylor/huma/v2" +) + +// checkForEmptyOrNullData checks if the data is zero or 'null' and returns a standardized error. +func checkForEmptyOrNullData(data []byte) error { + data = bytes.TrimSpace(data) + if len(data) == 0 { + return &JSONRPCError{ + Code: ParseError, + Message: "Received empty data", + } + } + if bytes.Equal(data, []byte("null")) { + return &JSONRPCError{ + Code: ParseError, + Message: "Received null data", + } + } + return nil +} + +// Generic function to unmarshal Meta structures +func unmarshalMeta[T any](data []byte, isBatch *bool, items *[]T) error { + if err := checkForEmptyOrNullData(data); err != nil { + return err + } + + data = bytes.TrimSpace(data) + // Try to unmarshal into []json.RawMessage to detect if it's a batch + var rawMessages []json.RawMessage + if err := json.Unmarshal(data, &rawMessages); err == nil { + // Data is a batch + *isBatch = true + // Process each message in the batch, empty slice input is also ok and valid + for _, msg := range rawMessages { + // Empty or null single item also should not be present + if err := checkForEmptyOrNullData(msg); err != nil { + return err + } + var item T + if err := json.Unmarshal(msg, &item); err != nil { + return &JSONRPCError{ + Code: ParseError, + Message: "Failed to unmarshal batch item: " + err.Error(), + } + } + *items = append(*items, item) + } + } else { + var item T + if err := json.Unmarshal(data, &item); err != nil { + return &JSONRPCError{ + Code: ParseError, + Message: "Failed to unmarshal single item: %s" + err.Error(), + } + } + *isBatch = false + *items = append(*items, item) + } + return nil +} + +// Generic function to marshal Meta structures +func marshalMeta[T any](isBatch bool, items []T) ([]byte, error) { + if isBatch { + return json.Marshal(items) + } + if len(items) > 0 { + return json.Marshal(items[0]) + } + return nil, &JSONRPCError{Code: ParseError, Message: "Received empty input"} +} + +func intPtr(i int) *int { + return &i +} + +// Meta is a generic struct to handle both MetaRequest and MetaResponse +type Meta[T any] struct { + IsBatch bool `json:"-"` + Items []T +} + +// UnmarshalJSON implements json.Unmarshaler for Meta[T] +func (m *Meta[T]) UnmarshalJSON(data []byte) error { + m.Items = make([]T, 0) + err := unmarshalMeta(data, &m.IsBatch, &m.Items) + return err +} + +// MarshalJSON implements json.Marshaler for Meta[T] +func (m Meta[T]) MarshalJSON() ([]byte, error) { + return marshalMeta(m.IsBatch, m.Items) +} + +func (m Meta[T]) Schema(r huma.Registry) *huma.Schema { + // Get the type of the Items slice + itemsType := reflect.TypeOf(m.Items) + + // Get the type of the element T + elementType := itemsType.Elem() + + // Use the elementType to get the schema + elementSchema := r.Schema(elementType, true, "") + + s := &huma.Schema{ + OneOf: []*huma.Schema{elementSchema, { + Type: huma.TypeArray, + Items: elementSchema, + MinItems: intPtr(1), + }, + }, + } + return s +} + +// Now, we can define MetaRequest and MetaResponse using Meta[T] +type MetaRequest struct { + Body *Meta[Request[json.RawMessage]] +} + +type MetaResponse struct { + Body *Meta[Response[json.RawMessage]] +} diff --git a/jsonrpc/type_meta_test.go b/jsonrpc/type_meta_test.go new file mode 100644 index 00000000..105c6362 --- /dev/null +++ b/jsonrpc/type_meta_test.go @@ -0,0 +1,876 @@ +package jsonrpc + +import ( + "encoding/json" + "reflect" + "strings" + "testing" +) + +// MyData is a sample data structure for testing. +type MyData struct { + Name string `json:"name"` + Value int `json:"value"` +} + +// Test unmarshalMeta with Request[json.RawMessage] +func TestUnmarshalMeta_Request(t *testing.T) { + tests := []struct { + name string + data []byte + wantIsBatch bool + wantItems []Request[json.RawMessage] + wantErr bool + wantErrMsg string + }{ + { + name: "Empty input data", + data: []byte{}, + wantErr: true, + wantErrMsg: "Received empty data", + }, + { + name: "Valid single request", + data: []byte(`{"jsonrpc": "2.0", "method": "sum", "params": [1,2,3], "id":1}`), + wantIsBatch: false, + wantItems: []Request[json.RawMessage]{ + { + JSONRPC: "2.0", + Method: "sum", + Params: json.RawMessage(`[1,2,3]`), + ID: &RequestID{Value: 1}, + }, + }, + wantErr: false, + }, + { + name: "Valid batch requests", + data: []byte( + `[{"jsonrpc": "2.0", "method": "sum", "params": [1,2,3], "id":1}, {"jsonrpc": "2.0", "method": "subtract", "params": [42,23], "id":2}]`, + ), + wantIsBatch: true, + wantItems: []Request[json.RawMessage]{ + { + JSONRPC: "2.0", + Method: "sum", + Params: json.RawMessage(`[1,2,3]`), + ID: &RequestID{Value: 1}, + }, + { + JSONRPC: "2.0", + Method: "subtract", + Params: json.RawMessage(`[42,23]`), + ID: &RequestID{Value: 2}, + }, + }, + wantErr: false, + }, + { + name: "Invalid JSON", + data: []byte(`{this is not valid JSON}`), + wantErr: true, + wantErrMsg: "Failed to unmarshal single item", + }, + { + name: "Empty batch", + data: []byte(`[]`), + wantErr: false, + wantIsBatch: true, + wantItems: []Request[json.RawMessage]{}, + }, + { + name: "Null input", + data: []byte(`null`), + wantErr: true, + wantErrMsg: "Received null data", + }, + { + name: "No input", + data: []byte(``), + wantErr: true, + wantErrMsg: "Received empty data", + }, + { + name: "Garbage data", + data: []byte(`garbage data`), + wantErr: true, + wantErrMsg: "Failed to unmarshal single item", + }, + { + name: "Whitespace input", + data: []byte(" "), + wantErr: true, + wantErrMsg: "Received empty data", + }, + { + name: "Only null byte", + data: []byte("\x00"), + wantErr: true, + wantErrMsg: "Failed to unmarshal single item", + }, + { + name: "Array with null", + data: []byte(`[null]`), + wantErr: true, + wantErrMsg: "Received null data", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var isBatch bool + var items []Request[json.RawMessage] + err := unmarshalMeta(tt.data, &isBatch, &items) + if (err != nil) != tt.wantErr { + t.Fatalf("unmarshalMeta() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if tt.wantErrMsg != "" && !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "unmarshalMeta() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + if isBatch != tt.wantIsBatch { + t.Errorf("unmarshalMeta() isBatch = %v, want %v", isBatch, tt.wantIsBatch) + } + if !compareRequestSlices(items, tt.wantItems) { + t.Errorf("unmarshalMeta() items = %+v, want %+v", items, tt.wantItems) + } + }) + } +} + +// Test marshalMeta with Request[json.RawMessage] +func TestMarshalMeta_Request(t *testing.T) { + tests := []struct { + name string + isBatch bool + items []Request[json.RawMessage] + wantData string + wantErr bool + wantErrMsg string + }{ + { + name: "Single item", + isBatch: false, + items: []Request[json.RawMessage]{ + { + JSONRPC: "2.0", + Method: "subtract", + Params: json.RawMessage(`[42,23]`), + ID: &RequestID{Value: 1}, + }, + }, + wantData: `{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}`, + wantErr: false, + }, + { + name: "Batch items", + isBatch: true, + items: []Request[json.RawMessage]{ + { + JSONRPC: "2.0", + Method: "sum", + Params: json.RawMessage(`[1,2,3]`), + ID: &RequestID{Value: 1}, + }, + { + JSONRPC: "2.0", + Method: "subtract", + Params: json.RawMessage(`[42,23]`), + ID: &RequestID{Value: 2}, + }, + }, + wantData: `[{"jsonrpc":"2.0","method":"sum","params":[1,2,3],"id":1},{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":2}]`, + wantErr: false, + }, + { + name: "Empty items with isBatch=false", + isBatch: false, + items: []Request[json.RawMessage]{}, + wantErr: true, + wantErrMsg: "Received empty input", + }, + { + name: "Empty items with isBatch=true", + isBatch: true, + items: []Request[json.RawMessage]{}, + wantData: `[]`, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := marshalMeta(tt.isBatch, tt.items) + if (err != nil) != tt.wantErr { + t.Fatalf("marshalMeta() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if tt.wantErrMsg != "" && !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "marshalMeta() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + if !jsonStringsEqual(string(data), tt.wantData) { + t.Errorf("marshalMeta() data = %s, want %s", string(data), tt.wantData) + } + + }) + } +} + +// Test Meta[T] UnmarshalJSON and MarshalJSON with MyData +func TestMeta_MyData(t *testing.T) { + tests := []struct { + name string + jsonData string + wantIsBatch bool + wantItems []MyData + wantErr bool + wantErrMsg string + }{ + { + name: "Single item", + jsonData: `{"name": "Item1", "value": 100}`, + wantIsBatch: false, + wantItems: []MyData{ + {Name: "Item1", Value: 100}, + }, + wantErr: false, + }, + { + name: "Batch items", + jsonData: `[{"name": "Item1", "value": 100}, {"name": "Item2", "value": 200}]`, + wantIsBatch: true, + wantItems: []MyData{ + {Name: "Item1", Value: 100}, + {Name: "Item2", Value: 200}, + }, + wantErr: false, + }, + { + name: "Invalid JSON", + jsonData: `{"name": "Item1", "value": 100`, + wantErr: true, + wantErrMsg: "unexpected end of JSON input", + }, + { + name: "Empty input", + jsonData: ` `, + wantErr: true, + wantErrMsg: "unexpected end of JSON input", + }, + { + name: "Invalid field type", + jsonData: `{"name": "Item1", "value": "one hundred"}`, + wantErr: true, + wantErrMsg: "Failed to unmarshal single item", + }, + { + name: "Empty batch", + jsonData: `[]`, + wantIsBatch: true, + wantItems: []MyData{}, + wantErr: false, + }, + { + name: "Valid empty object", + jsonData: `{}`, + wantIsBatch: false, + wantItems: []MyData{{}}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var meta Meta[MyData] + err := json.Unmarshal([]byte(tt.jsonData), &meta) + if (err != nil) != tt.wantErr { + t.Fatalf("Meta.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if tt.wantErrMsg != "" && !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "Meta.UnmarshalJSON() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + if meta.IsBatch != tt.wantIsBatch { + t.Errorf("Meta.UnmarshalJSON() IsBatch = %v, want %v", meta.IsBatch, tt.wantIsBatch) + } + if !reflect.DeepEqual(meta.Items, tt.wantItems) { + t.Errorf("Meta.UnmarshalJSON() Items = %#v, want %#v", meta.Items, tt.wantItems) + } + }) + } +} + +func TestMeta_MarshalJSON(t *testing.T) { + tests := []struct { + name string + meta Meta[MyData] + wantData string + wantErr bool + wantErrMsg string + }{ + { + name: "Single item", + meta: Meta[MyData]{ + IsBatch: false, + Items: []MyData{ + {Name: "Item1", Value: 100}, + }, + }, + wantData: `{"name":"Item1","value":100}`, + wantErr: false, + }, + { + name: "Batch items", + meta: Meta[MyData]{ + IsBatch: true, + Items: []MyData{ + {Name: "Item1", Value: 100}, + {Name: "Item2", Value: 200}, + }, + }, + wantData: `[{"name":"Item1","value":100},{"name":"Item2","value":200}]`, + wantErr: false, + }, + { + name: "Empty items with IsBatch=false", + meta: Meta[MyData]{ + IsBatch: false, + Items: []MyData{}, + }, + wantErr: true, + wantErrMsg: "Received empty input", + }, + { + name: "Empty items with IsBatch=true", + meta: Meta[MyData]{ + IsBatch: true, + Items: []MyData{}, + }, + wantData: `[]`, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(&tt.meta) + if (err != nil) != tt.wantErr { + t.Fatalf("Meta.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if tt.wantErrMsg != "" && !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "Meta.MarshalJSON() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + if !jsonStringsEqual(string(data), tt.wantData) { + t.Errorf("Meta.MarshalJSON() data = %s, want %s", string(data), tt.wantData) + } + }) + } +} + +// Test MetaRequest UnmarshalJSON and MarshalJSON +func TestMetaRequest_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + jsonData string + wantIsBatch bool + wantItems []Request[json.RawMessage] + wantErr bool + wantErrMsg string + }{ + { + name: "Valid single request", + jsonData: `{"jsonrpc":"2.0","method":"sum","params":[1,2,3],"id":1}`, + wantIsBatch: false, + wantItems: []Request[json.RawMessage]{ + { + JSONRPC: "2.0", + Method: "sum", + Params: json.RawMessage(`[1,2,3]`), + ID: &RequestID{Value: 1}, + }, + }, + wantErr: false, + }, + { + name: "Valid batch requests", + jsonData: `[{"jsonrpc":"2.0","method":"sum","params":[1,2,3],"id":1},{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":2}]`, + wantIsBatch: true, + wantItems: []Request[json.RawMessage]{ + { + JSONRPC: "2.0", + Method: "sum", + Params: json.RawMessage(`[1,2,3]`), + ID: &RequestID{Value: 1}, + }, + { + JSONRPC: "2.0", + Method: "subtract", + Params: json.RawMessage(`[42,23]`), + ID: &RequestID{Value: 2}, + }, + }, + wantErr: false, + }, + { + name: "Empty input", + jsonData: ``, + wantErr: true, + wantErrMsg: "unexpected end of JSON input", + }, + { + name: "Null input", + jsonData: `null`, + wantErr: true, + wantErrMsg: "Received null data", + }, + { + name: "Empty batch", + jsonData: `[]`, + wantIsBatch: true, + wantItems: []Request[json.RawMessage]{}, + wantErr: false, + }, + { + name: "Invalid JSON", + jsonData: `{this is not valid JSON}`, + wantErr: true, + wantErrMsg: "invalid character 't' looking for beginning of object key string", + }, + { + name: "Array with null", + jsonData: `[null]`, + wantErr: true, + wantErrMsg: "Received null data", + }, + { + name: "Whitespace input", + jsonData: " ", + wantErr: true, + wantErrMsg: "unexpected end of JSON input", + }, + { + name: "Garbage data", + jsonData: `garbage data`, + wantErr: true, + wantErrMsg: "invalid character 'g' looking for beginning of value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var metaRequest MetaRequest + metaRequest.Body = &Meta[Request[json.RawMessage]]{} + err := json.Unmarshal([]byte(tt.jsonData), metaRequest.Body) + if (err != nil) != tt.wantErr { + t.Fatalf("MetaRequest.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if tt.wantErrMsg != "" && !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "MetaRequest.UnmarshalJSON() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + if metaRequest.Body.IsBatch != tt.wantIsBatch { + t.Errorf( + "MetaRequest.UnmarshalJSON() IsBatch = %v, want %v", + metaRequest.Body.IsBatch, + tt.wantIsBatch, + ) + } + if !compareRequestSlices(metaRequest.Body.Items, tt.wantItems) { + t.Errorf( + "MetaRequest.UnmarshalJSON() Items = %+v, want %+v", + metaRequest.Body.Items, + tt.wantItems, + ) + } + }) + } +} + +func TestMetaRequest_MarshalJSON(t *testing.T) { + tests := []struct { + name string + meta MetaRequest + wantData string + wantErr bool + wantErrMsg string + }{ + { + name: "Single request", + meta: MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{ + { + JSONRPC: "2.0", + Method: "subtract", + Params: json.RawMessage(`[42,23]`), + ID: &RequestID{Value: 1}, + }, + }, + }, + }, + wantData: `{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":1}`, + wantErr: false, + }, + { + name: "Batch requests", + meta: MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: true, + Items: []Request[json.RawMessage]{ + { + JSONRPC: "2.0", + Method: "sum", + Params: json.RawMessage(`[1,2,3]`), + ID: &RequestID{Value: 1}, + }, + { + JSONRPC: "2.0", + Method: "subtract", + Params: json.RawMessage(`[42,23]`), + ID: &RequestID{Value: 2}, + }, + }, + }, + }, + wantData: `[{"jsonrpc":"2.0","method":"sum","params":[1,2,3],"id":1},{"jsonrpc":"2.0","method":"subtract","params":[42,23],"id":2}]`, + wantErr: false, + }, + { + name: "Empty items with IsBatch=false", + meta: MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: false, + Items: []Request[json.RawMessage]{}, + }, + }, + wantErr: true, + wantErrMsg: "Received empty input", + }, + { + name: "Empty items with IsBatch=true", + meta: MetaRequest{ + Body: &Meta[Request[json.RawMessage]]{ + IsBatch: true, + Items: []Request[json.RawMessage]{}, + }, + }, + wantData: `[]`, + wantErr: false, + }, + { + name: "Nil Body", + meta: MetaRequest{ + Body: nil, + }, + wantErr: false, + wantData: "null", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.meta.Body) + if (err != nil) != tt.wantErr { + t.Fatalf("MetaRequest.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "MetaRequest.MarshalJSON() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + + if !jsonStringsEqual(string(data), tt.wantData) { + t.Errorf("MetaRequest.MarshalJSON() data = %s, want %s", string(data), tt.wantData) + } + }) + } +} + +// Test MetaResponse UnmarshalJSON and MarshalJSON +func TestMetaResponse_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + jsonData string + wantIsBatch bool + wantItems []Response[json.RawMessage] + wantErr bool + wantErrMsg string + }{ + { + name: "Valid single response", + jsonData: `{"jsonrpc":"2.0","result":7,"id":1}`, + wantIsBatch: false, + wantItems: []Response[json.RawMessage]{ + { + JSONRPC: "2.0", + Result: json.RawMessage(`7`), + ID: &RequestID{Value: 1}, + }, + }, + wantErr: false, + }, + { + name: "Valid batch responses", + jsonData: `[{"jsonrpc":"2.0","result":7,"id":1},{"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":2}]`, + wantIsBatch: true, + wantItems: []Response[json.RawMessage]{ + { + JSONRPC: "2.0", + Result: json.RawMessage(`7`), + ID: &RequestID{Value: 1}, + }, + { + JSONRPC: "2.0", + Error: &JSONRPCError{ + Code: -32601, + Message: "Method not found", + }, + ID: &RequestID{Value: 2}, + }, + }, + wantErr: false, + }, + { + name: "Empty input", + jsonData: ``, + wantErr: true, + wantErrMsg: "unexpected end of JSON input", + }, + { + name: "Null input", + jsonData: `null`, + wantErr: true, + wantErrMsg: "Received null data", + }, + { + name: "Empty batch", + jsonData: `[]`, + wantIsBatch: true, + wantItems: []Response[json.RawMessage]{}, + wantErr: false, + }, + { + name: "Invalid JSON", + jsonData: `{this is not valid JSON}`, + wantErr: true, + wantErrMsg: "invalid character 't' looking for beginning of object key string", + }, + { + name: "Array with null", + jsonData: `[null]`, + wantErr: true, + wantErrMsg: "Received null data", + }, + { + name: "Whitespace input", + jsonData: " ", + wantErr: true, + wantErrMsg: "unexpected end of JSON input", + }, + { + name: "Garbage data", + jsonData: `garbage data`, + wantErr: true, + wantErrMsg: "invalid character 'g' looking for beginning of value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var metaResponse MetaResponse + metaResponse.Body = &Meta[Response[json.RawMessage]]{} + err := json.Unmarshal([]byte(tt.jsonData), metaResponse.Body) + if (err != nil) != tt.wantErr { + t.Fatalf("MetaResponse.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if tt.wantErrMsg != "" && !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "MetaResponse.UnmarshalJSON() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + if metaResponse.Body.IsBatch != tt.wantIsBatch { + t.Errorf( + "MetaResponse.UnmarshalJSON() IsBatch = %v, want %v", + metaResponse.Body.IsBatch, + tt.wantIsBatch, + ) + } + eq, err := jsonStructEqual(metaResponse.Body.Items, tt.wantItems) + if err != nil || !eq { + t.Errorf( + "MetaResponse.UnmarshalJSON() Items = %+v, want %+v", + metaResponse.Body.Items, + tt.wantItems, + ) + } + }) + } +} + +func TestMetaResponse_MarshalJSON(t *testing.T) { + tests := []struct { + name string + meta MetaResponse + wantData string + wantErr bool + wantErrMsg string + }{ + { + name: "Single response with result", + meta: MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{ + { + JSONRPC: "2.0", + Result: json.RawMessage(`7`), + ID: &RequestID{Value: 1}, + }, + }, + }, + }, + wantData: `{"jsonrpc":"2.0","result":7,"id":1}`, + wantErr: false, + }, + { + name: "Single response with error", + meta: MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{ + { + JSONRPC: "2.0", + Error: &JSONRPCError{ + Code: -32601, + Message: "Method not found", + }, + ID: &RequestID{Value: 2}, + }, + }, + }, + }, + wantData: `{"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":2}`, + wantErr: false, + }, + { + name: "Batch responses", + meta: MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: true, + Items: []Response[json.RawMessage]{ + { + JSONRPC: "2.0", + Result: json.RawMessage(`7`), + ID: &RequestID{Value: 1}, + }, + { + JSONRPC: "2.0", + Error: &JSONRPCError{ + Code: -32601, + Message: "Method not found", + }, + ID: &RequestID{Value: 2}, + }, + }, + }, + }, + wantData: `[{"jsonrpc":"2.0","result":7,"id":1},{"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":2}]`, + wantErr: false, + }, + { + name: "Empty items with IsBatch=false", + meta: MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: false, + Items: []Response[json.RawMessage]{}, + }, + }, + wantErr: true, + wantErrMsg: "Received empty input", + }, + { + name: "Empty items with IsBatch=true", + meta: MetaResponse{ + Body: &Meta[Response[json.RawMessage]]{ + IsBatch: true, + Items: []Response[json.RawMessage]{}, + }, + }, + wantData: `[]`, + wantErr: false, + }, + { + name: "Nil Body", + meta: MetaResponse{ + Body: nil, + }, + wantData: "null", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.meta.Body) + if (err != nil) != tt.wantErr { + t.Fatalf("MetaResponse.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if err != nil { + if !strings.Contains(err.Error(), tt.wantErrMsg) { + t.Errorf( + "MetaResponse.MarshalJSON() error message = %v, want %v", + err.Error(), + tt.wantErrMsg, + ) + } + return + } + + if !jsonStringsEqual(string(data), tt.wantData) { + t.Errorf("MetaResponse.MarshalJSON() data = %s, want %s", string(data), tt.wantData) + } + }) + } +} diff --git a/jsonrpc/type_spec.go b/jsonrpc/type_spec.go new file mode 100644 index 00000000..52432a04 --- /dev/null +++ b/jsonrpc/type_spec.go @@ -0,0 +1,30 @@ +package jsonrpc + +// http://www.jsonrpc.org/specification +const JSONRPCVersion = "2.0" + +// RequestID can be a int or a string +// Do a type alias as we want marshal/unmarshal etc to be available +type RequestID = IntString + +type Request[T any] struct { + // Support JSON RPC v2. + JSONRPC string `json:"jsonrpc" enum:"2.0" doc:"JSON-RPC version, must be '2.0'" required:"true"` + ID *RequestID `json:"id,omitempty" doc:"RequestID is int or string for methods and absent for notifications"` + Method string `json:"method" doc:"Method to invoke" required:"true"` + Params T `json:"params,omitempty" doc:"Method parameters"` +} + +type Response[T any] struct { + JSONRPC string `json:"jsonrpc" required:"true"` + ID *RequestID `json:"id,omitempty"` + Result T `json:"result,omitempty"` + Error *JSONRPCError `json:"error,omitempty"` +} + +// A notification which does not expect a response. +type Notification[T any] struct { + JSONRPC string `json:"jsonrpc"` + Method string `json:"method"` + Params T `json:"params,omitempty"` +} diff --git a/jsonrpc/utils_test.go b/jsonrpc/utils_test.go new file mode 100644 index 00000000..c0164aad --- /dev/null +++ b/jsonrpc/utils_test.go @@ -0,0 +1,65 @@ +package jsonrpc + +import ( + "bytes" + "encoding/json" + "errors" + "reflect" +) + +func jsonEqual(a, b json.RawMessage) bool { + var o1 interface{} + var o2 interface{} + + if err := json.Unmarshal(a, &o1); err != nil { + return false + } + if err := json.Unmarshal(b, &o2); err != nil { + return false + } + // Direct reflect Deepequal would have issues when there are pointers, keyorders etc. + // unmarshalling into a interface and then doing deepequal removes those issues + return reflect.DeepEqual(o1, o2) +} + +func jsonStringsEqual(a, b string) bool { + return jsonEqual([]byte(a), []byte(b)) +} + +func getJSONStrings(args ...interface{}) ([]string, error) { + ret := make([]string, 0, len(args)) + for _, a := range args { + jsonBytes, err := json.Marshal(a) + if err != nil { + return nil, err + } + ret = append(ret, string(jsonBytes)) + } + return ret, nil +} + +func jsonStructEqual(arg1 interface{}, arg2 interface{}) (bool, error) { + vals, err := getJSONStrings(arg1, arg2) + if err != nil { + return false, errors.New("Could not encode struct to json") + } + return jsonStringsEqual(vals[0], vals[1]), nil +} + +func compareRequestSlices(a, b []Request[json.RawMessage]) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i].JSONRPC != b[i].JSONRPC || a[i].Method != b[i].Method { + return false + } + if !a[i].ID.Equal(b[i].ID) { + return false + } + if !bytes.Equal(a[i].Params, b[i].Params) { + return false + } + } + return true +}