diff --git a/middleware.go b/middleware.go index 1c7c0b0..f9af9ef 100644 --- a/middleware.go +++ b/middleware.go @@ -8,7 +8,6 @@ import ( "net/http" "runtime/debug" "slices" - "strconv" "strings" "time" @@ -17,6 +16,7 @@ import ( "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/propagation" + semconv "go.opentelemetry.io/otel/semconv/v1.37.0" "go.opentelemetry.io/otel/trace" ) @@ -102,41 +102,31 @@ func (tm *tracingMiddleware) ServeHTTP( //nolint:gocognit,cyclop,funlen,maintidx ctx := r.Context() span := trace.SpanFromContext(ctx) - hostName, rawPort, _ := strings.Cut(r.Host, ":") - port := 80 + urlScheme := r.URL.Scheme + if urlScheme == "" { + urlScheme = "http" + } - if rawPort != "" { - p, err := strconv.Atoi(rawPort) - if err == nil { - port = p - } - } else if strings.HasPrefix(r.URL.Scheme, "https") { + hostName, port, _ := SplitHostPort(r.Host) + + switch { + case port > 0: + case urlScheme == "https": port = 443 + default: + port = 80 } metricAttrs := []attribute.KeyValue{ - attribute.String("http.request.method", r.Method), - attribute.String("url.scheme", r.URL.Scheme), - attribute.String("server.address", hostName), - attribute.Int("server.port", port), + { + Key: semconv.HTTPRequestMethodKey, + Value: attribute.StringValue(r.Method), + }, + semconv.URLScheme(urlScheme), + semconv.ServerAddress(hostName), + semconv.ServerPort(port), + semconv.ClientAddress(r.RemoteAddr), } - requestPathAttr := attribute.String("http.request.path", r.URL.Path) - - if !tm.Options.HighCardinalityMetricDisabled { - metricAttrs = append(metricAttrs, requestPathAttr) - } - - activeRequestsAttrSet := metric.WithAttributeSet(attribute.NewSet(metricAttrs...)) - - tm.ActiveRequestsMetric.Add(ctx, 1, activeRequestsAttrSet) - - metricAttrs = append( - metricAttrs, - attribute.String( - "network.protocol.version", - fmt.Sprintf("%d.%d", r.ProtoMajor, r.ProtoMinor), - ), - ) if !slices.Contains(tm.Options.DebugPaths, strings.ToLower(r.URL.Path)) { ctx, span = tm.Exporters.Tracer.Start( @@ -157,7 +147,37 @@ func (tm *tracingMiddleware) ServeHTTP( //nolint:gocognit,cyclop,funlen,maintidx // Add HTTP semantic attributes to the server span // See: https://opentelemetry.io/docs/specs/semconv/http/http-spans/#http-server-semantic-conventions span.SetAttributes(metricAttrs...) - span.SetAttributes(requestPathAttr) + + if !tm.Options.HighCardinalityMetricDisabled { + metricAttrs = append(metricAttrs, attribute.String("http.request.path", r.URL.Path)) + } + + activeRequestsAttrSet := metric.WithAttributeSet(attribute.NewSet(metricAttrs...)) + + tm.ActiveRequestsMetric.Add(ctx, 1, activeRequestsAttrSet) + + protocolAttr := semconv.NetworkProtocolVersion(fmt.Sprintf("%d.%d", r.ProtoMajor, r.ProtoMinor)) + + metricAttrs = append( + metricAttrs, + protocolAttr, + ) + + span.SetAttributes( + protocolAttr, + semconv.URLFull(r.URL.String()), + semconv.UserAgentOriginal(r.UserAgent()), + ) + + peer, peerPort, _ := SplitHostPort(r.RemoteAddr) + + if peer != "" { + span.SetAttributes(semconv.NetworkPeerAddress(peer)) + + if peerPort > 0 { + span.SetAttributes(semconv.NetworkPeerPort(peerPort)) + } + } requestBodySize := r.ContentLength requestLogHeaders := NewTelemetryHeaders(r.Header, tm.Options.AllowedRequestHeaders...) @@ -198,12 +218,10 @@ func (tm *tracingMiddleware) ServeHTTP( //nolint:gocognit,cyclop,funlen,maintidx activeRequestsAttrSet, ) - statusCodeAttr := attribute.Int( - "http.response.status_code", - statusCode, - ) - latency := time.Since(start).Seconds() + statusCodeAttr := semconv.HTTPResponseStatusCode(statusCode) + span.SetAttributes(statusCodeAttr) + latency := time.Since(start).Seconds() responseLogData["status"] = statusCode logAttrs := []any{ @@ -215,7 +233,7 @@ func (tm *tracingMiddleware) ServeHTTP( //nolint:gocognit,cyclop,funlen,maintidx if err != nil { stack := string(debug.Stack()) logAttrs = append(logAttrs, slog.Any("error", err), slog.String("stacktrace", stack)) - span.SetAttributes(statusCodeAttr, attribute.String("stacktrace", stack)) + span.SetAttributes(semconv.ExceptionStacktrace(stack)) } metricAttrs = append(metricAttrs, statusCodeAttr) @@ -265,7 +283,7 @@ func (tm *tracingMiddleware) ServeHTTP( //nolint:gocognit,cyclop,funlen,maintidx if requestBodySize > 0 { requestLogData["size"] = requestBodySize - span.SetAttributes(attribute.Int64("http.request.body.size", requestBodySize)) + span.SetAttributes(semconv.HTTPRequestBodySize(int(requestBodySize))) } defer func() { @@ -284,9 +302,9 @@ func (tm *tracingMiddleware) ServeHTTP( //nolint:gocognit,cyclop,funlen,maintidx errBytes, jsonErr := json.Marshal(err) if jsonErr != nil { - span.SetAttributes(attribute.String("error", fmt.Sprintf("%v", err))) + span.SetAttributes(attribute.String("exception.error", fmt.Sprintf("%v", err))) } else { - span.SetAttributes(attribute.String("error", string(errBytes))) + span.SetAttributes(attribute.String("exception.error", string(errBytes))) } } }() @@ -300,7 +318,7 @@ func (tm *tracingMiddleware) ServeHTTP( //nolint:gocognit,cyclop,funlen,maintidx responseLogData["size"] = ww.BytesWritten() responseLogData["headers"] = responseLogHeaders - span.SetAttributes(attribute.Int("http.response.body.size", ww.BytesWritten())) + span.SetAttributes(semconv.HTTPResponseBodySize(ww.BytesWritten())) SetSpanHeaderAttributes(span, "http.response.header", responseLogHeaders) // skip printing very large responses. diff --git a/utils.go b/utils.go index 911e75c..d7634ff 100644 --- a/utils.go +++ b/utils.go @@ -3,12 +3,15 @@ package gotel import ( "bytes" "encoding/json" + "errors" "fmt" "io" "log/slog" + "net" "net/http" "regexp" "slices" + "strconv" "strings" "github.com/google/uuid" @@ -23,6 +26,21 @@ const ( contentTypeHeader = "Content-Type" ) +var excludedSpanHeaderAttributes = map[string]bool{ + "baggage": true, + "traceparent": true, + "traceresponse": true, + "tracestate": true, + "x-b3-sampled": true, + "x-b3-spanid": true, + "x-b3-traceid": true, + "x-b3-parentspanid": true, + "x-b3-flags": true, + "b3": true, +} + +var errInvalidHostPort = errors.New("invalid host port") + // SetSpanHeaderAttributes sets header attributes to the otel span. func SetSpanHeaderAttributes( span trace.Span, @@ -33,8 +51,13 @@ func SetSpanHeaderAttributes( allowedHeadersLength := len(allowedHeaders) for key, values := range headers { - if allowedHeadersLength == 0 || slices.Contains(allowedHeaders, strings.ToLower(key)) { - span.SetAttributes(attribute.StringSlice(prefix+strings.ToLower(key), values)) + lowerKey := strings.ToLower(key) + + if (allowedHeadersLength == 0 && !excludedSpanHeaderAttributes[lowerKey]) || + (allowedHeadersLength > 0 && slices.Contains(allowedHeaders, lowerKey)) { + span.SetAttributes( + attribute.StringSlice(fmt.Sprintf("%s.%s", prefix, lowerKey), values), + ) } } } @@ -95,10 +118,51 @@ func MaskString(input string) string { case inputLength < 12: return input[0:1] + strings.Repeat("*", inputLength-1) default: - return input[0:3] + strings.Repeat("*", 7) + fmt.Sprintf("(%d)", inputLength) + return input[0:2] + strings.Repeat("*", 8) + fmt.Sprintf("(%d)", inputLength) } } +// SplitHostPort splits a network address hostport of the form "host", +// "host%zone", "[host]", "[host%zone]", "host:port", "host%zone:port", +// "[host]:port", "[host%zone]:port", or ":port" into host or host%zone and +// port. +// +// An empty host is returned if it is not provided or unparsable. A negative +// port is returned if it is not provided or unparsable. +func SplitHostPort(hostport string) (string, int, error) { + port := -1 + + if strings.HasPrefix(hostport, "[") { + addrEnd := strings.LastIndex(hostport, "]") + if addrEnd < 0 { + // Invalid hostport. + return "", port, errInvalidHostPort + } + + if i := strings.LastIndex(hostport[addrEnd:], ":"); i < 0 { + host := hostport[1:addrEnd] + + return host, port, nil + } + } else { + if i := strings.LastIndex(hostport, ":"); i < 0 { + return hostport, port, nil + } + } + + host, pStr, err := net.SplitHostPort(hostport) + if err != nil { + return host, port, err + } + + p, err := strconv.ParseUint(pStr, 10, 16) + if err != nil { + return "", port, err + } + + return host, int(p), err +} + // returns the value or default one if value is empty. func getDefault[T comparable](value T, defaultValue T) T { var empty T diff --git a/utils_test.go b/utils_test.go index 6c5c9f1..5211766 100644 --- a/utils_test.go +++ b/utils_test.go @@ -26,7 +26,7 @@ func TestNewTelemetryHeaders(t *testing.T) { }, Expected: http.Header{ "Content-Type": []string{"application/json"}, - "Authorization": []string{"Bea*******(65)"}, + "Authorization": []string{"Be********(65)"}, "Api-Key": []string{"******"}, "Secret-Key": []string{"s*********"}, },