diff --git a/.golangci.yml b/.golangci.yml index 46547fe..c35e558 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,29 +1,442 @@ +# This file is licensed under the terms of the MIT license https://opensource.org/license/mit +# Copyright (c) 2021-2025 Marat Reymers + +## Golden config for golangci-lint v2.1.6 +# +# This is the best config for golangci-lint based on my experience and opinion. +# It is very strict, but not extremely strict. +# Feel free to adapt it to suit your needs. +# If this config helps you, please consider keeping a link to this file (see the next comment). + +# Based on https://gist.github.com/maratori/47a4d00457a92aa426dbd48a18776322 + version: "2" + +issues: + # Maximum count of issues with the same text. + # Set to 0 to disable. + # Default: 3 + max-same-issues: 50 + +formatters: + enable: + - goimports # checks if the code and import statements are formatted according to the 'goimports' command + - golines # checks if code is formatted, and fixes long lines + + ## you may want to enable + #- gci # checks if code and import statements are formatted, with additional rules + #- gofmt # checks if the code is formatted according to 'gofmt' command + #- gofumpt # enforces a stricter format than 'gofmt', while being backwards compatible + + # All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml + settings: + goimports: + # A list of prefixes, which, if set, checks import paths + # with the given prefixes are grouped after 3rd-party packages. + # Default: [] + local-prefixes: + - github.com/pitabwire/frame + + golines: + # Target maximum line length. + # Default: 100 + max-len: 120 + linters: enable: - - gocyclo - - misspell - - unconvert - - unparam - - staticcheck - - unused + - asasalint # checks for pass []any as any in variadic func(...any) + - asciicheck # checks that your code does not contain non-ASCII identifiers + - bidichk # checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - canonicalheader # checks whether net/http.Header uses canonical header + - copyloopvar # detects places where loop variables are copied (Go 1.22+) + - cyclop # checks function and package cyclomatic complexity + - depguard # checks if package imports are in a list of acceptable packages + - dupl # tool for code clone detection + - durationcheck # checks for two durations multiplied together + - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases + - errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error + - errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13 + - exhaustive # checks exhaustiveness of enum switch statements + - exptostd # detects functions from golang.org/x/exp/ that can be replaced by std functions + - fatcontext # detects nested contexts in loops + - forbidigo # forbids identifiers + #- funcorder # checks the order of functions, methods, and constructors + - funlen # tool for detection of long functions + - gocheckcompilerdirectives # validates go compiler directive comments (//go:) + - gochecknoglobals # checks that no global variables exist + - gochecknoinits # checks that no init functions are present in Go code + - gochecksumtype # checks exhaustiveness on Go "sum types" + - gocognit # computes and checks the cognitive complexity of functions + - goconst # finds repeated strings that could be replaced by a constant + - gocritic # provides diagnostics that check for bugs, performance and style issues + - gocyclo # computes and checks the cyclomatic complexity of functions + - godot # checks if comments end in a period + - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod + - goprintffuncname # checks that printf-like functions are named with f at the end + - gosec # inspects source code for security problems + - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - iface # checks the incorrect use of interfaces, helping developers avoid interface pollution + - ineffassign # detects when assignments to existing variables are not used + - intrange # finds places where for loops could make use of an integer range + - loggercheck # checks key value pairs for common logger libraries (kitlog,klog,logr,zap) + - makezero # finds slice declarations with non-zero initial length + - mirror # reports wrong mirror patterns of bytes/strings usage + - mnd # detects magic numbers + - musttag # enforces field tags in (un)marshaled structs + - nakedret # finds naked returns in functions greater than a specified function length + - nestif # reports deeply nested if statements + - nilerr # finds the code that returns nil even if it checks that the error is not nil + - nilnesserr # reports that it checks for err != nil, but it returns a different nil value error (powered by nilness and nilerr) + - nilnil # checks that there is no simultaneous return of nil error and an invalid value + - noctx # finds sending http request without context.Context + - nolintlint # reports ill-formed or insufficient nolint directives + - nonamedreturns # reports all named returns + - nosprintfhostport # checks for misuse of Sprintf to construct a host with port in a URL + - perfsprint # checks that fmt.Sprintf can be replaced with a faster alternative + - predeclared # finds code that shadows one of Go's predeclared identifiers + - promlinter # checks Prometheus metrics naming via promlint + - protogetter # reports direct reads from proto message fields when getters should be used + - reassign # checks that package variables are not reassigned + - recvcheck # checks for receiver type consistency + - revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint + - rowserrcheck # checks whether Err of rows is checked successfully + - sloglint # ensure consistent code style when using log/slog + - spancheck # checks for mistakes with OpenTelemetry/Census spans + - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed + - staticcheck # is a go vet on steroids, applying a ton of static analysis checks + - testableexamples # checks if examples are testable (have an expected output) + - testifylint # checks usage of github.com/stretchr/testify + - testpackage # makes you use a separate _test package + - tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes + - unconvert # removes unnecessary type conversions + - unparam # reports unused function parameters + - unused # checks for unused constants, variables, functions and types + - usestdlibvars # detects the possibility to use variables/constants from the Go standard library + - usetesting # reports uses of functions with replacement inside the testing package + - wastedassign # finds wasted assignment statements + - whitespace # detects leading and trailing whitespace + + ## you may want to enable + #- decorder # checks declaration order and count of types, constants, variables and functions + #- exhaustruct # [highly recommend to enable] checks if all structure fields are initialized + #- ginkgolinter # [if you use ginkgo/gomega] enforces standards of using ginkgo and gomega + #- godox # detects usage of FIXME, TODO and other keywords inside comments + #- goheader # checks is file header matches to pattern + #- inamedparam # [great idea, but too strict, need to ignore a lot of cases by default] reports interfaces with unnamed method parameters + #- interfacebloat # checks the number of methods inside an interface + #- ireturn # accept interfaces, return concrete types + #- prealloc # [premature optimization, but can be used in some cases] finds slice declarations that could potentially be preallocated + #- tagalign # checks that struct tags are well aligned + #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope + #- wrapcheck # checks that errors returned from external packages are wrapped + #- zerologlint # detects the wrong usage of zerolog that a user forgets to dispatch zerolog.Event + + ## disabled + #- containedctx # detects struct contained context.Context field + #- contextcheck # [too many false positives] checks the function whether use a non-inherited context + #- dogsled # checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) + #- dupword # [useless without config] checks for duplicate words in the source code + #- err113 # [too strict] checks the errors handling expressions + #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted + #- forcetypeassert # [replaced by errcheck] finds forced type assertions + #- gomodguard # [use more powerful depguard] allow and block lists linter for direct Go module dependencies + #- gosmopolitan # reports certain i18n/l10n anti-patterns in your Go codebase + #- grouper # analyzes expression groups + #- importas # enforces consistent import aliases + #- lll # [replaced by golines] reports long lines + #- maintidx # measures the maintainability index of each function + #- misspell # [useless] finds commonly misspelled English words in comments + #- nlreturn # [too strict and mostly code is not more readable] checks for a new line before return and branch statements to increase code clarity + #- paralleltest # [too many false positives] detects missing usage of t.Parallel() method in your Go test + #- tagliatelle # checks the struct tags + #- thelper # detects golang test helpers without t.Helper() call and checks the consistency of test helpers + #- wsl # [too strict and mostly code is not more readable] whitespace linter forces you to use empty lines + + # All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml + settings: + cyclop: + # The maximal code complexity to report. + # Default: 10 + max-complexity: 30 + # The maximal average package complexity. + # If it's higher than 0.0 (float) the check is enabled. + # Default: 0.0 + package-average: 10.0 + + depguard: + # Rules to apply. + # + # Variables: + # - File Variables + # Use an exclamation mark `!` to negate a variable. + # Example: `!$test` matches any file that is not a go test file. + # + # `$all` - matches all go files + # `$test` - matches all go test files + # + # - Package Variables + # + # `$gostd` - matches all of go's standard library (Pulled from `GOROOT`) + # + # Default (applies if no custom rules are defined): Only allow $gostd in all files. + rules: + "deprecated": + # List of file globs that will match this list of settings to compare against. + # By default, if a path is relative, it is relative to the directory where the golangci-lint command is executed. + # The placeholder '${base-path}' is substituted with a path relative to the mode defined with `run.relative-path-mode`. + # The placeholder '${config-path}' is substituted with a path relative to the configuration file. + # Default: $all + files: + - "$all" + # List of packages that are not allowed. + # Entries can be a variable (starting with $), a string prefix, or an exact match (if ending with $). + # Default: [] + deny: + - pkg: github.com/golang/protobuf + desc: Use google.golang.org/protobuf instead, see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules + - pkg: github.com/satori/go.uuid + desc: Use github.com/google/uuid instead, satori's package is not maintained + - pkg: github.com/gofrs/uuid$ + desc: Use github.com/gofrs/uuid/v5 or later, it was not a go module before v5 + "non-test files": + files: + - "!$test" + deny: + - pkg: math/rand$ + desc: Use math/rand/v2 instead, see https://go.dev/blog/randv2 + "non-main files": + files: + - "!**/main.go" + deny: + - pkg: log$ + desc: Use log/slog instead, see https://go.dev/blog/slog + + errcheck: + # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. + # Such cases aren't reported by default. + # Default: false + check-type-assertions: true + + exhaustive: + # Program elements to check for exhaustiveness. + # Default: [ switch ] + check: + - switch + - map + + exhaustruct: + # List of regular expressions to exclude struct packages and their names from checks. + # Regular expressions must match complete canonical struct package/name/structname. + # Default: [] + exclude: + # std libs + - ^net/http.Client$ + - ^net/http.Cookie$ + - ^net/http.Request$ + - ^net/http.Response$ + - ^net/http.Server$ + - ^net/http.Transport$ + - ^net/url.URL$ + - ^os/exec.Cmd$ + - ^reflect.StructField$ + # public libs + - ^github.com/Shopify/sarama.Config$ + - ^github.com/Shopify/sarama.ProducerMessage$ + - ^github.com/mitchellh/mapstructure.DecoderConfig$ + - ^github.com/prometheus/client_golang/.+Opts$ + - ^github.com/spf13/cobra.Command$ + - ^github.com/spf13/cobra.CompletionOptions$ + - ^github.com/stretchr/testify/mock.Mock$ + - ^github.com/testcontainers/testcontainers-go.+Request$ + - ^github.com/testcontainers/testcontainers-go.FromDockerfile$ + - ^golang.org/x/tools/go/analysis.Analyzer$ + - ^google.golang.org/protobuf/.+Options$ + - ^gopkg.in/yaml.v3.Node$ + + funcorder: + # Checks if the exported methods of a structure are placed before the non-exported ones. + # Default: true + struct-method: false + + funlen: + # Checks the number of lines in a function. + # If lower than 0, disable the check. + # Default: 60 + lines: 100 + # Checks the number of statements in a function. + # If lower than 0, disable the check. + # Default: 40 + statements: 50 + + gochecksumtype: + # Presence of `default` case in switch statements satisfies exhaustiveness, if all members are not listed. + # Default: true + default-signifies-exhaustive: false + + gocognit: + # Minimal code complexity to report. + # Default: 30 (but we recommend 10-20) + min-complexity: 20 + + gocritic: + # Settings passed to gocritic. + # The settings key is the name of a supported gocritic checker. + # The list of supported checkers can be found at https://go-critic.com/overview. + settings: + captLocal: + # Whether to restrict checker to params only. + # Default: true + paramsOnly: false + underef: + # Whether to skip (*x).method() calls where x is a pointer receiver. + # Default: true + skipRecvDeref: false + + govet: + # Enable all analyzers. + # Default: false + enable-all: true + # Disable analyzers by name. + # Run `GL_DEBUG=govet golangci-lint run --enable=govet` to see default, all available analyzers, and enabled analyzers. + # Default: [] + disable: + - fieldalignment # too strict + # Settings per analyzer. + settings: + shadow: + # Whether to be strict about shadowing; can be noisy. + # Default: false + strict: true + + inamedparam: + # Skips check for interface methods with only a single parameter. + # Default: false + skip-single-param: true + + mnd: + # List of function patterns to exclude from analysis. + # Values always ignored: `time.Date`, + # `strconv.FormatInt`, `strconv.FormatUint`, `strconv.FormatFloat`, + # `strconv.ParseInt`, `strconv.ParseUint`, `strconv.ParseFloat`. + # Default: [] + ignored-functions: + - args.Error + - flag.Arg + - flag.Duration.* + - flag.Float.* + - flag.Int.* + - flag.Uint.* + - os.Chmod + - os.Mkdir.* + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets.* + - prometheus.LinearBuckets + + nakedret: + # Make an issue if func has more lines of code than this setting, and it has naked returns. + # Default: 30 + max-func-lines: 0 + + nolintlint: + # Exclude following linters from requiring an explanation. + # Default: [] + allow-no-explanation: [ funlen, gocognit, golines ] + # Enable to require an explanation of nonzero length after each nolint directive. + # Default: false + require-explanation: true + # Enable to require nolint directives to mention the specific linter being suppressed. + # Default: false + require-specific: true + + perfsprint: + # Optimizes into strings concatenation. + # Default: true + strconcat: false + + reassign: + # Patterns for global variable names that are checked for reassignment. + # See https://github.com/curioswitch/go-reassign#usage + # Default: ["EOF", "Err.*"] + patterns: + - ".*" + + rowserrcheck: + # database/sql is always checked. + # Default: [] + packages: + - github.com/jmoiron/sqlx + + sloglint: + # Enforce not using global loggers. + # Values: + # - "": disabled + # - "all": report all global loggers + # - "default": report only the default slog logger + # https://github.com/go-simpler/sloglint?tab=readme-ov-file#no-global + # Default: "" + no-global: all + # Enforce using methods that accept a context. + # Values: + # - "": disabled + # - "all": report all contextless calls + # - "scope": report only if a context exists in the scope of the outermost function + # https://github.com/go-simpler/sloglint?tab=readme-ov-file#context-only + # Default: "" + context: scope + + staticcheck: + # SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks + # Example (to disable some checks): [ "all", "-SA1000", "-SA1001"] + # Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"] + checks: + - all + # Incorrect or missing package comment. + # https://staticcheck.dev/docs/checks/#ST1000 + - -ST1000 + # Use consistent method receiver names. + # https://staticcheck.dev/docs/checks/#ST1016 + - -ST1016 + # Omit embedded fields from selector expression. + # https://staticcheck.dev/docs/checks/#QF1008 + - -QF1008 + + usetesting: + # Enable/disable `os.TempDir()` detections. + # Default: false + os-temp-dir: true + exclusions: - generated: lax + # Log a warning if an exclusion rule is unused. + # Default: false + warn-unused: true + # Predefined exclusion rules. + # Default: [] presets: - - comments - - common-false-positives - - legacy - std-error-handling - paths: - - third_party$ - - builtin$ - - examples$ -formatters: - enable: - - goimports - exclusions: - generated: lax - paths: - - third_party$ - - builtin$ - - examples$ + - common-false-positives + # Excluding configuration per-path, per-linter, per-text and per-source. + rules: + - source: 'TODO' + linters: [ godot ] + - text: 'should have a package comment' + linters: [ revive ] + - text: 'exported \S+ \S+ should have comment( \(or a comment on this block\))? or be unexported' + linters: [ revive ] + - text: 'package comment should be of the form ".+"' + source: '// ?(nolint|TODO)' + linters: [ revive ] + - text: 'comment on exported \S+ \S+ should be of the form ".+"' + source: '// ?(nolint|TODO)' + linters: [ revive, staticcheck ] + - path: '_test\.go' + linters: + - bodyclose + - dupl + - errcheck + - funlen + - goconst + - gosec + - noctx + - wrapcheck \ No newline at end of file diff --git a/Makefile b/Makefile index 1d5953c..092b2ad 100644 --- a/Makefile +++ b/Makefile @@ -35,6 +35,11 @@ vet: ## run go vet on the source files doc: ## generate godocs and start a local documentation webserver on port 8085 godoc -http=:8085 -index +format: + find . -name '*.go' -not -path './.git/*' -exec sed -i '/^import (/,/^)/{/^$$/d}' {} + + find . -name '*.go' -not -path './.git/*' -exec goimports -w {} + + golangci-lint run --fix + # this command will run all tests in the repo # INTEGRATION_TEST_SUITE_PATH is used to run specific tests in Golang, # if it's not specified it will run all tests diff --git a/context.go b/context.go index 40a85be..99366de 100644 --- a/context.go +++ b/context.go @@ -2,16 +2,18 @@ package util import ( "context" - - log "github.com/sirupsen/logrus" ) // contextKeys is a type alias for string to namespace Context keys per-package. type contextKeys string -// ctxValueRequestID is the key to extract the request ID for an HTTP request +// ctxValueRequestID is the key to extract the request ID for an HTTP request. const ctxValueRequestID = contextKeys("requestid") +func ContextWithRequestID(ctx context.Context, requestID string) context.Context { + return context.WithValue(ctx, ctxValueRequestID, requestID) +} + // GetRequestID returns the request ID associated with this context, or the empty string // if one is not associated with this context. func GetRequestID(ctx context.Context) string { @@ -19,24 +21,9 @@ func GetRequestID(ctx context.Context) string { if id == nil { return "" } - return id.(string) -} - -// ctxValueLogger is the key to extract the logrus Logger. -const ctxValueLogger = contextKeys("logger") - -// GetLogger retrieves the logrus logger from the supplied context. Always returns a logger, -// even if there wasn't one originally supplied. -func GetLogger(ctx context.Context) *log.Entry { - l := ctx.Value(ctxValueLogger) - if l == nil { - // Always return a logger so callers don't need to constantly nil check. - return log.WithField("context", "missing") + str, ok := id.(string) + if !ok { + return "" } - return l.(*log.Entry) -} - -// ContextWithLogger creates a new context, which will use the given logger. -func ContextWithLogger(ctx context.Context, l *log.Entry) context.Context { - return context.WithValue(ctx, ctxValueLogger, l) + return str } diff --git a/examples_test.go b/examples_test.go new file mode 100644 index 0000000..80c7a23 --- /dev/null +++ b/examples_test.go @@ -0,0 +1,54 @@ +package util_test + +import ( + "io" + "log/slog" + "os" + "testing" + + "github.com/pitabwire/util" +) + +func TestCustomHandler(t *testing.T) { + // Create a new logger with the custom handler + logger := util.NewLogger(t.Context(), util.WithLogLevel(slog.LevelDebug)) + defer logger.Release() // Return to pool when done + + // Log some messages + logger.Info("This will be logged in JSON format") + logger.Debug("Debug message in JSON format", "key", "value") + + // output: + // (JSON-formatted log output) +} + +func TestDirectHandlerUsage(t *testing.T) { + // Create a text handler + textHandler := slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelInfo, + }) + + // Create a new logger with the direct handler + logger := util.NewLogger(t.Context(), util.WithLogHandler(textHandler), util.WithLogStackTrace()) + defer logger.Release() // Return to pool when done + + // Log some messages + logger.Info("This will be logged in text format") + logger.Error("This will include a stack trace") + + // output: + // (Text-formatted log output) +} + +// TestCustomOutputWriter tests using a custom output writer. +func TestCustomOutputWriter(t *testing.T) { + // Create a buffer to capture logs + var buf io.Writer = os.Stderr + + // Create a new logger with custom output + logger := util.NewLogger(t.Context(), util.WithLogOutput(buf)) + defer logger.Release() // Return to pool when done + + // Log some messages + logger.Info("This message will be written to the custom writer") +} diff --git a/go.mod b/go.mod index fea5690..908ee0a 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,19 @@ module github.com/pitabwire/util -go 1.23 +go 1.24 require ( + github.com/lmittmann/tint v1.1.2 github.com/rs/xid v1.6.0 - github.com/sirupsen/logrus v1.9.3 + go.opentelemetry.io/contrib/bridges/otelslog v0.12.0 ) -require golang.org/x/sys v0.28.0 // indirect +require ( + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.37.0 // indirect + go.opentelemetry.io/otel/log v0.13.0 // indirect + go.opentelemetry.io/otel/metric v1.37.0 // indirect + go.opentelemetry.io/otel/trace v1.37.0 // indirect +) diff --git a/go.sum b/go.sum index c391324..5db0b91 100644 --- a/go.sum +++ b/go.sum @@ -1,18 +1,31 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/lmittmann/tint v1.1.2 h1:2CQzrL6rslrsyjqLDwD11bZ5OpLBPU+g3G/r5LSfS8w= +github.com/lmittmann/tint v1.1.2/go.mod h1:HIS3gSy7qNwGCj+5oRjAutErFBl4BzdQP6cJZ0NfMwE= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/bridges/otelslog v0.12.0 h1:lFM7SZo8Ce01RzRfnUFQZEYeWRf/MtOA3A5MobOqk2g= +go.opentelemetry.io/contrib/bridges/otelslog v0.12.0/go.mod h1:Dw05mhFtrKAYu72Tkb3YBYeQpRUJ4quDgo2DQw3No5A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/log v0.13.0 h1:yoxRoIZcohB6Xf0lNv9QIyCzQvrtGZklVbdCoyb7dls= +go.opentelemetry.io/otel/log v0.13.0/go.mod h1:INKfG4k1O9CL25BaM1qLe0zIedOpvlS5Z7XgSbmN83E= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/json.go b/json.go index f047dd6..39f8a3e 100644 --- a/json.go +++ b/json.go @@ -7,8 +7,6 @@ import ( "net/http" "reflect" "runtime/debug" - - log "github.com/sirupsen/logrus" ) // JSONResponse represents an HTTP response which contains a JSON body. @@ -23,7 +21,7 @@ type JSONResponse struct { // Is2xx returns true if the Code is between 200 and 299. func (r JSONResponse) Is2xx() bool { - return r.Code/100 == 2 + return r.Code/100 == Status2xx } // RedirectResponse returns a JSONResponse which 302s the client to the given location. @@ -31,7 +29,7 @@ func RedirectResponse(location string) JSONResponse { headers := make(map[string]any) headers["Location"] = location return JSONResponse{ - Code: 302, + Code: StatusFound, // 302 JSON: struct{}{}, Headers: headers, } @@ -49,10 +47,10 @@ func MessageResponse(code int, msg string) JSONResponse { // ErrorResponse returns an HTTP 500 JSONResponse with the stringified form of the given error. func ErrorResponse(err error) JSONResponse { - return MessageResponse(500, err.Error()) + return MessageResponse(StatusInternalServerError, err.Error()) } -// MatrixErrorResponse is a function that returns error responses in the standard Matrix Error format (errcode / error) +// MatrixErrorResponse is a function that returns error responses in the standard Matrix Error format (errcode / error). func MatrixErrorResponse(httpStatusCode int, errCode, message string) JSONResponse { return JSONResponse{ Code: httpStatusCode, @@ -69,17 +67,17 @@ type JSONRequestHandler interface { OnIncomingRequest(req *http.Request) JSONResponse } -// jsonRequestHandlerWrapper is a wrapper to allow in-line functions to conform to util.JSONRequestHandler +// jsonRequestHandlerWrapper is a wrapper to allow in-line functions to conform to util.JSONRequestHandler. type jsonRequestHandlerWrapper struct { function func(req *http.Request) JSONResponse } -// OnIncomingRequest implements util.JSONRequestHandler +// OnIncomingRequest implements util.JSONRequestHandler. func (r *jsonRequestHandlerWrapper) OnIncomingRequest(req *http.Request) JSONResponse { return r.function(req) } -// NewJSONRequestHandler converts the given OnIncomingRequest function into a JSONRequestHandler +// NewJSONRequestHandler converts the given OnIncomingRequest function into a JSONRequestHandler. func NewJSONRequestHandler(f func(req *http.Request) JSONResponse) JSONRequestHandler { return &jsonRequestHandlerWrapper{f} } @@ -91,13 +89,11 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { defer func() { if r := recover(); r != nil { - logger := GetLogger(req.Context()) - logger.WithFields(log.Fields{ - "panic": r, - }).Errorf( + logger := Log(req.Context()) + logger.WithField("panic", r).Error( "Request panicked!\n%s", debug.Stack(), ) - respond(w, req, MessageResponse(500, "Internal Server Error")) + respond(w, req, MessageResponse(StatusInternalServerError, "Internal Server Error")) } }() handler(w, req) @@ -108,18 +104,17 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc { // http.Requests will have a logger (with a request ID/method/path logged) attached to the Context. // This can be accessed via GetLogger(Context). func RequestWithLogging(req *http.Request) *http.Request { - reqID := RandomString(12) + reqID := RandomString(DefaultRequestIDLength) // Set a Logger and request ID on the context - ctx := ContextWithLogger(req.Context(), log.WithFields(log.Fields{ - "req.method": req.Method, - "req.path": req.URL.Path, - "req.id": reqID, - })) + ctx := ContextWithLogger(req.Context(), Log(req.Context()). + WithField("req.method", req.Method). + WithField("req.path", req.URL.Path). + WithField("req.id", reqID)) ctx = context.WithValue(ctx, ctxValueRequestID, reqID) req = req.WithContext(ctx) if req.Method != http.MethodOptions { - logger := GetLogger(req.Context()) + logger := Log(req.Context()) logger.Trace("Incoming request") } @@ -135,7 +130,7 @@ func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc { if req.Method == http.MethodOptions { SetCORSHeaders(w) - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) return } res := handler.OnIncomingRequest(req) @@ -149,18 +144,17 @@ func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc { } func respond(w http.ResponseWriter, req *http.Request, res JSONResponse) { - logger := GetLogger(req.Context()) + logger := Log(req.Context()) // Set custom headers if res.Headers != nil { for h, val := range res.Headers { - var headerValues []any // Check if the value is already a headerValues if reflect.TypeOf(val).Kind() == reflect.Slice { v := reflect.ValueOf(val) - for i := 0; i < v.Len(); i++ { + for i := range v.Len() { headerValues = append(headerValues, v.Index(i).Interface()) } } else { @@ -187,14 +181,14 @@ func respond(w http.ResponseWriter, req *http.Request, res JSONResponse) { if err != nil { logger.WithError(err).Error("Failed to marshal JSONResponse") // this should never fail to be marshalled so drop err to the floor - res = MessageResponse(500, "Internal Server Error") + res = MessageResponse(StatusInternalServerError, "Internal Server Error") resBytes, _ = json.Marshal(res.JSON) } // Set status code and write the body w.WriteHeader(res.Code) if req.Method != http.MethodOptions { - logger.WithField("code", res.Code).Tracef("Responding (%d bytes)", len(resBytes)) + logger.WithField("code", res.Code).WithField("bytes", len(resBytes)).Trace("Responding") } _, _ = w.Write(resBytes) } @@ -211,7 +205,7 @@ func WithCORSOptions(handler http.HandlerFunc) http.HandlerFunc { } } -// SetCORSHeaders sets unrestricted origin Access-Control headers on the response writer +// SetCORSHeaders sets unrestricted origin Access-Control headers on the response writer. func SetCORSHeaders(w http.ResponseWriter) { if w.Header().Get("Access-Control-Allow-Origin") == "" { w.Header().Set("Access-Control-Allow-Origin", "*") @@ -219,3 +213,10 @@ func SetCORSHeaders(w http.ResponseWriter) { w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, Authorization") } + +const ( + StatusFound = 302 + StatusInternalServerError = 500 + DefaultRequestIDLength = 12 + Status2xx = 2 +) diff --git a/json_test.go b/json_test.go index 94dcf90..ede05c1 100644 --- a/json_test.go +++ b/json_test.go @@ -1,20 +1,19 @@ -package util +package util_test import ( - "context" "errors" "net/http" "net/http/httptest" "testing" - log "github.com/sirupsen/logrus" + "github.com/pitabwire/util" ) type MockJSONRequestHandler struct { - handler func(req *http.Request) JSONResponse + handler func(req *http.Request) util.JSONResponse } -func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) JSONResponse { +func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) util.JSONResponse { return h.handler(req) } @@ -23,33 +22,44 @@ type MockResponse struct { } func TestMakeJSONAPI(t *testing.T) { - log.SetLevel(log.PanicLevel) // suppress logs in test output tests := []struct { - Return JSONResponse + Return util.JSONResponse ExpectCode int ExpectJSON string }{ // MessageResponse return values - {MessageResponse(500, "Everything is broken"), 500, `{"message":"Everything is broken"}`}, + { + util.MessageResponse(http.StatusInternalServerError, "Everything is broken"), + http.StatusInternalServerError, + `{"message":"Everything is broken"}`, + }, // interface return values - {JSONResponse{500, MockResponse{"yep"}, nil}, 500, `{"foo":"yep"}`}, + { + util.JSONResponse{http.StatusInternalServerError, MockResponse{"yep"}, nil}, + http.StatusInternalServerError, + `{"foo":"yep"}`, + }, // Error JSON return values which fail to be marshalled should fallback to text - {JSONResponse{500, struct { + {util.JSONResponse{http.StatusInternalServerError, struct { Foo interface{} `json:"foo"` - }{func(cannotBe, marshalled string) {}}, nil}, 500, `{"message":"Internal Server Error"}`}, + }{func(_, _ string) {}}, nil}, http.StatusInternalServerError, `{"message":"Internal Server Error"}`}, // With different status codes - {JSONResponse{201, MockResponse{"narp"}, nil}, 201, `{"foo":"narp"}`}, + {util.JSONResponse{http.StatusCreated, MockResponse{"narp"}, nil}, http.StatusCreated, `{"foo":"narp"}`}, // Top-level array success values - {JSONResponse{200, []MockResponse{{"yep"}, {"narp"}}, nil}, 200, `[{"foo":"yep"},{"foo":"narp"}]`}, + { + util.JSONResponse{http.StatusOK, []MockResponse{{"yep"}, {"narp"}}, nil}, + http.StatusOK, + `[{"foo":"yep"},{"foo":"narp"}]`, + }, } for _, tst := range tests { - mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + mock := MockJSONRequestHandler{func(_ *http.Request) util.JSONResponse { return tst.Return }} - mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) mockWriter := httptest.NewRecorder() - handlerFunc := MakeJSONAPI(&mock) + handlerFunc := util.MakeJSONAPI(&mock) handlerFunc(mockWriter, mockReq) if mockWriter.Code != tst.ExpectCode { t.Errorf("TestMakeJSONAPI wanted HTTP status %d, got %d", tst.ExpectCode, mockWriter.Code) @@ -62,19 +72,19 @@ func TestMakeJSONAPI(t *testing.T) { } func TestMakeJSONAPICustomHeaders(t *testing.T) { - mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + mock := MockJSONRequestHandler{func(_ *http.Request) util.JSONResponse { headers := make(map[string]any) headers["Custom"] = "Thing" headers["X-Custom"] = "Things" - return JSONResponse{ + return util.JSONResponse{ Code: 200, JSON: MockResponse{"yep"}, Headers: headers, } }} - mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) mockWriter := httptest.NewRecorder() - handlerFunc := MakeJSONAPI(&mock) + handlerFunc := util.MakeJSONAPI(&mock) handlerFunc(mockWriter, mockReq) if mockWriter.Code != 200 { t.Errorf("TestMakeJSONAPICustomHeaders wanted HTTP status 200, got %d", mockWriter.Code) @@ -90,13 +100,12 @@ func TestMakeJSONAPICustomHeaders(t *testing.T) { } func TestMakeJSONAPIRedirect(t *testing.T) { - log.SetLevel(log.PanicLevel) // suppress logs in test output - mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { - return RedirectResponse("https://matrix.org") + mock := MockJSONRequestHandler{func(_ *http.Request) util.JSONResponse { + return util.RedirectResponse("https://matrix.org") }} - mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) mockWriter := httptest.NewRecorder() - handlerFunc := MakeJSONAPI(&mock) + handlerFunc := util.MakeJSONAPI(&mock) handlerFunc(mockWriter, mockReq) if mockWriter.Code != 302 { t.Errorf("TestMakeJSONAPIRedirect wanted HTTP status 302, got %d", mockWriter.Code) @@ -108,14 +117,13 @@ func TestMakeJSONAPIRedirect(t *testing.T) { } func TestMakeJSONAPIError(t *testing.T) { - log.SetLevel(log.PanicLevel) // suppress logs in test output - mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + mock := MockJSONRequestHandler{func(_ *http.Request) util.JSONResponse { err := errors.New("oops") - return ErrorResponse(err) + return util.ErrorResponse(err) }} - mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) mockWriter := httptest.NewRecorder() - handlerFunc := MakeJSONAPI(&mock) + handlerFunc := util.MakeJSONAPI(&mock) handlerFunc(mockWriter, mockReq) if mockWriter.Code != 500 { t.Errorf("TestMakeJSONAPIError wanted HTTP status 500, got %d", mockWriter.Code) @@ -141,7 +149,7 @@ func TestIs2xx(t *testing.T) { {500, false}, } for _, test := range tests { - j := JSONResponse{ + j := util.JSONResponse{ Code: test.Code, } actual := j.Is2xx() @@ -152,31 +160,32 @@ func TestIs2xx(t *testing.T) { } func TestGetLogger(t *testing.T) { - log.SetLevel(log.PanicLevel) // suppress logs in test output - entry := log.WithField("test", "yep") - mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) - ctx := context.WithValue(mockReq.Context(), ctxValueLogger, entry) + entry := util.NewLogger(t.Context()).WithField("test", "yep") + mockReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) + ctx := util.ContextWithLogger(mockReq.Context(), entry) mockReq = mockReq.WithContext(ctx) - ctxLogger := GetLogger(mockReq.Context()) + ctxLogger := util.Log(mockReq.Context()) if ctxLogger != entry { t.Errorf("TestGetLogger wanted logger '%v', got '%v'", entry, ctxLogger) } - noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) - ctxLogger = GetLogger(noLoggerInReq.Context()) + noLoggerInReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) + ctxLogger = util.Log(noLoggerInReq.Context()) if ctxLogger == nil { t.Errorf("TestGetLogger wanted logger, got nil") } } func TestProtect(t *testing.T) { - log.SetLevel(log.PanicLevel) // suppress logs in test output mockWriter := httptest.NewRecorder() - mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) mockReq = mockReq.WithContext( - context.WithValue(mockReq.Context(), ctxValueLogger, log.WithField("test", "yep")), + util.ContextWithLogger( + mockReq.Context(), + util.NewLogger(t.Context()).WithField("test", "yep"), + ), ) - h := Protect(func(w http.ResponseWriter, req *http.Request) { + h := util.Protect(func(_ http.ResponseWriter, _ *http.Request) { panic("oh noes!") }) @@ -195,10 +204,9 @@ func TestProtect(t *testing.T) { } func TestProtectWithoutLogger(t *testing.T) { - log.SetLevel(log.PanicLevel) // suppress logs in test output mockWriter := httptest.NewRecorder() - mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) - h := Protect(func(w http.ResponseWriter, req *http.Request) { + mockReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) + h := util.Protect(func(_ http.ResponseWriter, _ *http.Request) { panic("oh noes!") }) @@ -217,12 +225,10 @@ func TestProtectWithoutLogger(t *testing.T) { } func TestWithCORSOptions(t *testing.T) { - log.SetLevel(log.PanicLevel) // suppress logs in test output mockWriter := httptest.NewRecorder() - mockReq, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) - h := WithCORSOptions(func(w http.ResponseWriter, req *http.Request) { - w.WriteHeader(200) - _, _ = w.Write([]byte("yep")) + mockReq, _ := http.NewRequest(http.MethodOptions, "http://example.com/foo", nil) + h := util.WithCORSOptions(func(_ http.ResponseWriter, _ *http.Request) { + _, _ = mockWriter.WriteString("yep") }) h(mockWriter, mockReq) if mockWriter.Code != 200 { @@ -243,18 +249,17 @@ func TestWithCORSOptions(t *testing.T) { } func TestGetRequestID(t *testing.T) { - log.SetLevel(log.PanicLevel) // suppress logs in test output reqID := "alphabetsoup" - mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) - ctx := context.WithValue(mockReq.Context(), ctxValueRequestID, reqID) + mockReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) + ctx := util.ContextWithRequestID(mockReq.Context(), reqID) mockReq = mockReq.WithContext(ctx) - ctxReqID := GetRequestID(mockReq.Context()) + ctxReqID := util.GetRequestID(mockReq.Context()) if reqID != ctxReqID { t.Errorf("TestGetRequestID wanted request ID '%s', got '%s'", reqID, ctxReqID) } - noReqIDInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) - ctxReqID = GetRequestID(noReqIDInReq.Context()) + noReqIDInReq, _ := http.NewRequest(http.MethodGet, "http://example.com/foo", nil) + ctxReqID = util.GetRequestID(noReqIDInReq.Context()) if ctxReqID != "" { t.Errorf("TestGetRequestID wanted empty request ID, got '%s'", ctxReqID) } diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..c0443fb --- /dev/null +++ b/logger.go @@ -0,0 +1,277 @@ +package util + +import ( + "context" + "fmt" + "io" + "log/slog" + "os" + "runtime" + "runtime/debug" + "sync" + + "github.com/lmittmann/tint" +) + +// contextKeyType is used as a type-safe key for context values. +type contextKeyType string + +// ctxValueLogger is the key to extract the LogEntry. +const ctxValueLogger contextKeyType = "logger" + +const ( + CallerDepth = 3 + FileLineAttr = 4 +) + +// ContextWithLogger pushes a LogEntry instance into the supplied context for easier propagation. +func ContextWithLogger(ctx context.Context, logger *LogEntry) context.Context { + return context.WithValue(ctx, ctxValueLogger, logger) +} + +// Log obtains a service instance being propagated through the context. +func Log(ctx context.Context) *LogEntry { + v := ctx.Value(ctxValueLogger) + if v != nil { + if logger, ok := v.(*LogEntry); ok { + return logger + } + } + + return NewLogger(ctx) +} + +// SLog obtains an slog interface from the log entry in the context. +func SLog(ctx context.Context) *slog.Logger { + return Log(ctx).SLog() +} + +// LogEntry handles logging functionality with immutable chained calls. +type LogEntry struct { + ctx context.Context + log *slog.Logger + stackTraces bool +} + +//nolint:gochecknoglobals // Pool is necessarily global +var logEntryPool = sync.Pool{ + New: func() interface{} { + return &LogEntry{} + }, +} + +// NewLogger creates a new instance of LogEntry with the provided context and options. +func NewLogger(ctx context.Context, opts ...Option) *LogEntry { + // Start with default options and apply provided options + options := defaultLogOptions() + for _, opt := range opts { + opt(options) + } + + // Determine output writer + var outputWriter io.Writer + + if options.output != nil { + outputWriter = options.output + } else { + if options.level >= slog.LevelError { + outputWriter = os.Stderr + } else { + outputWriter = os.Stdout + } + } + + // Create handler - use the specified handler or create one using the handler creator. + var handler slog.Handler + switch { + case options.handler != nil: + handler = options.handler + default: + // Fallback to default handler if no handler or creator specified + handler = defaultHandlerCreator(outputWriter, options) + } + + // Create logger + log := slog.New(handler) + slog.SetDefault(log) + + // Get a LogEntry from the pool + entry, ok := logEntryPool.Get().(*LogEntry) + if !ok { + // Fallback in case of type assertion failure + entry = &LogEntry{} + } + + entry.ctx = ctx + entry.log = log + entry.stackTraces = options.showStackTrace + + return entry +} + +// Release returns the LogEntry to the pool for reuse. +// Call this when you're done with a LogEntry and won't use it again. +func (e *LogEntry) Release() { + if e == nil { + return + } + + // Reset fields to avoid leaking data + e.ctx = nil + e.log = nil + e.stackTraces = false + + logEntryPool.Put(e) +} + +// clone creates a new LogEntry with the same properties as the original. +func (e *LogEntry) clone() *LogEntry { + if e == nil { + return NewLogger(context.Background()) + } + + // Get a new entry from the pool + clone, ok := logEntryPool.Get().(*LogEntry) + if !ok { + // Fallback in case of type assertion failure + clone = &LogEntry{} + } + + // Copy all fields + clone.ctx = e.ctx + clone.log = e.log + clone.stackTraces = e.stackTraces + + return clone +} + +// WithContext returns a new LogEntry with the given context. +func (e *LogEntry) WithContext(ctx context.Context) *LogEntry { + clone := e.clone() + clone.ctx = ctx + return clone +} + +// WithError returns a new LogEntry with the error added. +func (e *LogEntry) WithError(err error) *LogEntry { + return e.With(tint.Err(err)) +} + +// WithField returns a new LogEntry with the field added. +func (e *LogEntry) WithField(key string, value any) *LogEntry { + return e.With(key, value) +} + +// With returns a new LogEntry with the provided attributes added. +func (e *LogEntry) With(args ...any) *LogEntry { + // No args, return the same logger + if len(args) == 0 { + return e + } + + clone := e.clone() + clone.log = clone.log.With(args...) + return clone +} + +// _ctx returns the context or background if nil. +func (e *LogEntry) _ctx() context.Context { + if e.ctx == nil { + return context.Background() + } + return e.ctx +} + +// Log logs a message at the given level. +func (e *LogEntry) Log(ctx context.Context, level slog.Level, msg string, fields ...any) { + e.log.Log(ctx, level, msg, fields...) +} + +// Logf logs a formatted message at the given level. +func (e *LogEntry) Logf(ctx context.Context, level slog.Level, format string, args ...interface{}) { + if e.Enabled(ctx, level) { + e.log.Log(ctx, level, fmt.Sprintf(format, args...)) + } +} + +// Trace logs a message at debug level (alias for backward compatibility). +func (e *LogEntry) Trace(msg string, args ...any) { + e.Debug(msg, args...) +} + +// Debug logs a message at debug level. +func (e *LogEntry) Debug(msg string, args ...any) { + log := e.withFileLineNum() + log.DebugContext(e._ctx(), msg, args...) +} + +// Info logs a message at info level. +func (e *LogEntry) Info(msg string, args ...any) { + e.log.InfoContext(e._ctx(), msg, args...) +} + +// Printf logs a formatted message at info level. +func (e *LogEntry) Printf(format string, args ...any) { + e.Logf(e._ctx(), slog.LevelInfo, format, args...) +} + +// Warn logs a message at warn level. +func (e *LogEntry) Warn(msg string, args ...any) { + e.log.WarnContext(e._ctx(), msg, args...) +} + +// Error logs a message at error level. +func (e *LogEntry) Error(msg string, args ...any) { + log := e.withFileLineNum() + + if e.stackTraces { + log.ErrorContext(e._ctx(), fmt.Sprintf(" %s\n%s\n", msg, debug.Stack()), args...) + } + + log.ErrorContext(e._ctx(), msg, args...) +} + +// Fatal logs a message at error level and exits with code 1. +func (e *LogEntry) Fatal(msg string, args ...any) { + log := e.withFileLineNum() + + if e.stackTraces { + log.ErrorContext(e._ctx(), fmt.Sprintf(" %s\n%s\n", msg, debug.Stack()), args...) + } + e.log.ErrorContext(e._ctx(), msg, args...) + e.Exit(1) +} + +// Panic logs a message and panics. +func (e *LogEntry) Panic(msg string, _ ...any) { + panic(fmt.Sprintf(" %s\n%s\n", msg, debug.Stack())) +} + +// Exit terminates the application with the given code. +func (e *LogEntry) Exit(code int) { + os.Exit(code) +} + +// Enabled returns whether the logger will log at the given level. +func (e *LogEntry) Enabled(ctx context.Context, level slog.Level) bool { + return e.log.Enabled(ctx, level) +} + +// LevelEnabled is an alias for Enabled for backward compatibility. +func (e *LogEntry) LevelEnabled(ctx context.Context, level slog.Level) bool { + return e.Enabled(ctx, level) +} + +// SLog returns the underlying slog.Logger. +func (e *LogEntry) SLog() *slog.Logger { + return e.log +} + +// withFileLineNum adds file and line information to the log entry. +func (e *LogEntry) withFileLineNum() *slog.Logger { + _, file, line, ok := runtime.Caller(CallerDepth) + if ok { + return e.log.With(tint.Attr(FileLineAttr, slog.Any("file", fmt.Sprintf("%s:%d", file, line)))) + } + return e.log +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..fa4e0de --- /dev/null +++ b/logger_test.go @@ -0,0 +1,142 @@ +package util_test + +import ( + "errors" + "fmt" + "testing" + + "github.com/pitabwire/util" +) + +// TestLogs tests basic logging functionality. +func TestLogs(t *testing.T) { + ctx := t.Context() + logger := util.NewLogger(ctx) + logger.Info("test") + logger.Debug("debugging") + logger.Error("error occurred") + logger.Error("error occurred with field", "field", "field-value") + + err := errors.New("") + logger.WithError(err).Error("testing errors") + withLog := logger.WithField("g1", "group 1") + withLog2 := withLog.WithField("g2", "group 2") + + withLog.Info("testing group 1") + withLog2.Info("testing group 2") + + withLog3 := withLog2.WithField("g3", "group 3") + withLog2.WithError(err).Error("testing group 2 errors") + + withLog3.Info("testing group 3") + + // Release loggers back to the pool + defer withLog.Release() + defer withLog2.Release() + defer withLog3.Release() +} + +// TestStackTraceLogs tests logging with stack traces. +func TestStackTraceLogs(t *testing.T) { + ctx := t.Context() + logger := util.NewLogger(ctx, util.WithLogStackTrace()) + logger.Debug("testing debug logs") + logger.Info("testing logs") + + err := errors.New("") + logger.WithError(err).Error("testing errors") + defer logger.Release() +} + +// TestPanicLogs tests panic recovery in logging. +func TestPanicLogs(t *testing.T) { + ctx := t.Context() + logger := util.NewLogger(ctx) + + logger.Info("testing logs") + defer logger.Release() + + // Set up a deferred function that will recover from the panic + didPanic := false + defer func() { + if r := recover(); r != nil { + didPanic = true + // Optional: Check the panic message or value + // if !strings.Contains(fmt.Sprint(r), "expected panic message") { + // t.Errorf("unexpected panic message: %v", r) + // } + } + + if !didPanic { + t.Error("expected Panic() to panic, but it didn't") + } + }() + + // Call the function that should panic + logger.Panic("this should panic") + + // If we get here without panicking, the test will fail + t.Error("execution continued past panic point") +} + +// BenchmarkLoggerWithField benchmarks the logger WithField method to measure performance. +func BenchmarkLoggerWithField(b *testing.B) { + ctx := b.Context() + logger := util.NewLogger(ctx) + defer logger.Release() + + b.ResetTimer() + b.ReportAllocs() + for range b.N { + l := logger.WithField("key", "value") + l.Release() // Important to return to the pool + } +} + +// BenchmarkLoggerMultipleWithField benchmarks chaining multiple WithField calls. +func BenchmarkLoggerMultipleWithField(b *testing.B) { + ctx := b.Context() + logger := util.NewLogger(ctx) + defer logger.Release() + + b.ResetTimer() + b.ReportAllocs() + for range b.N { + l := logger.WithField("key1", "value1"). + WithField("key2", "value2"). + WithField("key3", "value3") + l.Release() // Important to return to the pool + } +} + +// BenchmarkLoggerWithoutPooling simulates the overhead without using pools. +func BenchmarkLoggerWithoutPooling(b *testing.B) { + ctx := b.Context() + logger := util.NewLogger(ctx) + defer logger.Release() + + b.ResetTimer() + b.ReportAllocs() + for range b.N { + // Intentionally creating and dropping references without explicit release + _ = logger.WithField("key1", "value1"). + WithField("key2", "value2"). + WithField("key3", "value3") + } +} + +// BenchmarkLogAllocation measures allocation in logging operations. +func BenchmarkLogAllocation(b *testing.B) { + ctx := b.Context() + logger := util.NewLogger(ctx) + defer logger.Release() + + b.ResetTimer() + b.ReportAllocs() + for i := range b.N { + // Typical logging pattern: context with some fields then log + l := logger.WithField("request_id", fmt.Sprintf("req-%d", i)) + l.Info("Processing request", "index", i) + l.Release() // Important to return to the pool + } +} diff --git a/options.go b/options.go new file mode 100644 index 0000000..c394e03 --- /dev/null +++ b/options.go @@ -0,0 +1,127 @@ +package util + +import ( + "io" + "log/slog" + "time" + + "github.com/lmittmann/tint" +) + +// logOptions contains configuration for the logging system. +// This is intentionally kept private to hide implementation details. +type logOptions struct { + // level defines the minimum log level that will be output + level slog.Level + + // addSource determines whether source code position should be added to log entries + addSource bool + + // timeFormat defines how timestamps are formatted in logs + timeFormat string + + // noColor disables colored output when set to true + noColor bool + + // showStackTrace enables automatic stack trace printing for Error and Fatal logs + showStackTrace bool + + // output specifies the destination for log output (defaults to os.Stdout or os.Stderr based on level) + output io.Writer + + // handler specifies a custom slog.Handler implementation to use + handler slog.Handler +} + +// Option is a function that configures logOptions. +type Option func(*logOptions) + +// defaultLogOptions returns a logOptions instance with sensible defaults. +func defaultLogOptions() *logOptions { + return &logOptions{ + level: slog.LevelInfo, + addSource: false, + timeFormat: time.DateTime, + noColor: false, + showStackTrace: false, + } +} + +// defaultHandlerCreator creates the default tint-based colored slog.Handler. +func defaultHandlerCreator(writer io.Writer, opts *logOptions) slog.Handler { + handlerOptions := &tint.Options{ + AddSource: opts.addSource, + Level: opts.level, + TimeFormat: opts.timeFormat, + NoColor: opts.noColor, + } + + return tint.NewHandler(writer, handlerOptions) +} + +// WithLogLevel sets the log level. +func WithLogLevel(level slog.Level) Option { + return func(o *logOptions) { + o.level = level + } +} + +// WithLogAddSource enables or disables source code position in log entries. +func WithLogAddSource(addSource bool) Option { + return func(o *logOptions) { + o.addSource = addSource + } +} + +// WithLogTimeFormat sets the time format for log timestamps. +func WithLogTimeFormat(format string) Option { + return func(o *logOptions) { + o.timeFormat = format + } +} + +// WithLogNoColor enables or disables colored output. +func WithLogNoColor(noColor bool) Option { + return func(o *logOptions) { + o.noColor = noColor + } +} + +// WithLogStackTrace enables automatic stack trace printing. +func WithLogStackTrace() Option { + return func(o *logOptions) { + o.showStackTrace = true + } +} + +// WithLogOutput sets the output writer for logs. +func WithLogOutput(output io.Writer) Option { + return func(o *logOptions) { + o.output = output + } +} + +// WithLogHandler sets a custom slog.Handler implementation. +func WithLogHandler(handler slog.Handler) Option { + return func(o *logOptions) { + o.handler = handler + } +} + +// ParseLevel converts a string to a log.level. +// It is case-insensitive. +// Returns an error if the string does not match a known level. +func ParseLevel(levelStr string) (slog.Level, error) { + switch levelStr { + case "debug", "DEBUG", "Debug", "trace", "TRACE", "Trace": + return slog.LevelDebug, nil + case "info", "INFO", "Info": + return slog.LevelInfo, nil + case "warn", "WARN", "Warn", "warning", "WARNING", "Warning": + return slog.LevelWarn, nil + case "error", "ERROR", "Error", "fatal", "FATAL", "Fatal", "panic", "PANIC", "Panic": + return slog.LevelError, nil + default: + return slog.LevelInfo, nil + } +} diff --git a/random.go b/random.go index e32d6e1..35006c5 100644 --- a/random.go +++ b/random.go @@ -1,20 +1,28 @@ package util import ( + "crypto/rand" + "math/big" "time" - "math/rand" - "github.com/rs/xid" ) const alphanumerics = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -// RandomString generates a pseudo-random string of length n. +// RandomString generates a cryptographically secure random string of length n. func RandomString(n int) string { + if n <= 0 { + return "" + } + b := make([]byte, n) for i := range b { - b[i] = alphanumerics[rand.Int63()%int64(len(alphanumerics))] + idx, err := rand.Int(rand.Reader, big.NewInt(int64(len(alphanumerics)))) + if err != nil { + panic(err) + } + b[i] = alphanumerics[idx.Int64()] } return string(b) } diff --git a/unique.go b/unique.go index 401c556..aad61d9 100644 --- a/unique.go +++ b/unique.go @@ -1,7 +1,9 @@ +// Package util provides various utility functions for common tasks +// revive:disable:var-naming package util import ( - "fmt" + "errors" "sort" ) @@ -12,7 +14,7 @@ import ( // O(n). func Unique(data sort.Interface) int { if !sort.IsSorted(data) { - panic(fmt.Errorf("util: the input to Unique() must be sorted")) + panic(errors.New("util: the input to Unique() must be sorted")) } if data.Len() == 0 { @@ -35,23 +37,18 @@ func Unique(data sort.Interface) int { j++ } } - // Output the last element. + // output the last element. data.Swap(length-1, j) return j + 1 } -// SortAndUnique sorts a list and removes duplicate entries in place. -// Takes the same interface as sort.Sort -// Returns the length of the data without duplicates -// Uses the last occurrence of a duplicate. -// O(nlog(n)) +// SortAndUnique sorts the data and removes duplicates. O(nlog(n)). func SortAndUnique(data sort.Interface) int { sort.Sort(data) return Unique(data) } -// UniqueStrings turns a list of strings into a sorted list of unique strings. -// O(nlog(n)) +// UniqueStrings returns a sorted slice of unique strings. O(nlog(n)). func UniqueStrings(strings []string) []string { return strings[:SortAndUnique(sort.StringSlice(strings))] } diff --git a/unique_test.go b/unique_test.go index 747ac1b..4b40e58 100644 --- a/unique_test.go +++ b/unique_test.go @@ -1,8 +1,9 @@ -package util +package util_test import ( - "sort" "testing" + + "github.com/pitabwire/util" ) type sortBytes []byte @@ -24,7 +25,7 @@ func TestUnique(t *testing.T) { for _, test := range testCases { input := []byte(test.Input) want := test.Want - got := string(input[:Unique(sortBytes(input))]) + got := string(input[:util.Unique(sortBytes(input))]) if got != want { t.Fatal("Wanted ", want, " got ", got) } @@ -48,7 +49,7 @@ func TestUniquePicksLastDuplicate(t *testing.T) { "avacado", "cucumber", } - got := input[:Unique(sortByFirstByte(input))] + got := input[:util.Unique(sortByFirstByte(input))] if len(want) != len(got) { t.Errorf("Wanted %#v got %#v", want, got) @@ -63,34 +64,29 @@ func TestUniquePicksLastDuplicate(t *testing.T) { func TestUniquePanicsIfNotSorted(t *testing.T) { defer func() { if r := recover(); r == nil { - t.Error("Expected Unique() to panic on unsorted input but it didn't") + t.Errorf("Unique did not panic on unsorted input") } }() - Unique(sort.StringSlice{"out", "of", "order"}) + unsorted := sortBytes{'b', 'a'} + _ = util.Unique(unsorted) } func TestUniqueStrings(t *testing.T) { - input := []string{ - "badger", "badger", "badger", "badger", - "badger", "badger", "badger", "badger", - "badger", "badger", "badger", "badger", - "mushroom", "mushroom", - "badger", "badger", "badger", "badger", - "badger", "badger", "badger", "badger", - "badger", "badger", "badger", "badger", - "snake", "snake", - } - - want := []string{"badger", "mushroom", "snake"} - - got := UniqueStrings(input) - - if len(want) != len(got) { - t.Errorf("Wanted %#v got %#v", want, got) + testCases := []struct { + Input []string + Want []string + }{ + {[]string{"b", "a", "a", "c"}, []string{"a", "b", "c"}}, } - for i := range want { - if want[i] != got[i] { - t.Errorf("Wanted %#v got %#v", want, got) + for _, test := range testCases { + got := util.UniqueStrings(test.Input) + if len(got) != len(test.Want) { + t.Errorf("Wanted %v got %v", test.Want, got) + } + for i := range got { + if got[i] != test.Want[i] { + t.Errorf("Wanted %v got %v", test.Want, got) + } } } }