From ec9232814e7be4242138ae988b2577d17c3f4f50 Mon Sep 17 00:00:00 2001 From: toimtoimtoim Date: Fri, 26 Dec 2025 16:21:08 +0200 Subject: [PATCH 1/7] V5 changes --- .gitattributes | 5 - .github/ISSUE_TEMPLATE.md | 4 +- .github/workflows/checks.yml | 1 - .github/workflows/echo.yml | 2 +- API_CHANGES_V5.md | 1158 +++++++++++++ LICENSE | 2 +- Makefile | 13 +- README.md | 13 +- bind.go | 91 +- bind_test.go | 281 ++-- binder.go | 70 +- binder_external_test.go | 13 +- binder_generic.go | 140 +- binder_generic_test.go | 91 +- binder_test.go | 400 +++-- context.go | 668 ++++---- context_fs.go | 52 - context_fs_test.go | 135 -- context_generic.go | 11 +- context_generic_test.go | 18 +- context_test.go | 1047 +++++++----- echo.go | 1173 ++++++-------- echo_fs.go | 162 -- echo_fs_test.go | 271 ---- echo_test.go | 1699 ++++++------------- echotest/context.go | 183 +++ echotest/context_external_test.go | 27 + echotest/context_test.go | 157 ++ echotest/reader.go | 46 + echotest/reader_external_test.go | 25 + echotest/reader_test.go | 21 + echotest/testdata/test.json | 3 + go.mod | 11 +- go.sum | 15 - group.go | 157 +- group_fs.go | 33 - group_fs_test.go | 103 -- group_test.go | 647 +++++++- httperror.go | 107 ++ httperror_external_test.go | 52 + httperror_test.go | 67 + ip.go | 12 +- ip_test.go | 52 +- json.go | 17 +- json_test.go | 18 +- log.go | 41 - middleware/DEVELOPMENT.md | 11 + middleware/basic_auth.go | 138 +- middleware/basic_auth_test.go | 200 ++- middleware/body_dump.go | 154 +- middleware/body_dump_test.go | 447 ++++- middleware/body_limit.go | 71 +- middleware/body_limit_test.go | 153 +- middleware/compress.go | 81 +- middleware/compress_test.go | 331 ++-- middleware/context_timeout.go | 55 +- middleware/context_timeout_test.go | 25 +- middleware/cors.go | 235 ++- middleware/cors_test.go | 375 ++--- middleware/csrf.go | 78 +- middleware/csrf_test.go | 60 +- middleware/decompress.go | 89 +- middleware/decompress_test.go | 352 +++- middleware/extractor.go | 194 ++- middleware/extractor_test.go | 134 +- middleware/key_auth.go | 164 +- middleware/key_auth_test.go | 244 ++- middleware/logger.go | 420 ----- middleware/logger_strings.go | 242 --- middleware/logger_strings_test.go | 288 ---- middleware/logger_test.go | 540 ------ middleware/method_override.go | 22 +- middleware/method_override_test.go | 68 +- middleware/middleware.go | 17 +- middleware/middleware_test.go | 5 - middleware/proxy.go | 112 +- middleware/proxy_test.go | 247 +-- middleware/rate_limiter.go | 84 +- middleware/rate_limiter_test.go | 174 +- middleware/recover.go | 89 +- middleware/recover_test.go | 208 +-- middleware/redirect.go | 145 +- middleware/redirect_test.go | 24 +- middleware/request_id.go | 44 +- middleware/request_id_test.go | 109 +- middleware/request_logger.go | 189 +-- middleware/request_logger_test.go | 148 +- middleware/rewrite.go | 36 +- middleware/rewrite_test.go | 77 +- middleware/secure.go | 32 +- middleware/secure_test.go | 86 +- middleware/slash.go | 72 +- middleware/slash_test.go | 14 +- middleware/static.go | 164 +- middleware/static_test.go | 235 ++- middleware/timeout.go | 256 --- middleware/timeout_test.go | 492 ------ middleware/util.go | 63 +- middleware/util_test.go | 64 +- renderer.go | 7 +- renderer_test.go | 6 +- response.go | 78 +- response_test.go | 45 +- route.go | 192 +++ route_test.go | 517 ++++++ router.go | 938 +++++++---- router_concurrent.go | 47 + router_concurrent_test.go | 378 +++++ router_test.go | 2432 ++++++++++++++++++---------- server.go | 175 ++ server_test.go | 699 ++++++++ version.go | 9 + vhost.go | 20 + vhost_test.go | 117 ++ 114 files changed, 13299 insertions(+), 10032 deletions(-) create mode 100644 API_CHANGES_V5.md delete mode 100644 context_fs.go delete mode 100644 context_fs_test.go delete mode 100644 echo_fs.go delete mode 100644 echo_fs_test.go create mode 100644 echotest/context.go create mode 100644 echotest/context_external_test.go create mode 100644 echotest/context_test.go create mode 100644 echotest/reader.go create mode 100644 echotest/reader_external_test.go create mode 100644 echotest/reader_test.go create mode 100644 echotest/testdata/test.json delete mode 100644 group_fs.go delete mode 100644 group_fs_test.go create mode 100644 httperror.go create mode 100644 httperror_external_test.go create mode 100644 httperror_test.go delete mode 100644 log.go create mode 100644 middleware/DEVELOPMENT.md delete mode 100644 middleware/logger.go delete mode 100644 middleware/logger_strings.go delete mode 100644 middleware/logger_strings_test.go delete mode 100644 middleware/logger_test.go delete mode 100644 middleware/timeout.go delete mode 100644 middleware/timeout_test.go create mode 100644 route.go create mode 100644 route_test.go create mode 100644 router_concurrent.go create mode 100644 router_concurrent_test.go create mode 100644 server.go create mode 100644 server_test.go create mode 100644 version.go create mode 100644 vhost.go create mode 100644 vhost_test.go diff --git a/.gitattributes b/.gitattributes index 49b63e526..28981b84a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -13,8 +13,3 @@ *.js text eol=lf *.json text eol=lf LICENSE text eol=lf - -# Exclude `website` and `cookbook` from GitHub's language statistics -# https://github.com/github/linguist#using-gitattributes -cookbook/* linguist-documentation -website/* linguist-documentation diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md index 82220c0a1..1a76adca7 100644 --- a/.github/ISSUE_TEMPLATE.md +++ b/.github/ISSUE_TEMPLATE.md @@ -6,7 +6,7 @@ package main import ( - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "net/http" "net/http/httptest" "testing" @@ -15,7 +15,7 @@ import ( func TestExample(t *testing.T) { e := echo.New() - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "Hello, World!") }) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 436254a63..f8f20dccd 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -45,4 +45,3 @@ jobs: go install golang.org/x/vuln/cmd/govulncheck@latest govulncheck ./... - diff --git a/.github/workflows/echo.yml b/.github/workflows/echo.yml index c7780fd21..136986a2e 100644 --- a/.github/workflows/echo.yml +++ b/.github/workflows/echo.yml @@ -25,7 +25,7 @@ jobs: # Echo tests with last four major releases (unless there are pressing vulnerabilities) # As we depend on `golang.org/x/` libraries which only support last 2 Go releases we could have situations when # we derive from last four major releases promise. - go: ["1.22", "1.23", "1.24", "1.25"] + go: ["1.25"] name: ${{ matrix.os }} @ Go ${{ matrix.go }} runs-on: ${{ matrix.os }} steps: diff --git a/API_CHANGES_V5.md b/API_CHANGES_V5.md new file mode 100644 index 000000000..6c36a7a5a --- /dev/null +++ b/API_CHANGES_V5.md @@ -0,0 +1,1158 @@ +# Echo v5 Public API Changes + +**Comparison between `master` (v4.15.0) and `v5` (v5.0.0-alpha) branches** + +Generated: 2026-01-01 + +--- + +## Executive Summary + +Echo v5 represents a **major breaking release** with significant architectural changes focused on: +- **Updated generic helpers** to take `*Context` and rename form helpers to `FormValue*` +- **Simplified API surface** by moving Context from interface to concrete struct +- **Modern Go patterns** including slog.Logger integration +- **Enhanced routing** with explicit RouteInfo and Routes types +- **Better error handling** with simplified HTTPError +- **New test helpers** via the `echotest` package + +### Change Statistics + +- **Major Breaking Changes**: 15+ +- **New Functions Added**: 30+ +- **Type Signature Changes**: 20+ +- **Removed APIs**: 10+ +- **New Packages Added**: 1 (`echotest`) +- **Version Change**: `4.15.0` → `5.0.0-alpha` + +--- + +## Critical Breaking Changes + +### 1. **Context: Interface → Concrete Struct** + +**v4 (master):** +```go +type Context interface { + Request() *http.Request + // ... many methods +} + +// Handler signature +func handler(c echo.Context) error +``` + +**v5:** +```go +type Context struct { + // Has unexported fields +} + +// Handler signature - NOW USES POINTER! +func handler(c *echo.Context) error +``` + +**Impact:** 🔴 **CRITICAL BREAKING CHANGE** +- ALL handlers must change from `echo.Context` to `*echo.Context` +- Context is now a concrete struct, not an interface +- This affects every single handler function in user code + +**Migration:** +```go +// Before (v4) +func MyHandler(c echo.Context) error { + return c.JSON(200, map[string]string{"hello": "world"}) +} + +// After (v5) +func MyHandler(c *echo.Context) error { + return c.JSON(200, map[string]string{"hello": "world"}) +} +``` + +--- + +### 2. **Logger: Custom Interface → slog.Logger** + +**v4:** +```go +type Echo struct { + Logger Logger // Custom interface with Print, Debug, Info, etc. +} + +type Logger interface { + Output() io.Writer + SetOutput(w io.Writer) + Prefix() string + // ... many custom methods +} + +// Context returns Logger interface +func (c Context) Logger() Logger +``` + +**v5:** +```go +type Echo struct { + Logger *slog.Logger // Standard library structured logger +} + +// Context returns slog.Logger +func (c *Context) Logger() *slog.Logger +func (c *Context) SetLogger(logger *slog.Logger) +``` + +**Impact:** 🔴 **BREAKING CHANGE** +- Must use Go's standard `log/slog` package +- Logger interface completely removed +- All logging code needs updating + +--- + +### 3. **Router: From Router to DefaultRouter** + +**v4:** +```go +type Router struct { ... } + +func NewRouter(e *Echo) *Router +func (e *Echo) Router() *Router +``` + +**v5:** +```go +type DefaultRouter struct { ... } + +func NewRouter(config RouterConfig) *DefaultRouter +func (e *Echo) Router() Router // Returns interface +``` + +**Changes:** +- New `Router` interface introduced +- `DefaultRouter` is the concrete implementation +- `NewRouter()` now takes `RouterConfig` instead of `*Echo` +- Added `NewConcurrentRouter(r Router) Router` for thread-safe routing + +--- + +### 4. **Route Return Types Changed** + +**v4:** +```go +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route +func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) []*Route +func (e *Echo) Routes() []*Route +``` + +**v5:** +```go +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo +func (e *Echo) Any(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo +func (e *Echo) Match(...) Routes // Returns Routes type +func (e *Echo) Router() Router // Returns interface +``` + +**New Types:** +```go +type RouteInfo struct { + Name string + Method string + Path string + Parameters []string +} + +type Routes []RouteInfo // Collection with helper methods +``` + +**Impact:** 🔴 **BREAKING CHANGE** +- Route registration methods return `RouteInfo` instead of `*Route` +- New `Routes` collection type with filtering methods +- `Route` struct still exists but used differently + +--- + +### 5. **Response Type Changed** + +**v4:** +```go +func (c Context) Response() *Response +type Response struct { + Writer http.ResponseWriter + Status int + Size int64 + Committed bool +} +func NewResponse(w http.ResponseWriter, e *Echo) *Response +``` + +**v5:** +```go +func (c *Context) Response() http.ResponseWriter +type Response struct { + http.ResponseWriter // Embedded + Status int + Size int64 + Committed bool +} +func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response +func UnwrapResponse(rw http.ResponseWriter) (*Response, error) +``` + +**Changes:** +- Context.Response() returns `http.ResponseWriter` instead of `*Response` +- Response now embeds `http.ResponseWriter` +- NewResponse takes `*slog.Logger` instead of `*Echo` +- New `UnwrapResponse()` helper function + +--- + +### 6. **HTTPError Simplified** + +**v4:** +```go +type HTTPError struct { + Internal error + Message interface{} // Can be any type + Code int +} + +func NewHTTPError(code int, message ...interface{}) *HTTPError +``` + +**v5:** +```go +type HTTPError struct { + Code int + Message string // Now string only + // Has unexported fields (Internal moved) +} + +func NewHTTPError(code int, message string) *HTTPError +func (he HTTPError) Wrap(err error) error // New method +func (he *HTTPError) StatusCode() int // Implements HTTPStatusCoder +``` + +**Changes:** +- `Message` field changed from `interface{}` to `string` +- `NewHTTPError()` now takes `string` instead of `...interface{}` +- Added `HTTPStatusCoder` interface and `StatusCode()` method +- Added `Wrap(err error)` method for error wrapping + +--- + +### 7. **HTTPErrorHandler Signature Changed** + +**v4:** +```go +type HTTPErrorHandler func(err error, c Context) + +func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) +``` + +**v5:** +```go +type HTTPErrorHandler func(c *Context, err error) // Parameters swapped! + +func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler // Now a factory +``` + +**Impact:** 🔴 **BREAKING CHANGE** +- Parameter order reversed: `(c *Context, err error)` instead of `(err error, c Context)` +- DefaultHTTPErrorHandler is now a factory function that returns HTTPErrorHandler +- Takes `exposeError` bool to control error message exposure + +--- + +## Notable API Changes in v5 + +### 1. **Generic Parameter Extraction Functions (Updated Signatures)** + +These helpers keep the same generic API but now accept `*Context`, and the +form helpers are renamed from `FormParam*` to `FormValue*`: + +```go +// Query Parameters +func QueryParam[T any](c *Context, key string, opts ...any) (T, error) +func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) +func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) +func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) + +// Path Parameters +func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) +func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) + +// Form Values +func FormValue[T any](c *Context, key string, opts ...any) (T, error) +func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) +func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) +func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) + +// Generic Parsing +func ParseValue[T any](value string, opts ...any) (T, error) +func ParseValueOr[T any](value string, defaultValue T, opts ...any) (T, error) +func ParseValues[T any](values []string, opts ...any) ([]T, error) +func ParseValuesOr[T any](values []string, defaultValue []T, opts ...any) ([]T, error) +``` + +`FormParam*` was renamed to `FormValue*`; the rest keep names but now take `*Context`. + +**Supported Types:** +- bool, string +- int, int8, int16, int32, int64 +- uint, uint8, uint16, uint32, uint64 +- float32, float64 +- time.Time, time.Duration +- BindUnmarshaler, encoding.TextUnmarshaler, json.Unmarshaler + +**Example Usage:** +```go +// v5 - Type-safe parameter binding +id, err := echo.PathParam[int](c, "id") +page, err := echo.QueryParamOr[int](c, "page", 1) +tags, err := echo.QueryParams[string](c, "tags") +``` + +--- + +### 2. **Context Store Helpers Now Use `*Context`** + +```go +// Type-safe context value retrieval +func ContextGet[T any](c *Context, key string) (T, error) +func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) + +// Error types +var ErrNonExistentKey = errors.New("non existent key") +var ErrInvalidKeyType = errors.New("invalid key type") +``` + +These helpers existed in v4 with `Context` and now accept `*Context`. + +**Example:** +```go +// v5 +user, err := echo.ContextGet[*User](c, "user") +count, err := echo.ContextGetOr[int](c, "count", 0) +``` + +--- + +### 3. **PathValues Type** + +New structured path parameter handling: + +```go +type PathValue struct { + Name string + Value string +} + +type PathValues []PathValue + +func (p PathValues) Get(name string) (string, bool) +func (p PathValues) GetOr(name string, defaultValue string) string + +// Context methods +func (c *Context) PathValues() PathValues +func (c *Context) SetPathValues(pathValues PathValues) +``` + +--- + +### 4. **Time Parsing Options** + +```go +type TimeLayout string + +const ( + TimeLayoutUnixTime = TimeLayout("UnixTime") + TimeLayoutUnixTimeMilli = TimeLayout("UnixTimeMilli") + TimeLayoutUnixTimeNano = TimeLayout("UnixTimeNano") +) + +type TimeOpts struct { + Layout TimeLayout + ParseInLocation *time.Location + ToInLocation *time.Location +} +``` + +--- + +### 5. **StartConfig for Server Configuration** + +```go +type StartConfig struct { + Address string + HideBanner bool + HidePort bool + CertFilesystem fs.FS + TLSConfig *tls.Config + ListenerNetwork string + ListenerAddrFunc func(addr net.Addr) + GracefulTimeout time.Duration + OnShutdownError func(err error) + BeforeServeFunc func(s *http.Server) error +} + +func (sc StartConfig) Start(ctx context.Context, h http.Handler) error +func (sc StartConfig) StartTLS(ctx context.Context, h http.Handler, certFile, keyFile any) error +``` + +**Example:** +```go +// v5 - More control over server startup +ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) +defer cancel() + +sc := echo.StartConfig{ + Address: ":8080", + GracefulTimeout: 10 * time.Second, +} +if err := sc.Start(ctx, e); err != nil { + log.Fatal(err) +} +``` + +--- + +### 6. **Echo Config and Constructors** + +```go +type Config struct { + // Configuration for Echo (logger, binder, renderer, etc.) +} + +func NewWithConfig(config Config) *Echo +``` + +This adds a configuration struct for creating an `Echo` instance without +mutating fields after `New()`. + +--- + +### 7. **Enhanced Routing Features** + +```go +// New route methods +func (e *Echo) AddRoute(route Route) (RouteInfo, error) +func (e *Echo) Middlewares() []MiddlewareFunc +func (e *Echo) PreMiddlewares() []MiddlewareFunc +type AddRouteError struct{ ... } + +// Routes collection with filters +type Routes []RouteInfo + +func (r Routes) Clone() Routes +func (r Routes) FilterByMethod(method string) (Routes, error) +func (r Routes) FilterByName(name string) (Routes, error) +func (r Routes) FilterByPath(path string) (Routes, error) +func (r Routes) FindByMethodPath(method string, path string) (RouteInfo, error) +func (r Routes) Reverse(routeName string, pathValues ...any) (string, error) + +// RouteInfo operations +func (r RouteInfo) Clone() RouteInfo +func (r RouteInfo) Reverse(pathValues ...any) string +``` + +--- + +### 8. **Middleware Configuration Interface** + +```go +type MiddlewareConfigurator interface { + ToMiddleware() (MiddlewareFunc, error) +} +``` + +Allows middleware configs to be converted to middleware without panicking. + +--- + +### 9. **New Context Methods** + +```go +// v5 additions +func (c *Context) FileFS(file string, filesystem fs.FS) error +func (c *Context) FormValueOr(name, defaultValue string) string +func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) +func (c *Context) ParamOr(name, defaultValue string) string +func (c *Context) QueryParamOr(name, defaultValue string) string +func (c *Context) RouteInfo() RouteInfo +``` + +--- + +### 10. **Virtual Host Support** + +```go +func NewVirtualHostHandler(vhosts map[string]*Echo) *Echo +``` + +Creates an Echo instance that routes requests to different Echo instances based on host. + +--- + +### 11. **New Binder Functions** + +```go +func BindBody(c *Context, target any) error +func BindHeaders(c *Context, target any) error +func BindPathValues(c *Context, target any) error // Renamed from BindPathParams +func BindQueryParams(c *Context, target any) error +``` + +Top-level binding functions that work with `*Context`. + +--- + +### 12. **New echotest Package** + +```go +package echotest // import "github.com/labstack/echo/v5/echotest" + +func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte +func TrimNewlineEnd(bytes []byte) []byte +type ContextConfig struct{ ... } +type MultipartForm struct{ ... } +type MultipartFormFile struct{ ... } +``` + +Helpers for loading fixtures and constructing test contexts. + +--- + +## Removed APIs in v5 + +### Constants + +```go +// v4 - Removed in v5 +const CONNECT = http.MethodConnect // Use http.MethodConnect directly +``` + +**Reason:** Deprecated in v4, use stdlib `http.Method*` constants instead. + +--- + +### Constants Added in v5 + +```go +// v5 additions +const ( + NotFoundRouteName = "echo_route_not_found_name" +) +``` + +--- + +### Error Variable Changes + +**v4 exports:** +```go +ErrBadRequest +ErrInvalidKeyType +ErrNonExistentKey +``` + +**v5 exports:** +```go +ErrBadRequest // Now backed by unexported httpError type +ErrValidatorNotRegistered // New +ErrInvalidKeyType +ErrNonExistentKey +``` + +**Reason:** v5 centralizes on `NewHTTPError(code, message)` rather than a broad set +of predefined HTTP error variables. + +--- + +### Functions Removed + +```go +// v4 - Removed in v5 +func GetPath(r *http.Request) string // Use r.URL.Path or r.URL.RawPath +``` + +### Variables Removed + +```go +// v4 - Removed in v5 +var MethodNotAllowedHandler = func(c Context) error { ... } +var NotFoundHandler = func(c Context) error { ... } +``` + +### Functions Renamed + +```go +// v4 +func FormParam[T any](c Context, key string, opts ...any) (T, error) +func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) +func FormParams[T any](c Context, key string, opts ...any) ([]T, error) +func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) + +// v5 +func FormValue[T any](c *Context, key string, opts ...any) (T, error) +func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) +func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) +func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) +``` + +--- + +### Type Methods Removed/Changed + +**Echo struct changes:** +```go +// v4 fields removed in v5 +type Echo struct { + StdLogger *stdLog.Logger // Removed + Server *http.Server // Removed (use StartConfig) + TLSServer *http.Server // Removed (use StartConfig) + Listener net.Listener // Removed (use StartConfig) + TLSListener net.Listener // Removed (use StartConfig) + AutoTLSManager autocert.Manager // Removed + ListenerNetwork string // Removed + OnAddRouteHandler func(...) // Changed to OnAddRoute + DisableHTTP2 bool // Removed (use StartConfig) + Debug bool // Removed + HideBanner bool // Removed (use StartConfig) + HidePort bool // Removed (use StartConfig) +} + +// v5 Echo struct (simplified) +type Echo struct { + Binder Binder + Filesystem fs.FS // NEW + Renderer Renderer + Validator Validator + JSONSerializer JSONSerializer + IPExtractor IPExtractor + OnAddRoute func(route Route) error // Simplified + HTTPErrorHandler HTTPErrorHandler + Logger *slog.Logger // Changed from Logger interface +} +``` + +--- + +**Context interface → struct:** +```go +// v4 +type Context interface { + // Had: SetResponse(*Response) + Response() *Response + + // Had: ParamNames(), SetParamNames(), ParamValues(), SetParamValues() + // These are removed in v5 (use PathValues() instead) +} + +// v5 +type Context struct { + // Concrete struct with unexported fields +} + +func (c *Context) Response() http.ResponseWriter // Changed return type +func (c *Context) PathValues() PathValues // Replaces ParamNames/Values +``` + +--- + +**Types removed:** +```go +// v4 +type Map map[string]interface{} +``` + +**Group changes:** +```go +// v4 +func (g *Group) File(path, file string) // No return value +func (g *Group) Static(pathPrefix, fsRoot string) // No return value +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) // No return value + +// v5 +func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo +func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo +``` + +Now return `RouteInfo` and accept middleware. + +--- + +### Value Binder Factory Name Changes + +```go +// v4 +func PathParamsBinder(c Context) *ValueBinder +func QueryParamsBinder(c Context) *ValueBinder +func FormFieldBinder(c Context) *ValueBinder + +// v5 +func PathValuesBinder(c *Context) *ValueBinder // Renamed +func QueryParamsBinder(c *Context) *ValueBinder +func FormFieldBinder(c *Context) *ValueBinder +``` + +--- + +## Type Signature Changes + +### Binder Interface + +```go +// v4 +type Binder interface { + Bind(i interface{}, c Context) error +} + +// v5 +type Binder interface { + Bind(c *Context, target any) error // Parameters swapped! +} +``` + +--- + +### DefaultBinder Methods + +```go +// v4 +func (b *DefaultBinder) Bind(i interface{}, c Context) error +func (b *DefaultBinder) BindBody(c Context, i interface{}) error +func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error + +// v5 +func (b *DefaultBinder) Bind(c *Context, target any) error // Swapped params +// BindBody, BindPathParams, etc. are now top-level functions +``` + +--- + +### JSONSerializer Interface + +```go +// v4 +type JSONSerializer interface { + Serialize(c Context, i interface{}, indent string) error + Deserialize(c Context, i interface{}) error +} + +// v5 +type JSONSerializer interface { + Serialize(c *Context, target any, indent string) error + Deserialize(c *Context, target any) error +} +``` + +--- + +### Renderer Interface + +```go +// v4 +type Renderer interface { + Render(io.Writer, string, interface{}, Context) error +} + +// v5 +type Renderer interface { + Render(c *Context, w io.Writer, templateName string, data any) error +} +``` + +Parameters reordered with Context first. + +--- + +### NewBindingError + +```go +// v4 +func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error + +// v5 +func NewBindingError(sourceParam string, values []string, message string, err error) error +``` + +Message parameter changed from `interface{}` to `string`. + +--- + +### HandlerName + +```go +// v5 only +func HandlerName(h HandlerFunc) string +``` + +New utility function to get handler function name. + +--- + +## Middleware Package Changes + +### Signature and Type Updates + +```go +// CORS now accepts optional allow-origins +func CORS(allowOrigins ...string) echo.MiddlewareFunc + +// BodyLimit now accepts bytes +func BodyLimit(limitBytes int64) echo.MiddlewareFunc + +// DefaultSkipper now uses *echo.Context +func DefaultSkipper(c *echo.Context) bool + +// Trailing slash configs renamed/split +func AddTrailingSlashWithConfig(config AddTrailingSlashConfig) echo.MiddlewareFunc +func RemoveTrailingSlashWithConfig(config RemoveTrailingSlashConfig) echo.MiddlewareFunc +type AddTrailingSlashConfig struct{ ... } +type RemoveTrailingSlashConfig struct{ ... } + +// Auth + extractor signatures now use *echo.Context and add ExtractorSource +type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error) +type Extractor func(c *echo.Context) (string, error) +type ExtractorSource string +type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error) +type KeyAuthErrorHandler func(c *echo.Context, err error) error + +// BodyDump handler now includes err +type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error) + +// ValuesExtractor now returns extractor source and CreateExtractors takes a limit +type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) +func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) +type ValueExtractorError struct{ ... } + +// New constants +const KB = 1024 + +// Rate limiter store now takes a float64 limit +func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) +``` + +### Added Middleware Exports + +```go +var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") +var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") +var RedirectHTTPSConfig = RedirectConfig{ ... } +var RedirectHTTPSWWWConfig = RedirectConfig{ ... } +var RedirectNonHTTPSWWWConfig = RedirectConfig{ ... } +var RedirectNonWWWConfig = RedirectConfig{ ... } +var RedirectWWWConfig = RedirectConfig{ ... } +``` + +### Removed/Consolidated Middleware Exports + +```go +// Removed in v5 +func Logger() echo.MiddlewareFunc +func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc +func Timeout() echo.MiddlewareFunc +func TimeoutWithConfig(config TimeoutConfig) echo.MiddlewareFunc +type ErrKeyAuthMissing struct{ ... } +type CSRFErrorHandler func(err error, c echo.Context) error +type LoggerConfig struct{ ... } +type LogErrorFunc func(c echo.Context, err error, stack []byte) error +type TargetProvider interface{ ... } +type TrailingSlashConfig struct{ ... } +type TimeoutConfig struct{ ... } +``` + +Also removed defaults: `DefaultBasicAuthConfig`, `DefaultBodyDumpConfig`, `DefaultBodyLimitConfig`, +`DefaultCORSConfig`, `DefaultDecompressConfig`, `DefaultGzipConfig`, `DefaultLoggerConfig`, +`DefaultRedirectConfig`, `DefaultRequestIDConfig`, `DefaultRewriteConfig`, `DefaultTimeoutConfig`, +`DefaultTrailingSlashConfig`. + +--- + +## Router Interface Changes + +### v4 Router (Concrete Struct) + +```go +type Router struct { ... } + +func NewRouter(e *Echo) *Router +func (r *Router) Add(method, path string, h HandlerFunc) +func (r *Router) Find(method, path string, c Context) +func (r *Router) Reverse(name string, params ...interface{}) string +func (r *Router) Routes() []*Route +``` + +### v5 Router (Interface + DefaultRouter) + +```go +type Router interface { + Add(routable Route) (RouteInfo, error) + Remove(method string, path string) error + Routes() Routes + Route(c *Context) HandlerFunc +} + +type DefaultRouter struct { ... } + +func NewRouter(config RouterConfig) *DefaultRouter +func NewConcurrentRouter(r Router) Router // NEW + +type RouterConfig struct { + NotFoundHandler HandlerFunc + MethodNotAllowedHandler HandlerFunc + OptionsMethodHandler HandlerFunc + AllowOverwritingRoute bool + UnescapePathParamValues bool + UseEscapedPathForMatching bool +} +``` + +**Key Changes:** +- Router is now an interface +- DefaultRouter is the concrete implementation +- Add() returns `(RouteInfo, error)` instead of being void +- New `Remove()` method +- New `Route()` method replaces `Find()` +- Configuration through `RouterConfig` + +--- + +## Echo Instance Method Changes + +### Route Registration + +```go +// v4 +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route + +// v5 +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo +func (e *Echo) AddRoute(route Route) (RouteInfo, error) // NEW +``` + +### Static File Serving + +```go +// v4 +func (e *Echo) Static(pathPrefix, fsRoot string) *Route +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route +func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route + +// v5 +func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo +func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo +``` + +Return type changed from `*Route` to `RouteInfo`. + +### Server Management + +```go +// v4 +func (e *Echo) Start(address string) error +func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) error +func (e *Echo) StartAutoTLS(address string) error +func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error +func (e *Echo) StartServer(s *http.Server) error +func (e *Echo) Shutdown(ctx context.Context) error +func (e *Echo) Close() error +func (e *Echo) ListenerAddr() net.Addr +func (e *Echo) TLSListenerAddr() net.Addr +func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) + +// v5 +func (e *Echo) Start(address string) error // Simplified +func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) + +// Removed: StartTLS, StartAutoTLS, StartH2CServer, StartServer +// Use StartConfig instead for advanced server configuration +// Removed: Shutdown, Close, ListenerAddr, TLSListenerAddr +// Removed: DefaultHTTPErrorHandler (now a top-level factory function) +``` + +**v5 provides** `StartConfig` type for all advanced server configuration. + +### Router Access + +```go +// v4 +func (e *Echo) Router() *Router +func (e *Echo) Routers() map[string]*Router // For multi-host +func (e *Echo) Routes() []*Route +func (e *Echo) Reverse(name string, params ...interface{}) string +func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string +func (e *Echo) URL(h HandlerFunc, params ...interface{}) string +func (e *Echo) Host(name string, m ...MiddlewareFunc) *Group + +// v5 +func (e *Echo) Router() Router // Returns interface +// Removed: Routers(), Reverse(), URI(), URL(), Host() +// Use router.Routes() and Routes.Reverse() instead +``` + +--- + +## NewContext Changes + +```go +// v4 +func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context +func NewResponse(w http.ResponseWriter, e *Echo) *Response + +// v5 +func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context +func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context // Standalone +func NewResponse(w http.ResponseWriter, logger *slog.Logger) *Response +``` + +--- + +## Migration Guide Summary + +### 1. Update All Handler Signatures + +```go +// Before +func MyHandler(c echo.Context) error { ... } + +// After +func MyHandler(c *echo.Context) error { ... } +``` + +### 2. Update Logger Usage + +```go +// Before +e.Logger.Info("Server started") +c.Logger().Error("Something went wrong") + +// After +e.Logger.Info("Server started") +c.Logger().Error("Something went wrong") // Same API, different logger +``` + +### 3. Use Type-Safe Parameter Extraction + +```go +// Before +idStr := c.Param("id") +id, err := strconv.Atoi(idStr) + +// After +id, err := echo.PathParam[int](c, "id") +``` + +### 4. Update Error Handler + +```go +// Before +e.HTTPErrorHandler = func(err error, c echo.Context) { + // handle error +} + +// After +e.HTTPErrorHandler = func(c *echo.Context, err error) { // Swapped! + // handle error +} + +// Or use factory +e.HTTPErrorHandler = echo.DefaultHTTPErrorHandler(true) // exposeError=true +``` + +### 5. Update Server Startup + +```go +// Before +e.Start(":8080") +e.StartTLS(":443", "cert.pem", "key.pem") + +// After +// Simple +e.Start(":8080") + +// Advanced with graceful shutdown +ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt) +defer cancel() +sc := echo.StartConfig{Address: ":8080"} +sc.Start(ctx, e) +``` + +### 6. Update Route Info Access + +```go +// Before +routes := e.Routes() +for _, r := range routes { + fmt.Println(r.Method, r.Path) +} + +// After +routes := e.Router().Routes() +for _, r := range routes { + fmt.Println(r.Method, r.Path) +} +``` + +### 7. Update HTTPError Creation + +```go +// Before +return echo.NewHTTPError(400, "invalid request", someDetail) + +// After +return echo.NewHTTPError(400, "invalid request") +``` + +### 8. Update Custom Binder + +```go +// Before +type MyBinder struct{} +func (b *MyBinder) Bind(i interface{}, c echo.Context) error { ... } + +// After +type MyBinder struct{} +func (b *MyBinder) Bind(c *echo.Context, target any) error { ... } // Swapped! +``` + +### 9. Path Parameters + +```go +// Before +names := c.ParamNames() +values := c.ParamValues() + +// After +pathValues := c.PathValues() +for _, pv := range pathValues { + fmt.Println(pv.Name, pv.Value) +} +``` + +### 10. Response Access + +```go +// Before +resp := c.Response() +resp.Header().Set("X-Custom", "value") + +// After +c.Response().Header().Set("X-Custom", "value") // Returns http.ResponseWriter + +// To get *echo.Response +resp, err := echo.UnwrapResponse(c.Response()) +``` + +### Go Version Requirements + +- **v4**: Go 1.24.0 (per `go.mod`) +- **v5**: Go 1.25.0 (per `go.mod`) + +--- + +**Generated by comparing `go doc` output from master (v4.15.0) and v5 (v5.0.0-alpha) branches** diff --git a/LICENSE b/LICENSE index c46d0105f..2f18411bd 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2021 LabStack +Copyright (c) 2022 LabStack Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Makefile b/Makefile index cbd78f1bf..bd075bbae 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,6 @@ PKG := "github.com/labstack/echo" PKG_LIST := $(shell go list ${PKG}/...) -tag: - @git tag `grep -P '^\tversion = ' echo.go|cut -f2 -d'"'` - @git tag|grep -v ^v - .DEFAULT_GOAL := check check: lint vet race ## Check project @@ -26,12 +22,11 @@ race: ## Run tests with data race detector @go test -race ${PKG_LIST} benchmark: ## Run benchmarks - @go test -run="-" -bench=".*" ${PKG_LIST} + @go test -run="-" -benchmem -bench=".*" ${PKG_LIST} help: ## Display this help screen @grep -h -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' -goversion ?= "1.22" -docker_user ?= "1000" -test_version: ## Run tests inside Docker with given version (defaults to 1.22 oldest supported). Example: make test_version goversion=1.22 - @docker run --rm -it --user $(docker_user) -e HOME=/tmp -e GOCACHE=/tmp/go-cache -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "mkdir -p /tmp/go-cache /tmp/.cache && cd /project && make init check" +goversion ?= "1.25" +test_version: ## Run tests inside Docker with given version (defaults to 1.25 oldest supported). Example: make test_version goversion=1.25 + @docker run --rm -it -v $(shell pwd):/project golang:$(goversion) /bin/sh -c "cd /project && make init check" diff --git a/README.md b/README.md index 5e52d1d4e..8b9d02785 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Click [here](https://github.com/sponsors/labstack) for more information on spons ```sh // go get github.com/labstack/echo/{version} -go get github.com/labstack/echo/v4 +go get github.com/labstack/echo/v5 ``` Latest version of Echo supports last four Go major [releases](https://go.dev/doc/devel/release) and might work with older versions. @@ -62,8 +62,9 @@ Latest version of Echo supports last four Go major [releases](https://go.dev/doc package main import ( - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + "errors" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" "log/slog" "net/http" ) @@ -73,20 +74,20 @@ func main() { e := echo.New() // Middleware - e.Use(middleware.RequestLogger()) // use the default RequestLogger middleware with slog logger + e.Use(middleware.RequestLogger()) // use the RequestLogger middleware with slog logger e.Use(middleware.Recover()) // recover panics as errors for proper error handling // Routes e.GET("/", hello) // Start server - if err := e.Start(":8080"); err != nil && !errors.Is(err, http.ErrServerClosed) { + if err := e.Start(":8080"); err != nil { slog.Error("failed to start server", "error", err) } } // Handler -func hello(c echo.Context) error { +func hello(c *echo.Context) error { return c.String(http.StatusOK, "Hello, World!") } ``` diff --git a/bind.go b/bind.go index 1d4fe6f0a..050e8973b 100644 --- a/bind.go +++ b/bind.go @@ -7,7 +7,6 @@ import ( "encoding" "encoding/xml" "errors" - "fmt" "mime/multipart" "net/http" "reflect" @@ -18,7 +17,7 @@ import ( // Binder is the interface that wraps the Bind method. type Binder interface { - Bind(i interface{}, c Context) error + Bind(c *Context, target any) error } // DefaultBinder is the default implementation of the Binder interface. @@ -39,31 +38,22 @@ type bindMultipleUnmarshaler interface { UnmarshalParams(params []string) error } -// BindPathParams binds path params to bindable object -// -// Time format support: time.Time fields can use `format` tags to specify custom parsing layouts. -// Example: `param:"created" format:"2006-01-02T15:04"` for datetime-local format -// Example: `param:"date" format:"2006-01-02"` for date format -// Uses Go's standard time format reference time: Mon Jan 2 15:04:05 MST 2006 -// Works with form data, query parameters, and path parameters (not JSON body) -// Falls back to default time.Time parsing if no format tag is specified -func (b *DefaultBinder) BindPathParams(c Context, i interface{}) error { - names := c.ParamNames() - values := c.ParamValues() +// BindPathValues binds path parameter values to bindable object +func BindPathValues(c *Context, target any) error { params := map[string][]string{} - for i, name := range names { - params[name] = []string{values[i]} + for _, param := range c.PathValues() { + params[param.Name] = []string{param.Value} } - if err := b.bindData(i, params, "param", nil); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err := bindData(target, params, "param", nil); err != nil { + return ErrBadRequest.Wrap(err) } return nil } // BindQueryParams binds query params to bindable object -func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { - if err := b.bindData(i, c.QueryParams(), "query", nil); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindQueryParams(c *Context, target any) error { + if err := bindData(target, c.QueryParams(), "query", nil); err != nil { + return ErrBadRequest.Wrap(err) } return nil } @@ -73,7 +63,7 @@ func (b *DefaultBinder) BindQueryParams(c Context, i interface{}) error { // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See non-MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseForm // See MIMEMultipartForm: https://golang.org/pkg/net/http/#Request.ParseMultipartForm -func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { +func BindBody(c *Context, target any) (err error) { req := c.Request() if req.ContentLength == 0 { return @@ -85,58 +75,52 @@ func (b *DefaultBinder) BindBody(c Context, i interface{}) (err error) { switch mediatype { case MIMEApplicationJSON: - if err = c.Echo().JSONSerializer.Deserialize(c, i); err != nil { - switch err.(type) { - case *HTTPError: + if err = c.Echo().JSONSerializer.Deserialize(c, target); err != nil { + var hErr *HTTPError + if errors.As(err, &hErr) { return err - default: - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) } + return ErrBadRequest.Wrap(err) } case MIMEApplicationXML, MIMETextXML: - if err = xml.NewDecoder(req.Body).Decode(i); err != nil { - if ute, ok := err.(*xml.UnsupportedTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unsupported type error: type=%v, error=%v", ute.Type, ute.Error())).SetInternal(err) - } else if se, ok := err.(*xml.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: line=%v, error=%v", se.Line, se.Error())).SetInternal(err) - } - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err = xml.NewDecoder(req.Body).Decode(target); err != nil { + return ErrBadRequest.Wrap(err) } case MIMEApplicationForm: - params, err := c.FormParams() + params, err := c.FormValues() if err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return ErrBadRequest.Wrap(err) } - if err = b.bindData(i, params, "form", nil); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err = bindData(target, params, "form", nil); err != nil { + return ErrBadRequest.Wrap(err) } case MIMEMultipartForm: params, err := c.MultipartForm() if err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + return ErrBadRequest.Wrap(err) } - if err = b.bindData(i, params.Value, "form", params.File); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) + if err = bindData(target, params.Value, "form", params.File); err != nil { + return ErrBadRequest.Wrap(err) } default: - return ErrUnsupportedMediaType + return &HTTPError{Code: http.StatusUnsupportedMediaType} } return nil } // BindHeaders binds HTTP headers to a bindable object -func (b *DefaultBinder) BindHeaders(c Context, i interface{}) error { - if err := b.bindData(i, c.Request().Header, "header", nil); err != nil { - return NewHTTPError(http.StatusBadRequest, err.Error()).SetInternal(err) +func BindHeaders(c *Context, target any) error { + if err := bindData(target, c.Request().Header, "header", nil); err != nil { + return ErrBadRequest.Wrap(err) } return nil } // Bind implements the `Binder#Bind` function. // Binding is done in following order: 1) path params; 2) query params; 3) request body. Each step COULD override previous -// step binded values. For single source binding use their own methods BindBody, BindQueryParams, BindPathParams. -func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { - if err := b.BindPathParams(c, i); err != nil { +// step bound values. For single source binding use their own methods BindBody, BindQueryParams, BindPathValues. +func (b *DefaultBinder) Bind(c *Context, target any) error { + if err := BindPathValues(c, target); err != nil { return err } // Only bind query parameters for GET/DELETE/HEAD to avoid unexpected behavior with destination struct binding from body. @@ -144,15 +128,15 @@ func (b *DefaultBinder) Bind(i interface{}, c Context) (err error) { // The HTTP method check restores pre-v4.1.11 behavior to avoid these problems (see issue #1670) method := c.Request().Method if method == http.MethodGet || method == http.MethodDelete || method == http.MethodHead { - if err = b.BindQueryParams(c, i); err != nil { + if err := BindQueryParams(c, target); err != nil { return err } } - return b.BindBody(c, i) + return BindBody(c, target) } // bindData will bind data ONLY fields in destination struct that have EXPLICIT tag -func (b *DefaultBinder) bindData(destination interface{}, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { +func bindData(destination any, data map[string][]string, tag string, dataFiles map[string][]*multipart.FileHeader) error { if destination == nil || (len(data) == 0 && len(dataFiles) == 0) { return nil } @@ -163,7 +147,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri // Support binding to limited Map destinations: // - map[string][]string, // - map[string]string <-- (binds first value from data slice) - // - map[string]interface{} + // - map[string]any // You are better off binding to struct but there are user who want this map feature. Source of data for these cases are: // params,query,header,form as these sources produce string values, most of the time slice of strings, actually. if typ.Kind() == reflect.Map && typ.Key().Kind() == reflect.String { @@ -182,7 +166,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) } else if isElemInterface { // To maintain backward compatibility, we always bind to the first string value - // and not the slice of strings when dealing with map[string]interface{}{} + // and not the slice of strings when dealing with map[string]any{} val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v[0])) } else { val.SetMapIndex(reflect.ValueOf(k), reflect.ValueOf(v)) @@ -222,7 +206,7 @@ func (b *DefaultBinder) bindData(destination interface{}, data map[string][]stri // If tag is nil, we inspect if the field is a not BindUnmarshaler struct and try to bind data into it (might contain fields with tags). // structs that implement BindUnmarshaler are bound only when they have explicit tag if _, ok := structField.Addr().Interface().(BindUnmarshaler); !ok && structFieldKind == reflect.Struct { - if err := b.bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { + if err := bindData(structField.Addr().Interface(), data, tag, dataFiles); err != nil { return err } } @@ -374,7 +358,6 @@ func unmarshalInputToField(valueKind reflect.Kind, val string, field reflect.Val } fieldIValue := field.Addr().Interface() - // Handle time.Time with custom format tag if formatTag != "" { if _, isTime := fieldIValue.(*time.Time); isTime { diff --git a/bind_test.go b/bind_test.go index 3e387ba19..1d5f8ca41 100644 --- a/bind_test.go +++ b/bind_test.go @@ -25,79 +25,79 @@ import ( ) type bindTestStruct struct { - I int - PtrI *int - I8 int8 - PtrI8 *int8 - I16 int16 + T Timestamp + GoT time.Time PtrI16 *int16 - I32 int32 + PtrUI *uint + Tptr *Timestamp + PtrF32 *float32 + PtrB *bool PtrI32 *int32 - I64 int64 + GoTptr *time.Time PtrI64 *int64 - UI uint - PtrUI *uint - UI8 uint8 + PtrI *int + PtrI8 *int8 + PtrF64 *float64 PtrUI8 *uint8 - UI16 uint16 + PtrUI64 *uint64 PtrUI16 *uint16 - UI32 uint32 + PtrS *string PtrUI32 *uint32 - UI64 uint64 - PtrUI64 *uint64 - B bool - PtrB *bool - F32 float32 - PtrF32 *float32 - F64 float64 - PtrF64 *float64 S string - PtrS *string cantSet string DoesntExist string - GoT time.Time - GoTptr *time.Time - T Timestamp - Tptr *Timestamp SA StringArray + F64 float64 + I int + UI64 uint64 + UI uint + I64 int64 + F32 float32 + UI32 uint32 + I32 int32 + UI16 uint16 + I16 int16 + B bool + UI8 uint8 + I8 int8 } type bindTestStructWithTags struct { - I int `json:"I" form:"I"` - PtrI *int `json:"PtrI" form:"PtrI"` - I8 int8 `json:"I8" form:"I8"` - PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` - I16 int16 `json:"I16" form:"I16"` - PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` - I32 int32 `json:"I32" form:"I32"` - PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` - I64 int64 `json:"I64" form:"I64"` - PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` - UI uint `json:"UI" form:"UI"` - PtrUI *uint `json:"PtrUI" form:"PtrUI"` - UI8 uint8 `json:"UI8" form:"UI8"` - PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` - UI16 uint16 `json:"UI16" form:"UI16"` - PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` - UI32 uint32 `json:"UI32" form:"UI32"` - PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` - UI64 uint64 `json:"UI64" form:"UI64"` - PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` - B bool `json:"B" form:"B"` - PtrB *bool `json:"PtrB" form:"PtrB"` - F32 float32 `json:"F32" form:"F32"` - PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` - F64 float64 `json:"F64" form:"F64"` - PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` - S string `json:"S" form:"S"` - PtrS *string `json:"PtrS" form:"PtrS"` + T Timestamp `json:"T" form:"T"` + GoT time.Time `json:"GoT" form:"GoT"` + PtrI16 *int16 `json:"PtrI16" form:"PtrI16"` + PtrUI *uint `json:"PtrUI" form:"PtrUI"` + Tptr *Timestamp `json:"Tptr" form:"Tptr"` + PtrF32 *float32 `json:"PtrF32" form:"PtrF32"` + PtrB *bool `json:"PtrB" form:"PtrB"` + PtrI32 *int32 `json:"PtrI32" form:"PtrI32"` + GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` + PtrI64 *int64 `json:"PtrI64" form:"PtrI64"` + PtrI *int `json:"PtrI" form:"PtrI"` + PtrI8 *int8 `json:"PtrI8" form:"PtrI8"` + PtrF64 *float64 `json:"PtrF64" form:"PtrF64"` + PtrUI8 *uint8 `json:"PtrUI8" form:"PtrUI8"` + PtrUI64 *uint64 `json:"PtrUI64" form:"PtrUI64"` + PtrUI16 *uint16 `json:"PtrUI16" form:"PtrUI16"` + PtrS *string `json:"PtrS" form:"PtrS"` + PtrUI32 *uint32 `json:"PtrUI32" form:"PtrUI32"` + S string `json:"S" form:"S"` cantSet string DoesntExist string `json:"DoesntExist" form:"DoesntExist"` - GoT time.Time `json:"GoT" form:"GoT"` - GoTptr *time.Time `json:"GoTptr" form:"GoTptr"` - T Timestamp `json:"T" form:"T"` - Tptr *Timestamp `json:"Tptr" form:"Tptr"` SA StringArray `json:"SA" form:"SA"` + F64 float64 `json:"F64" form:"F64"` + I int `json:"I" form:"I"` + UI64 uint64 `json:"UI64" form:"UI64"` + UI uint `json:"UI" form:"UI"` + I64 int64 `json:"I64" form:"I64"` + F32 float32 `json:"F32" form:"F32"` + UI32 uint32 `json:"UI32" form:"UI32"` + I32 int32 `json:"I32" form:"I32"` + UI16 uint16 `json:"UI16" form:"UI16"` + I16 int16 `json:"I16" form:"I16"` + B bool `json:"B" form:"B"` + UI8 uint8 `json:"UI8" form:"UI8"` + I8 int8 `json:"I8" form:"I8"` } type Timestamp time.Time @@ -283,7 +283,7 @@ func TestBindHeaderParam(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := (&DefaultBinder{}).BindHeaders(c, u) + err := BindHeaders(c, u) if assert.NoError(t, err) { assert.Equal(t, 2, u.ID) assert.Equal(t, "Jon Doe", u.Name) @@ -297,7 +297,7 @@ func TestBindHeaderParamBadType(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) u := new(user) - err := (&DefaultBinder{}).BindHeaders(c, u) + err := BindHeaders(c, u) assert.Error(t, err) httpErr, ok := err.(*HTTPError) @@ -312,13 +312,13 @@ func TestBindUnmarshalParam(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { - T Timestamp `query:"ts"` - TA []Timestamp `query:"ta"` - SA StringArray `query:"sa"` + T Timestamp `query:"ts"` ST Struct StWithTag struct { Foo string `query:"st"` } + TA []Timestamp `query:"ta"` + SA StringArray `query:"sa"` }{} err := c.Bind(&result) ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC)) @@ -339,10 +339,10 @@ func TestBindUnmarshalText(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) result := struct { - T time.Time `query:"ts"` + T time.Time `query:"ts"` + ST Struct TA []time.Time `query:"ta"` SA StringArray `query:"sa"` - ST Struct }{} err := c.Bind(&result) ts := time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC) @@ -447,7 +447,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]string", func(t *testing.T) { dest := map[string]string{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", @@ -459,7 +459,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]string with nil map", func(t *testing.T) { var dest map[string]string - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]string{ "multiple": "1", @@ -471,7 +471,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string][]string", func(t *testing.T) { dest := map[string][]string{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, @@ -483,7 +483,7 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string][]string with nil map", func(t *testing.T) { var dest map[string][]string - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]string{ "multiple": {"1", "2"}, @@ -494,10 +494,10 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { }) t.Run("ok, bind to map[string]interface", func(t *testing.T) { - dest := map[string]interface{}{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + dest := map[string]any{} + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, - map[string]interface{}{ + map[string]any{ "multiple": "1", "single": "3", }, @@ -506,10 +506,10 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { }) t.Run("ok, bind to map[string]interface with nil map", func(t *testing.T) { - var dest map[string]interface{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + var dest map[string]any + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, - map[string]interface{}{ + map[string]any{ "multiple": "1", "single": "3", }, @@ -519,33 +519,32 @@ func TestDefaultBinder_bindDataToMap(t *testing.T) { t.Run("ok, bind to map[string]int skips", func(t *testing.T) { dest := map[string]int{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int{}, dest) }) t.Run("ok, bind to map[string]int skips with nil map", func(t *testing.T) { var dest map[string]int - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string]int(nil), dest) }) t.Run("ok, bind to map[string][]int skips", func(t *testing.T) { dest := map[string][]int{} - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int{}, dest) }) t.Run("ok, bind to map[string][]int skips with nil map", func(t *testing.T) { var dest map[string][]int - assert.NoError(t, new(DefaultBinder).bindData(&dest, exampleData, "param", nil)) + assert.NoError(t, bindData(&dest, exampleData, "param", nil)) assert.Equal(t, map[string][]int(nil), dest) }) } func TestBindbindData(t *testing.T) { ts := new(bindTestStruct) - b := new(DefaultBinder) - err := b.bindData(ts, values, "form", nil) + err := bindData(ts, values, "form", nil) assert.NoError(t, err) assert.Equal(t, 0, ts.I) @@ -570,9 +569,13 @@ func TestBindParam(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - c.SetPath("/users/:id/:name") - c.SetParamNames("id", "name") - c.SetParamValues("1", "Jon Snow") + c.InitializeRoute( + &RouteInfo{Path: "/users/:id/:name"}, + &PathValues{ + {Name: "id", Value: "1"}, + {Name: "name", Value: "Jon Snow"}, + }, + ) u := new(user) err := c.Bind(u) @@ -583,9 +586,12 @@ func TestBindParam(t *testing.T) { // Second test for the absence of a param c2 := e.NewContext(req, rec) - c2.SetPath("/users/:id") - c2.SetParamNames("id") - c2.SetParamValues("1") + c2.InitializeRoute( + &RouteInfo{Path: "/users/:id"}, + &PathValues{ + {Name: "id", Value: "1"}, + }, + ) u = new(user) err = c2.Bind(u) @@ -603,9 +609,12 @@ func TestBindParam(t *testing.T) { rec2 := httptest.NewRecorder() c3 := e2.NewContext(req2, rec2) - c3.SetPath("/users/:id") - c3.SetParamNames("id") - c3.SetParamValues("1") + c3.InitializeRoute( + &RouteInfo{Path: "/users/:id"}, + &PathValues{ + {Name: "id", Value: "1"}, + }, + ) u = new(user) err = c3.Bind(u) @@ -627,9 +636,7 @@ func TestBindUnmarshalTypeError(t *testing.T) { err := c.Bind(u) - he := &HTTPError{Code: http.StatusBadRequest, Message: "Unmarshal type error: expected=int, got=string, field=id, offset=14", Internal: err.(*HTTPError).Internal} - - assert.Equal(t, he, err) + assert.EqualError(t, err, `code=400, message=Bad Request, err=json: cannot unmarshal string into Go struct field user.id of type int`) } func TestBindSetWithProperType(t *testing.T) { @@ -663,11 +670,10 @@ func TestBindSetWithProperType(t *testing.T) { func BenchmarkBindbindDataWithTags(b *testing.B) { b.ReportAllocs() ts := new(bindTestStructWithTags) - binder := new(DefaultBinder) var err error b.ResetTimer() for i := 0; i < b.N; i++ { - err = binder.bindData(ts, values, "form", nil) + err = bindData(ts, values, "form", nil) } assert.NoError(b, err) assertBindTestStruct(b, (*bindTestStruct)(ts)) @@ -742,36 +748,36 @@ func testBindError(t *testing.T, r io.Reader, ctype string, expectedInternal err strings.HasPrefix(ctype, MIMEApplicationForm), strings.HasPrefix(ctype, MIMEMultipartForm): if assert.IsType(t, new(HTTPError), err) { assert.Equal(t, http.StatusBadRequest, err.(*HTTPError).Code) - assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) + assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap()) } default: if assert.IsType(t, new(HTTPError), err) { assert.Equal(t, ErrUnsupportedMediaType, err) - assert.IsType(t, expectedInternal, err.(*HTTPError).Internal) + assert.IsType(t, expectedInternal, err.(*HTTPError).Unwrap()) } } } func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use - // binding is done in steps and one source could overwrite previous source binded data + // binding is done in steps and one source could overwrite previous source bound data // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed type Opts struct { - ID int `json:"id" form:"id" query:"id"` Node string `json:"node" form:"node" query:"node" param:"node"` Lang string + ID int `json:"id" form:"id" query:"id"` } var testCases = []struct { + givenContent io.Reader + whenBindTarget any + expect any name string givenURL string - givenContent io.Reader givenMethod string - whenBindTarget interface{} - whenNoPathParams bool - expect interface{} expectError string + whenNoPathValues bool }{ { name: "ok, POST bind to struct with: path param + query param + body", @@ -799,14 +805,14 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), - expect: &Opts{ID: 1, Node: "zzz"}, // body is binded last and overwrites previous (path,query) values + expect: &Opts{ID: 1, Node: "zzz"}, // body is bound last and overwrites previous (path,query) values }, { name: "ok, DELETE bind to struct with: path param + query param + body", givenMethod: http.MethodDelete, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), - expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is binded after query params + expect: &Opts{ID: 1, Node: "zzz"}, // for DELETE body is bound after query params }, { name: "ok, POST bind to struct with: path param + body", @@ -828,7 +834,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`{`), expect: &Opts{ID: 0, Node: "node_from_path"}, // query binding has already modified bind target - expectError: "code=400, message=unexpected EOF, internal=unexpected EOF", + expectError: "code=400, message=Bad Request, err=unexpected EOF", }, { name: "nok, GET with body bind failure when types are not convertible", @@ -836,7 +842,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint?id=nope", givenContent: strings.NewReader(`{"id": 1, "node": "zzz"}`), expect: &Opts{ID: 0, Node: "node_from_path"}, // path params binding has already modified bind target - expectError: "code=400, message=strconv.ParseInt: parsing \"nope\": invalid syntax, internal=strconv.ParseInt: parsing \"nope\": invalid syntax", + expectError: `code=400, message=Bad Request, err=strconv.ParseInt: parsing "nope": invalid syntax`, }, { name: "nok, GET body bind failure - trying to bind json array to struct", @@ -844,14 +850,14 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), expect: &Opts{ID: 0, Node: "xxx"}, // query binding has already modified bind target - expectError: "code=400, message=Unmarshal type error: expected=echo.Opts, got=array, field=, offset=1, internal=json: cannot unmarshal array into Go value of type echo.Opts", + expectError: `code=400, message=Bad Request, err=json: cannot unmarshal array into Go value of type echo.Opts`, }, { // query param is ignored as we do not know where exactly to bind it in slice name: "ok, GET bind to struct slice, ignore query param", givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint?node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), - whenNoPathParams: true, + whenNoPathValues: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{ {ID: 1, Node: ""}, @@ -862,7 +868,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenMethod: http.MethodPost, givenURL: "/api/real_node/endpoint?id=nope&node=xxx", givenContent: strings.NewReader(`[{"id": 1}]`), - whenNoPathParams: true, + whenNoPathValues: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{{ID: 1}}, expectError: "", @@ -882,7 +888,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { givenMethod: http.MethodGet, givenURL: "/api/real_node/endpoint", givenContent: strings.NewReader(`[{"id": 1}]`), - whenNoPathParams: true, + whenNoPathValues: true, whenBindTarget: &[]Opts{}, expect: &[]Opts{{ID: 1, Node: ""}}, expectError: "", @@ -898,12 +904,13 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("node_from_path") + if !tc.whenNoPathValues { + c.SetPathValues(PathValues{ + {Name: "node", Value: "node_from_path"}, + }) } - var bindTarget interface{} + var bindTarget any if tc.whenBindTarget != nil { bindTarget = tc.whenBindTarget } else { @@ -911,7 +918,7 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { } b := new(DefaultBinder) - err := b.Bind(bindTarget, c) + err := b.Bind(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -924,28 +931,28 @@ func TestDefaultBinder_BindToStructFromMixedSources(t *testing.T) { func TestDefaultBinder_BindBody(t *testing.T) { // tests to check binding behaviour when multiple sources (path params, query params and request body) are in use - // generally when binding from request body - URL and path params are ignored - unless form is being binded. + // generally when binding from request body - URL and path params are ignored - unless form is being bound. // these tests are to document this behaviour and detect further possible regressions when bind implementation is changed type Node struct { - ID int `json:"id" xml:"id" form:"id" query:"id"` Node string `json:"node" xml:"node" form:"node" query:"node" param:"node"` + ID int `json:"id" xml:"id" form:"id" query:"id"` } type Nodes struct { Nodes []Node `xml:"node" form:"node"` } var testCases = []struct { + givenContent io.Reader + whenBindTarget any + expect any name string givenURL string - givenContent io.Reader givenMethod string givenContentType string - whenNoPathParams bool - whenChunkedBody bool - whenBindTarget interface{} - expect interface{} expectError string + whenNoPathValues bool + whenChunkedBody bool }{ { name: "ok, JSON POST bind to struct with: path + query + empty field in body", @@ -969,7 +976,7 @@ func TestDefaultBinder_BindBody(t *testing.T) { givenMethod: http.MethodPost, givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`[{"id": 1}]`), - whenNoPathParams: true, + whenNoPathValues: true, whenBindTarget: &[]Node{}, expect: &[]Node{{ID: 1, Node: ""}}, expectError: "", @@ -997,7 +1004,7 @@ func TestDefaultBinder_BindBody(t *testing.T) { givenContentType: MIMEApplicationJSON, givenContent: strings.NewReader(`{`), expect: &Node{ID: 0, Node: ""}, - expectError: "code=400, message=unexpected EOF, internal=unexpected EOF", + expectError: "code=400, message=Bad Request, err=unexpected EOF", }, { name: "ok, XML POST bind to struct with: path + query + empty body", @@ -1023,7 +1030,7 @@ func TestDefaultBinder_BindBody(t *testing.T) { givenContentType: MIMEApplicationXML, givenContent: strings.NewReader(`<`), expect: &Node{ID: 0, Node: ""}, - expectError: "code=400, message=Syntax error: line=1, error=XML syntax error on line 1: unexpected EOF, internal=XML syntax error on line 1: unexpected EOF", + expectError: "code=400, message=Bad Request, err=XML syntax error on line 1: unexpected EOF", }, { name: "ok, FORM POST bind to struct with: path + query + body", @@ -1113,20 +1120,20 @@ func TestDefaultBinder_BindBody(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if !tc.whenNoPathParams { - c.SetParamNames("node") - c.SetParamValues("real_node") + if !tc.whenNoPathValues { + c.SetPathValues(PathValues{ + {Name: "node", Value: "real_node"}, + }) } - var bindTarget interface{} + var bindTarget any if tc.whenBindTarget != nil { bindTarget = tc.whenBindTarget } else { bindTarget = &Node{} } - b := new(DefaultBinder) - err := b.BindBody(c, bindTarget) + err := BindBody(c, bindTarget) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -1189,7 +1196,7 @@ func TestBindUnmarshalParamExtras(t *testing.T) { }{} err := testBindURL("/?t=xxxx", &result) - assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer") + assert.EqualError(t, err, `code=400, message=Bad Request, err='xxxx' is not an integer`) }) t.Run("ok, target is struct", func(t *testing.T) { @@ -1294,7 +1301,7 @@ func TestBindUnmarshalParams(t *testing.T) { }{} err := testBindURL("/?t=xxxx", &result) - assert.EqualError(t, err, "code=400, message='xxxx' is not an integer, internal='xxxx' is not an integer") + assert.EqualError(t, err, "code=400, message=Bad Request, err='xxxx' is not an integer") }) t.Run("ok, target is struct", func(t *testing.T) { @@ -1361,7 +1368,7 @@ func TestBindInt8(t *testing.T) { } p := target{} err := testBindURL("/?v=x&v=2", &p) - assert.EqualError(t, err, "code=400, message=strconv.ParseInt: parsing \"x\": invalid syntax, internal=strconv.ParseInt: parsing \"x\": invalid syntax") + assert.EqualError(t, err, `code=400, message=Bad Request, err=strconv.ParseInt: parsing "x": invalid syntax`) }) t.Run("nok, int8 embedded in struct", func(t *testing.T) { @@ -1469,7 +1476,7 @@ func TestBindMultipartFormFiles(t *testing.T) { } err := bindMultipartFiles(t, &target, file1, file2) // file2 should be ignored - assert.EqualError(t, err, "code=400, message=binding to multipart.FileHeader struct is not supported, use pointer to struct, internal=binding to multipart.FileHeader struct is not supported, use pointer to struct") + assert.EqualError(t, err, `code=400, message=Bad Request, err=binding to multipart.FileHeader struct is not supported, use pointer to struct`) }) t.Run("ok, bind single multipart file to pointer to multipart file", func(t *testing.T) { @@ -1577,7 +1584,7 @@ func TestTimeFormatBinding(t *testing.T) { DateTimeLocal time.Time `form:"datetime_local" format:"2006-01-02T15:04"` Date time.Time `query:"date" format:"2006-01-02"` CustomFormat time.Time `form:"custom" format:"01/02/2006 15:04:05"` - DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing + DefaultTime time.Time `form:"default_time"` // No format tag - should use default parsing PtrTime *time.Time `query:"ptr_time" format:"2006-01-02"` } @@ -1623,7 +1630,7 @@ func TestTimeFormatBinding(t *testing.T) { { name: "nok, wrong format should fail", contentType: MIMEApplicationForm, - data: "datetime_local=2023-12-25", // Missing time part + data: "datetime_local=2023-12-25", // Missing time part expectError: true, }, } diff --git a/binder.go b/binder.go index da15ae82a..32029ec0f 100644 --- a/binder.go +++ b/binder.go @@ -16,7 +16,7 @@ import ( /** Following functions provide handful of methods for binding to Go native types from request query or path parameters. * QueryParamsBinder(c) - binds query parameters (source URL) - * PathParamsBinder(c) - binds path parameters (source URL) + * PathValuesBinder(c) - binds path parameters (source URL) * FormFieldBinder(c) - binds form fields (source URL + body) Example: @@ -75,15 +75,11 @@ type BindingError struct { } // NewBindingError creates new instance of binding error -func NewBindingError(sourceParam string, values []string, message interface{}, internalError error) error { +func NewBindingError(sourceParam string, values []string, message string, err error) error { return &BindingError{ - Field: sourceParam, - Values: values, - HTTPError: &HTTPError{ - Code: http.StatusBadRequest, - Message: message, - Internal: internalError, - }, + Field: sourceParam, + Values: values, + HTTPError: &HTTPError{Code: http.StatusBadRequest, Message: message, err: err}, } } @@ -99,14 +95,14 @@ type ValueBinder struct { // ValuesFunc is used to get all values for parameter from request. i.e. `/api/search?ids=1&ids=2` ValuesFunc func(sourceParam string) []string // ErrorFunc is used to create errors. Allows you to use your own error type, that for example marshals to your specific json response - ErrorFunc func(sourceParam string, values []string, message interface{}, internalError error) error + ErrorFunc func(sourceParam string, values []string, message string, internalError error) error errors []error // failFast is flag for binding methods to return without attempting to bind when previous binding already failed failFast bool } // QueryParamsBinder creates query parameter value binder -func QueryParamsBinder(c Context) *ValueBinder { +func QueryParamsBinder(c *Context) *ValueBinder { return &ValueBinder{ failFast: true, ValueFunc: c.QueryParam, @@ -121,8 +117,8 @@ func QueryParamsBinder(c Context) *ValueBinder { } } -// PathParamsBinder creates path parameter value binder -func PathParamsBinder(c Context) *ValueBinder { +// PathValuesBinder creates path parameter value binder +func PathValuesBinder(c *Context) *ValueBinder { return &ValueBinder{ failFast: true, ValueFunc: c.Param, @@ -148,7 +144,7 @@ func PathParamsBinder(c Context) *ValueBinder { // NB: when binding forms take note that this implementation uses standard library form parsing // which parses form data from BOTH URL and BODY if content type is not MIMEMultipartForm // See https://golang.org/pkg/net/http/#Request.ParseForm -func FormFieldBinder(c Context) *ValueBinder { +func FormFieldBinder(c *Context) *ValueBinder { vb := &ValueBinder{ failFast: true, ValueFunc: func(sourceParam string) string { @@ -159,7 +155,7 @@ func FormFieldBinder(c Context) *ValueBinder { vb.ValuesFunc = func(sourceParam string) []string { if c.Request().Form == nil { // this is same as `Request().FormValue()` does internally - _ = c.Request().ParseMultipartForm(32 << 20) + _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory) } values, ok := c.Request().Form[sourceParam] if !ok { @@ -402,17 +398,17 @@ func (b *ValueBinder) MustTextUnmarshaler(sourceParam string, dest encoding.Text // BindWithDelimiter binds parameter to destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values -func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { +func (b *ValueBinder) BindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, false) } // MustBindWithDelimiter requires parameter value to exist to bind destination by suitable conversion function. // Delimiter is used before conversion to split parameter value to separate values -func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest interface{}, delimiter string) *ValueBinder { +func (b *ValueBinder) MustBindWithDelimiter(sourceParam string, dest any, delimiter string) *ValueBinder { return b.bindWithDelimiter(sourceParam, dest, delimiter, true) } -func (b *ValueBinder) bindWithDelimiter(sourceParam string, dest interface{}, delimiter string, valueMustExist bool) *ValueBinder { +func (b *ValueBinder) bindWithDelimiter(sourceParam string, dest any, delimiter string, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -500,7 +496,7 @@ func (b *ValueBinder) MustInt(sourceParam string, dest *int) *ValueBinder { return b.intValue(sourceParam, dest, 0, true) } -func (b *ValueBinder) intValue(sourceParam string, dest interface{}, bitSize int, valueMustExist bool) *ValueBinder { +func (b *ValueBinder) intValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -516,7 +512,7 @@ func (b *ValueBinder) intValue(sourceParam string, dest interface{}, bitSize int return b.int(sourceParam, value, dest, bitSize) } -func (b *ValueBinder) int(sourceParam string, value string, dest interface{}, bitSize int) *ValueBinder { +func (b *ValueBinder) int(sourceParam string, value string, dest any, bitSize int) *ValueBinder { n, err := strconv.ParseInt(value, 10, bitSize) if err != nil { if bitSize == 0 { @@ -531,18 +527,18 @@ func (b *ValueBinder) int(sourceParam string, value string, dest interface{}, bi case *int64: *d = n case *int32: - *d = int32(n) + *d = int32(n) // #nosec G115 case *int16: - *d = int16(n) + *d = int16(n) // #nosec G115 case *int8: - *d = int8(n) + *d = int8(n) // #nosec G115 case *int: *d = int(n) } return b } -func (b *ValueBinder) intsValue(sourceParam string, dest interface{}, valueMustExist bool) *ValueBinder { +func (b *ValueBinder) intsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -557,7 +553,7 @@ func (b *ValueBinder) intsValue(sourceParam string, dest interface{}, valueMustE return b.ints(sourceParam, values, dest) } -func (b *ValueBinder) ints(sourceParam string, values []string, dest interface{}) *ValueBinder { +func (b *ValueBinder) ints(sourceParam string, values []string, dest any) *ValueBinder { switch d := dest.(type) { case *[]int64: tmp := make([]int64, len(values)) @@ -728,7 +724,7 @@ func (b *ValueBinder) MustUint(sourceParam string, dest *uint) *ValueBinder { return b.uintValue(sourceParam, dest, 0, true) } -func (b *ValueBinder) uintValue(sourceParam string, dest interface{}, bitSize int, valueMustExist bool) *ValueBinder { +func (b *ValueBinder) uintValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -744,7 +740,7 @@ func (b *ValueBinder) uintValue(sourceParam string, dest interface{}, bitSize in return b.uint(sourceParam, value, dest, bitSize) } -func (b *ValueBinder) uint(sourceParam string, value string, dest interface{}, bitSize int) *ValueBinder { +func (b *ValueBinder) uint(sourceParam string, value string, dest any, bitSize int) *ValueBinder { n, err := strconv.ParseUint(value, 10, bitSize) if err != nil { if bitSize == 0 { @@ -759,18 +755,18 @@ func (b *ValueBinder) uint(sourceParam string, value string, dest interface{}, b case *uint64: *d = n case *uint32: - *d = uint32(n) + *d = uint32(n) // #nosec G115 case *uint16: - *d = uint16(n) + *d = uint16(n) // #nosec G115 case *uint8: // byte is alias to uint8 - *d = uint8(n) + *d = uint8(n) // #nosec G115 case *uint: - *d = uint(n) + *d = uint(n) // #nosec G115 } return b } -func (b *ValueBinder) uintsValue(sourceParam string, dest interface{}, valueMustExist bool) *ValueBinder { +func (b *ValueBinder) uintsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -785,7 +781,7 @@ func (b *ValueBinder) uintsValue(sourceParam string, dest interface{}, valueMust return b.uints(sourceParam, values, dest) } -func (b *ValueBinder) uints(sourceParam string, values []string, dest interface{}) *ValueBinder { +func (b *ValueBinder) uints(sourceParam string, values []string, dest any) *ValueBinder { switch d := dest.(type) { case *[]uint64: tmp := make([]uint64, len(values)) @@ -991,7 +987,7 @@ func (b *ValueBinder) MustFloat32(sourceParam string, dest *float32) *ValueBinde return b.floatValue(sourceParam, dest, 32, true) } -func (b *ValueBinder) floatValue(sourceParam string, dest interface{}, bitSize int, valueMustExist bool) *ValueBinder { +func (b *ValueBinder) floatValue(sourceParam string, dest any, bitSize int, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -1007,7 +1003,7 @@ func (b *ValueBinder) floatValue(sourceParam string, dest interface{}, bitSize i return b.float(sourceParam, value, dest, bitSize) } -func (b *ValueBinder) float(sourceParam string, value string, dest interface{}, bitSize int) *ValueBinder { +func (b *ValueBinder) float(sourceParam string, value string, dest any, bitSize int) *ValueBinder { n, err := strconv.ParseFloat(value, bitSize) if err != nil { b.setError(b.ErrorFunc(sourceParam, []string{value}, fmt.Sprintf("failed to bind field value to float%v", bitSize), err)) @@ -1023,7 +1019,7 @@ func (b *ValueBinder) float(sourceParam string, value string, dest interface{}, return b } -func (b *ValueBinder) floatsValue(sourceParam string, dest interface{}, valueMustExist bool) *ValueBinder { +func (b *ValueBinder) floatsValue(sourceParam string, dest any, valueMustExist bool) *ValueBinder { if b.failFast && b.errors != nil { return b } @@ -1038,7 +1034,7 @@ func (b *ValueBinder) floatsValue(sourceParam string, dest interface{}, valueMus return b.floats(sourceParam, values, dest) } -func (b *ValueBinder) floats(sourceParam string, values []string, dest interface{}) *ValueBinder { +func (b *ValueBinder) floats(sourceParam string, values []string, dest any) *ValueBinder { switch d := dest.(type) { case *[]float64: tmp := make([]float64, len(values)) diff --git a/binder_external_test.go b/binder_external_test.go index e44055a23..d83c891b3 100644 --- a/binder_external_test.go +++ b/binder_external_test.go @@ -7,18 +7,19 @@ package echo_test import ( "encoding/base64" "fmt" - "github.com/labstack/echo/v4" "log" "net/http" "net/http/httptest" + + "github.com/labstack/echo/v5" ) func ExampleValueBinder_BindErrors() { // example route function that binds query params to different destinations and returns all bind errors in one go - routeFunc := func(c echo.Context) error { + routeFunc := func(c *echo.Context) error { var opts struct { - Active bool IDs []int64 + Active bool } length := int64(50) // default length is 50 @@ -53,10 +54,10 @@ func ExampleValueBinder_BindErrors() { func ExampleValueBinder_BindError() { // example route function that binds query params to different destinations and stops binding on first bind error - failFastRouteFunc := func(c echo.Context) error { + failFastRouteFunc := func(c *echo.Context) error { var opts struct { - Active bool IDs []int64 + Active bool } length := int64(50) // default length is 50 @@ -89,7 +90,7 @@ func ExampleValueBinder_BindError() { func ExampleValueBinder_CustomFunc() { // example route function that binds query params using custom function closure - routeFunc := func(c echo.Context) error { + routeFunc := func(c *echo.Context) error { length := int64(50) // default length is 50 var binary []byte diff --git a/binder_generic.go b/binder_generic.go index f4d45af76..0c0eb9089 100644 --- a/binder_generic.go +++ b/binder_generic.go @@ -49,20 +49,18 @@ const ( // It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. // // Empty String Handling: -// -// If the parameter exists but has an empty value, the zero value of type T is returned -// with no error. For example, a path parameter with value "" returns (0, nil) for int types. -// This differs from standard library behavior where parsing empty strings returns errors. -// To treat empty values as errors, validate the result separately or check the raw value. +// If the parameter exists but has an empty value, the zero value of type T is returned +// with no error. For example, a path parameter with value "" returns (0, nil) for int types. +// This differs from standard library behavior where parsing empty strings returns errors. +// To treat empty values as errors, validate the result separately or check the raw value. // // See ParseValue for supported types and options -func PathParam[T any](c Context, paramName string, opts ...any) (T, error) { - for i, name := range c.ParamNames() { - if name == paramName { - pValues := c.ParamValues() - v, err := ParseValue[T](pValues[i], opts...) +func PathParam[T any](c *Context, paramName string, opts ...any) (T, error) { + for _, pv := range c.PathValues() { + if pv.Name == paramName { + v, err := ParseValue[T](pv.Value, opts...) if err != nil { - return v, NewBindingError(paramName, []string{pValues[i]}, "path param", err) + return v, NewBindingError(paramName, []string{pv.Value}, "path value", err) } return v, nil } @@ -76,20 +74,18 @@ func PathParam[T any](c Context, paramName string, opts ...any) (T, error) { // Returns an error only if parsing fails (e.g., "abc" for int type). // // Example: -// -// id, err := echo.PathParamOr[int](c, "id", 0) -// // If "id" is missing: returns (0, nil) -// // If "id" is "123": returns (123, nil) -// // If "id" is "abc": returns (0, BindingError) +// id, err := echo.PathParamOr[int](c, "id", 0) +// // If "id" is missing: returns (0, nil) +// // If "id" is "123": returns (123, nil) +// // If "id" is "abc": returns (0, BindingError) // // See ParseValue for supported types and options -func PathParamOr[T any](c Context, paramName string, defaultValue T, opts ...any) (T, error) { - for i, name := range c.ParamNames() { - if name == paramName { - pValues := c.ParamValues() - v, err := ParseValueOr[T](pValues[i], defaultValue, opts...) +func PathParamOr[T any](c *Context, paramName string, defaultValue T, opts ...any) (T, error) { + for _, pv := range c.PathValues() { + if pv.Name == paramName { + v, err := ParseValueOr[T](pv.Value, defaultValue, opts...) if err != nil { - return v, NewBindingError(paramName, []string{pValues[i]}, "path param", err) + return v, NewBindingError(paramName, []string{pv.Value}, "path value", err) } return v, nil } @@ -101,11 +97,10 @@ func PathParamOr[T any](c Context, paramName string, defaultValue T, opts ...any // It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. // // Empty String Handling: -// -// If the parameter exists but has an empty value (?key=), the zero value of type T is returned -// with no error. For example, "?count=" returns (0, nil) for int types. -// This differs from standard library behavior where parsing empty strings returns errors. -// To treat empty values as errors, validate the result separately or check the raw value. +// If the parameter exists but has an empty value (?key=), the zero value of type T is returned +// with no error. For example, "?count=" returns (0, nil) for int types. +// This differs from standard library behavior where parsing empty strings returns errors. +// To treat empty values as errors, validate the result separately or check the raw value. // // Behavior Summary: // - Missing key (?other=value): returns (zero, ErrNonExistentKey) @@ -113,7 +108,7 @@ func PathParamOr[T any](c Context, paramName string, defaultValue T, opts ...any // - Invalid value (?key=abc for int): returns (zero, BindingError) // // See ParseValue for supported types and options -func QueryParam[T any](c Context, key string, opts ...any) (T, error) { +func QueryParam[T any](c *Context, key string, opts ...any) (T, error) { values, ok := c.QueryParams()[key] if !ok { var zero T @@ -136,14 +131,13 @@ func QueryParam[T any](c Context, key string, opts ...any) (T, error) { // Returns an error only if parsing fails (e.g., "abc" for int type). // // Example: -// -// page, err := echo.QueryParamOr[int](c, "page", 1) -// // If "page" is missing: returns (1, nil) -// // If "page" is "5": returns (5, nil) -// // If "page" is "abc": returns (1, BindingError) +// page, err := echo.QueryParamOr[int](c, "page", 1) +// // If "page" is missing: returns (1, nil) +// // If "page" is "5": returns (5, nil) +// // If "page" is "abc": returns (1, BindingError) // // See ParseValue for supported types and options -func QueryParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) { +func QueryParamOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) { values, ok := c.QueryParams()[key] if !ok { return defaultValue, nil @@ -163,7 +157,7 @@ func QueryParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, // It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found. // // See ParseValues for supported types and options -func QueryParams[T any](c Context, key string, opts ...any) ([]T, error) { +func QueryParams[T any](c *Context, key string, opts ...any) ([]T, error) { values, ok := c.QueryParams()[key] if !ok { return nil, ErrNonExistentKey @@ -181,14 +175,13 @@ func QueryParams[T any](c Context, key string, opts ...any) ([]T, error) { // Returns an error only if parsing any value fails. // // Example: -// -// ids, err := echo.QueryParamsOr[int](c, "ids", []int{}) -// // If "ids" is missing: returns ([], nil) -// // If "ids" is "1&ids=2": returns ([1, 2], nil) -// // If "ids" contains "abc": returns ([], BindingError) +// ids, err := echo.QueryParamsOr[int](c, "ids", []int{}) +// // If "ids" is missing: returns ([], nil) +// // If "ids" is "1&ids=2": returns ([1, 2], nil) +// // If "ids" contains "abc": returns ([], BindingError) // // See ParseValues for supported types and options -func QueryParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) { +func QueryParamsOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) { values, ok := c.QueryParams()[key] if !ok { return defaultValue, nil @@ -201,22 +194,21 @@ func QueryParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) return result, nil } -// FormParam extracts and parses a single form value from the request by key. +// FormValue extracts and parses a single form value from the request by key. // It returns the typed value and an error if binding fails. Returns ErrNonExistentKey if parameter not found. // // Empty String Handling: -// -// If the form field exists but has an empty value, the zero value of type T is returned -// with no error. For example, an empty form field returns (0, nil) for int types. -// This differs from standard library behavior where parsing empty strings returns errors. -// To treat empty values as errors, validate the result separately or check the raw value. +// If the form field exists but has an empty value, the zero value of type T is returned +// with no error. For example, an empty form field returns (0, nil) for int types. +// This differs from standard library behavior where parsing empty strings returns errors. +// To treat empty values as errors, validate the result separately or check the raw value. // // See ParseValue for supported types and options -func FormParam[T any](c Context, key string, opts ...any) (T, error) { - formValues, err := c.FormParams() +func FormValue[T any](c *Context, key string, opts ...any) (T, error) { + formValues, err := c.FormValues() if err != nil { var zero T - return zero, fmt.Errorf("failed to parse form param, key: %s, err: %w", key, err) + return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { @@ -230,28 +222,27 @@ func FormParam[T any](c Context, key string, opts ...any) (T, error) { value := values[0] v, err := ParseValue[T](value, opts...) if err != nil { - return v, NewBindingError(key, []string{value}, "form param", err) + return v, NewBindingError(key, []string{value}, "form value", err) } return v, nil } -// FormParamOr extracts and parses a single form value from the request by key. +// FormValueOr extracts and parses a single form value from the request by key. // Returns defaultValue if the parameter is not found or has an empty value. // Returns an error only if parsing fails or form parsing errors occur. // // Example: -// -// limit, err := echo.FormValueOr[int](c, "limit", 100) -// // If "limit" is missing: returns (100, nil) -// // If "limit" is "50": returns (50, nil) -// // If "limit" is "abc": returns (100, BindingError) +// limit, err := echo.FormValueOr[int](c, "limit", 100) +// // If "limit" is missing: returns (100, nil) +// // If "limit" is "50": returns (50, nil) +// // If "limit" is "abc": returns (100, BindingError) // // See ParseValue for supported types and options -func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, error) { - formValues, err := c.FormParams() +func FormValueOr[T any](c *Context, key string, defaultValue T, opts ...any) (T, error) { + formValues, err := c.FormValues() if err != nil { var zero T - return zero, fmt.Errorf("failed to parse form param, key: %s, err: %w", key, err) + return zero, fmt.Errorf("failed to parse form value, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { @@ -263,19 +254,19 @@ func FormParamOr[T any](c Context, key string, defaultValue T, opts ...any) (T, value := values[0] v, err := ParseValueOr[T](value, defaultValue, opts...) if err != nil { - return v, NewBindingError(key, []string{value}, "form param", err) + return v, NewBindingError(key, []string{value}, "form value", err) } return v, nil } -// FormParams extracts and parses all values for a form values key as a slice. +// FormValues extracts and parses all values for a form values key as a slice. // It returns the typed slice and an error if binding any value fails. Returns ErrNonExistentKey if parameter not found. // // See ParseValues for supported types and options -func FormParams[T any](c Context, key string, opts ...any) ([]T, error) { - formValues, err := c.FormParams() +func FormValues[T any](c *Context, key string, opts ...any) ([]T, error) { + formValues, err := c.FormValues() if err != nil { - return nil, fmt.Errorf("failed to parse form params, key: %s, err: %w", key, err) + return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { @@ -283,26 +274,25 @@ func FormParams[T any](c Context, key string, opts ...any) ([]T, error) { } result, err := ParseValues[T](values, opts...) if err != nil { - return nil, NewBindingError(key, values, "form params", err) + return nil, NewBindingError(key, values, "form values", err) } return result, nil } -// FormParamsOr extracts and parses all values for a form values key as a slice. +// FormValuesOr extracts and parses all values for a form values key as a slice. // Returns defaultValue if the parameter is not found. // Returns an error only if parsing any value fails or form parsing errors occur. // // Example: -// -// tags, err := echo.FormParamsOr[string](c, "tags", []string{}) -// // If "tags" is missing: returns ([], nil) -// // If form parsing fails: returns (nil, error) +// tags, err := echo.FormValuesOr[string](c, "tags", []string{}) +// // If "tags" is missing: returns ([], nil) +// // If form parsing fails: returns (nil, error) // // See ParseValues for supported types and options -func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ([]T, error) { - formValues, err := c.FormParams() +func FormValuesOr[T any](c *Context, key string, defaultValue []T, opts ...any) ([]T, error) { + formValues, err := c.FormValues() if err != nil { - return nil, fmt.Errorf("failed to parse form params, key: %s, err: %w", key, err) + return nil, fmt.Errorf("failed to parse form values, key: %s, err: %w", key, err) } values, ok := formValues[key] if !ok { @@ -310,7 +300,7 @@ func FormParamsOr[T any](c Context, key string, defaultValue []T, opts ...any) ( } result, err := ParseValuesOr[T](values, defaultValue, opts...) if err != nil { - return nil, NewBindingError(key, values, "form params", err) + return nil, NewBindingError(key, values, "form values", err) } return result, nil } diff --git a/binder_generic_test.go b/binder_generic_test.go index 96dfc5ed8..849d75962 100644 --- a/binder_generic_test.go +++ b/binder_generic_test.go @@ -64,15 +64,16 @@ func TestPathParam(t *testing.T) { name: "nok, invalid value", givenValue: "can_parse_me", expect: false, - expectErr: `code=400, message=path param, internal=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`, + expectErr: `code=400, message=path value, err=failed to parse value, err: strconv.ParseBool: parsing "can_parse_me": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - c.SetParamNames(cmp.Or(tc.givenKey, "key")) - c.SetParamValues(tc.givenValue) + c := NewContext(nil, nil) + c.SetPathValues(PathValues{{ + Name: cmp.Or(tc.givenKey, "key"), + Value: tc.givenValue, + }}) v, err := PathParam[bool](c, "key") if tc.expectErr != "" { @@ -86,14 +87,12 @@ func TestPathParam(t *testing.T) { } func TestPathParam_UnsupportedType(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - c.SetParamNames("key") - c.SetParamValues("true") + c := NewContext(nil, nil) + c.SetPathValues(PathValues{{Name: "key", Value: "true"}}) v, err := PathParam[[]bool](c, "key") - expectErr := "code=400, message=path param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=path value, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } @@ -120,14 +119,13 @@ func TestQueryParam(t *testing.T) { name: "nok, invalid value", givenURL: "/?key=invalidbool", expect: false, - expectErr: `code=400, message=query param, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + expectErr: `code=400, message=query param, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) v, err := QueryParam[bool](c, "key") if tc.expectErr != "" { @@ -142,12 +140,11 @@ func TestQueryParam(t *testing.T) { func TestQueryParam_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) v, err := QueryParam[[]bool](c, "key") - expectErr := "code=400, message=query param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=query param, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } @@ -174,14 +171,13 @@ func TestQueryParams(t *testing.T) { name: "nok, invalid value", givenURL: "/?key=true&key=invalidbool", expect: []bool(nil), - expectErr: `code=400, message=query params, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + expectErr: `code=400, message=query params, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) v, err := QueryParams[bool](c, "key") if tc.expectErr != "" { @@ -196,12 +192,11 @@ func TestQueryParams(t *testing.T) { func TestQueryParams_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) v, err := QueryParams[[]bool](c, "key") - expectErr := "code=400, message=query params, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=query params, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, [][]bool(nil), v) } @@ -228,16 +223,15 @@ func TestFormValue(t *testing.T) { name: "nok, invalid value", givenURL: "/?key=invalidbool", expect: false, - expectErr: `code=400, message=form param, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + expectErr: `code=400, message=form value, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) - v, err := FormParam[bool](c, "key") + v, err := FormValue[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { @@ -250,12 +244,11 @@ func TestFormValue(t *testing.T) { func TestFormValue_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) - v, err := FormParam[[]bool](c, "key") + v, err := FormValue[[]bool](c, "key") - expectErr := "code=400, message=form param, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=form value, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, []bool(nil), v) } @@ -282,16 +275,15 @@ func TestFormValues(t *testing.T) { name: "nok, invalid value", givenURL: "/?key=true&key=invalidbool", expect: []bool(nil), - expectErr: `code=400, message=form params, internal=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, + expectErr: `code=400, message=form values, err=failed to parse value, err: strconv.ParseBool: parsing "invalidbool": invalid syntax, field=key`, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) - v, err := FormParams[bool](c, "key") + v, err := FormValues[bool](c, "key") if tc.expectErr != "" { assert.EqualError(t, err, tc.expectErr) } else { @@ -304,12 +296,11 @@ func TestFormValues(t *testing.T) { func TestFormValues_UnsupportedType(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/?key=bool", nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) - v, err := FormParams[[]bool](c, "key") + v, err := FormValues[[]bool](c, "key") - expectErr := "code=400, message=form params, internal=failed to parse value, err: unsupported value type: *[]bool, field=key" + expectErr := "code=400, message=form values, err=failed to parse value, err: unsupported value type: *[]bool, field=key" assert.EqualError(t, err, expectErr) assert.Equal(t, [][]bool(nil), v) } @@ -1433,15 +1424,13 @@ func TestPathParamOr(t *testing.T) { givenKey: "id", givenValue: "invalid", defaultValue: 999, - expectErr: "code=400, message=path param, internal=failed to parse value", + expectErr: "code=400, message=path value, err=failed to parse value", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - c.SetParamNames(tc.givenKey) - c.SetParamValues(tc.givenValue) + c := NewContext(nil, nil) + c.SetPathValues(PathValues{{Name: tc.givenKey, Value: tc.givenValue}}) v, err := PathParamOr[int](c, "id", tc.defaultValue) if tc.expectErr != "" { @@ -1490,8 +1479,7 @@ func TestQueryParamOr(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) v, err := QueryParamOr[int](c, "key", tc.defaultValue) if tc.expectErr != "" { @@ -1534,8 +1522,7 @@ func TestQueryParamsOr(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) v, err := QueryParamsOr[int](c, "key", tc.defaultValue) if tc.expectErr != "" { @@ -1578,10 +1565,9 @@ func TestFormValueOr(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) - v, err := FormParamOr[string](c, "name", tc.defaultValue) + v, err := FormValueOr[string](c, "name", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { @@ -1616,10 +1602,9 @@ func TestFormValuesOr(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest(http.MethodPost, tc.givenURL, nil) - e := New() - c := e.NewContext(req, nil) + c := NewContext(req, nil) - v, err := FormParamsOr[string](c, "tags", tc.defaultValue) + v, err := FormValuesOr[string](c, "tags", tc.defaultValue) if tc.expectErr != "" { assert.ErrorContains(t, err, tc.expectErr) } else { diff --git a/binder_test.go b/binder_test.go index d552b604d..8eced8208 100644 --- a/binder_test.go +++ b/binder_test.go @@ -18,7 +18,7 @@ import ( "time" ) -func createTestContext(URL string, body io.Reader, pathParams map[string]string) Context { +func createTestContext(URL string, body io.Reader, pathValues map[string]string) *Context { e := New() req := httptest.NewRequest(http.MethodGet, URL, body) if body != nil { @@ -27,15 +27,15 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if len(pathParams) > 0 { - names := make([]string, 0) - values := make([]string, 0) - for name, value := range pathParams { - names = append(names, name) - values = append(values, value) + if len(pathValues) > 0 { + params := make(PathValues, 0) + for name, value := range pathValues { + params = append(params, PathValue{ + Name: name, + Value: value, + }) } - c.SetParamNames(names...) - c.SetParamValues(values...) + c.SetPathValues(params) } return c @@ -43,12 +43,12 @@ func createTestContext(URL string, body io.Reader, pathParams map[string]string) func TestBindingError_Error(t *testing.T) { err := NewBindingError("id", []string{"1", "nope"}, "bind failed", errors.New("internal error")) - assert.EqualError(t, err, `code=400, message=bind failed, internal=internal error, field=id`) + assert.EqualError(t, err, `code=400, message=bind failed, err=internal error, field=id`) bErr := err.(*BindingError) assert.Equal(t, 400, bErr.Code) assert.Equal(t, "bind failed", bErr.Message) - assert.Equal(t, errors.New("internal error"), bErr.Internal) + assert.Equal(t, errors.New("internal error"), bErr.err) assert.Equal(t, "id", bErr.Field) assert.Equal(t, []string{"1", "nope"}, bErr.Values) @@ -62,13 +62,13 @@ func TestBindingError_ErrorJSON(t *testing.T) { assert.Equal(t, `{"field":"id","message":"bind failed"}`, string(resp)) } -func TestPathParamsBinder(t *testing.T) { +func TestPathValuesBinder(t *testing.T) { c := createTestContext("/api/user/999", nil, map[string]string{ "id": "1", "nr": "2", "slice": "3", }) - b := PathParamsBinder(c) + b := PathValuesBinder(c) id := int64(99) nr := int64(88) @@ -91,15 +91,15 @@ func TestQueryParamsBinder_FailFast(t *testing.T) { var testCases = []struct { name string whenURL string - givenFailFast bool expectError []string + givenFailFast bool }{ { name: "ok, FailFast=true stops at first error", whenURL: "/api/user/999?nr=en&id=nope", givenFailFast: true, expectError: []string{ - `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, + `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, }, }, { @@ -107,8 +107,8 @@ func TestQueryParamsBinder_FailFast(t *testing.T) { whenURL: "/api/user/999?nr=en&id=nope", givenFailFast: false, expectError: []string{ - `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, - `code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing "en": invalid syntax, field=nr`, + `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "nope": invalid syntax, field=id`, + `code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing "en": invalid syntax, field=nr`, }, }, } @@ -165,7 +165,7 @@ func TestFormFieldBinder(t *testing.T) { } func TestValueBinder_errorStopsBinding(t *testing.T) { - // this test documents "feature" that binding multiple params can change destination if it was binded before + // this test documents "feature" that binding multiple params can change destination if it was bound before // failing parameter binding c := createTestContext("/api/user/999?id=1&nr=nope", nil, nil) @@ -177,7 +177,7 @@ func TestValueBinder_errorStopsBinding(t *testing.T) { Int64("nr", &nr). BindError() - assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr") + assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=nr") assert.Equal(t, int64(1), id) assert.Equal(t, int64(88), nr) } @@ -192,17 +192,17 @@ func TestValueBinder_BindError(t *testing.T) { Int64("nr", &nr). BindError() - assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id") + assert.EqualError(t, err, "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=id") assert.Nil(t, b.errors) assert.Nil(t, b.BindError()) } func TestValueBinder_GetValues(t *testing.T) { var testCases = []struct { - name string whenValuesFunc func(sourceParam string) []string - expect []int64 + name string expectError string + expect []int64 }{ { name: "ok, default implementation", @@ -266,13 +266,13 @@ func TestValueBinder_CustomFuncWithError(t *testing.T) { func TestValueBinder_CustomFunc(t *testing.T) { var testCases = []struct { + expectValue any name string - givenFailFast bool - givenFuncErrors []error whenURL string + givenFuncErrors []error expectParamValues []string - expectValue interface{} expectErrors []string + givenFailFast bool }{ { name: "ok, binds value", @@ -341,13 +341,13 @@ func TestValueBinder_CustomFunc(t *testing.T) { func TestValueBinder_MustCustomFunc(t *testing.T) { var testCases = []struct { + expectValue any name string - givenFailFast bool - givenFuncErrors []error whenURL string + givenFuncErrors []error expectParamValues []string - expectValue interface{} expectErrors []string + givenFailFast bool }{ { name: "ok, binds value", @@ -418,12 +418,12 @@ func TestValueBinder_MustCustomFunc(t *testing.T) { func TestValueBinder_String(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool expectValue string expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -494,12 +494,12 @@ func TestValueBinder_String(t *testing.T) { func TestValueBinder_Strings(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []string expectError string + givenBindErrors []error + expectValue []string + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -570,12 +570,12 @@ func TestValueBinder_Strings(t *testing.T) { func TestValueBinder_Int64_intValue(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue int64 expectError string + givenBindErrors []error + expectValue int64 + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -598,7 +598,7 @@ func TestValueBinder_Int64_intValue(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 99, - expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -626,7 +626,7 @@ func TestValueBinder_Int64_intValue(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 99, - expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -667,19 +667,19 @@ func TestValueBinder_Int_errorMessage(t *testing.T) { assert.Equal(t, 99, destInt) assert.Equal(t, uint(98), destUint) - assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, internal=strconv.ParseInt: parsing "nope": invalid syntax, field=param`) - assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, internal=strconv.ParseUint: parsing "nope": invalid syntax, field=param`) + assert.EqualError(t, errs[0], `code=400, message=failed to bind field value to int, err=strconv.ParseInt: parsing "nope": invalid syntax, field=param`) + assert.EqualError(t, errs[1], `code=400, message=failed to bind field value to uint, err=strconv.ParseUint: parsing "nope": invalid syntax, field=param`) } func TestValueBinder_Uint64_uintValue(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue uint64 expectError string + givenBindErrors []error + expectValue uint64 + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -702,7 +702,7 @@ func TestValueBinder_Uint64_uintValue(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 99, - expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -730,7 +730,7 @@ func TestValueBinder_Uint64_uintValue(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 99, - expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, } @@ -881,12 +881,12 @@ func TestValueBinder_Int_Types(t *testing.T) { func TestValueBinder_Int64s_intsValue(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []int64 expectError string + givenBindErrors []error + expectValue []int64 + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -909,7 +909,7 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []int64{99}, - expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -937,7 +937,7 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []int64{99}, - expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -970,12 +970,12 @@ func TestValueBinder_Int64s_intsValue(t *testing.T) { func TestValueBinder_Uint64s_uintsValue(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []uint64 expectError string + givenBindErrors []error + expectValue []uint64 + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -998,7 +998,7 @@ func TestValueBinder_Uint64s_uintsValue(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []uint64{99}, - expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1026,7 +1026,7 @@ func TestValueBinder_Uint64s_uintsValue(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []uint64{99}, - expectError: "code=400, message=failed to bind field value to uint64, internal=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to uint64, err=strconv.ParseUint: parsing \"nope\": invalid syntax, field=param", }, } @@ -1169,7 +1169,7 @@ func TestValueBinder_Ints_Types(t *testing.T) { func TestValueBinder_Ints_Types_FailFast(t *testing.T) { // FailFast() should stop parsing and return early - errTmpl := "code=400, message=failed to bind field value to %v, internal=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param" + errTmpl := "code=400, message=failed to bind field value to %v, err=strconv.Parse%v: parsing \"nope\": invalid syntax, field=param" c := createTestContext("/search?param=1¶m=nope¶m=2", nil, nil) var dest64 []int64 @@ -1226,12 +1226,12 @@ func TestValueBinder_Ints_Types_FailFast(t *testing.T) { func TestValueBinder_Bool(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string + expectError string + givenBindErrors []error + givenFailFast bool whenMust bool expectValue bool - expectError string }{ { name: "ok, binds value", @@ -1254,7 +1254,7 @@ func TestValueBinder_Bool(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: false, - expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1282,7 +1282,7 @@ func TestValueBinder_Bool(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: false, - expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, } @@ -1315,12 +1315,12 @@ func TestValueBinder_Bool(t *testing.T) { func TestValueBinder_Bools(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []bool expectError string + givenBindErrors []error + expectValue []bool + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -1344,14 +1344,14 @@ func TestValueBinder_Bools(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=true¶m=nope¶m=100", expectValue: []bool(nil), - expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=true¶m=nope¶m=100", expectValue: []bool(nil), - expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1380,7 +1380,7 @@ func TestValueBinder_Bools(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []bool(nil), - expectError: "code=400, message=failed to bind field value to bool, internal=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to bool, err=strconv.ParseBool: parsing \"nope\": invalid syntax, field=param", }, } @@ -1411,12 +1411,12 @@ func TestValueBinder_Bools(t *testing.T) { func TestValueBinder_Float64(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue float64 expectError string + givenBindErrors []error + expectValue float64 + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -1439,7 +1439,7 @@ func TestValueBinder_Float64(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 1.123, - expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1467,7 +1467,7 @@ func TestValueBinder_Float64(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 1.123, - expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } @@ -1500,12 +1500,12 @@ func TestValueBinder_Float64(t *testing.T) { func TestValueBinder_Float64s(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []float64 expectError string + givenBindErrors []error + expectValue []float64 + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -1529,14 +1529,14 @@ func TestValueBinder_Float64s(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []float64(nil), - expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=0¶m=nope¶m=100", expectValue: []float64(nil), - expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1565,7 +1565,7 @@ func TestValueBinder_Float64s(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []float64(nil), - expectError: "code=400, message=failed to bind field value to float64, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float64, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } @@ -1596,12 +1596,12 @@ func TestValueBinder_Float64s(t *testing.T) { func TestValueBinder_Float32(t *testing.T) { var testCases = []struct { name string - givenNoFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue float32 expectError string + givenBindErrors []error + expectValue float32 + givenNoFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -1624,7 +1624,7 @@ func TestValueBinder_Float32(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 1.123, - expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1652,7 +1652,7 @@ func TestValueBinder_Float32(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 1.123, - expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } @@ -1685,12 +1685,12 @@ func TestValueBinder_Float32(t *testing.T) { func TestValueBinder_Float32s(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []float32 expectError string + givenBindErrors []error + expectValue []float32 + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -1714,14 +1714,14 @@ func TestValueBinder_Float32s(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []float32(nil), - expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "nok, conversion fails fast, value is not changed", givenFailFast: true, whenURL: "/search?param=0¶m=nope¶m=100", expectValue: []float32(nil), - expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -1750,7 +1750,7 @@ func TestValueBinder_Float32s(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []float32(nil), - expectError: "code=400, message=failed to bind field value to float32, internal=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to float32, err=strconv.ParseFloat: parsing \"nope\": invalid syntax, field=param", }, } @@ -1781,14 +1781,14 @@ func TestValueBinder_Float32s(t *testing.T) { func TestValueBinder_Time(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") var testCases = []struct { + expectValue time.Time name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool whenLayout string - expectValue time.Time expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -1863,13 +1863,13 @@ func TestValueBinder_Times(t *testing.T) { exampleTime2, _ := time.Parse(time.RFC3339, "2000-01-02T09:45:31+00:00") var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool whenLayout string - expectValue []time.Time expectError string + givenBindErrors []error + expectValue []time.Time + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -1948,12 +1948,12 @@ func TestValueBinder_Duration(t *testing.T) { example := 42 * time.Second var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue time.Duration expectError string + givenBindErrors []error + expectValue time.Duration + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -2026,12 +2026,12 @@ func TestValueBinder_Durations(t *testing.T) { exampleDuration2 := 1 * time.Millisecond var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []time.Duration expectError string + givenBindErrors []error + expectValue []time.Duration + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -2103,13 +2103,13 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-23T09:45:31+02:00") var testCases = []struct { + expectValue Timestamp name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue Timestamp expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -2132,7 +2132,7 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: Timestamp{}, - expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, { name: "ok (must), binds value", @@ -2160,7 +2160,7 @@ func TestValueBinder_BindUnmarshaler(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: Timestamp{}, - expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to BindUnmarshaler interface, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, } @@ -2195,12 +2195,12 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue big.Int expectError string + expectValue big.Int + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -2223,7 +2223,7 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, - expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, { name: "ok (must), binds value", @@ -2251,7 +2251,7 @@ func TestValueBinder_JSONUnmarshaler(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, - expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + expectError: "code=400, message=failed to bind field value to json.Unmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, } @@ -2286,12 +2286,12 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue big.Int expectError string + expectValue big.Int + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -2314,7 +2314,7 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, - expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, { name: "ok (must), binds value", @@ -2342,7 +2342,7 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=xxx", expectValue: big.Int{}, - expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, internal=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", + expectError: "code=400, message=failed to bind field value to encoding.TextUnmarshaler interface, err=math/big: cannot unmarshal \"nope\" into a *big.Int, field=param", }, } @@ -2374,9 +2374,9 @@ func TestValueBinder_TextUnmarshaler(t *testing.T) { func TestValueBinder_BindWithDelimiter_types(t *testing.T) { var testCases = []struct { + expect any name string whenURL string - expect interface{} }{ { name: "ok, strings", @@ -2522,12 +2522,12 @@ func TestValueBinder_BindWithDelimiter_types(t *testing.T) { func TestValueBinder_BindWithDelimiter(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []int64 expectError string + givenBindErrors []error + expectValue []int64 + givenFailFast bool + whenMust bool }{ { name: "ok, binds value", @@ -2550,7 +2550,7 @@ func TestValueBinder_BindWithDelimiter(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []int64(nil), - expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -2578,7 +2578,7 @@ func TestValueBinder_BindWithDelimiter(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []int64(nil), - expectError: "code=400, message=failed to bind field value to int64, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to int64, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -2621,13 +2621,13 @@ func TestBindWithDelimiter_invalidType(t *testing.T) { func TestValueBinder_UnixTime(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339, "2020-12-28T18:36:43+00:00") // => 1609180603 var testCases = []struct { + expectValue time.Time name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue time.Time expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "ok, binds value, unix time in seconds", @@ -2655,7 +2655,7 @@ func TestValueBinder_UnixTime(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -2683,7 +2683,7 @@ func TestValueBinder_UnixTime(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -2717,13 +2717,13 @@ func TestValueBinder_UnixTime(t *testing.T) { func TestValueBinder_UnixTimeMilli(t *testing.T) { exampleTime, _ := time.Parse(time.RFC3339Nano, "2022-03-13T15:13:30.140000000+00:00") // => 1647184410140 var testCases = []struct { + expectValue time.Time name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue time.Time expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "ok, binds value, unix time in milliseconds", @@ -2746,7 +2746,7 @@ func TestValueBinder_UnixTimeMilli(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -2774,7 +2774,7 @@ func TestValueBinder_UnixTimeMilli(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -2810,13 +2810,13 @@ func TestValueBinder_UnixTimeNano(t *testing.T) { exampleTimeNano, _ := time.Parse(time.RFC3339Nano, "2020-12-28T18:36:43.123456789+00:00") // => 1609180603123456789 exampleTimeNanoBelowSec, _ := time.Parse(time.RFC3339Nano, "1970-01-01T00:00:00.999999999+00:00") var testCases = []struct { + expectValue time.Time name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue time.Time expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "ok, binds value, unix time in nano seconds (sec precision)", @@ -2849,7 +2849,7 @@ func TestValueBinder_UnixTimeNano(t *testing.T) { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, { name: "ok (must), binds value", @@ -2877,7 +2877,7 @@ func TestValueBinder_UnixTimeNano(t *testing.T) { whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", + expectError: "code=400, message=failed to bind field value to Time, err=strconv.ParseInt: parsing \"nope\": invalid syntax, field=param", }, } @@ -2919,7 +2919,7 @@ func BenchmarkDefaultBinder_BindInt64_single(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) } } @@ -2967,17 +2967,16 @@ func BenchmarkRawFunc_Int64_single(b *testing.B) { func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { type Opts struct { - Int64 int64 `query:"int64"` - Int32 int32 `query:"int32"` - Int16 int16 `query:"int16"` - Int8 int8 `query:"int8"` - String string `query:"string"` - + String string `query:"string"` + Strings []string `query:"strings"` + Int64 int64 `query:"int64"` Uint64 uint64 `query:"uint64"` + Int32 int32 `query:"int32"` Uint32 uint32 `query:"uint32"` + Int16 int16 `query:"int16"` Uint16 uint16 `query:"uint16"` + Int8 int8 `query:"int8"` Uint8 uint8 `query:"uint8"` - Strings []string `query:"strings"` } c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) @@ -2986,7 +2985,7 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { binder := new(DefaultBinder) for i := 0; i < b.N; i++ { var dest Opts - _ = binder.Bind(&dest, c) + _ = binder.Bind(c, &dest) if dest.Int64 != 1 { b.Fatalf("int64!=1") } @@ -2995,17 +2994,16 @@ func BenchmarkDefaultBinder_BindInt64_10_fields(b *testing.B) { func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { type Opts struct { - Int64 int64 `query:"int64"` - Int32 int32 `query:"int32"` - Int16 int16 `query:"int16"` - Int8 int8 `query:"int8"` - String string `query:"string"` - + String string `query:"string"` + Strings []string `query:"strings"` + Int64 int64 `query:"int64"` Uint64 uint64 `query:"uint64"` + Int32 int32 `query:"int32"` Uint32 uint32 `query:"uint32"` + Int16 int16 `query:"int16"` Uint16 uint16 `query:"uint16"` + Int8 int8 `query:"int8"` Uint8 uint8 `query:"uint8"` - Strings []string `query:"strings"` } c := createTestContext("/search?int64=1&int32=2&int16=3&int8=4&string=test&uint64=5&uint32=6&uint16=7&uint8=8&strings=first&strings=second", nil, nil) @@ -3034,27 +3032,27 @@ func BenchmarkValueBinder_BindInt64_10_fields(b *testing.B) { func TestValueBinder_TimeError(t *testing.T) { var testCases = []struct { + expectValue time.Time name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool whenLayout string - expectValue time.Time expectError string + givenBindErrors []error + givenFailFast bool + whenMust bool }{ { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: time.Time{}, - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\": extra text: \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\": extra text: \"nope\", field=param", }, } @@ -3087,33 +3085,33 @@ func TestValueBinder_TimeError(t *testing.T) { func TestValueBinder_TimesError(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool whenLayout string - expectValue []time.Time expectError string + givenBindErrors []error + expectValue []time.Time + givenFailFast bool + whenMust bool }{ { name: "nok, fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"1\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"1\" as \"2006\", field=param", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []time.Time(nil), - expectError: "code=400, message=failed to bind field value to Time, internal=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", + expectError: "code=400, message=failed to bind field value to Time, err=parsing time \"nope\" as \"2006-01-02T15:04:05Z07:00\": cannot parse \"nope\" as \"2006\", field=param", }, } @@ -3149,25 +3147,25 @@ func TestValueBinder_TimesError(t *testing.T) { func TestValueBinder_DurationError(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue time.Duration expectError string + givenBindErrors []error + expectValue time.Duration + givenFailFast bool + whenMust bool }{ { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: 0, - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", }, } @@ -3200,32 +3198,32 @@ func TestValueBinder_DurationError(t *testing.T) { func TestValueBinder_DurationsError(t *testing.T) { var testCases = []struct { name string - givenFailFast bool - givenBindErrors []error whenURL string - whenMust bool - expectValue []time.Duration expectError string + givenBindErrors []error + expectValue []time.Duration + givenFailFast bool + whenMust bool }{ { name: "nok, fail fast without binding value", givenFailFast: true, whenURL: "/search?param=1¶m=100", expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: missing unit in duration \"1\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, err=time: missing unit in duration \"1\", field=param", }, { name: "nok, conversion fails, value is not changed", whenURL: "/search?param=nope¶m=100", expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", }, { name: "nok (must), conversion fails, value is not changed", whenMust: true, whenURL: "/search?param=nope¶m=100", expectValue: []time.Duration(nil), - expectError: "code=400, message=failed to bind field value to Duration, internal=time: invalid duration \"nope\", field=param", + expectError: "code=400, message=failed to bind field value to Duration, err=time: invalid duration \"nope\", field=param", }, } diff --git a/context.go b/context.go index 67e83181c..6fb2091b8 100644 --- a/context.go +++ b/context.go @@ -6,273 +6,158 @@ package echo import ( "bytes" "encoding/xml" + "errors" "fmt" "io" + "io/fs" + "log/slog" "mime/multipart" "net" "net/http" "net/url" + "path/filepath" "strings" "sync" ) -// Context represents the context of the current HTTP request. It holds request and -// response objects, path, path parameters, data and registered handler. -type Context interface { - // Request returns `*http.Request`. - Request() *http.Request - - // SetRequest sets `*http.Request`. - SetRequest(r *http.Request) - - // SetResponse sets `*Response`. - SetResponse(r *Response) - - // Response returns `*Response`. - Response() *Response - - // IsTLS returns true if HTTP connection is TLS otherwise false. - IsTLS() bool - - // IsWebSocket returns true if HTTP connection is WebSocket otherwise false. - IsWebSocket() bool - - // Scheme returns the HTTP protocol scheme, `http` or `https`. - Scheme() string - - // RealIP returns the client's network address based on `X-Forwarded-For` - // or `X-Real-IP` request header. - // The behavior can be configured using `Echo#IPExtractor`. - RealIP() string - - // Path returns the registered path for the handler. - Path() string - - // SetPath sets the registered path for the handler. - SetPath(p string) - - // Param returns path parameter by name. - Param(name string) string - - // ParamNames returns path parameter names. - ParamNames() []string - - // SetParamNames sets path parameter names. - SetParamNames(names ...string) - - // ParamValues returns path parameter values. - ParamValues() []string - - // SetParamValues sets path parameter values. - SetParamValues(values ...string) - - // QueryParam returns the query param for the provided name. - QueryParam(name string) string - - // QueryParams returns the query parameters as `url.Values`. - QueryParams() url.Values - - // QueryString returns the URL query string. - QueryString() string - - // FormValue returns the form field value for the provided name. - FormValue(name string) string - - // FormParams returns the form parameters as `url.Values`. - FormParams() (url.Values, error) - - // FormFile returns the multipart form file for the provided name. - FormFile(name string) (*multipart.FileHeader, error) - - // MultipartForm returns the multipart form. - MultipartForm() (*multipart.Form, error) - - // Cookie returns the named cookie provided in the request. - Cookie(name string) (*http.Cookie, error) - - // SetCookie adds a `Set-Cookie` header in HTTP response. - SetCookie(cookie *http.Cookie) - - // Cookies returns the HTTP cookies sent with the request. - Cookies() []*http.Cookie - - // Get retrieves data from the context. - Get(key string) any - - // Set saves data in the context. - Set(key string, val any) - - // Bind binds path params, query params and the request body into provided type `i`. The default binder - // binds body based on Content-Type header. - Bind(i any) error - - // Validate validates provided `i`. It is usually called after `Context#Bind()`. - // Validator must be registered using `Echo#Validator`. - Validate(i any) error - - // Render renders a template with data and sends a text/html response with status - // code. Renderer must be registered using `Echo.Renderer`. - Render(code int, name string, data any) error - - // HTML sends an HTTP response with status code. - HTML(code int, html string) error - - // HTMLBlob sends an HTTP blob response with status code. - HTMLBlob(code int, b []byte) error - - // String sends a string response with status code. - String(code int, s string) error - - // JSON sends a JSON response with status code. - JSON(code int, i any) error - - // JSONPretty sends a pretty-print JSON with status code. - JSONPretty(code int, i any, indent string) error - - // JSONBlob sends a JSON blob response with status code. - JSONBlob(code int, b []byte) error - - // JSONP sends a JSONP response with status code. It uses `callback` to construct - // the JSONP payload. - JSONP(code int, callback string, i any) error - - // JSONPBlob sends a JSONP blob response with status code. It uses `callback` - // to construct the JSONP payload. - JSONPBlob(code int, callback string, b []byte) error - - // XML sends an XML response with status code. - XML(code int, i any) error - - // XMLPretty sends a pretty-print XML with status code. - XMLPretty(code int, i any, indent string) error - - // XMLBlob sends an XML blob response with status code. - XMLBlob(code int, b []byte) error - - // Blob sends a blob response with status code and content type. - Blob(code int, contentType string, b []byte) error - - // Stream sends a streaming response with status code and content type. - Stream(code int, contentType string, r io.Reader) error - - // File sends a response with the content of the file. - File(file string) error - - // Attachment sends a response as attachment, prompting client to save the - // file. - Attachment(file string, name string) error - - // Inline sends a response as inline, opening the file in the browser. - Inline(file string, name string) error - - // NoContent sends a response with no body and a status code. - NoContent(code int) error - - // Redirect redirects the request to a provided URL with status code. - Redirect(code int, url string) error - - // Error invokes the registered global HTTP error handler. Generally used by middleware. - // A side-effect of calling global error handler is that now Response has been committed (sent to the client) and - // middlewares up in chain can not change Response status code or Response body anymore. - // - // Avoid using this method in handlers as no middleware will be able to effectively handle errors after that. - Error(err error) - - // Handler returns the matched handler by router. - Handler() HandlerFunc - - // SetHandler sets the matched handler by router. - SetHandler(h HandlerFunc) - - // Logger returns the `Logger` instance. - Logger() Logger - - // SetLogger Set the logger - SetLogger(l Logger) +const ( + // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. + // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. + // It is added to context only when Router does not find matching method handler for request. + ContextKeyHeaderAllow = "echo_header_allow" +) - // Echo returns the `Echo` instance. - Echo() *Echo +const ( + // defaultMemory is default value for memory limit that is used when + // parsing multipart forms (See (*http.Request).ParseMultipartForm) + defaultMemory int64 = 32 << 20 // 32 MB + indexPage = "index.html" +) - // Reset resets the context after request completes. It must be called along - // with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. - // See `Echo#ServeHTTP()` - Reset(r *http.Request, w http.ResponseWriter) -} +// Context represents the context of the current HTTP request. It holds request and +// response objects, path, path parameters, data and registered handler. +type Context struct { + request *http.Request + orgResponse *Response + response http.ResponseWriter + query url.Values -type context struct { - logger Logger - request *http.Request - response *Response - query url.Values - echo *Echo + // formParseMaxMemory is used for http.Request.ParseMultipartForm + formParseMaxMemory int64 - store Map - lock sync.RWMutex + route *RouteInfo + pathValues *PathValues - // following fields are set by Router - handler HandlerFunc + store map[string]any + echo *Echo + logger *slog.Logger - // path is route path that Router matched. It is empty string where there is no route match. - // Route registered with RouteNotFound is considered as a match and path therefore is not empty. path string + lock sync.RWMutex +} + +// NewContext returns a new Context instance. +// +// Note: request,response and e can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway +// these arguments are useful when creating context for tests and cases like that. +func NewContext(r *http.Request, w http.ResponseWriter, opts ...any) *Context { + var e *Echo + for _, opt := range opts { + switch v := opt.(type) { + case *Echo: + e = v + } + } + return newContext(r, w, e) +} - // Usually echo.Echo is sizing pvalues but there could be user created middlewares that decide to - // overwrite parameter by calling SetParamNames + SetParamValues. - // When echo.Echo allocated that slice it length/capacity is tied to echo.Echo.maxParam value. - // - // It is important that pvalues size is always equal or bigger to pnames length. - pvalues []string +func newContext(r *http.Request, w http.ResponseWriter, e *Echo) *Context { + c := &Context{ + pathValues: nil, + store: make(map[string]any), + echo: e, + logger: nil, + } + var logger *slog.Logger + paramLen := int32(0) + formParseMaxMemory := defaultMemory + if e != nil { + paramLen = e.contextPathParamAllocSize.Load() + logger = e.Logger + formParseMaxMemory = e.formParseMaxMemory + } + if logger == nil { + logger = slog.Default() + } + c.logger = logger + p := make(PathValues, 0, paramLen) + c.pathValues = &p - // pnames length is tied to param count for the matched route - pnames []string + c.SetRequest(r) + c.orgResponse = NewResponse(w, logger) + c.response = c.orgResponse + c.formParseMaxMemory = formParseMaxMemory + return c } -const ( - // ContextKeyHeaderAllow is set by Router for getting value for `Allow` header in later stages of handler call chain. - // Allow header is mandatory for status 405 (method not found) and useful for OPTIONS method requests. - // It is added to context only when Router does not find matching method handler for request. - ContextKeyHeaderAllow = "echo_header_allow" -) +// Reset resets the context after request completes. It must be called along +// with `Echo#AcquireContext()` and `Echo#ReleaseContext()`. +// See `Echo#ServeHTTP()` +func (c *Context) Reset(r *http.Request, w http.ResponseWriter) { + c.request = r + c.orgResponse.reset(w) + c.response = c.orgResponse + c.query = nil + c.store = nil + c.logger = c.echo.Logger -const ( - defaultMemory = 32 << 20 // 32 MB - indexPage = "index.html" - defaultIndent = " " -) + c.route = nil + c.path = "" + // NOTE: empty by setting length to 0. PathValues has to have capacity of c.echo.contextPathParamAllocSize at all times + *c.pathValues = (*c.pathValues)[:0] +} -func (c *context) writeContentType(value string) { - header := c.Response().Header() +func (c *Context) writeContentType(value string) { + header := c.response.Header() if header.Get(HeaderContentType) == "" { header.Set(HeaderContentType, value) } } -func (c *context) Request() *http.Request { +// Request returns `*http.Request`. +func (c *Context) Request() *http.Request { return c.request } -func (c *context) SetRequest(r *http.Request) { +// SetRequest sets `*http.Request`. +func (c *Context) SetRequest(r *http.Request) { c.request = r } -func (c *context) Response() *Response { +// Response returns `*Response`. +func (c *Context) Response() http.ResponseWriter { return c.response } -func (c *context) SetResponse(r *Response) { +// SetResponse sets `*http.ResponseWriter`. Some middleware require that given ResponseWriter implements following +// method `Unwrap() http.ResponseWriter` which eventually should return echo.Response instance. +func (c *Context) SetResponse(r http.ResponseWriter) { c.response = r } -func (c *context) IsTLS() bool { +// IsTLS returns true if HTTP connection is TLS otherwise false. +func (c *Context) IsTLS() bool { return c.request.TLS != nil } -func (c *context) IsWebSocket() bool { +// IsWebSocket returns true if HTTP connection is WebSocket otherwise false. +func (c *Context) IsWebSocket() bool { upgrade := c.request.Header.Get(HeaderUpgrade) return strings.EqualFold(upgrade, "websocket") } -func (c *context) Scheme() string { +// Scheme returns the HTTP protocol scheme, `http` or `https`. +func (c *Context) Scheme() string { // Can't use `r.Request.URL.Scheme` // See: https://groups.google.com/forum/#!topic/golang-nuts/pMUkBlQBDF0 if c.IsTLS() { @@ -293,7 +178,10 @@ func (c *context) Scheme() string { return "http" } -func (c *context) RealIP() string { +// RealIP returns the client's network address based on `X-Forwarded-For` +// or `X-Real-IP` request header. +// The behavior can be configured using `Echo#IPExtractor`. +func (c *Context) RealIP() string { if c.echo != nil && c.echo.IPExtractor != nil { return c.echo.IPExtractor(c.request) } @@ -317,83 +205,134 @@ func (c *context) RealIP() string { return ra } -func (c *context) Path() string { +// Path returns the registered path for the handler. +func (c *Context) Path() string { return c.path } -func (c *context) SetPath(p string) { +// SetPath sets the registered path for the handler. +func (c *Context) SetPath(p string) { c.path = p } -func (c *context) Param(name string) string { - for i, n := range c.pnames { - if i < len(c.pvalues) { - if n == name { - return c.pvalues[i] - } - } +// RouteInfo returns current request route information. Method, Path, Name and params if they exist for matched route. +// +// RouteInfo returns generic "empty" struct for these cases: +// * Context is accessed before Routing is done. For example inside Pre middlewares (`e.Pre()`) +// * Router did not find matching route - 404 (route not found) +// * Router did not find matching route with same method - 405 (method not allowed) +func (c *Context) RouteInfo() RouteInfo { + if c.route != nil { + return c.route.Clone() } - return "" + return RouteInfo{} +} + +// Param returns path parameter by name. +func (c *Context) Param(name string) string { + return c.pathValues.GetOr(name, "") } -func (c *context) ParamNames() []string { - return c.pnames +// ParamOr returns the path parameter or default value for the provided name. +// +// Notes for DefaultRouter implementation: +// Path parameter could be empty for cases like that: +// * route `/release-:version/bin` and request URL is `/release-/bin` +// * route `/api/:version/image.jpg` and request URL is `/api//image.jpg` +// but not when path parameter is last part of route path +// * route `/download/file.:ext` will not match request `/download/file.` +func (c *Context) ParamOr(name, defaultValue string) string { + return c.pathValues.GetOr(name, defaultValue) } -func (c *context) SetParamNames(names ...string) { - c.pnames = names +// PathValues returns path parameter values. +func (c *Context) PathValues() PathValues { + return *c.pathValues +} - l := len(names) - if len(c.pvalues) < l { - // Keeping the old pvalues just for backward compatibility, but it sounds that doesn't make sense to keep them, - // probably those values will be overridden in a Context#SetParamValues - newPvalues := make([]string, l) - copy(newPvalues, c.pvalues) - c.pvalues = newPvalues +// SetPathValues sets path parameters for current request. +func (c *Context) SetPathValues(pathValues PathValues) { + if pathValues == nil { + panic("context SetPathValues called with nil PathValues") } + c.setPathValues(&pathValues) } -func (c *context) ParamValues() []string { - return c.pvalues[:len(c.pnames)] +// InitializeRoute sets the route related variables of this request to the context. +func (c *Context) InitializeRoute(ri *RouteInfo, pathValues *PathValues) { + c.route = ri + c.path = ri.Path + c.setPathValues(pathValues) } -func (c *context) SetParamValues(values ...string) { - // NOTE: Don't just set c.pvalues = values, because it has to have length c.echo.maxParam (or bigger) at all times - // It will break the Router#Find code - limit := len(values) - if limit > len(c.pvalues) { - c.pvalues = make([]string, limit) - } - for i := 0; i < limit; i++ { - c.pvalues[i] = values[i] +func (c *Context) setPathValues(pv *PathValues) { + // Router accesses c.pathValues by index and may resize it to full capacity during routing + // for that to work without going out-of-bounds we must make sure that c.pathValues slice is not replaced with smaller + // slice than Router can set when routing Route with maximum amount of parameters. + pathValues := c.pathValues + if cap(*c.pathValues) < len(*pv) { + // normally we should not end up here. pathValues is normally sized to Echo.contextPathParamAllocSize which should not + // be smaller than anything router knows as maximum path parameter count to be. + tmp := make(PathValues, len(*pv)) + c.pathValues = &tmp + pathValues = c.pathValues + } else if len(*c.pathValues) != len(*pv) { + *pathValues = (*pathValues)[0:len(*pv)] // resize slice to given params length for copy to work } + copy(*pathValues, *pv) } -func (c *context) QueryParam(name string) string { +// QueryParam returns the query param for the provided name. +func (c *Context) QueryParam(name string) string { if c.query == nil { c.query = c.request.URL.Query() } return c.query.Get(name) } -func (c *context) QueryParams() url.Values { +// QueryParamOr returns the query param or default value for the provided name. +// Note: QueryParamOr does not distinguish if query had no value by that name or value was empty string +// This means URLs `/test?search=` and `/test` would both return `1` for `c.QueryParamOr("search", "1")` +func (c *Context) QueryParamOr(name, defaultValue string) string { + value := c.QueryParam(name) + if value == "" { + value = defaultValue + } + return value +} + +// QueryParams returns the query parameters as `url.Values`. +func (c *Context) QueryParams() url.Values { if c.query == nil { c.query = c.request.URL.Query() } return c.query } -func (c *context) QueryString() string { +// QueryString returns the URL query string. +func (c *Context) QueryString() string { return c.request.URL.RawQuery } -func (c *context) FormValue(name string) string { +// FormValue returns the form field value for the provided name. +func (c *Context) FormValue(name string) string { return c.request.FormValue(name) } -func (c *context) FormParams() (url.Values, error) { +// FormValueOr returns the form field value or default value for the provided name. +// Note: FormValueOr does not distinguish if form had no value by that name or value was empty string +func (c *Context) FormValueOr(name, defaultValue string) string { + value := c.FormValue(name) + if value == "" { + value = defaultValue + } + return value +} + +// FormValues returns the form field values as `url.Values`. +func (c *Context) FormValues() (url.Values, error) { if strings.HasPrefix(c.request.Header.Get(HeaderContentType), MIMEMultipartForm) { - if err := c.request.ParseMultipartForm(defaultMemory); err != nil { + if err := c.request.ParseMultipartForm(c.formParseMaxMemory); err != nil { return nil, err } } else { @@ -404,93 +343,106 @@ func (c *context) FormParams() (url.Values, error) { return c.request.Form, nil } -func (c *context) FormFile(name string) (*multipart.FileHeader, error) { +// FormFile returns the multipart form file for the provided name. +func (c *Context) FormFile(name string) (*multipart.FileHeader, error) { f, fh, err := c.request.FormFile(name) if err != nil { return nil, err } - f.Close() + _ = f.Close() return fh, nil } -func (c *context) MultipartForm() (*multipart.Form, error) { - err := c.request.ParseMultipartForm(defaultMemory) +// MultipartForm returns the multipart form. +func (c *Context) MultipartForm() (*multipart.Form, error) { + err := c.request.ParseMultipartForm(c.formParseMaxMemory) return c.request.MultipartForm, err } -func (c *context) Cookie(name string) (*http.Cookie, error) { +// Cookie returns the named cookie provided in the request. +func (c *Context) Cookie(name string) (*http.Cookie, error) { return c.request.Cookie(name) } -func (c *context) SetCookie(cookie *http.Cookie) { +// SetCookie adds a `Set-Cookie` header in HTTP response. +func (c *Context) SetCookie(cookie *http.Cookie) { http.SetCookie(c.Response(), cookie) } -func (c *context) Cookies() []*http.Cookie { +// Cookies returns the HTTP cookies sent with the request. +func (c *Context) Cookies() []*http.Cookie { return c.request.Cookies() } -func (c *context) Get(key string) any { +// Get retrieves data from the context. +// Method returns any(nil) when key does not exist which is different from typed nil (eg. []byte(nil)). +func (c *Context) Get(key string) any { c.lock.RLock() defer c.lock.RUnlock() return c.store[key] } -func (c *context) Set(key string, val any) { +// Set saves data in the context. +func (c *Context) Set(key string, val any) { c.lock.Lock() defer c.lock.Unlock() if c.store == nil { - c.store = make(Map) + c.store = make(map[string]any) } c.store[key] = val } -func (c *context) Bind(i any) error { - return c.echo.Binder.Bind(i, c) +// Bind binds path params, query params and the request body into provided type `i`. The default binder +// binds body based on Content-Type header. +func (c *Context) Bind(i any) error { + return c.echo.Binder.Bind(c, i) } -func (c *context) Validate(i any) error { +// Validate validates provided `i`. It is usually called after `Context#Bind()`. +// Validator must be registered using `Echo#Validator`. +func (c *Context) Validate(i any) error { if c.echo.Validator == nil { return ErrValidatorNotRegistered } return c.echo.Validator.Validate(i) } -func (c *context) Render(code int, name string, data any) (err error) { +// Render renders a template with data and sends a text/html response with status +// code. Renderer must be registered using `Echo.Renderer`. +func (c *Context) Render(code int, name string, data any) (err error) { if c.echo.Renderer == nil { return ErrRendererNotRegistered } buf := new(bytes.Buffer) - if err = c.echo.Renderer.Render(buf, name, data, c); err != nil { + if err = c.echo.Renderer.Render(c, buf, name, data); err != nil { return } return c.HTMLBlob(code, buf.Bytes()) } -func (c *context) HTML(code int, html string) (err error) { +// HTML sends an HTTP response with status code. +func (c *Context) HTML(code int, html string) (err error) { return c.HTMLBlob(code, []byte(html)) } -func (c *context) HTMLBlob(code int, b []byte) (err error) { +// HTMLBlob sends an HTTP blob response with status code. +func (c *Context) HTMLBlob(code int, b []byte) (err error) { return c.Blob(code, MIMETextHTMLCharsetUTF8, b) } -func (c *context) String(code int, s string) (err error) { +// String sends a string response with status code. +func (c *Context) String(code int, s string) (err error) { return c.Blob(code, MIMETextPlainCharsetUTF8, []byte(s)) } -func (c *context) jsonPBlob(code int, callback string, i any) (err error) { - indent := "" - if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { - indent = defaultIndent - } +func (c *Context) jsonPBlob(code int, callback string, i any) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { return } - if err = c.echo.JSONSerializer.Serialize(c, i, indent); err != nil { + if err = c.echo.JSONSerializer.Serialize(c, i, ""); err != nil { return } if _, err = c.response.Write([]byte(");")); err != nil { @@ -499,33 +451,36 @@ func (c *context) jsonPBlob(code int, callback string, i any) (err error) { return } -func (c *context) json(code int, i any, indent string) error { +func (c *Context) json(code int, i any, indent string) error { c.writeContentType(MIMEApplicationJSON) - c.response.Status = code + c.response.WriteHeader(code) return c.echo.JSONSerializer.Serialize(c, i, indent) } -func (c *context) JSON(code int, i any) (err error) { - indent := "" - if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { - indent = defaultIndent - } - return c.json(code, i, indent) +// JSON sends a JSON response with status code. +func (c *Context) JSON(code int, i any) (err error) { + return c.json(code, i, "") } -func (c *context) JSONPretty(code int, i any, indent string) (err error) { +// JSONPretty sends a pretty-print JSON with status code. +func (c *Context) JSONPretty(code int, i any, indent string) (err error) { return c.json(code, i, indent) } -func (c *context) JSONBlob(code int, b []byte) (err error) { +// JSONBlob sends a JSON blob response with status code. +func (c *Context) JSONBlob(code int, b []byte) (err error) { return c.Blob(code, MIMEApplicationJSON, b) } -func (c *context) JSONP(code int, callback string, i any) (err error) { +// JSONP sends a JSONP response with status code. It uses `callback` to construct +// the JSONP payload. +func (c *Context) JSONP(code int, callback string, i any) (err error) { return c.jsonPBlob(code, callback, i) } -func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { +// JSONPBlob sends a JSONP blob response with status code. It uses `callback` +// to construct the JSONP payload. +func (c *Context) JSONPBlob(code int, callback string, b []byte) (err error) { c.writeContentType(MIMEApplicationJavaScriptCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(callback + "(")); err != nil { @@ -538,7 +493,7 @@ func (c *context) JSONPBlob(code int, callback string, b []byte) (err error) { return } -func (c *context) xml(code int, i any, indent string) (err error) { +func (c *Context) xml(code int, i any, indent string) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) enc := xml.NewEncoder(c.response) @@ -551,19 +506,18 @@ func (c *context) xml(code int, i any, indent string) (err error) { return enc.Encode(i) } -func (c *context) XML(code int, i any) (err error) { - indent := "" - if _, pretty := c.QueryParams()["pretty"]; c.echo.Debug || pretty { - indent = defaultIndent - } - return c.xml(code, i, indent) +// XML sends an XML response with status code. +func (c *Context) XML(code int, i any) (err error) { + return c.xml(code, i, "") } -func (c *context) XMLPretty(code int, i any, indent string) (err error) { +// XMLPretty sends a pretty-print XML with status code. +func (c *Context) XMLPretty(code int, i any, indent string) (err error) { return c.xml(code, i, indent) } -func (c *context) XMLBlob(code int, b []byte) (err error) { +// XMLBlob sends an XML blob response with status code. +func (c *Context) XMLBlob(code int, b []byte) (err error) { c.writeContentType(MIMEApplicationXMLCharsetUTF8) c.response.WriteHeader(code) if _, err = c.response.Write([]byte(xml.Header)); err != nil { @@ -573,41 +527,88 @@ func (c *context) XMLBlob(code int, b []byte) (err error) { return } -func (c *context) Blob(code int, contentType string, b []byte) (err error) { +// Blob sends a blob response with status code and content type. +func (c *Context) Blob(code int, contentType string, b []byte) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = c.response.Write(b) return } -func (c *context) Stream(code int, contentType string, r io.Reader) (err error) { +// Stream sends a streaming response with status code and content type. +func (c *Context) Stream(code int, contentType string, r io.Reader) (err error) { c.writeContentType(contentType) c.response.WriteHeader(code) _, err = io.Copy(c.response, r) return } -func (c *context) Attachment(file, name string) error { +// File sends a response with the content of the file. +func (c *Context) File(file string) error { + return fsFile(c, file, c.echo.Filesystem) +} + +// FileFS serves file from given file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (c *Context) FileFS(file string, filesystem fs.FS) error { + return fsFile(c, file, filesystem) +} + +func fsFile(c *Context, file string, filesystem fs.FS) error { + f, err := filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + + fi, _ := f.Stat() + if fi.IsDir() { + file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. + f, err = filesystem.Open(file) + if err != nil { + return ErrNotFound + } + defer f.Close() + if fi, err = f.Stat(); err != nil { + return err + } + } + ff, ok := f.(io.ReadSeeker) + if !ok { + return errors.New("file does not implement io.ReadSeeker") + } + http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) + return nil +} + +// Attachment sends a response as attachment, prompting client to save the file. +func (c *Context) Attachment(file, name string) error { return c.contentDisposition(file, name, "attachment") } -func (c *context) Inline(file, name string) error { +// Inline sends a response as inline, opening the file in the browser. +func (c *Context) Inline(file, name string) error { return c.contentDisposition(file, name, "inline") } var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") -func (c *context) contentDisposition(file, name, dispositionType string) error { +func (c *Context) contentDisposition(file, name, dispositionType string) error { c.response.Header().Set(HeaderContentDisposition, fmt.Sprintf(`%s; filename="%s"`, dispositionType, quoteEscaper.Replace(name))) return c.File(file) } -func (c *context) NoContent(code int) error { +// NoContent sends a response with no body and a status code. +func (c *Context) NoContent(code int) error { c.response.WriteHeader(code) return nil } -func (c *context) Redirect(code int, url string) error { +// Redirect redirects the request to a provided URL with status code. +func (c *Context) Redirect(code int, url string) error { if code < 300 || code > 308 { return ErrInvalidRedirectCode } @@ -616,45 +617,20 @@ func (c *context) Redirect(code int, url string) error { return nil } -func (c *context) Error(err error) { - c.echo.HTTPErrorHandler(err, c) -} - -func (c *context) Echo() *Echo { - return c.echo -} - -func (c *context) Handler() HandlerFunc { - return c.handler -} - -func (c *context) SetHandler(h HandlerFunc) { - c.handler = h -} - -func (c *context) Logger() Logger { - res := c.logger - if res != nil { - return res +// Logger returns logger in Context +func (c *Context) Logger() *slog.Logger { + if c.logger != nil { + return c.logger } return c.echo.Logger } -func (c *context) SetLogger(l Logger) { - c.logger = l +// SetLogger sets logger in Context +func (c *Context) SetLogger(logger *slog.Logger) { + c.logger = logger } -func (c *context) Reset(r *http.Request, w http.ResponseWriter) { - c.request = r - c.response.reset(w) - c.query = nil - c.handler = NotFoundHandler - c.store = nil - c.path = "" - c.pnames = nil - c.logger = nil - // NOTE: Don't reset because it has to have length c.echo.maxParam (or bigger) at all times - for i := 0; i < len(c.pvalues); i++ { - c.pvalues[i] = "" - } +// Echo returns the `Echo` instance. +func (c *Context) Echo() *Echo { + return c.echo } diff --git a/context_fs.go b/context_fs.go deleted file mode 100644 index 1c25baf12..000000000 --- a/context_fs.go +++ /dev/null @@ -1,52 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "errors" - "io" - "io/fs" - "net/http" - "path/filepath" -) - -func (c *context) File(file string) error { - return fsFile(c, file, c.echo.Filesystem) -} - -// FileFS serves file from given file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (c *context) FileFS(file string, filesystem fs.FS) error { - return fsFile(c, file, filesystem) -} - -func fsFile(c Context, file string, filesystem fs.FS) error { - f, err := filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - - fi, _ := f.Stat() - if fi.IsDir() { - file = filepath.ToSlash(filepath.Join(file, indexPage)) // ToSlash is necessary for Windows. fs.Open and os.Open are different in that aspect. - f, err = filesystem.Open(file) - if err != nil { - return ErrNotFound - } - defer f.Close() - if fi, err = f.Stat(); err != nil { - return err - } - } - ff, ok := f.(io.ReadSeeker) - if !ok { - return errors.New("file does not implement io.ReadSeeker") - } - http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), ff) - return nil -} diff --git a/context_fs_test.go b/context_fs_test.go deleted file mode 100644 index 83232ea45..000000000 --- a/context_fs_test.go +++ /dev/null @@ -1,135 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestContext_File(t *testing.T) { - var testCases = []struct { - name string - whenFile string - whenFS fs.FS - expectStatus int - expectStartsWith []byte - expectError string - }{ - { - name: "ok, from default file system", - whenFile: "_fixture/images/walle.png", - whenFS: nil, - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "ok, from custom file system", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "code=404, message=Not Found", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - if tc.whenFS != nil { - e.Filesystem = tc.whenFS - } - - handler := func(ec Context) error { - return ec.(*context).File(tc.whenFile) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestContext_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenFile string - whenFS fs.FS - expectStatus int - expectStartsWith []byte - expectError string - }{ - { - name: "ok", - whenFile: "walle.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, not existent file", - whenFile: "not.png", - whenFS: os.DirFS("_fixture/images"), - expectStatus: http.StatusOK, - expectStartsWith: nil, - expectError: "code=404, message=Not Found", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - handler := func(ec Context) error { - return ec.(*context).FileFS(tc.whenFile, tc.whenFS) - } - - req := httptest.NewRequest(http.MethodGet, "/match.png", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := handler(c) - - assert.Equal(t, tc.expectStatus, rec.Code) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} diff --git a/context_generic.go b/context_generic.go index f06041bbf..7cf8b296c 100644 --- a/context_generic.go +++ b/context_generic.go @@ -13,9 +13,12 @@ var ErrInvalidKeyType = errors.New("invalid key type") // ContextGet retrieves a value from the context store or ErrNonExistentKey error the key is missing. // Returns ErrInvalidKeyType error if the value is not castable to type T. -func ContextGet[T any](c Context, key string) (T, error) { - val := c.Get(key) - if val == any(nil) { +func ContextGet[T any](c *Context, key string) (T, error) { + c.lock.RLock() + defer c.lock.RUnlock() + + val, ok := c.store[key] + if !ok { var zero T return zero, ErrNonExistentKey } @@ -31,7 +34,7 @@ func ContextGet[T any](c Context, key string) (T, error) { // ContextGetOr retrieves a value from the context store or returns a default value when the key // is missing. Returns ErrInvalidKeyType error if the value is not castable to type T. -func ContextGetOr[T any](c Context, key string, defaultValue T) (T, error) { +func ContextGetOr[T any](c *Context, key string, defaultValue T) (T, error) { typed, err := ContextGet[T](c, key) if err == ErrNonExistentKey { return defaultValue, nil diff --git a/context_generic_test.go b/context_generic_test.go index 9b6d2d04e..ce468ac3e 100644 --- a/context_generic_test.go +++ b/context_generic_test.go @@ -10,8 +10,7 @@ import ( ) func TestContextGetOK(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) + c := NewContext(nil, nil) c.Set("key", int64(123)) @@ -21,8 +20,7 @@ func TestContextGetOK(t *testing.T) { } func TestContextGetNonExistentKey(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) + c := NewContext(nil, nil) c.Set("key", int64(123)) @@ -32,8 +30,7 @@ func TestContextGetNonExistentKey(t *testing.T) { } func TestContextGetInvalidCast(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) + c := NewContext(nil, nil) c.Set("key", int64(123)) @@ -43,8 +40,7 @@ func TestContextGetInvalidCast(t *testing.T) { } func TestContextGetOrOK(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) + c := NewContext(nil, nil) c.Set("key", int64(123)) @@ -54,8 +50,7 @@ func TestContextGetOrOK(t *testing.T) { } func TestContextGetOrNonExistentKey(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) + c := NewContext(nil, nil) c.Set("key", int64(123)) @@ -65,8 +60,7 @@ func TestContextGetOrNonExistentKey(t *testing.T) { } func TestContextGetOrInvalidCast(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) + c := NewContext(nil, nil) c.Set("key", int64(123)) diff --git a/context_test.go b/context_test.go index 1fd89edb4..1ac517cfc 100644 --- a/context_test.go +++ b/context_test.go @@ -8,20 +8,20 @@ import ( "crypto/tls" "encoding/json" "encoding/xml" - "errors" "fmt" "io" - "math" + "io/fs" + "log/slog" "mime/multipart" "net/http" "net/http/httptest" "net/url" + "os" "strings" "testing" "text/template" "time" - "github.com/labstack/gommon/log" "github.com/stretchr/testify/assert" ) @@ -29,13 +29,14 @@ type Template struct { templates *template.Template } -var testUser = user{1, "Jon Snow"} +var testUser = user{ID: 1, Name: "Jon Snow"} func BenchmarkAllocJSONP(b *testing.B) { e := New() + e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() @@ -47,9 +48,10 @@ func BenchmarkAllocJSONP(b *testing.B) { func BenchmarkAllocJSON(b *testing.B) { e := New() + e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() @@ -61,9 +63,10 @@ func BenchmarkAllocJSON(b *testing.B) { func BenchmarkAllocXML(b *testing.B) { e := New() + e.Logger = slog.New(slog.DiscardHandler) req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) b.ResetTimer() b.ReportAllocs() @@ -74,7 +77,7 @@ func BenchmarkAllocXML(b *testing.B) { } func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { - c := context{request: &http.Request{ + c := Context{request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }} for i := 0; i < b.N; i++ { @@ -82,7 +85,7 @@ func BenchmarkRealIPForHeaderXForwardFor(b *testing.B) { } } -func (t *Template) Render(w io.Writer, name string, data interface{}, c Context) error { +func (t *Template) Render(c *Context, w io.Writer, name string, data any) error { return t.templates.ExecuteTemplate(w, name, data) } @@ -91,7 +94,7 @@ func TestContextEcho(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) assert.Equal(t, e, c.Echo()) } @@ -101,7 +104,7 @@ func TestContextRequest(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) assert.NotNil(t, c.Request()) assert.Equal(t, req, c.Request()) @@ -112,7 +115,7 @@ func TestContextResponse(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) assert.NotNil(t, c.Response()) } @@ -122,12 +125,12 @@ func TestContextRenderTemplate(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) tmpl := &Template{ templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")), } - c.echo.Renderer = tmpl + c.Echo().Renderer = tmpl err := c.Render(http.StatusOK, "hello", "Jon Snow") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) @@ -140,57 +143,94 @@ func TestContextRenderErrorsOnNoRenderer(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) - c.echo.Renderer = nil + c.Echo().Renderer = nil assert.Error(t, c.Render(http.StatusOK, "hello", "Jon Snow")) } -func TestContextJSON(t *testing.T) { +func TestContextStream(t *testing.T) { e := New() rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) - c := e.NewContext(req, rec).(*context) + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) - err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) + r, w := io.Pipe() + go func() { + defer w.Close() + for i := 0; i < 3; i++ { + fmt.Fprintf(w, "data: index %v\n\n", i) + time.Sleep(5 * time.Millisecond) + } + }() + + err := c.Stream(http.StatusOK, "text/event-stream", r) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) - assert.Equal(t, userJSON+"\n", rec.Body.String()) + assert.Equal(t, "text/event-stream", rec.Header().Get(HeaderContentType)) + assert.Equal(t, "data: index 0\n\ndata: index 1\n\ndata: index 2\n\n", rec.Body.String()) } } -func TestContextJSONErrorsOut(t *testing.T) { +func TestContextHTML(t *testing.T) { e := New() rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) - c := e.NewContext(req, rec).(*context) + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) - err := c.JSON(http.StatusOK, make(chan bool)) - assert.EqualError(t, err, "json: unsupported type: chan bool") + err := c.HTML(http.StatusOK, "Hi, Jon Snow") + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hi, Jon Snow", rec.Body.String()) + } } -func TestContextJSONPrettyURL(t *testing.T) { +func TestContextHTMLBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, rec) - err := c.JSON(http.StatusOK, user{1, "Jon Snow"}) + err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow")) + if assert.NoError(t, err) { + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) + assert.Equal(t, "Hi, Jon Snow", rec.Body.String()) + } +} + +func TestContextJSON(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec) + + err := c.JSON(http.StatusOK, user{ID: 1, Name: "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) - assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } +func TestContextJSONErrorsOut(t *testing.T) { + e := New() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) + c := e.NewContext(req, rec) + + err := c.JSON(http.StatusOK, make(chan bool)) + assert.EqualError(t, err, "json: unsupported type: chan bool") +} + func TestContextJSONPretty(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) - err := c.JSONPretty(http.StatusOK, user{1, "Jon Snow"}, " ") + err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) @@ -202,16 +242,16 @@ func TestContextJSONWithEmptyIntent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) - u := user{1, "Jon Snow"} + u := user{ID: 1, Name: "Jon Snow"} emptyIndent := "" buf := new(bytes.Buffer) enc := json.NewEncoder(buf) enc.SetIndent(emptyIndent, emptyIndent) _ = enc.Encode(u) - err := c.json(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) + err := c.JSONPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) @@ -223,10 +263,10 @@ func TestContextJSONP(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) callback := "callback" - err := c.JSONP(http.StatusOK, callback, user{1, "Jon Snow"}) + err := c.JSONP(http.StatusOK, callback, user{ID: 1, Name: "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationJavaScriptCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -238,9 +278,9 @@ func TestContextJSONBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) - data, err := json.Marshal(user{1, "Jon Snow"}) + data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"}) assert.NoError(t, err) err = c.JSONBlob(http.StatusOK, data) if assert.NoError(t, err) { @@ -254,10 +294,10 @@ func TestContextJSONPBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) callback := "callback" - data, err := json.Marshal(user{1, "Jon Snow"}) + data, err := json.Marshal(user{ID: 1, Name: "Jon Snow"}) assert.NoError(t, err) err = c.JSONPBlob(http.StatusOK, callback, data) if assert.NoError(t, err) { @@ -271,9 +311,9 @@ func TestContextXML(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) - err := c.XML(http.StatusOK, user{1, "Jon Snow"}) + err := c.XML(http.StatusOK, user{ID: 1, Name: "Jon Snow"}) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -281,27 +321,13 @@ func TestContextXML(t *testing.T) { } } -func TestContextXMLPrettyURL(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - - err := c.XML(http.StatusOK, user{1, "Jon Snow"}) - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(t, xml.Header+userXMLPretty, rec.Body.String()) - } -} - func TestContextXMLPretty(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) - err := c.XMLPretty(http.StatusOK, user{1, "Jon Snow"}, " ") + err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -313,9 +339,9 @@ func TestContextXMLBlob(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) - data, err := xml.Marshal(user{1, "Jon Snow"}) + data, err := xml.Marshal(user{ID: 1, Name: "Jon Snow"}) assert.NoError(t, err) err = c.XMLBlob(http.StatusOK, data) if assert.NoError(t, err) { @@ -329,16 +355,16 @@ func TestContextXMLWithEmptyIntent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) - u := user{1, "Jon Snow"} + u := user{ID: 1, Name: "Jon Snow"} emptyIndent := "" buf := new(bytes.Buffer) enc := xml.NewEncoder(buf) enc.Indent(emptyIndent, emptyIndent) _ = enc.Encode(u) - err := c.xml(http.StatusOK, user{1, "Jon Snow"}, emptyIndent) + err := c.XMLPretty(http.StatusOK, user{ID: 1, Name: "Jon Snow"}, emptyIndent) if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, MIMEApplicationXMLCharsetUTF8, rec.Header().Get(HeaderContentType)) @@ -346,71 +372,17 @@ func TestContextXMLWithEmptyIntent(t *testing.T) { } } -type responseWriterErr struct { -} - -func (responseWriterErr) Header() http.Header { - return http.Header{} -} - -func (responseWriterErr) Write([]byte) (int, error) { - return 0, errors.New("responseWriterErr") -} - -func (responseWriterErr) WriteHeader(statusCode int) { -} - -func TestContextXMLError(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - c.response.Writer = responseWriterErr{} - - err := c.XML(http.StatusOK, make(chan bool)) - assert.EqualError(t, err, "responseWriterErr") -} - -func TestContextString(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - - err := c.String(http.StatusOK, "Hello, World!") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMETextPlainCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(t, "Hello, World!", rec.Body.String()) - } -} - -func TestContextHTML(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - - err := c.HTML(http.StatusOK, "Hello, World!") - if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, MIMETextHTMLCharsetUTF8, rec.Header().Get(HeaderContentType)) - assert.Equal(t, "Hello, World!", rec.Body.String()) - } -} - -func TestContextStream(t *testing.T) { +func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) + err := c.JSON(http.StatusCreated, user{ID: 1, Name: "Jon Snow"}) - r := strings.NewReader("response from a stream") - err := c.Stream(http.StatusOK, "application/octet-stream", r) if assert.NoError(t, err) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "application/octet-stream", rec.Header().Get(HeaderContentType)) - assert.Equal(t, "response from a stream", rec.Body.String()) + assert.Equal(t, http.StatusCreated, rec.Code) + assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) + assert.Equal(t, userJSON+"\n", rec.Body.String()) } } @@ -436,7 +408,7 @@ func TestContextAttachment(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) err := c.Attachment("_fixture/images/walle.png", tc.whenName) if assert.NoError(t, err) { @@ -471,7 +443,7 @@ func TestContextInline(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) err := c.Inline("_fixture/images/walle.png", tc.whenName) if assert.NoError(t, err) { @@ -488,69 +460,12 @@ func TestContextNoContent(t *testing.T) { e := New() rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) c.NoContent(http.StatusOK) assert.Equal(t, http.StatusOK, rec.Code) } -func TestContextError(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/?pretty", nil) - c := e.NewContext(req, rec).(*context) - - c.Error(errors.New("error")) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.True(t, c.Response().Committed) -} - -func TestContextReset(t *testing.T) { - e := New() - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, rec).(*context) - - c.SetParamNames("foo") - c.SetParamValues("bar") - c.Set("foe", "ban") - c.query = url.Values(map[string][]string{"fon": {"baz"}}) - - c.Reset(req, httptest.NewRecorder()) - - assert.Len(t, c.ParamValues(), 0) - assert.Len(t, c.ParamNames(), 0) - assert.Len(t, c.Path(), 0) - assert.Len(t, c.QueryParams(), 0) - assert.Len(t, c.store, 0) -} - -func TestContext_JSON_CommitsCustomResponseCode(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - err := c.JSON(http.StatusCreated, user{1, "Jon Snow"}) - - if assert.NoError(t, err) { - assert.Equal(t, http.StatusCreated, rec.Code) - assert.Equal(t, MIMEApplicationJSON, rec.Header().Get(HeaderContentType)) - assert.Equal(t, userJSON+"\n", rec.Body.String()) - } -} - -func TestContext_JSON_DoesntCommitResponseCodePrematurely(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) - err := c.JSON(http.StatusCreated, map[string]float64{"a": math.NaN()}) - - if assert.Error(t, err) { - assert.False(t, c.response.Committed) - } -} - func TestContextCookie(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -559,7 +474,7 @@ func TestContextCookie(t *testing.T) { req.Header.Add(HeaderCookie, theme) req.Header.Add(HeaderCookie, user) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) // Read single cookie, err := c.Cookie("theme") @@ -596,107 +511,237 @@ func TestContextCookie(t *testing.T) { assert.Contains(t, rec.Header().Get(HeaderSetCookie), "HttpOnly") } -func TestContextPath(t *testing.T) { - e := New() - r := e.Router() +func TestContext_PathValues(t *testing.T) { + var testCases = []struct { + name string + given PathValues + expect PathValues + }{ + { + name: "param exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + expect: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + }, + { + name: "params is empty", + given: PathValues{}, + expect: PathValues{}, + }, + } - handler := func(c Context) error { return c.String(http.StatusOK, "OK") } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) - r.Add(http.MethodGet, "/users/:id", handler) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1", c) + c.SetPathValues(tc.given) + + assert.EqualValues(t, tc.expect, c.PathValues()) + }) + } +} + +func TestContext_PathParam(t *testing.T) { + var testCases = []struct { + name string + given PathValues + whenParamName string + expect string + }{ + { + name: "param exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "multiple same param values exists - return first", + given: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "uid", Value: "202"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + expect: "101", + }, + { + name: "param does not exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) - assert.Equal(t, "/users/:id", c.Path()) + c.SetPathValues(tc.given) - r.Add(http.MethodGet, "/users/:uid/files/:fid", handler) - c = e.NewContext(nil, nil) - r.Find(http.MethodGet, "/users/1/files/1", c) - assert.Equal(t, "/users/:uid/files/:fid", c.Path()) + assert.EqualValues(t, tc.expect, c.Param(tc.whenParamName)) + }) + } } -func TestContextPathParam(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - c := e.NewContext(req, nil) +func TestContext_PathParamDefault(t *testing.T) { + var testCases = []struct { + name string + given PathValues + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "param exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "101", + }, + { + name: "param exists and is empty", + given: PathValues{ + {Name: "uid", Value: ""}, + {Name: "fid", Value: "501"}, + }, + whenParamName: "uid", + whenDefaultValue: "999", + expect: "", // <-- this is different from QueryParamOr behaviour + }, + { + name: "param does not exists", + given: PathValues{ + {Name: "uid", Value: "101"}, + }, + whenParamName: "nope", + whenDefaultValue: "999", + expect: "999", + }, + } - // ParamNames - c.SetParamNames("uid", "fid") - assert.EqualValues(t, []string{"uid", "fid"}, c.ParamNames()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + c := e.NewContext(req, nil) - // ParamValues - c.SetParamValues("101", "501") - assert.EqualValues(t, []string{"101", "501"}, c.ParamValues()) + c.SetPathValues(tc.given) - // Param - assert.Equal(t, "501", c.Param("fid")) - assert.Equal(t, "", c.Param("undefined")) + assert.EqualValues(t, tc.expect, c.ParamOr(tc.whenParamName, tc.whenDefaultValue)) + }) + } } -func TestContextGetAndSetParam(t *testing.T) { - e := New() - r := e.Router() - r.Add(http.MethodGet, "/:foo", func(Context) error { return nil }) - req := httptest.NewRequest(http.MethodGet, "/:foo", nil) - c := e.NewContext(req, nil) - c.SetParamNames("foo") - - // round-trip param values with modification - paramVals := c.ParamValues() - assert.EqualValues(t, []string{""}, c.ParamValues()) - paramVals[0] = "bar" - c.SetParamValues(paramVals...) - assert.EqualValues(t, []string{"bar"}, c.ParamValues()) - - // shouldn't explode during Reset() afterwards! - assert.NotPanics(t, func() { - c.Reset(nil, nil) +func TestContextGetAndSetPathValuesMutability(t *testing.T) { + t.Run("c.PathValues() does not return copy and modifying raw slice mutates value in context", func(t *testing.T) { + e := New() + e.contextPathParamAllocSize.Store(1) + + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) + c := e.NewContext(req, nil) + + params := PathValues{{Name: "foo", Value: "101"}} + c.SetPathValues(params) + + // round-trip param values with modification + paramVals := c.PathValues() + assert.Equal(t, params, c.PathValues()) + + // PathValues() does not return copy and modifying raw slice mutates value in context + paramVals[0] = PathValue{Name: "xxx", Value: "yyy"} + assert.Equal(t, PathValues{PathValue{Name: "xxx", Value: "yyy"}}, c.PathValues()) }) -} -func TestContextSetParamNamesEchoMaxParam(t *testing.T) { - e := New() - assert.Equal(t, 0, *e.maxParam) - - expectedOneParam := []string{"one"} - expectedTwoParams := []string{"one", "two"} - expectedThreeParams := []string{"one", "two", ""} - - { - c := e.AcquireContext() - c.SetParamNames("1", "2") - c.SetParamValues(expectedTwoParams...) - assert.Equal(t, 0, *e.maxParam) // has not been changed - assert.EqualValues(t, expectedTwoParams, c.ParamValues()) - e.ReleaseContext(c) - } + t.Run("calling SetPathValues with bigger size changes capacity in context", func(t *testing.T) { + e := New() + e.contextPathParamAllocSize.Store(1) - { - c := e.AcquireContext() - c.SetParamNames("1", "2", "3") - c.SetParamValues(expectedThreeParams...) - assert.Equal(t, 0, *e.maxParam) // has not been changed - assert.EqualValues(t, expectedThreeParams, c.ParamValues()) - e.ReleaseContext(c) - } + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) + c := e.NewContext(req, nil) + // increase path param capacity in context + pathValues := PathValues{ + {Name: "aaa", Value: "bbb"}, + {Name: "ccc", Value: "ddd"}, + } + c.SetPathValues(pathValues) + assert.Equal(t, pathValues, c.PathValues()) - { // values is always same size as names length - c := e.NewContext(nil, nil) - c.SetParamValues([]string{"one", "two"}...) // more values than names should be ok - c.SetParamNames("1") - assert.Equal(t, 0, *e.maxParam) // has not been changed - assert.EqualValues(t, expectedOneParam, c.ParamValues()) - } + // shouldn't explode during Reset() afterwards! + assert.NotPanics(t, func() { + c.Reset(nil, nil) + }) + assert.Equal(t, PathValues{}, c.PathValues()) + assert.Len(t, *c.pathValues, 0) + assert.Equal(t, 2, cap(*c.pathValues)) + }) - e.GET("/:id", handlerFunc) - assert.Equal(t, 1, *e.maxParam) // has not been changed + t.Run("calling SetPathValues with smaller size slice does not change capacity in context", func(t *testing.T) { + e := New() - { - c := e.NewContext(nil, nil) - c.SetParamValues([]string{"one", "two"}...) - c.SetParamNames("1") - assert.Equal(t, 1, *e.maxParam) // has not been changed - assert.EqualValues(t, expectedOneParam, c.ParamValues()) + req := httptest.NewRequest(http.MethodGet, "/:foo", nil) + c := e.NewContext(req, nil) + c.pathValues = &PathValues{ + {Name: "aaa", Value: "bbb"}, + {Name: "ccc", Value: "ddd"}, + } + + pathValues := PathValues{ + {Name: "aaa", Value: "bbb"}, + } + // given pathValues slice is smaller. this should not decrease c.pathValues capacity + c.SetPathValues(pathValues) + assert.Equal(t, pathValues, c.PathValues()) + + // shouldn't explode during Reset() afterwards! + assert.NotPanics(t, func() { + c.Reset(nil, nil) + }) + assert.Equal(t, PathValues{}, c.PathValues()) + assert.Len(t, *c.pathValues, 0) + assert.Equal(t, 2, cap(*c.pathValues)) + }) + +} + +// Issue #1655 +func TestContext_SetParamNamesShouldNotModifyPathValuesCapacity(t *testing.T) { + e := New() + c := e.NewContext(nil, nil) + + assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) + expectedTwoParams := PathValues{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + } + c.SetPathValues(expectedTwoParams) + assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) + assert.Equal(t, expectedTwoParams, c.PathValues()) + + expectedThreeParams := PathValues{ + {Name: "1", Value: "one"}, + {Name: "2", Value: "two"}, + {Name: "3", Value: "three"}, } + c.SetPathValues(expectedThreeParams) + assert.Equal(t, int32(0), e.contextPathParamAllocSize.Load()) + assert.Equal(t, expectedThreeParams, c.PathValues()) } func TestContextFormValue(t *testing.T) { @@ -713,41 +758,151 @@ func TestContextFormValue(t *testing.T) { assert.Equal(t, "Jon Snow", c.FormValue("name")) assert.Equal(t, "jon@labstack.com", c.FormValue("email")) - // FormParams - params, err := c.FormParams() + // FormValueOr + assert.Equal(t, "Jon Snow", c.FormValueOr("name", "nope")) + assert.Equal(t, "default", c.FormValueOr("missing", "default")) + + // FormValues + values, err := c.FormValues() if assert.NoError(t, err) { assert.Equal(t, url.Values{ "name": []string{"Jon Snow"}, "email": []string{"jon@labstack.com"}, - }, params) + }, values) } // Multipart FormParams error req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(f.Encode())) req.Header.Add(HeaderContentType, MIMEMultipartForm) c = e.NewContext(req, nil) - params, err = c.FormParams() - assert.Nil(t, params) + values, err = c.FormValues() + assert.Nil(t, values) assert.Error(t, err) } -func TestContextQueryParam(t *testing.T) { - q := make(url.Values) - q.Set("name", "Jon Snow") - q.Set("email", "jon@labstack.com") - req := httptest.NewRequest(http.MethodGet, "/?"+q.Encode(), nil) - e := New() - c := e.NewContext(req, nil) +func TestContext_QueryParams(t *testing.T) { + var testCases = []struct { + expect url.Values + name string + givenURL string + }{ + { + name: "multiple values in url", + givenURL: "/?test=1&test=2&email=jon%40labstack.com", + expect: url.Values{ + "test": []string{"1", "2"}, + "email": []string{"jon@labstack.com"}, + }, + }, + { + name: "single value in url", + givenURL: "/?nope=1", + expect: url.Values{ + "nope": []string{"1"}, + }, + }, + { + name: "no query params in url", + givenURL: "/?", + expect: url.Values{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParams()) + }) + } +} + +func TestContext_QueryParam(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + expect: "1", + }, + { + name: "multiple values exists in url", + givenURL: "/?test=9&test=8", + whenParamName: "test", + expect: "9", // <-- first value in returned + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + expect: "", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + expect: "", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) - // QueryParam - assert.Equal(t, "Jon Snow", c.QueryParam("name")) - assert.Equal(t, "jon@labstack.com", c.QueryParam("email")) + assert.Equal(t, tc.expect, c.QueryParam(tc.whenParamName)) + }) + } +} + +func TestContext_QueryParamDefault(t *testing.T) { + var testCases = []struct { + name string + givenURL string + whenParamName string + whenDefaultValue string + expect string + }{ + { + name: "value exists in url", + givenURL: "/?test=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "1", + }, + { + name: "value does not exists in url", + givenURL: "/?nope=1", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + { + name: "value is empty in url", + givenURL: "/?test=", + whenParamName: "test", + whenDefaultValue: "999", + expect: "999", + }, + } - // QueryParams - assert.Equal(t, url.Values{ - "name": []string{"Jon Snow"}, - "email": []string{"jon@labstack.com"}, - }, c.QueryParams()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + e := New() + c := e.NewContext(req, nil) + + assert.Equal(t, tc.expect, c.QueryParamOr(tc.whenParamName, tc.whenDefaultValue)) + }) + } } func TestContextFormFile(t *testing.T) { @@ -808,16 +963,47 @@ func TestContextRedirect(t *testing.T) { assert.Error(t, c.Redirect(310, "http://labstack.github.io/echo")) } -func TestContextStore(t *testing.T) { - var c Context = new(context) - c.Set("name", "Jon Snow") - assert.Equal(t, "Jon Snow", c.Get("name")) +func TestContextGet(t *testing.T) { + var testCases = []struct { + name string + given any + whenKey string + expect any + }{ + { + name: "ok, value exist", + given: "Jon Snow", + whenKey: "key", + expect: "Jon Snow", + }, + { + name: "ok, value does not exist", + given: "Jon Snow", + whenKey: "nope", + expect: nil, + }, + { + name: "ok, value is nil value", + given: []byte(nil), + whenKey: "key", + expect: []byte(nil), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var c = new(Context) + c.Set("key", tc.given) + + v := c.Get(tc.whenKey) + assert.Equal(t, tc.expect, v) + }) + } } func BenchmarkContext_Store(b *testing.B) { e := &Echo{} - c := &context{ + c := &Context{ echo: e, } @@ -829,45 +1015,9 @@ func BenchmarkContext_Store(b *testing.B) { } } -func TestContextHandler(t *testing.T) { - e := New() - r := e.Router() - b := new(bytes.Buffer) - - r.Add(http.MethodGet, "/handler", func(Context) error { - _, err := b.Write([]byte("handler")) - return err - }) - c := e.NewContext(nil, nil) - r.Find(http.MethodGet, "/handler", c) - err := c.Handler()(c) - assert.Equal(t, "handler", b.String()) - assert.NoError(t, err) -} - -func TestContext_SetHandler(t *testing.T) { - var c Context = new(context) - - assert.Nil(t, c.Handler()) - - c.SetHandler(func(c Context) error { - return nil - }) - assert.NotNil(t, c.Handler()) -} - -func TestContext_Path(t *testing.T) { - path := "/pa/th" - - var c Context = new(context) - - c.SetPath(path) - assert.Equal(t, path, c.Path()) -} - type validator struct{} -func (*validator) Validate(i interface{}) error { +func (*validator) Validate(i any) error { return nil } @@ -893,7 +1043,7 @@ func TestContext_QueryString(t *testing.T) { } func TestContext_Request(t *testing.T) { - var c Context = new(context) + var c = new(Context) assert.Nil(t, c.Request()) @@ -905,11 +1055,11 @@ func TestContext_Request(t *testing.T) { func TestContext_Scheme(t *testing.T) { tests := []struct { - c Context + c *Context s string }{ { - &context{ + &Context{ request: &http.Request{ TLS: &tls.ConnectionState{}, }, @@ -917,7 +1067,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedProto: []string{"https"}}, }, @@ -925,7 +1075,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedProtocol: []string{"http"}}, }, @@ -933,7 +1083,7 @@ func TestContext_Scheme(t *testing.T) { "http", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedSsl: []string{"on"}}, }, @@ -941,7 +1091,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXUrlScheme: []string{"https"}}, }, @@ -949,7 +1099,7 @@ func TestContext_Scheme(t *testing.T) { "https", }, { - &context{ + &Context{ request: &http.Request{}, }, "http", @@ -963,11 +1113,11 @@ func TestContext_Scheme(t *testing.T) { func TestContext_IsWebSocket(t *testing.T) { tests := []struct { - c Context + c *Context ws assert.BoolAssertionFunc }{ { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"websocket"}}, }, @@ -975,7 +1125,7 @@ func TestContext_IsWebSocket(t *testing.T) { assert.True, }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"Websocket"}}, }, @@ -983,13 +1133,13 @@ func TestContext_IsWebSocket(t *testing.T) { assert.True, }, { - &context{ + &Context{ request: &http.Request{}, }, assert.False, }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderUpgrade: []string{"other"}}, }, @@ -1014,32 +1164,16 @@ func TestContext_Bind(t *testing.T) { req.Header.Add(HeaderContentType, MIMEApplicationJSON) err := c.Bind(u) assert.NoError(t, err) - assert.Equal(t, &user{1, "Jon Snow"}, u) -} - -func TestContext_Logger(t *testing.T) { - e := New() - c := e.NewContext(nil, nil) - - log1 := c.Logger() - assert.NotNil(t, log1) - - log2 := log.New("echo2") - c.SetLogger(log2) - assert.Equal(t, log2, c.Logger()) - - // Resetting the context returns the initial logger - c.Reset(nil, nil) - assert.Equal(t, log1, c.Logger()) + assert.Equal(t, &user{ID: 1, Name: "Jon Snow"}, u) } func TestContext_RealIP(t *testing.T) { tests := []struct { - c Context + c *Context s string }{ { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1, 127.0.1.1, "}}, }, @@ -1047,7 +1181,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1,127.0.1.1"}}, }, @@ -1055,7 +1189,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"127.0.0.1"}}, }, @@ -1063,7 +1197,7 @@ func TestContext_RealIP(t *testing.T) { "127.0.0.1", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348], 2001:db8::1, "}}, }, @@ -1071,7 +1205,7 @@ func TestContext_RealIP(t *testing.T) { "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"[2001:db8:85a3:8d3:1319:8a2e:370:7348],[2001:db8::1]"}}, }, @@ -1079,7 +1213,7 @@ func TestContext_RealIP(t *testing.T) { "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{HeaderXForwardedFor: []string{"2001:db8:85a3:8d3:1319:8a2e:370:7348"}}, }, @@ -1087,7 +1221,7 @@ func TestContext_RealIP(t *testing.T) { "2001:db8:85a3:8d3:1319:8a2e:370:7348", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{ "X-Real-Ip": []string{"192.168.0.1"}, @@ -1097,7 +1231,7 @@ func TestContext_RealIP(t *testing.T) { "192.168.0.1", }, { - &context{ + &Context{ request: &http.Request{ Header: http.Header{ "X-Real-Ip": []string{"[2001:db8::1]"}, @@ -1108,7 +1242,7 @@ func TestContext_RealIP(t *testing.T) { }, { - &context{ + &Context{ request: &http.Request{ RemoteAddr: "89.89.89.89:1654", }, @@ -1121,3 +1255,170 @@ func TestContext_RealIP(t *testing.T) { assert.Equal(t, tt.s, tt.c.RealIP()) } } + +func TestContext_File(t *testing.T) { + var testCases = []struct { + whenFS fs.FS + name string + whenFile string + expectError string + expectStartsWith []byte + expectStatus int + }{ + { + name: "ok, from default file system", + whenFile: "_fixture/images/walle.png", + whenFS: nil, + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "ok, from custom file system", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + if tc.whenFS != nil { + e.Filesystem = tc.whenFS + } + + handler := func(ec *Context) error { + return ec.File(tc.whenFile) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestContext_FileFS(t *testing.T) { + var testCases = []struct { + whenFS fs.FS + name string + whenFile string + expectError string + expectStartsWith []byte + expectStatus int + }{ + { + name: "ok", + whenFile: "walle.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, not existent file", + whenFile: "not.png", + whenFS: os.DirFS("_fixture/images"), + expectStatus: http.StatusOK, + expectStartsWith: nil, + expectError: "Not Found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + handler := func(ec *Context) error { + return ec.FileFS(tc.whenFile, tc.whenFS) + } + + req := httptest.NewRequest(http.MethodGet, "/match.png", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler(c) + + assert.Equal(t, tc.expectStatus, rec.Code) + if tc.expectError != "" { + assert.EqualError(t, err, tc.expectError) + } else { + assert.NoError(t, err) + } + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestLogger(t *testing.T) { + e := New() + c := e.NewContext(nil, nil) + + log1 := c.Logger() + assert.NotNil(t, log1) + assert.Equal(t, e.Logger, log1) + + customLogger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + c.SetLogger(customLogger) + assert.Equal(t, customLogger, c.Logger()) + + // Resetting the context returns the initial Echo logger + c.Reset(nil, nil) + assert.Equal(t, e.Logger, c.Logger()) +} + +func TestRouteInfo(t *testing.T) { + e := New() + c := e.NewContext(nil, nil) + + orgRI := RouteInfo{ + Name: "root", + Method: http.MethodGet, + Path: "/*", + Parameters: []string{"*"}, + } + c.route = &orgRI + ri := c.RouteInfo() + assert.Equal(t, orgRI, ri) + + // Test mutability when middlewares start to change things + + // RouteInfo inside context will not be affected when returned instance is changed + expect := orgRI.Clone() + ri.Path = "changed" + ri.Parameters[0] = "changed" + assert.Equal(t, expect, c.RouteInfo()) + + // RouteInfo inside context will not be affected when returned instance is changed + expect = c.RouteInfo() + orgRI.Name = "changed" + assert.NotEqual(t, expect, c.RouteInfo()) +} diff --git a/echo.go b/echo.go index ae2283f60..22c27a43f 100644 --- a/echo.go +++ b/echo.go @@ -9,30 +9,33 @@ Example: package main import ( - "net/http" + "log/slog" + "net/http" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/middleware" ) // Handler - func hello(c echo.Context) error { - return c.String(http.StatusOK, "Hello, World!") + func hello(c *echo.Context) error { + return c.String(http.StatusOK, "Hello, World!") } func main() { - // Echo instance - e := echo.New() + // Echo instance + e := echo.New() - // Middleware - e.Use(middleware.Logger()) - e.Use(middleware.Recover()) + // Middleware + e.Use(middleware.RequestLogger()) + e.Use(middleware.Recover()) - // Routes - e.GET("/", hello) + // Routes + e.GET("/", hello) - // Start server - e.Logger.Fatal(e.Start(":1323")) + // Start server + if err := e.Start(":8080"); err != nil { + slog.Error("failed to start server", "error", err) + } } Learn more at https://echo.labstack.com @@ -41,126 +44,80 @@ package echo import ( stdContext "context" - "crypto/tls" "encoding/json" "errors" "fmt" - stdLog "log" - "net" + "io/fs" + "log/slog" "net/http" + "net/url" "os" - "reflect" - "runtime" + "os/signal" + "path/filepath" + "strings" "sync" - "time" - - "github.com/labstack/gommon/color" - "github.com/labstack/gommon/log" - "golang.org/x/crypto/acme" - "golang.org/x/crypto/acme/autocert" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" + "sync/atomic" + "syscall" ) // Echo is the top-level framework instance. // // Goroutine safety: Do not mutate Echo instance fields after server has started. Accessing these // fields from handlers/middlewares and changing field values at the same time leads to data-races. -// Adding new routes after the server has been started is also not safe! +// Same rule applies to adding new routes after server has been started - Adding a route is not Goroutine safe action. type Echo struct { - filesystem - common - // startupMutex is mutex to lock Echo instance access during server configuration and startup. Useful for to get - // listener address info (on which interface/port was listener bound) without having data races. - startupMutex sync.RWMutex - colorer *color.Color - - // premiddleware are middlewares that are run before routing is done. In case a pre-middleware returns - // an error the router is not executed and the request will end up in the global error handler. - premiddleware []MiddlewareFunc - middleware []MiddlewareFunc - maxParam *int - router *Router - routers map[string]*Router - pool sync.Pool - - StdLogger *stdLog.Logger - Server *http.Server - TLSServer *http.Server - Listener net.Listener - TLSListener net.Listener - AutoTLSManager autocert.Manager - HTTPErrorHandler HTTPErrorHandler + serveHTTPFunc func(http.ResponseWriter, *http.Request) + Binder Binder - JSONSerializer JSONSerializer - Validator Validator + Filesystem fs.FS Renderer Renderer - Logger Logger + Validator Validator + JSONSerializer JSONSerializer IPExtractor IPExtractor - ListenerNetwork string + OnAddRoute func(route Route) error + HTTPErrorHandler HTTPErrorHandler + Logger *slog.Logger - // OnAddRouteHandler is called when Echo adds new route to specific host router. - OnAddRouteHandler func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) - DisableHTTP2 bool - Debug bool - HideBanner bool - HidePort bool -} + contextPool sync.Pool -// Route contains a handler and information for matching against requests. -type Route struct { - Method string `json:"method"` - Path string `json:"path"` - Name string `json:"name"` -} + router Router -// HTTPError represents an error that occurred while handling a request. -type HTTPError struct { - Internal error `json:"-"` // Stores the error returned by an external dependency - Message interface{} `json:"message"` - Code int `json:"-"` -} - -// MiddlewareFunc defines a function to process middleware. -type MiddlewareFunc func(next HandlerFunc) HandlerFunc + // premiddleware are middlewares that are called before routing is done + premiddleware []MiddlewareFunc -// HandlerFunc defines a function to serve HTTP requests. -type HandlerFunc func(c Context) error + // middleware are middlewares that are called after routing is done and before handler is called + middleware []MiddlewareFunc -// HTTPErrorHandler is a centralized HTTP error handler. -type HTTPErrorHandler func(err error, c Context) + contextPathParamAllocSize atomic.Int32 -// Validator is the interface that wraps the Validate function. -type Validator interface { - Validate(i interface{}) error + // formParseMaxMemory is passed to Context for multipart form parsing (See http.Request.ParseMultipartForm) + formParseMaxMemory int64 } // JSONSerializer is the interface that encodes and decodes JSON to and from interfaces. type JSONSerializer interface { - Serialize(c Context, i interface{}, indent string) error - Deserialize(c Context, i interface{}) error + Serialize(c *Context, target any, indent string) error + Deserialize(c *Context, target any) error } -// Map defines a generic map of type `map[string]interface{}`. -type Map map[string]interface{} +// HTTPErrorHandler is a centralized HTTP error handler. +type HTTPErrorHandler func(c *Context, err error) -// Common struct for Echo & Group. -type common struct{} +// HandlerFunc defines a function to serve HTTP requests. +type HandlerFunc func(c *Context) error -// HTTP methods -// NOTE: Deprecated, please use the stdlib constants directly instead. -const ( - CONNECT = http.MethodConnect - DELETE = http.MethodDelete - GET = http.MethodGet - HEAD = http.MethodHead - OPTIONS = http.MethodOptions - PATCH = http.MethodPatch - POST = http.MethodPost - // PROPFIND = "PROPFIND" - PUT = http.MethodPut - TRACE = http.MethodTrace -) +// MiddlewareFunc defines a function to process middleware. +type MiddlewareFunc func(next HandlerFunc) HandlerFunc + +// MiddlewareConfigurator defines interface for creating middleware handlers with possibility to return configuration errors instead of panicking. +type MiddlewareConfigurator interface { + ToMiddleware() (MiddlewareFunc, error) +} + +// Validator is the interface that wraps the Validate function. +type Validator interface { + Validate(i any) error +} // MIME types const ( @@ -169,7 +126,7 @@ const ( // Deprecated: Please use MIMEApplicationJSON instead. JSON should be encoded using UTF-8 by default. // No "charset" parameter is defined for this registration. // Adding one really has no effect on compliant recipients. - // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1 + // See RFC 8259, section 8.1. https://datatracker.ietf.org/doc/html/rfc8259#section-8.1n" MIMEApplicationJSONCharsetUTF8 = MIMEApplicationJSON + "; " + charsetUTF8 MIMEApplicationJavaScript = "application/javascript" MIMEApplicationJavaScriptCharsetUTF8 = MIMEApplicationJavaScript + "; " + charsetUTF8 @@ -196,6 +153,9 @@ const ( REPORT = "REPORT" // RouteNotFound is special method type for routes handling "route not found" (404) cases RouteNotFound = "echo_route_not_found" + // RouteAny is special method type that matches any HTTP method in request. Any has lower + // priority that other methods that have been registered with Router to that path. + RouteAny = "echo_route_any" ) // Headers @@ -256,7 +216,7 @@ const ( HeaderXFrameOptions = "X-Frame-Options" HeaderContentSecurityPolicy = "Content-Security-Policy" HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" - HeaderXCSRFToken = "X-CSRF-Token" + HeaderXCSRFToken = "X-CSRF-Token" // #nosec G101 HeaderReferrerPolicy = "Referrer-Policy" // HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's @@ -265,273 +225,255 @@ const ( HeaderSecFetchSite = "Sec-Fetch-Site" ) -const ( - // Version of Echo - Version = "4.15.0" - website = "https://echo.labstack.com" - // http://patorjk.com/software/taag/#p=display&f=Small%20Slant&t=Echo - banner = ` - ____ __ - / __/___/ / ___ - / _// __/ _ \/ _ \ -/___/\__/_//_/\___/ %s -High performance, minimalist Go web framework -%s -____________________________________O/_______ - O\ -` -) +// Config is configuration for NewWithConfig function +type Config struct { + // Logger is the slog logger instance used for application-wide structured logging. + // If not set, a default TextHandler writing to stdout is created. + Logger *slog.Logger -var methods = [...]string{ - http.MethodConnect, - http.MethodDelete, - http.MethodGet, - http.MethodHead, - http.MethodOptions, - http.MethodPatch, - http.MethodPost, - PROPFIND, - http.MethodPut, - http.MethodTrace, - REPORT, -} - -// Errors -var ( - ErrBadRequest = NewHTTPError(http.StatusBadRequest) // HTTP 400 Bad Request - ErrUnauthorized = NewHTTPError(http.StatusUnauthorized) // HTTP 401 Unauthorized - ErrPaymentRequired = NewHTTPError(http.StatusPaymentRequired) // HTTP 402 Payment Required - ErrForbidden = NewHTTPError(http.StatusForbidden) // HTTP 403 Forbidden - ErrNotFound = NewHTTPError(http.StatusNotFound) // HTTP 404 Not Found - ErrMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed) // HTTP 405 Method Not Allowed - ErrNotAcceptable = NewHTTPError(http.StatusNotAcceptable) // HTTP 406 Not Acceptable - ErrProxyAuthRequired = NewHTTPError(http.StatusProxyAuthRequired) // HTTP 407 Proxy AuthRequired - ErrRequestTimeout = NewHTTPError(http.StatusRequestTimeout) // HTTP 408 Request Timeout - ErrConflict = NewHTTPError(http.StatusConflict) // HTTP 409 Conflict - ErrGone = NewHTTPError(http.StatusGone) // HTTP 410 Gone - ErrLengthRequired = NewHTTPError(http.StatusLengthRequired) // HTTP 411 Length Required - ErrPreconditionFailed = NewHTTPError(http.StatusPreconditionFailed) // HTTP 412 Precondition Failed - ErrStatusRequestEntityTooLarge = NewHTTPError(http.StatusRequestEntityTooLarge) // HTTP 413 Payload Too Large - ErrRequestURITooLong = NewHTTPError(http.StatusRequestURITooLong) // HTTP 414 URI Too Long - ErrUnsupportedMediaType = NewHTTPError(http.StatusUnsupportedMediaType) // HTTP 415 Unsupported Media Type - ErrRequestedRangeNotSatisfiable = NewHTTPError(http.StatusRequestedRangeNotSatisfiable) // HTTP 416 Range Not Satisfiable - ErrExpectationFailed = NewHTTPError(http.StatusExpectationFailed) // HTTP 417 Expectation Failed - ErrTeapot = NewHTTPError(http.StatusTeapot) // HTTP 418 I'm a teapot - ErrMisdirectedRequest = NewHTTPError(http.StatusMisdirectedRequest) // HTTP 421 Misdirected Request - ErrUnprocessableEntity = NewHTTPError(http.StatusUnprocessableEntity) // HTTP 422 Unprocessable Entity - ErrLocked = NewHTTPError(http.StatusLocked) // HTTP 423 Locked - ErrFailedDependency = NewHTTPError(http.StatusFailedDependency) // HTTP 424 Failed Dependency - ErrTooEarly = NewHTTPError(http.StatusTooEarly) // HTTP 425 Too Early - ErrUpgradeRequired = NewHTTPError(http.StatusUpgradeRequired) // HTTP 426 Upgrade Required - ErrPreconditionRequired = NewHTTPError(http.StatusPreconditionRequired) // HTTP 428 Precondition Required - ErrTooManyRequests = NewHTTPError(http.StatusTooManyRequests) // HTTP 429 Too Many Requests - ErrRequestHeaderFieldsTooLarge = NewHTTPError(http.StatusRequestHeaderFieldsTooLarge) // HTTP 431 Request Header Fields Too Large - ErrUnavailableForLegalReasons = NewHTTPError(http.StatusUnavailableForLegalReasons) // HTTP 451 Unavailable For Legal Reasons - ErrInternalServerError = NewHTTPError(http.StatusInternalServerError) // HTTP 500 Internal Server Error - ErrNotImplemented = NewHTTPError(http.StatusNotImplemented) // HTTP 501 Not Implemented - ErrBadGateway = NewHTTPError(http.StatusBadGateway) // HTTP 502 Bad Gateway - ErrServiceUnavailable = NewHTTPError(http.StatusServiceUnavailable) // HTTP 503 Service Unavailable - ErrGatewayTimeout = NewHTTPError(http.StatusGatewayTimeout) // HTTP 504 Gateway Timeout - ErrHTTPVersionNotSupported = NewHTTPError(http.StatusHTTPVersionNotSupported) // HTTP 505 HTTP Version Not Supported - ErrVariantAlsoNegotiates = NewHTTPError(http.StatusVariantAlsoNegotiates) // HTTP 506 Variant Also Negotiates - ErrInsufficientStorage = NewHTTPError(http.StatusInsufficientStorage) // HTTP 507 Insufficient Storage - ErrLoopDetected = NewHTTPError(http.StatusLoopDetected) // HTTP 508 Loop Detected - ErrNotExtended = NewHTTPError(http.StatusNotExtended) // HTTP 510 Not Extended - ErrNetworkAuthenticationRequired = NewHTTPError(http.StatusNetworkAuthenticationRequired) // HTTP 511 Network Authentication Required - - ErrValidatorNotRegistered = errors.New("validator not registered") - ErrRendererNotRegistered = errors.New("renderer not registered") - ErrInvalidRedirectCode = errors.New("invalid redirect status code") - ErrCookieNotFound = errors.New("cookie not found") - ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") - ErrInvalidListenerNetwork = errors.New("invalid listener network") -) + // HTTPErrorHandler is the centralized error handler that processes errors returned + // by handlers and middleware, converting them to appropriate HTTP responses. + // If not set, DefaultHTTPErrorHandler(false) is used. + HTTPErrorHandler HTTPErrorHandler + + // Router is the HTTP request router responsible for matching URLs to handlers + // using a radix tree-based algorithm. + // If not set, NewRouter(RouterConfig{}) is used. + Router Router + + // OnAddRoute is an optional callback hook executed when routes are registered. + // Useful for route validation, logging, or custom route processing. + // If not set, no callback is executed. + OnAddRoute func(route Route) error -// NotFoundHandler is the handler that router uses in case there was no matching route found. Returns an error that results -// HTTP 404 status code. -var NotFoundHandler = func(c Context) error { - return ErrNotFound + // Filesystem is the fs.FS implementation used for serving static files. + // Supports os.DirFS, embed.FS, and custom implementations. + // If not set, defaults to current working directory. + Filesystem fs.FS + + // Binder handles automatic data binding from HTTP requests to Go structs. + // Supports JSON, XML, form data, query parameters, and path parameters. + // If not set, DefaultBinder is used. + Binder Binder + + // Validator provides optional struct validation after data binding. + // Commonly used with third-party validation libraries. + // If not set, Context.Validate() returns ErrValidatorNotRegistered. + Validator Validator + + // Renderer provides template rendering for generating HTML responses. + // Requires integration with a template engine like html/template. + // If not set, Context.Render() returns ErrRendererNotRegistered. + Renderer Renderer + + // JSONSerializer handles JSON encoding and decoding for HTTP requests/responses. + // Can be replaced with faster alternatives like jsoniter or sonic. + // If not set, DefaultJSONSerializer using encoding/json is used. + JSONSerializer JSONSerializer + + // IPExtractor defines the strategy for extracting the real client IP address + // from requests, particularly important when behind proxies or load balancers. + // Used for rate limiting, access control, and logging. + // If not set, falls back to checking X-Forwarded-For and X-Real-IP headers. + IPExtractor IPExtractor + + // FormParseMaxMemory is default value for memory limit that is used + // when parsing multipart forms (See (*http.Request).ParseMultipartForm) + FormParseMaxMemory int64 } -// MethodNotAllowedHandler is the handler thar router uses in case there was no matching route found but there was -// another matching routes for that requested URL. Returns an error that results HTTP 405 Method Not Allowed status code. -var MethodNotAllowedHandler = func(c Context) error { - // See RFC 7231 section 7.4.1: An origin server MUST generate an Allow field in a 405 (Method Not Allowed) - // response and MAY do so in any other response. For disabled resources an empty Allow header may be returned - routerAllowMethods, ok := c.Get(ContextKeyHeaderAllow).(string) - if ok && routerAllowMethods != "" { - c.Response().Header().Set(HeaderAllow, routerAllowMethods) +// NewWithConfig creates an instance of Echo with given configuration. +func NewWithConfig(config Config) *Echo { + e := New() + if config.Logger != nil { + e.Logger = config.Logger + } + if config.HTTPErrorHandler != nil { + e.HTTPErrorHandler = config.HTTPErrorHandler } - return ErrMethodNotAllowed + if config.Router != nil { + e.router = config.Router + } + if config.OnAddRoute != nil { + e.OnAddRoute = config.OnAddRoute + } + if config.Filesystem != nil { + e.Filesystem = config.Filesystem + } + if config.Binder != nil { + e.Binder = config.Binder + } + if config.Validator != nil { + e.Validator = config.Validator + } + if config.Renderer != nil { + e.Renderer = config.Renderer + } + if config.JSONSerializer != nil { + e.JSONSerializer = config.JSONSerializer + } + if config.IPExtractor != nil { + e.IPExtractor = config.IPExtractor + } + if config.FormParseMaxMemory > 0 { + e.formParseMaxMemory = config.FormParseMaxMemory + } + return e } // New creates an instance of Echo. -func New() (e *Echo) { - e = &Echo{ - filesystem: createFilesystem(), - Server: new(http.Server), - TLSServer: new(http.Server), - AutoTLSManager: autocert.Manager{ - Prompt: autocert.AcceptTOS, - }, - Logger: log.New("echo"), - colorer: color.New(), - maxParam: new(int), - ListenerNetwork: "tcp", - } - e.Server.Handler = e - e.TLSServer.Handler = e - e.HTTPErrorHandler = e.DefaultHTTPErrorHandler - e.Binder = &DefaultBinder{} - e.JSONSerializer = &DefaultJSONSerializer{} - e.Logger.SetLevel(log.ERROR) - e.StdLogger = stdLog.New(e.Logger.Output(), e.Logger.Prefix()+": ", 0) - e.pool.New = func() interface{} { - return e.NewContext(nil, nil) - } - e.router = NewRouter(e) - e.routers = map[string]*Router{} - return -} +func New() *Echo { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + e := &Echo{ + Logger: logger, + Filesystem: newDefaultFS(), + Binder: &DefaultBinder{}, + JSONSerializer: &DefaultJSONSerializer{}, + formParseMaxMemory: defaultMemory, + } -// NewContext returns a Context instance. -func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) Context { - return &context{ - request: r, - response: NewResponse(w, e), - store: make(Map), - echo: e, - pvalues: make([]string, *e.maxParam), - handler: NotFoundHandler, + e.serveHTTPFunc = e.serveHTTP + e.router = NewRouter(RouterConfig{}) + e.HTTPErrorHandler = DefaultHTTPErrorHandler(false) + e.contextPool.New = func() any { + return newContext(nil, nil, e) } + return e } -// Router returns the default router. -func (e *Echo) Router() *Router { - return e.router +// NewContext returns a new Context instance. +// +// Note: both request and response can be left to nil as Echo.ServeHTTP will call c.Reset(req,resp) anyway +// these arguments are useful when creating context for tests and cases like that. +func (e *Echo) NewContext(r *http.Request, w http.ResponseWriter) *Context { + return newContext(r, w, e) } -// Routers returns the map of host => router. -func (e *Echo) Routers() map[string]*Router { - return e.routers +// Router returns the default router. +func (e *Echo) Router() Router { + return e.router } -// DefaultHTTPErrorHandler is the default HTTP error handler. It sends a JSON response -// with status code. +// DefaultHTTPErrorHandler creates new default HTTP error handler implementation. It sends a JSON response +// with status code. `exposeError` parameter decides if returned message will contain also error message or not // -// NOTE: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). +// Note: DefaultHTTPErrorHandler does not log errors. Use middleware for it if errors need to be logged (separately) +// Note: In case errors happens in middleware call-chain that is returning from handler (which did not return an error). // When handler has already sent response (ala c.JSON()) and there is error in middleware that is returning from // handler. Then the error that global error handler received will be ignored because we have already "committed" the // response and status code header has been sent to the client. -func (e *Echo) DefaultHTTPErrorHandler(err error, c Context) { - - if c.Response().Committed { - return - } +func DefaultHTTPErrorHandler(exposeError bool) HTTPErrorHandler { + return func(c *Context, err error) { + if r, _ := UnwrapResponse(c.response); r != nil && r.Committed { + return + } - he, ok := err.(*HTTPError) - if ok { - if he.Internal != nil { - if herr, ok := he.Internal.(*HTTPError); ok { - he = herr + code := http.StatusInternalServerError + var sc HTTPStatusCoder + if errors.As(err, &sc) { + if tmp := sc.StatusCode(); tmp != 0 { + code = tmp } } - } else { - he = &HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), - } - } - // Issue #1426 - code := he.Code - message := he.Message + var result any + switch m := sc.(type) { + case json.Marshaler: // this type knows how to format itself to JSON + result = m + case *HTTPError: + sText := m.Message + if sText == "" { + sText = http.StatusText(code) + } + msg := map[string]any{"message": sText} + if exposeError { + if wrappedErr := m.Unwrap(); wrappedErr != nil { + msg["error"] = wrappedErr.Error() + } + } + result = msg + default: + msg := map[string]any{"message": http.StatusText(code)} + if exposeError { + msg["error"] = err.Error() + } + result = msg + } - switch m := he.Message.(type) { - case string: - if e.Debug { - message = Map{"message": m, "error": err.Error()} + var cErr error + if c.Request().Method == http.MethodHead { // Issue #608 + cErr = c.NoContent(code) } else { - message = Map{"message": m} + cErr = c.JSON(code, result) + } + if cErr != nil { + c.Logger().Error("echo default error handler failed to send error to client", "error", cErr) // truly rare case. ala client already disconnected } - case json.Marshaler: - // do nothing - this type knows how to format itself to JSON - case error: - message = Map{"message": m.Error()} - } - - // Send response - if c.Request().Method == http.MethodHead { // Issue #608 - err = c.NoContent(he.Code) - } else { - err = c.JSON(code, message) - } - if err != nil { - e.Logger.Error(err) } } -// Pre adds middleware to the chain which is run before router. +// Pre adds middleware to the chain which is run before router tries to find matching route. +// Meaning middleware is executed even for 404 (not found) cases. func (e *Echo) Pre(middleware ...MiddlewareFunc) { e.premiddleware = append(e.premiddleware, middleware...) } -// Use adds middleware to the chain which is run after router. +// Use adds middleware to the chain which is run after router has found matching route and before route/request handler method is executed. func (e *Echo) Use(middleware ...MiddlewareFunc) { e.middleware = append(e.middleware, middleware...) } // CONNECT registers a new CONNECT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodConnect, path, h, m...) } // DELETE registers a new DELETE route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodDelete, path, h, m...) } // GET registers a new GET route for a path with matching handler in the router -// with optional route-level middleware. -func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// with optional route-level middleware. Panics on error. +func (e *Echo) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodGet, path, h, m...) } // HEAD registers a new HEAD route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodHead, path, h, m...) } // OPTIONS registers a new OPTIONS route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodOptions, path, h, m...) } // PATCH registers a new PATCH route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPatch, path, h, m...) } // POST registers a new POST route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPost, path, h, m...) } // PUT registers a new PUT route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodPut, path, h, m...) } // TRACE registers a new TRACE route for a path with matching handler in the -// router with optional route-level middleware. -func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// router with optional route-level middleware. Panics on error. +func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(http.MethodTrace, path, h, m...) } @@ -540,8 +482,8 @@ func (e *Echo) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { // Path supports static and named/any parameters just like other http method is defined. Generally path is ended with // wildcard/match-any character (`/*`, `/download/*` etc). // -// Example: `e.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` -func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// Example: `e.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return e.Add(RouteNotFound, path, h, m...) } @@ -550,64 +492,149 @@ func (e *Echo) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *R // // Note: this method only adds specific set of supported HTTP methods as handler and is not true // "catch-any-arbitrary-method" way of matching requests. -func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) - } - return routes +func (e *Echo) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + return e.Add(RouteAny, path, handler, middleware...) } // Match registers a new route for multiple HTTP methods and path with matching -// handler in the router with optional route-level middleware. -func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = e.Add(m, path, handler, middleware...) +// handler in the router with optional route-level middleware. Panics on error. +func (e *Echo) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := e.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ris +} + +// Static registers a new route with path prefix to serve static files from the provided root directory. +func (e *Echo) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo { + subFs := MustSubFS(e.Filesystem, fsRoot) + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(subFs, false), + middleware..., + ) } -func (common) file(path, file string, get func(string, HandlerFunc, ...MiddlewareFunc) *Route, - m ...MiddlewareFunc) *Route { - return get(path, func(c Context) error { +// StaticFS registers a new route with path prefix to serve static files from the provided file system. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo { + return e.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + middleware..., + ) +} + +// StaticDirectoryHandler creates handler function to serve files from provided file system +// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. +func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { + return func(c *Context) error { + p := c.Param("*") + if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice + tmpPath, err := url.PathUnescape(p) + if err != nil { + return fmt.Errorf("failed to unescape path variable: %w", err) + } + p = tmpPath + } + + // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid + name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) + fi, err := fs.Stat(fileSystem, name) + if err != nil { + return ErrNotFound + } + + // If the request is for a directory and does not end with "/" + p = c.Request().URL.Path // path must not be empty. + if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { + // Redirect to ends with "/" + return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) + } + return fsFile(c, name, fileSystem) + } +} + +// FileFS registers a new route with path to serve file from the provided file system. +func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return e.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// StaticFileHandler creates handler function to serve file from provided file system +func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { + return func(c *Context) error { + return fsFile(c, file, filesystem) + } +} + +// File registers a new route with path to serve a static file with optional route-level middleware. Panics on error. +func (e *Echo) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c *Context) error { return c.File(file) - }, m...) + } + return e.Add(http.MethodGet, path, handler, middleware...) } -// File registers a new route with path to serve a static file with optional route-level middleware. -func (e *Echo) File(path, file string, m ...MiddlewareFunc) *Route { - return e.file(path, file, e.GET, m...) +// AddRoute registers a new Route with default host Router +func (e *Echo) AddRoute(route Route) (RouteInfo, error) { + return e.add(route) } -func (e *Echo) add(host, method, path string, handler HandlerFunc, middlewares ...MiddlewareFunc) *Route { - router := e.findRouter(host) - //FIXME: when handler+middleware are both nil ... make it behave like handler removal - name := handlerName(handler) - route := router.add(method, path, name, func(c Context) error { - h := applyMiddleware(handler, middlewares...) - return h(c) - }) +func (e *Echo) add(route Route) (RouteInfo, error) { + if e.OnAddRoute != nil { + if err := e.OnAddRoute(route); err != nil { + return RouteInfo{}, err + } + } - if e.OnAddRouteHandler != nil { - e.OnAddRouteHandler(host, *route, handler, middlewares) + ri, err := e.router.Add(route) + if err != nil { + return RouteInfo{}, err } - return route + paramsCount := int32(len(ri.Parameters)) // #nosec G115 + if paramsCount > e.contextPathParamAllocSize.Load() { + e.contextPathParamAllocSize.Store(paramsCount) + } + return ri, nil } // Add registers a new route for an HTTP method and path with matching handler // in the router with optional route-level middleware. -func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - return e.add("", method, path, handler, middleware...) -} - -// Host creates a new router group for the provided host and optional host-level middleware. -func (e *Echo) Host(name string, m ...MiddlewareFunc) (g *Group) { - e.routers[name] = NewRouter(e) - g = &Group{host: name, echo: e} - g.Use(m...) - return +func (e *Echo) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := e.add( + Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + Name: "", + }, + ) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri } // Group creates a new router group with prefix and optional group-level middleware. @@ -617,321 +644,105 @@ func (e *Echo) Group(prefix string, m ...MiddlewareFunc) (g *Group) { return } -// URI generates an URI from handler. -func (e *Echo) URI(handler HandlerFunc, params ...interface{}) string { - name := handlerName(handler) - return e.Reverse(name, params...) -} - -// URL is an alias for `URI` function. -func (e *Echo) URL(h HandlerFunc, params ...interface{}) string { - return e.URI(h, params...) -} - -// Reverse generates a URL from route name and provided parameters. -func (e *Echo) Reverse(name string, params ...interface{}) string { - return e.router.Reverse(name, params...) +// PreMiddlewares returns registered pre middlewares. These are middleware to the chain +// which are run before router tries to find matching route. +// Use this method to build your own ServeHTTP method. +// +// NOTE: returned slice is not a copy. Do not mutate. +func (e *Echo) PreMiddlewares() []MiddlewareFunc { + return e.premiddleware } -// Routes returns the registered routes for default router. -// In case when Echo serves multiple hosts/domains use `e.Routers()["domain2.site"].Routes()` to get specific host routes. -func (e *Echo) Routes() []*Route { - return e.router.Routes() +// Middlewares returns registered route level middlewares. Does not contain any group level +// middlewares. Use this method to build your own ServeHTTP method. +// +// NOTE: returned slice is not a copy. Do not mutate. +func (e *Echo) Middlewares() []MiddlewareFunc { + return e.middleware } // AcquireContext returns an empty `Context` instance from the pool. // You must return the context by calling `ReleaseContext()`. -func (e *Echo) AcquireContext() Context { - return e.pool.Get().(Context) +func (e *Echo) AcquireContext() *Context { + return e.contextPool.Get().(*Context) } // ReleaseContext returns the `Context` instance back to the pool. // You must call it after `AcquireContext()`. -func (e *Echo) ReleaseContext(c Context) { - e.pool.Put(c) +func (e *Echo) ReleaseContext(c *Context) { + e.contextPool.Put(c) } // ServeHTTP implements `http.Handler` interface, which serves HTTP requests. func (e *Echo) ServeHTTP(w http.ResponseWriter, r *http.Request) { - // Acquire context - c := e.pool.Get().(*context) + e.serveHTTPFunc(w, r) +} + +// serveHTTP implements `http.Handler` interface, which serves HTTP requests. +func (e *Echo) serveHTTP(w http.ResponseWriter, r *http.Request) { + c := e.contextPool.Get().(*Context) + defer e.contextPool.Put(c) + c.Reset(r, w) var h HandlerFunc if e.premiddleware == nil { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h = c.Handler() - h = applyMiddleware(h, e.middleware...) + h = applyMiddleware(e.router.Route(c), e.middleware...) } else { - h = func(c Context) error { - e.findRouter(r.Host).Find(r.Method, GetPath(r), c) - h := c.Handler() - h = applyMiddleware(h, e.middleware...) - return h(c) + h = func(cc *Context) error { + h1 := applyMiddleware(e.router.Route(cc), e.middleware...) + return h1(cc) } h = applyMiddleware(h, e.premiddleware...) } // Execute chain if err := h(c); err != nil { - e.HTTPErrorHandler(err, c) + e.HTTPErrorHandler(c, err) } - - // Release context - e.pool.Put(c) } -// Start starts an HTTP server. +// Start stars HTTP server on given address with Echo as a handler serving requests. The server can be shutdown by +// sending os.Interrupt signal with `ctrl+c`. Method returns only errors that are not http.ErrServerClosed. +// +// Note: this method is created for use in examples/demos and is deliberately simple without providing configuration +// options. +// +// In need of customization use: +// +// ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) +// defer cancel() +// sc := echo.StartConfig{Address: ":8080"} +// if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) { +// slog.Error(err.Error()) +// } +// +// // or standard library `http.Server` +// +// s := http.Server{Addr: ":8080", Handler: e} +// if err := s.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { +// slog.Error(err.Error()) +// } func (e *Echo) Start(address string) error { - e.startupMutex.Lock() - e.Server.Addr = address - if err := e.configureServer(e.Server); err != nil { - e.startupMutex.Unlock() + sc := StartConfig{Address: address} + ctx, cancel := signal.NotifyContext(stdContext.Background(), os.Interrupt, syscall.SIGTERM) // start shutdown process on ctrl+c + defer cancel() + if err := sc.Start(ctx, e); err != nil && !errors.Is(err, http.ErrServerClosed) { return err } - e.startupMutex.Unlock() - return e.Server.Serve(e.Listener) -} - -// StartTLS starts an HTTPS server. -// If `certFile` or `keyFile` is `string` the values are treated as file paths. -// If `certFile` or `keyFile` is `[]byte` the values are treated as the certificate or key as-is. -func (e *Echo) StartTLS(address string, certFile, keyFile interface{}) (err error) { - e.startupMutex.Lock() - var cert []byte - if cert, err = filepathOrContent(certFile); err != nil { - e.startupMutex.Unlock() - return - } - - var key []byte - if key, err = filepathOrContent(keyFile); err != nil { - e.startupMutex.Unlock() - return - } - - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.Certificates = make([]tls.Certificate, 1) - if s.TLSConfig.Certificates[0], err = tls.X509KeyPair(cert, key); err != nil { - e.startupMutex.Unlock() - return - } - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func filepathOrContent(fileOrContent interface{}) (content []byte, err error) { - switch v := fileOrContent.(type) { - case string: - return os.ReadFile(v) - case []byte: - return v, nil - default: - return nil, ErrInvalidCertOrKeyType - } -} - -// StartAutoTLS starts an HTTPS server using certificates automatically installed from https://letsencrypt.org. -func (e *Echo) StartAutoTLS(address string) error { - e.startupMutex.Lock() - s := e.TLSServer - s.TLSConfig = new(tls.Config) - s.TLSConfig.GetCertificate = e.AutoTLSManager.GetCertificate - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, acme.ALPNProto) - - e.configureTLS(address) - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) -} - -func (e *Echo) configureTLS(address string) { - s := e.TLSServer - s.Addr = address - if !e.DisableHTTP2 { - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2") - } -} - -// StartServer starts a custom http server. -func (e *Echo) StartServer(s *http.Server) (err error) { - e.startupMutex.Lock() - if err := e.configureServer(s); err != nil { - e.startupMutex.Unlock() - return err - } - if s.TLSConfig != nil { - e.startupMutex.Unlock() - return s.Serve(e.TLSListener) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -func (e *Echo) configureServer(s *http.Server) error { - // Setup - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = e - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if s.TLSConfig == nil { - if e.Listener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - e.Listener = l - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - return nil - } - if e.TLSListener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - return err - } - e.TLSListener = tls.NewListener(l, s.TLSConfig) - } - if !e.HidePort { - e.colorer.Printf("⇨ https server started on %s\n", e.colorer.Green(e.TLSListener.Addr())) - } return nil } -// ListenerAddr returns net.Addr for Listener -func (e *Echo) ListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.Listener == nil { - return nil - } - return e.Listener.Addr() -} - -// TLSListenerAddr returns net.Addr for TLSListener -func (e *Echo) TLSListenerAddr() net.Addr { - e.startupMutex.RLock() - defer e.startupMutex.RUnlock() - if e.TLSListener == nil { - return nil - } - return e.TLSListener.Addr() -} - -// StartH2CServer starts a custom http/2 server with h2c (HTTP/2 Cleartext). -func (e *Echo) StartH2CServer(address string, h2s *http2.Server) error { - e.startupMutex.Lock() - // Setup - s := e.Server - s.Addr = address - e.colorer.SetOutput(e.Logger.Output()) - s.ErrorLog = e.StdLogger - s.Handler = h2c.NewHandler(e, h2s) - if e.Debug { - e.Logger.SetLevel(log.DEBUG) - } - - if !e.HideBanner { - e.colorer.Printf(banner, e.colorer.Red("v"+Version), e.colorer.Blue(website)) - } - - if e.Listener == nil { - l, err := newListener(s.Addr, e.ListenerNetwork) - if err != nil { - e.startupMutex.Unlock() - return err - } - e.Listener = l - } - if !e.HidePort { - e.colorer.Printf("⇨ http server started on %s\n", e.colorer.Green(e.Listener.Addr())) - } - e.startupMutex.Unlock() - return s.Serve(e.Listener) -} - -// Close immediately stops the server. -// It internally calls `http.Server#Close()`. -func (e *Echo) Close() error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Close(); err != nil { - return err - } - return e.Server.Close() -} - -// Shutdown stops the server gracefully. -// It internally calls `http.Server#Shutdown()`. -func (e *Echo) Shutdown(ctx stdContext.Context) error { - e.startupMutex.Lock() - defer e.startupMutex.Unlock() - if err := e.TLSServer.Shutdown(ctx); err != nil { - return err - } - return e.Server.Shutdown(ctx) -} - -// NewHTTPError creates a new HTTPError instance. -func NewHTTPError(code int, message ...interface{}) *HTTPError { - he := &HTTPError{Code: code, Message: http.StatusText(code)} - if len(message) > 0 { - he.Message = message[0] - } - return he -} - -// Error makes it compatible with `error` interface. -func (he *HTTPError) Error() string { - if he.Internal == nil { - return fmt.Sprintf("code=%d, message=%v", he.Code, he.Message) - } - return fmt.Sprintf("code=%d, message=%v, internal=%v", he.Code, he.Message, he.Internal) -} - -// SetInternal sets error to HTTPError.Internal -func (he *HTTPError) SetInternal(err error) *HTTPError { - he.Internal = err - return he -} - -// WithInternal returns clone of HTTPError with err set to HTTPError.Internal field -func (he *HTTPError) WithInternal(err error) *HTTPError { - return &HTTPError{ - Code: he.Code, - Message: he.Message, - Internal: err, - } -} - -// Unwrap satisfies the Go 1.13 error wrapper interface. -func (he *HTTPError) Unwrap() error { - return he.Internal -} - // WrapHandler wraps `http.Handler` into `echo.HandlerFunc`. func WrapHandler(h http.Handler) HandlerFunc { - return func(c Context) error { - h.ServeHTTP(c.Response(), c.Request()) + return func(c *Context) error { + req := c.Request() + req.Pattern = c.Path() + for _, p := range c.PathValues() { + req.SetPathValue(p.Name, p.Value) + } + + h.ServeHTTP(c.Response(), req) return nil } } @@ -939,85 +750,91 @@ func WrapHandler(h http.Handler) HandlerFunc { // WrapMiddleware wraps `func(http.Handler) http.Handler` into `echo.MiddlewareFunc` func WrapMiddleware(m func(http.Handler) http.Handler) MiddlewareFunc { return func(next HandlerFunc) HandlerFunc { - return func(c Context) (err error) { + return func(c *Context) (err error) { + req := c.Request() + req.Pattern = c.Path() + for _, p := range c.PathValues() { + req.SetPathValue(p.Name, p.Value) + } + m(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c.SetRequest(r) - c.SetResponse(NewResponse(w, c.Echo())) + c.SetResponse(NewResponse(w, c.echo.Logger)) err = next(c) - })).ServeHTTP(c.Response(), c.Request()) + })).ServeHTTP(c.Response(), req) return } } } -// GetPath returns RawPath, if it's empty returns Path from URL -// Difference between RawPath and Path is: -// - Path is where request path is stored. Value is stored in decoded form: /%47%6f%2f becomes /Go/. -// - RawPath is an optional field which only gets set if the default encoding is different from Path. -func GetPath(r *http.Request) string { - path := r.URL.RawPath - if path == "" { - path = r.URL.Path +func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { + for i := len(middleware) - 1; i >= 0; i-- { + h = middleware[i](h) } - return path + return h } -func (e *Echo) findRouter(host string) *Router { - if len(e.routers) > 0 { - if r, ok := e.routers[host]; ok { - return r - } - } - return e.router +// defaultFS emulates os.Open behaviour with filesystem opened by `os.DirFs`. Difference between `os.Open` and `fs.Open` +// is that FS does not allow to open path that start with `..` or `/` etc. For example previously you could have `../images` +// in your application but `fs := os.DirFS("./")` would not allow you to use `fs.Open("../images")` and this would break +// all old applications that rely on being able to traverse up from current executable run path. +// NB: private because you really should use fs.FS implementation instances +type defaultFS struct { + fs fs.FS + prefix string } -func handlerName(h HandlerFunc) string { - t := reflect.ValueOf(h).Type() - if t.Kind() == reflect.Func { - return runtime.FuncForPC(reflect.ValueOf(h).Pointer()).Name() +func newDefaultFS() *defaultFS { + dir, _ := os.Getwd() + return &defaultFS{ + prefix: dir, + fs: nil, } - return t.String() } -// // PathUnescape is wraps `url.PathUnescape` -// func PathUnescape(s string) (string, error) { -// return url.PathUnescape(s) -// } - -// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted -// connections. It's used by ListenAndServe and ListenAndServeTLS so -// dead TCP connections (e.g. closing laptop mid-download) eventually -// go away. -type tcpKeepAliveListener struct { - *net.TCPListener +func (fs defaultFS) Open(name string) (fs.File, error) { + if fs.fs == nil { + return os.Open(name) // #nosec G304 + } + return fs.fs.Open(name) } -func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) { - if c, err = ln.AcceptTCP(); err != nil { - return - } else if err = c.(*net.TCPConn).SetKeepAlive(true); err != nil { - return +func subFS(currentFs fs.FS, root string) (fs.FS, error) { + root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows + if dFS, ok := currentFs.(*defaultFS); ok { + // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. + // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we + // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs + if !filepath.IsAbs(root) { + root = filepath.Join(dFS.prefix, root) + } + return &defaultFS{ + prefix: root, + fs: os.DirFS(root), + }, nil } - // Ignore error from setting the KeepAlivePeriod as some systems, such as - // OpenBSD, do not support setting TCP_USER_TIMEOUT on IPPROTO_TCP - _ = c.(*net.TCPConn).SetKeepAlivePeriod(3 * time.Minute) - return + return fs.Sub(currentFs, root) } -func newListener(address, network string) (*tcpKeepAliveListener, error) { - if network != "tcp" && network != "tcp4" && network != "tcp6" { - return nil, ErrInvalidListenerNetwork - } - l, err := net.Listen(network, address) +// MustSubFS creates sub FS from current filesystem or panic on failure. +// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. +// +// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with +// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to +// create sub fs which uses necessary prefix for directory path. +func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { + subFs, err := subFS(currentFs, fsRoot) if err != nil { - return nil, err + panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) } - return &tcpKeepAliveListener{l.(*net.TCPListener)}, nil + return subFs } -func applyMiddleware(h HandlerFunc, middleware ...MiddlewareFunc) HandlerFunc { - for i := len(middleware) - 1; i >= 0; i-- { - h = middleware[i](h) +func sanitizeURI(uri string) string { + // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri + // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash + if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { + uri = "/" + strings.TrimLeft(uri, `/\`) } - return h + return uri } diff --git a/echo_fs.go b/echo_fs.go deleted file mode 100644 index 0ffc4b0bf..000000000 --- a/echo_fs.go +++ /dev/null @@ -1,162 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "fmt" - "io/fs" - "net/http" - "net/url" - "os" - "path/filepath" - "strings" -) - -type filesystem struct { - // Filesystem is file system used by Static and File handlers to access files. - // Defaults to os.DirFS(".") - // - // When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary - // prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths - // including `assets/images` as their prefix. - Filesystem fs.FS -} - -func createFilesystem() filesystem { - return filesystem{ - Filesystem: newDefaultFS(), - } -} - -// Static registers a new route with path prefix to serve static files from the provided root directory. -func (e *Echo) Static(pathPrefix, fsRoot string) *Route { - subFs := MustSubFS(e.Filesystem, fsRoot) - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(subFs, false), - ) -} - -// StaticFS registers a new route with path prefix to serve static files from the provided file system. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (e *Echo) StaticFS(pathPrefix string, filesystem fs.FS) *Route { - return e.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// StaticDirectoryHandler creates handler function to serve files from provided file system -// When disablePathUnescaping is set then file name from path is not unescaped and is served as is. -func StaticDirectoryHandler(fileSystem fs.FS, disablePathUnescaping bool) HandlerFunc { - return func(c Context) error { - p := c.Param("*") - if !disablePathUnescaping { // when router is already unescaping we do not want to do is twice - tmpPath, err := url.PathUnescape(p) - if err != nil { - return fmt.Errorf("failed to unescape path variable: %w", err) - } - p = tmpPath - } - - // fs.FS.Open() already assumes that file names are relative to FS root path and considers name with prefix `/` as invalid - name := filepath.ToSlash(filepath.Clean(strings.TrimPrefix(p, "/"))) - fi, err := fs.Stat(fileSystem, name) - if err != nil { - return ErrNotFound - } - - // If the request is for a directory and does not end with "/" - p = c.Request().URL.Path // path must not be empty. - if fi.IsDir() && len(p) > 0 && p[len(p)-1] != '/' { - // Redirect to ends with "/" - return c.Redirect(http.StatusMovedPermanently, sanitizeURI(p+"/")) - } - return fsFile(c, name, fileSystem) - } -} - -// FileFS registers a new route with path to serve file from the provided file system. -func (e *Echo) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return e.GET(path, StaticFileHandler(file, filesystem), m...) -} - -// StaticFileHandler creates handler function to serve file from provided file system -func StaticFileHandler(file string, filesystem fs.FS) HandlerFunc { - return func(c Context) error { - return fsFile(c, file, filesystem) - } -} - -// defaultFS exists to preserve pre v4.7.0 behaviour where files were open by `os.Open`. -// v4.7 introduced `echo.Filesystem` field which is Go1.16+ `fs.Fs` interface. -// Difference between `os.Open` and `fs.Open` is that FS does not allow opening path that start with `.`, `..` or `/` -// etc. For example previously you could have `../images` in your application but `fs := os.DirFS("./")` would not -// allow you to use `fs.Open("../images")` and this would break all old applications that rely on being able to -// traverse up from current executable run path. -// NB: private because you really should use fs.FS implementation instances -type defaultFS struct { - fs fs.FS - prefix string -} - -func newDefaultFS() *defaultFS { - dir, _ := os.Getwd() - return &defaultFS{ - prefix: dir, - fs: nil, - } -} - -func (fs defaultFS) Open(name string) (fs.File, error) { - if fs.fs == nil { - return os.Open(name) - } - return fs.fs.Open(name) -} - -func subFS(currentFs fs.FS, root string) (fs.FS, error) { - root = filepath.ToSlash(filepath.Clean(root)) // note: fs.FS operates only with slashes. `ToSlash` is necessary for Windows - if dFS, ok := currentFs.(*defaultFS); ok { - // we need to make exception for `defaultFS` instances as it interprets root prefix differently from fs.FS. - // fs.Fs.Open does not like relative paths ("./", "../") and absolute paths at all but prior echo.Filesystem we - // were able to use paths like `./myfile.log`, `/etc/hosts` and these would work fine with `os.Open` but not with fs.Fs - if !filepath.IsAbs(root) { - root = filepath.Join(dFS.prefix, root) - } - return &defaultFS{ - prefix: root, - fs: os.DirFS(root), - }, nil - } - return fs.Sub(currentFs, root) -} - -// MustSubFS creates sub FS from current filesystem or panic on failure. -// Panic happens when `fsRoot` contains invalid path according to `fs.ValidPath` rules. -// -// MustSubFS is helpful when dealing with `embed.FS` because for example `//go:embed assets/images` embeds files with -// paths including `assets/images` as their prefix. In that case use `fs := echo.MustSubFS(fs, "rootDirectory") to -// create sub fs which uses necessary prefix for directory path. -func MustSubFS(currentFs fs.FS, fsRoot string) fs.FS { - subFs, err := subFS(currentFs, fsRoot) - if err != nil { - panic(fmt.Errorf("can not create sub FS, invalid root given, err: %w", err)) - } - return subFs -} - -func sanitizeURI(uri string) string { - // double slash `\\`, `//` or even `\/` is absolute uri for browsers and by redirecting request to that uri - // we are vulnerable to open redirect attack. so replace all slashes from the beginning with single slash - if len(uri) > 1 && (uri[0] == '\\' || uri[0] == '/') && (uri[1] == '\\' || uri[1] == '/') { - uri = "/" + strings.TrimLeft(uri, `/\`) - } - return uri -} diff --git a/echo_fs_test.go b/echo_fs_test.go deleted file mode 100644 index ab8faa7fa..000000000 --- a/echo_fs_test.go +++ /dev/null @@ -1,271 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "strings" - "testing" -) - -func TestEcho_StaticFS(t *testing.T) { - var testCases = []struct { - name string - givenPrefix string - givenFs fs.FS - givenFsRoot string - whenURL string - expectStatus int - expectHeaderLocation string - expectBodyStartsWith string - }{ - { - name: "ok", - givenPrefix: "/images", - givenFs: os.DirFS("./_fixture/images"), - whenURL: "/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "ok, from sub fs", - givenPrefix: "/images", - givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), - whenURL: "/images/walle.png", - expectStatus: http.StatusOK, - expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), - }, - { - name: "No file", - givenPrefix: "/images", - givenFs: os.DirFS("_fixture/scripts"), - whenURL: "/images/bolt.png", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory", - givenPrefix: "/images", - givenFs: os.DirFS("_fixture/images"), - whenURL: "/images/", - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Directory Redirect", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: "/folder", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory Redirect with non-root path", - givenPrefix: "/static", - givenFs: os.DirFS("_fixture"), - whenURL: "/static", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/static/", - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory 404 (request URL without slash)", - givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenFs: os.DirFS("_fixture"), - whenURL: "/folder", // no trailing slash - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "Prefixed directory redirect (without slash redirect to slash)", - givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenFs: os.DirFS("_fixture"), - whenURL: "/folder", // no trailing slash - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/folder/", - expectBodyStartsWith: "", - }, - { - name: "Directory with index.html", - givenPrefix: "/", - givenFs: os.DirFS("_fixture"), - whenURL: "/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending with slash)", - givenPrefix: "/assets/", - givenFs: os.DirFS("_fixture"), - whenURL: "/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Prefixed directory with index.html (prefix ending without slash)", - givenPrefix: "/assets", - givenFs: os.DirFS("_fixture"), - whenURL: "/assets/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "Sub-directory with index.html", - givenPrefix: "/", - givenFs: os.DirFS("_fixture"), - whenURL: "/folder/", - expectStatus: http.StatusOK, - expectBodyStartsWith: "", - }, - { - name: "do not allow directory traversal (backslash - windows separator)", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: `/..\\middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "do not allow directory traversal (slash - unix separator)", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: `/../middleware/basic_auth.go`, - expectStatus: http.StatusNotFound, - expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", - }, - { - name: "open redirect vulnerability", - givenPrefix: "/", - givenFs: os.DirFS("_fixture/"), - whenURL: "/open.redirect.hackercom%2f..", - expectStatus: http.StatusMovedPermanently, - expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad - expectBodyStartsWith: "", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - - tmpFs := tc.givenFs - if tc.givenFsRoot != "" { - tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) - } - e.StaticFS(tc.givenPrefix, tmpFs) - - req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectStatus, rec.Code) - body := rec.Body.String() - if tc.expectBodyStartsWith != "" { - assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) - } else { - assert.Equal(t, "", body) - } - - if tc.expectHeaderLocation != "" { - assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) - } else { - _, ok := rec.Result().Header["Location"] - assert.False(t, ok) - } - }) - } -} - -func TestEcho_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenPath string - whenFile string - whenFS fs.FS - givenURL string - expectCode int - expectStartsWith []byte - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestEcho_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - }{ - { - name: "panics for ../", - givenRoot: "../assets", - }, - { - name: "panics for /", - givenRoot: "/assets", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - assert.Panics(t, func() { - e.Static("../assets", tc.givenRoot) - }) - }) - } -} diff --git a/echo_test.go b/echo_test.go index b7f32017a..f26eed8e2 100644 --- a/echo_test.go +++ b/echo_test.go @@ -6,23 +6,21 @@ package echo import ( "bytes" stdContext "context" - "crypto/tls" "errors" "fmt" - "io" + "io/fs" + "log/slog" "net" "net/http" "net/http/httptest" "net/url" "os" - "reflect" + "runtime" "strings" "testing" "time" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/net/http2" ) type user struct { @@ -62,33 +60,48 @@ func TestEcho(t *testing.T) { // Router assert.NotNil(t, e.Router()) - // DefaultHTTPErrorHandler - e.DefaultHTTPErrorHandler(errors.New("error"), c) + e.HTTPErrorHandler(c, errors.New("error")) + assert.Equal(t, http.StatusInternalServerError, rec.Code) } -func TestEchoStatic(t *testing.T) { +func TestNewWithConfig(t *testing.T) { + e := NewWithConfig(Config{}) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + + e.GET("/", func(c *Context) error { + return c.String(http.StatusTeapot, "Hello, World!") + }) + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, `Hello, World!`, rec.Body.String()) +} + +func TestEcho_StaticFS(t *testing.T) { var testCases = []struct { + givenFs fs.FS name string givenPrefix string - givenRoot string + givenFsRoot string whenURL string - expectStatus int expectHeaderLocation string expectBodyStartsWith string + expectStatus int }{ { name: "ok", givenPrefix: "/images", - givenRoot: "_fixture/images", + givenFs: os.DirFS("./_fixture/images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), }, { - name: "ok with relative path for root points to directory", + name: "ok, from sub fs", givenPrefix: "/images", - givenRoot: "./_fixture/images", + givenFs: MustSubFS(os.DirFS("./_fixture/"), "images"), whenURL: "/images/walle.png", expectStatus: http.StatusOK, expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), @@ -96,7 +109,7 @@ func TestEchoStatic(t *testing.T) { { name: "No file", givenPrefix: "/images", - givenRoot: "_fixture/scripts", + givenFs: os.DirFS("_fixture/scripts"), whenURL: "/images/bolt.png", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -104,7 +117,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory", givenPrefix: "/images", - givenRoot: "_fixture/images", + givenFs: os.DirFS("_fixture/images"), whenURL: "/images/", expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -112,7 +125,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory Redirect", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture/"), whenURL: "/folder", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -121,7 +134,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory Redirect with non-root path", givenPrefix: "/static", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/static", expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/static/", @@ -130,7 +143,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory 404 (request URL without slash)", givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -138,7 +151,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory redirect (without slash redirect to slash)", givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder", // no trailing slash expectStatus: http.StatusMovedPermanently, expectHeaderLocation: "/folder/", @@ -147,7 +160,7 @@ func TestEchoStatic(t *testing.T) { { name: "Directory with index.html", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -155,7 +168,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending with slash)", givenPrefix: "/assets/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -163,7 +176,7 @@ func TestEchoStatic(t *testing.T) { { name: "Prefixed directory with index.html (prefix ending without slash)", givenPrefix: "/assets", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/assets/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -171,7 +184,7 @@ func TestEchoStatic(t *testing.T) { { name: "Sub-directory with index.html", givenPrefix: "/", - givenRoot: "_fixture", + givenFs: os.DirFS("_fixture"), whenURL: "/folder/", expectStatus: http.StatusOK, expectBodyStartsWith: "", @@ -179,7 +192,7 @@ func TestEchoStatic(t *testing.T) { { name: "do not allow directory traversal (backslash - windows separator)", givenPrefix: "/", - givenRoot: "_fixture/", + givenFs: os.DirFS("_fixture/"), whenURL: `/..\\middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", @@ -187,20 +200,37 @@ func TestEchoStatic(t *testing.T) { { name: "do not allow directory traversal (slash - unix separator)", givenPrefix: "/", - givenRoot: "_fixture/", + givenFs: os.DirFS("_fixture/"), whenURL: `/../middleware/basic_auth.go`, expectStatus: http.StatusNotFound, expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", }, + { + name: "open redirect vulnerability", + givenPrefix: "/", + givenFs: os.DirFS("_fixture/"), + whenURL: "/open.redirect.hackercom%2f..", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/open.redirect.hackercom/../", // location starting with `//open` would be very bad + expectBodyStartsWith: "", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := New() - e.Static(tc.givenPrefix, tc.givenRoot) + + tmpFs := tc.givenFs + if tc.givenFsRoot != "" { + tmpFs = MustSubFS(tmpFs, tc.givenFsRoot) + } + e.StaticFS(tc.givenPrefix, tmpFs) + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, tc.expectStatus, rec.Code) body := rec.Body.String() if tc.expectBodyStartsWith != "" { @@ -219,44 +249,114 @@ func TestEchoStatic(t *testing.T) { } } -func TestEchoStaticRedirectIndex(t *testing.T) { - e := New() +func TestEcho_FileFS(t *testing.T) { + var testCases = []struct { + whenFS fs.FS + name string + whenPath string + whenFile string + givenURL string + expectStartsWith []byte + expectCode int + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } - // HandlerFunc - e.Static("/static", "_fixture") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - errCh := make(chan error) + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() - go func() { - errCh <- e.Start(":0") - }() + e.ServeHTTP(rec, req) - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) + assert.Equal(t, tc.expectCode, rec.Code) - addr := e.ListenerAddr().String() - if resp, err := http.Get("http://" + addr + "/static"); err == nil { // http.Get follows redirects by default - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - assert.Fail(t, err.Error()) + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] } - }(resp.Body) - assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} - if body, err := io.ReadAll(resp.Body); err == nil { - assert.Equal(t, true, strings.HasPrefix(string(body), "")) - } else { - assert.Fail(t, err.Error()) - } +func TestEcho_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + }{ + { + name: "panics for ../", + givenRoot: "../assets", + }, + { + name: "panics for /", + givenRoot: "/assets", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") - } else { - assert.NoError(t, err) + assert.Panics(t, func() { + e.Static("../assets", tc.givenRoot) + }) + }) } +} - if err := e.Close(); err != nil { - t.Fatal(err) +func TestEchoStaticRedirectIndex(t *testing.T) { + e := New() + + // HandlerFunc + ri := e.Static("/static", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method) + assert.Equal(t, "/static*", ri.Path) + assert.Equal(t, "GET:/static*", ri.Name) + assert.Equal(t, []string{"*"}, ri.Parameters) + + ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) + defer cancel() + addr, err := startOnRandomPort(ctx, e) + if err != nil { + assert.Fail(t, err.Error()) } + + code, body, err := doGet(fmt.Sprintf("http://%v/static", addr)) + assert.NoError(t, err) + assert.True(t, strings.HasPrefix(body, "")) + assert.Equal(t, http.StatusOK, code) } func TestEchoFile(t *testing.T) { @@ -265,8 +365,8 @@ func TestEchoFile(t *testing.T) { givenPath string givenFile string whenPath string - expectCode int expectStartsWith string + expectCode int }{ { name: "ok", @@ -315,36 +415,37 @@ func TestEchoMiddleware(t *testing.T) { buf := new(bytes.Buffer) e.Pre(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - assert.Empty(t, c.Path()) + return func(c *Context) error { + // before route match is found RouteInfo does not exist + assert.Equal(t, RouteInfo{}, c.RouteInfo()) buf.WriteString("-1") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("1") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("2") return next(c) } }) e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("3") return next(c) } }) // Route - e.GET("/", func(c Context) error { + e.GET("/", func(c *Context) error { return c.String(http.StatusOK, "OK") }) @@ -357,11 +458,11 @@ func TestEchoMiddleware(t *testing.T) { func TestEchoMiddlewareError(t *testing.T) { e := New() e.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return errors.New("error") } }) - e.GET("/", NotFoundHandler) + e.GET("/", notFoundHandler) c, _ := request(http.MethodGet, "/", e) assert.Equal(t, http.StatusInternalServerError, c) } @@ -370,7 +471,7 @@ func TestEchoHandler(t *testing.T) { e := New() // HandlerFunc - e.GET("/ok", func(c Context) error { + e.GET("/ok", func(c *Context) error { return c.String(http.StatusOK, "OK") }) @@ -381,230 +482,256 @@ func TestEchoHandler(t *testing.T) { func TestEchoWrapHandler(t *testing.T) { e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + var actualID string + var actualPattern string + e.GET("/:id", WrapHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte("test")) - if err != nil { - assert.Fail(t, err.Error()) - } - })) - if assert.NoError(t, h(c)) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "test", rec.Body.String()) - } + w.Write([]byte("test")) + actualID = r.PathValue("id") + actualPattern = r.Pattern + }))) + + req := httptest.NewRequest(http.MethodGet, "/123", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "test", rec.Body.String()) + assert.Equal(t, "123", actualID) + assert.Equal(t, "/:id", actualPattern) } func TestEchoWrapMiddleware(t *testing.T) { e := New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - buf := new(bytes.Buffer) - mw := WrapMiddleware(func(h http.Handler) http.Handler { + + var actualID string + var actualPattern string + e.Use(WrapMiddleware(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - buf.Write([]byte("mw")) + actualID = r.PathValue("id") + actualPattern = r.Pattern h.ServeHTTP(w, r) }) + })) + + e.GET("/:id", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") }) - h := mw(func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - if assert.NoError(t, h(c)) { - assert.Equal(t, "mw", buf.String()) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "OK", rec.Body.String()) - } + + req := httptest.NewRequest(http.MethodGet, "/123", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusTeapot, rec.Code) + assert.Equal(t, "OK", rec.Body.String()) + assert.Equal(t, "123", actualID) + assert.Equal(t, "/:id", actualPattern) } func TestEchoConnect(t *testing.T) { e := New() - testMethod(t, http.MethodConnect, "/", e) + + ri := e.CONNECT("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodConnect+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodConnect, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoDelete(t *testing.T) { e := New() - testMethod(t, http.MethodDelete, "/", e) + + ri := e.DELETE("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodDelete+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodDelete, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoGet(t *testing.T) { e := New() - testMethod(t, http.MethodGet, "/", e) + + ri := e.GET("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodGet, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodGet+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodGet, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoHead(t *testing.T) { e := New() - testMethod(t, http.MethodHead, "/", e) + + ri := e.HEAD("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodHead+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodHead, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoOptions(t *testing.T) { e := New() - testMethod(t, http.MethodOptions, "/", e) + + ri := e.OPTIONS("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodOptions+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodOptions, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPatch(t *testing.T) { e := New() - testMethod(t, http.MethodPatch, "/", e) + + ri := e.PATCH("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodPatch+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPatch, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPost(t *testing.T) { e := New() - testMethod(t, http.MethodPost, "/", e) + + ri := e.POST("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodPost+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPost, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoPut(t *testing.T) { e := New() - testMethod(t, http.MethodPut, "/", e) + + ri := e.PUT("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodPut+":/", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPut, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } func TestEchoTrace(t *testing.T) { e := New() - testMethod(t, http.MethodTrace, "/", e) -} -func TestEchoAny(t *testing.T) { // JFC - e := New() - e.Any("/", func(c Context) error { - return c.String(http.StatusOK, "Any") + ri := e.TRACE("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") }) -} -func TestEchoMatch(t *testing.T) { // JFC - e := New() - e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c Context) error { - return c.String(http.StatusOK, "Match") - }) -} + assert.Equal(t, http.MethodTrace, ri.Method) + assert.Equal(t, "/", ri.Path) + assert.Equal(t, http.MethodTrace+":/", ri.Name) + assert.Nil(t, ri.Parameters) -func TestEchoURL(t *testing.T) { - e := New() - static := func(Context) error { return nil } - getUser := func(Context) error { return nil } - getAny := func(Context) error { return nil } - getFile := func(Context) error { return nil } - - e.GET("/static/file", static) - e.GET("/users/:id", getUser) - e.GET("/documents/*", getAny) - g := e.Group("/group") - g.GET("/users/:uid/files/:fid", getFile) - - assert.Equal(t, "/static/file", e.URL(static)) - assert.Equal(t, "/users/:id", e.URL(getUser)) - assert.Equal(t, "/users/1", e.URL(getUser, "1")) - assert.Equal(t, "/users/1", e.URL(getUser, "1")) - assert.Equal(t, "/documents/foo.txt", e.URL(getAny, "foo.txt")) - assert.Equal(t, "/documents/*", e.URL(getAny)) - assert.Equal(t, "/group/users/1/files/:fid", e.URL(getFile, "1")) - assert.Equal(t, "/group/users/1/files/1", e.URL(getFile, "1", "1")) + status, body := request(http.MethodTrace, "/", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, "OK", body) } -func TestEchoRoutes(t *testing.T) { +func TestEcho_Any(t *testing.T) { e := New() - routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, - } - for _, r := range routes { - e.Add(r.Method, r.Path, func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - } - if assert.Equal(t, len(routes), len(e.Routes())) { - for _, r := range e.Routes() { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } - } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) - } - } - } + ri := e.Any("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK from ANY") + }) + + assert.Equal(t, RouteAny, ri.Method) + assert.Equal(t, "/activate", ri.Path) + assert.Equal(t, RouteAny+":/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodTrace, "/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK from ANY`, body) } -func TestEchoRoutesHandleAdditionalHosts(t *testing.T) { +func TestEcho_Any_hasLowerPriority(t *testing.T) { e := New() - domain2Router := e.Host("domain2.router.com") - routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, - } - for _, r := range routes { - domain2Router.Add(r.Method, r.Path, func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - } - e.Add(http.MethodGet, "/api", func(c Context) error { - return c.String(http.StatusOK, "OK") + + e.Any("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "ANY") + }) + e.GET("/activate", func(c *Context) error { + return c.String(http.StatusLocked, "GET") }) - domain2Routes := e.Routers()["domain2.router.com"].Routes() + status, body := request(http.MethodTrace, "/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `ANY`, body) - assert.Len(t, domain2Routes, len(routes)) - for _, r := range domain2Routes { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } - } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) - } - } + status, body = request(http.MethodGet, "/activate", e) + assert.Equal(t, http.StatusLocked, status) + assert.Equal(t, `GET`, body) } -func TestEchoRoutesHandleDefaultHost(t *testing.T) { +func TestEchoMatch(t *testing.T) { // JFC e := New() - routes := []*Route{ - {http.MethodGet, "/users/:user/events", ""}, - {http.MethodGet, "/users/:user/events/public", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/refs", ""}, - {http.MethodPost, "/repos/:owner/:repo/git/tags", ""}, - } - for _, r := range routes { - e.Add(r.Method, r.Path, func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - } - e.Host("subdomain.mysite.site").Add(http.MethodGet, "/api", func(c Context) error { - return c.String(http.StatusOK, "OK") + ris := e.Match([]string{http.MethodGet, http.MethodPost}, "/", func(c *Context) error { + return c.String(http.StatusOK, "Match") }) - - defaultRouterRoutes := e.Routes() - assert.Len(t, defaultRouterRoutes, len(routes)) - for _, r := range defaultRouterRoutes { - found := false - for _, rr := range routes { - if r.Method == rr.Method && r.Path == rr.Path { - found = true - break - } - } - if !found { - t.Errorf("Route %s %s not found", r.Method, r.Path) - } - } + assert.Len(t, ris, 2) } func TestEchoServeHTTPPathEncoding(t *testing.T) { e := New() - e.GET("/with/slash", func(c Context) error { + e.GET("/with/slash", func(c *Context) error { return c.String(http.StatusOK, "/with/slash") }) - e.GET("/:id", func(c Context) error { + e.GET("/:id", func(c *Context) error { return c.String(http.StatusOK, c.Param("id")) }) @@ -641,117 +768,16 @@ func TestEchoServeHTTPPathEncoding(t *testing.T) { } } -func TestEchoHost(t *testing.T) { - okHandler := func(c Context) error { return c.String(http.StatusOK, http.StatusText(http.StatusOK)) } - teapotHandler := func(c Context) error { return c.String(http.StatusTeapot, http.StatusText(http.StatusTeapot)) } - acceptHandler := func(c Context) error { return c.String(http.StatusAccepted, http.StatusText(http.StatusAccepted)) } - teapotMiddleware := MiddlewareFunc(func(next HandlerFunc) HandlerFunc { return teapotHandler }) - - e := New() - e.GET("/", acceptHandler) - e.GET("/foo", acceptHandler) - - ok := e.Host("ok.com") - ok.GET("/", okHandler) - ok.GET("/foo", okHandler) - - teapot := e.Host("teapot.com") - teapot.GET("/", teapotHandler) - teapot.GET("/foo", teapotHandler) - - middle := e.Host("middleware.com", teapotMiddleware) - middle.GET("/", okHandler) - middle.GET("/foo", okHandler) - - var testCases = []struct { - name string - whenHost string - whenPath string - expectBody string - expectStatus int - }{ - { - name: "No Host Root", - whenHost: "", - whenPath: "/", - expectBody: http.StatusText(http.StatusAccepted), - expectStatus: http.StatusAccepted, - }, - { - name: "No Host Foo", - whenHost: "", - whenPath: "/foo", - expectBody: http.StatusText(http.StatusAccepted), - expectStatus: http.StatusAccepted, - }, - { - name: "OK Host Root", - whenHost: "ok.com", - whenPath: "/", - expectBody: http.StatusText(http.StatusOK), - expectStatus: http.StatusOK, - }, - { - name: "OK Host Foo", - whenHost: "ok.com", - whenPath: "/foo", - expectBody: http.StatusText(http.StatusOK), - expectStatus: http.StatusOK, - }, - { - name: "Teapot Host Root", - whenHost: "teapot.com", - whenPath: "/", - expectBody: http.StatusText(http.StatusTeapot), - expectStatus: http.StatusTeapot, - }, - { - name: "Teapot Host Foo", - whenHost: "teapot.com", - whenPath: "/foo", - expectBody: http.StatusText(http.StatusTeapot), - expectStatus: http.StatusTeapot, - }, - { - name: "Middleware Host", - whenHost: "middleware.com", - whenPath: "/", - expectBody: http.StatusText(http.StatusTeapot), - expectStatus: http.StatusTeapot, - }, - { - name: "Middleware Host Foo", - whenHost: "middleware.com", - whenPath: "/foo", - expectBody: http.StatusText(http.StatusTeapot), - expectStatus: http.StatusTeapot, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, tc.whenPath, nil) - req.Host = tc.whenHost - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectStatus, rec.Code) - assert.Equal(t, tc.expectBody, rec.Body.String()) - }) - } -} - func TestEchoGroup(t *testing.T) { e := New() buf := new(bytes.Buffer) e.Use(MiddlewareFunc(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("0") return next(c) } })) - h := func(c Context) error { + h := func(c *Context) error { return c.NoContent(http.StatusOK) } @@ -764,7 +790,7 @@ func TestEchoGroup(t *testing.T) { // Group g1 := e.Group("/group1") g1.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("1") return next(c) } @@ -774,14 +800,14 @@ func TestEchoGroup(t *testing.T) { // Nested groups with middleware g2 := e.Group("/group2") g2.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("2") return next(c) } }) g3 := g2.Group("/group3") g3.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { buf.WriteString("3") return next(c) } @@ -800,19 +826,11 @@ func TestEchoGroup(t *testing.T) { assert.Equal(t, "023", buf.String()) } -func TestEchoNotFound(t *testing.T) { - e := New() - req := httptest.NewRequest(http.MethodGet, "/files", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusNotFound, rec.Code) -} - func TestEcho_RouteNotFound(t *testing.T) { var testCases = []struct { + expectRoute any name string whenURL string - expectRoute interface{} expectCode int }{ { @@ -845,10 +863,10 @@ func TestEcho_RouteNotFound(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := New() - okHandler := func(c Context) error { + okHandler := func(c *Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } - notFoundHandler := func(c Context) error { + notFoundHandler := func(c *Context) error { return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) } @@ -872,10 +890,18 @@ func TestEcho_RouteNotFound(t *testing.T) { } } +func TestEchoNotFound(t *testing.T) { + e := New() + req := httptest.NewRequest(http.MethodGet, "/files", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) +} + func TestEchoMethodNotAllowed(t *testing.T) { e := New() - e.GET("/", func(c Context) error { + e.GET("/", func(c *Context) error { return c.String(http.StatusOK, "Echo!") }) req := httptest.NewRequest(http.MethodPost, "/", nil) @@ -886,348 +912,133 @@ func TestEchoMethodNotAllowed(t *testing.T) { assert.Equal(t, "OPTIONS, GET", rec.Header().Get(HeaderAllow)) } -func TestEchoContext(t *testing.T) { - e := New() - c := e.AcquireContext() - assert.IsType(t, new(context), c) - e.ReleaseContext(c) -} - -func waitForServerStart(e *Echo, errChan <-chan error, isTLS bool) error { - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 200*time.Millisecond) - defer cancel() - - ticker := time.NewTicker(5 * time.Millisecond) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-ticker.C: - var addr net.Addr - if isTLS { - addr = e.TLSListenerAddr() - } else { - addr = e.ListenerAddr() - } - if addr != nil && strings.Contains(addr.String(), ":") { - return nil // was started - } - case err := <-errChan: - if err == http.ErrServerClosed { - return nil - } - return err - } +func TestEcho_OnAddRoute(t *testing.T) { + exampleRoute := Route{ + Method: http.MethodGet, + Path: "/api/files/:id", + Handler: notFoundHandler, + Middlewares: nil, + Name: "x", } -} - -func TestEchoStart(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - err := e.Start(":0") - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - assert.NoError(t, err) - - assert.NoError(t, e.Close()) -} -func TestEcho_StartTLS(t *testing.T) { var testCases = []struct { + whenRoute Route + whenError error name string - addr string - certFile string - keyFile string expectError string + expectAdded []string + expectLen int }{ { - name: "ok", - addr: ":0", + name: "ok", + whenRoute: exampleRoute, + whenError: nil, + expectAdded: []string{"/static", "/api/files/:id"}, + expectError: "", + expectLen: 2, }, { - name: "nok, invalid certFile", - addr: ":0", - certFile: "not existing", - expectError: "open not existing: no such file or directory", - }, - { - name: "nok, invalid keyFile", - addr: ":0", - keyFile: "not existing", - expectError: "open not existing: no such file or directory", - }, - { - name: "nok, failed to create cert out of certFile and keyFile", - addr: ":0", - keyFile: "_fixture/certs/cert.pem", // we are passing cert instead of key - expectError: "tls: found a certificate rather than a key in the PEM for the private key", - }, - { - name: "nok, invalid tls address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", + name: "nok, error is returned", + whenRoute: exampleRoute, + whenError: errors.New("nope"), + expectAdded: []string{"/static"}, + expectError: "nope", + expectLen: 1, }, } - for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + e := New() - errChan := make(chan error) - go func() { - certFile := "_fixture/certs/cert.pem" - if tc.certFile != "" { - certFile = tc.certFile - } - keyFile := "_fixture/certs/key.pem" - if tc.keyFile != "" { - keyFile = tc.keyFile + added := make([]string, 0) + cnt := 0 + e.OnAddRoute = func(route Route) error { + if cnt > 0 && tc.whenError != nil { // we want to GET /static to succeed for nok tests + return tc.whenError } + cnt++ + added = append(added, route.Path) + return nil + } - err := e.StartTLS(tc.addr, certFile, keyFile) - if err != nil { - errChan <- err - } - }() + e.GET("/static", notFoundHandler) + + var err error + _, err = e.AddRoute(tc.whenRoute) - err := waitForServerStart(e, errChan, true) if tc.expectError != "" { - if _, ok := err.(*os.PathError); ok { - assert.Error(t, err) // error messages for unix and windows are different. so test only error type here - } else { - assert.EqualError(t, err, tc.expectError) - } + assert.EqualError(t, err, tc.expectError) } else { assert.NoError(t, err) } - assert.NoError(t, e.Close()) + assert.Len(t, e.Router().Routes(), tc.expectLen) + assert.Equal(t, tc.expectAdded, added) }) } } -func TestEchoStartTLSAndStart(t *testing.T) { - // We test if Echo and listeners work correctly when Echo is simultaneously attached to HTTP and HTTPS server +func TestEchoContext(t *testing.T) { e := New() - e.GET("/", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - errTLSChan := make(chan error) - go func() { - certFile := "_fixture/certs/cert.pem" - keyFile := "_fixture/certs/key.pem" - err := e.StartTLS("localhost:", certFile, keyFile) - if err != nil { - errTLSChan <- err - } - }() - - err := waitForServerStart(e, errTLSChan, true) - assert.NoError(t, err) - defer func() { - if err := e.Shutdown(stdContext.Background()); err != nil { - t.Error(err) - } - }() - - // check if HTTPS works (note: we are using self signed certs so InsecureSkipVerify=true) - client := &http.Client{Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - }} - res, err := client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - errChan := make(chan error) - go func() { - err := e.Start("localhost:") - if err != nil { - errChan <- err - } - }() - err = waitForServerStart(e, errChan, false) - assert.NoError(t, err) - - // now we are serving both HTTPS and HTTP listeners. see if HTTP works in addition to HTTPS - res, err = http.Get("http://" + e.ListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) - - // see if HTTPS works after HTTP listener is also added - res, err = client.Get("https://" + e.TLSListenerAddr().String()) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, res.StatusCode) + c := e.AcquireContext() + assert.IsType(t, new(Context), c) + e.ReleaseContext(c) } -func TestEchoStartTLSByteString(t *testing.T) { - cert, err := os.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := os.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - - testCases := []struct { - cert interface{} - key interface{} - expectedErr error - name string - }{ - { - cert: "_fixture/certs/cert.pem", - key: "_fixture/certs/key.pem", - expectedErr: nil, - name: `ValidCertAndKeyFilePath`, - }, - { - cert: cert, - key: key, - expectedErr: nil, - name: `ValidCertAndKeyByteString`, - }, - { - cert: cert, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidKeyType`, - }, - { - cert: 0, - key: key, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertType`, - }, - { - cert: 0, - key: 1, - expectedErr: ErrInvalidCertOrKeyType, - name: `InvalidCertAndKeyTypes`, - }, - } - - for _, test := range testCases { - test := test - t.Run(test.name, func(t *testing.T) { - e := New() - e.HideBanner = true - - errChan := make(chan error) - - go func() { - errChan <- e.StartTLS(":0", test.cert, test.key) - }() +func TestPreMiddlewares(t *testing.T) { + e := New() + assert.Equal(t, 0, len(e.PreMiddlewares())) - err := waitForServerStart(e, errChan, true) - if test.expectedErr != nil { - assert.EqualError(t, err, test.expectedErr.Error()) - } else { - assert.NoError(t, err) - } + e.Pre(func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + return next(c) + } + }) - assert.NoError(t, e.Close()) - }) - } + assert.Equal(t, 1, len(e.PreMiddlewares())) } -func TestEcho_StartAutoTLS(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - errChan := make(chan error) - - go func() { - errChan <- e.StartAutoTLS(tc.addr) - }() +func TestMiddlewares(t *testing.T) { + e := New() + assert.Equal(t, 0, len(e.Middlewares())) - err := waitForServerStart(e, errChan, true) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } + e.Use(func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + return next(c) + } + }) - assert.NoError(t, e.Close()) - }) - } + assert.Equal(t, 1, len(e.Middlewares())) } -func TestEcho_StartH2CServer(t *testing.T) { - var testCases = []struct { - name string - addr string - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, +func TestEcho_Start(t *testing.T) { + e := New() + e.GET("/", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + rndPort, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) } + defer rndPort.Close() + errChan := make(chan error, 1) + go func() { + errChan <- e.Start(rndPort.Addr().String()) + }() - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Debug = true - h2s := &http2.Server{} - - errChan := make(chan error) - go func() { - err := e.StartH2CServer(tc.addr, h2s) - if err != nil { - errChan <- err - } - }() - - err := waitForServerStart(e, errChan, false) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - - assert.NoError(t, e.Close()) - }) + select { + case <-time.After(250 * time.Millisecond): + t.Fatal("start did not error out") + case err := <-errChan: + expectContains := "bind: address already in use" + if runtime.GOOS == "windows" { + expectContains = "bind: Only one usage of each socket address" + } + assert.Contains(t, err.Error(), expectContains) } } -func testMethod(t *testing.T, method, path string, e *Echo) { - p := reflect.ValueOf(path) - h := reflect.ValueOf(func(c Context) error { - return c.String(http.StatusOK, method) - }) - i := interface{}(e) - reflect.ValueOf(i).MethodByName(method).Call([]reflect.Value{p, h}) - _, body := request(method, path, e) - assert.Equal(t, method, body) -} - func request(method, path string, e *Echo) (int, string) { req := httptest.NewRequest(method, path, nil) rec := httptest.NewRecorder() @@ -1235,589 +1046,143 @@ func request(method, path string, e *Echo) (int, string) { return rec.Code, rec.Body.String() } -func TestHTTPError(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Equal(t, "code=400, message=map[code:12]", err.Error()) - }) - - t.Run("internal and SetInternal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) - }) - - t.Run("internal and WithInternal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err = err.WithInternal(errors.New("internal error")) - assert.Equal(t, "code=400, message=map[code:12], internal=internal error", err.Error()) - }) -} - -func TestHTTPError_Unwrap(t *testing.T) { - t.Run("non-internal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - - assert.Nil(t, errors.Unwrap(err)) - }) - - t.Run("unwrap internal and SetInternal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err.SetInternal(errors.New("internal error")) - assert.Equal(t, "internal error", errors.Unwrap(err).Error()) - }) - - t.Run("unwrap internal and WithInternal", func(t *testing.T) { - err := NewHTTPError(http.StatusBadRequest, map[string]interface{}{ - "code": 12, - }) - err = err.WithInternal(errors.New("internal error")) - assert.Equal(t, "internal error", errors.Unwrap(err).Error()) - }) +type customError struct { + Code int + Message string } -type customError struct { - s string +func (ce *customError) StatusCode() int { + return ce.Code } func (ce *customError) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.s)), nil + return []byte(fmt.Sprintf(`{"x":"%v"}`, ce.Message)), nil } func (ce *customError) Error() string { - return ce.s + return ce.Message } func TestDefaultHTTPErrorHandler(t *testing.T) { var testCases = []struct { - name string - givenDebug bool - whenPath string - expectCode int - expectBody string + whenError error + name string + whenMethod string + expectBody string + expectLogged string + expectStatus int + givenExposeError bool + givenLoggerFunc bool }{ { - name: "with Debug=true plain response contains error message", - givenDebug: true, - whenPath: "/plain", - expectCode: http.StatusInternalServerError, - expectBody: "{\n \"error\": \"an error occurred\",\n \"message\": \"Internal Server Error\"\n}\n", + name: "ok, expose error = true, HTTPError, no wrapped err", + givenExposeError: true, + whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"}, + expectStatus: http.StatusTeapot, + expectBody: `{"message":"my_error"}` + "\n", }, { - name: "with Debug=true special handling for HTTPError", - givenDebug: true, - whenPath: "/badrequest", - expectCode: http.StatusBadRequest, - expectBody: "{\n \"error\": \"code=400, message=Invalid request\",\n \"message\": \"Invalid request\"\n}\n", + name: "ok, expose error = true, HTTPError + wrapped error", + givenExposeError: true, + whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(errors.New("internal_error")), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"internal_error","message":"my_error"}` + "\n", }, { - name: "with Debug=true complex errors are serialized to pretty JSON", - givenDebug: true, - whenPath: "/servererror", - expectCode: http.StatusInternalServerError, - expectBody: "{\n \"code\": 33,\n \"error\": \"stackinfo\",\n \"message\": \"Something bad happened\"\n}\n", + name: "ok, expose error = true, HTTPError + wrapped HTTPError", + givenExposeError: true, + whenError: HTTPError{Code: http.StatusTeapot, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}), + expectStatus: http.StatusTeapot, + expectBody: `{"error":"code=418, message=early_error","message":"my_error"}` + "\n", }, { - name: "with Debug=true if the body is already set HTTPErrorHandler should not add anything to response body", - givenDebug: true, - whenPath: "/early-return", - expectCode: http.StatusOK, - expectBody: "OK", + name: "ok, expose error = false, HTTPError", + whenError: &HTTPError{Code: http.StatusTeapot, Message: "my_error"}, + expectStatus: http.StatusTeapot, + expectBody: `{"message":"my_error"}` + "\n", }, { - name: "with Debug=true internal error should be reflected in the message", - givenDebug: true, - whenPath: "/internal-error", - expectCode: http.StatusBadRequest, - expectBody: "{\n \"error\": \"code=400, message=Bad Request, internal=internal error message body\",\n \"message\": \"Bad Request\"\n}\n", + name: "ok, expose error = false, HTTPError, no message", + whenError: &HTTPError{Code: http.StatusTeapot, Message: ""}, + expectStatus: http.StatusTeapot, + expectBody: `{"message":"I'm a teapot"}` + "\n", }, { - name: "with Debug=false the error response is shortened", - whenPath: "/plain", - expectCode: http.StatusInternalServerError, - expectBody: "{\"message\":\"Internal Server Error\"}\n", + name: "ok, expose error = false, HTTPError + internal HTTPError", + whenError: HTTPError{Code: http.StatusTooEarly, Message: "my_error"}.Wrap(&HTTPError{Code: http.StatusTeapot, Message: "early_error"}), + expectStatus: http.StatusTooEarly, + expectBody: `{"message":"my_error"}` + "\n", }, { - name: "with Debug=false the error response is shortened", - whenPath: "/badrequest", - expectCode: http.StatusBadRequest, - expectBody: "{\"message\":\"Invalid request\"}\n", + name: "ok, expose error = true, Error", + givenExposeError: true, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"error":"my errors wraps: internal_error","message":"Internal Server Error"}` + "\n", }, { - name: "with Debug=false No difference for error response with non plain string errors", - whenPath: "/servererror", - expectCode: http.StatusInternalServerError, - expectBody: "{\"code\":33,\"error\":\"stackinfo\",\"message\":\"Something bad happened\"}\n", + name: "ok, expose error = false, Error", + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: `{"message":"Internal Server Error"}` + "\n", }, { - name: "with Debug=false when httpError contains an error", - whenPath: "/error-in-httperror", - expectCode: http.StatusBadRequest, - expectBody: "{\"message\":\"error in httperror\"}\n", + name: "ok, http.HEAD, expose error = true, Error", + givenExposeError: true, + whenMethod: http.MethodHead, + whenError: fmt.Errorf("my errors wraps: %w", errors.New("internal_error")), + expectStatus: http.StatusInternalServerError, + expectBody: ``, }, { - name: "with Debug=false when httpError contains an error", - whenPath: "/customerror-in-httperror", - expectCode: http.StatusBadRequest, - expectBody: "{\"x\":\"custom error msg\"}\n", + name: "ok, custom error implement MarshalJSON + HTTPStatusCoder", + whenMethod: http.MethodGet, + whenError: &customError{Code: http.StatusTeapot, Message: "custom error msg"}, + expectStatus: http.StatusTeapot, + expectBody: `{"x":"custom error msg"}` + "\n", }, } + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { + buf := new(bytes.Buffer) e := New() - e.Debug = tc.givenDebug // With Debug=true plain response contains error message - - e.Any("/plain", func(c Context) error { - return errors.New("an error occurred") - }) - - e.Any("/badrequest", func(c Context) error { // and special handling for HTTPError - return NewHTTPError(http.StatusBadRequest, "Invalid request") - }) - - e.Any("/servererror", func(c Context) error { // complex errors are serialized to pretty JSON - return NewHTTPError(http.StatusInternalServerError, map[string]interface{}{ - "code": 33, - "message": "Something bad happened", - "error": "stackinfo", - }) - }) - - // if the body is already set HTTPErrorHandler should not add anything to response body - e.Any("/early-return", func(c Context) error { - err := c.String(http.StatusOK, "OK") - if err != nil { - assert.Fail(t, err.Error()) - } - return errors.New("ERROR") - }) - - // internal error should be reflected in the message - e.GET("/internal-error", func(c Context) error { - err := errors.New("internal error message body") - return NewHTTPError(http.StatusBadRequest).SetInternal(err) - }) - - e.GET("/error-in-httperror", func(c Context) error { - return NewHTTPError(http.StatusBadRequest, errors.New("error in httperror")) - }) - - e.GET("/customerror-in-httperror", func(c Context) error { - return NewHTTPError(http.StatusBadRequest, &customError{s: "custom error msg"}) - }) - - c, b := request(http.MethodGet, tc.whenPath, e) - assert.Equal(t, tc.expectCode, c) - assert.Equal(t, tc.expectBody, b) - }) - } -} - -func TestEchoClose(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - assert.NoError(t, e.Close()) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -func TestEchoShutdown(t *testing.T) { - e := New() - errCh := make(chan error) - - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if err := e.Close(); err != nil { - t.Fatal(err) - } - - ctx, cancel := stdContext.WithTimeout(stdContext.Background(), 10*time.Second) - defer cancel() - assert.NoError(t, e.Shutdown(ctx)) - - err = <-errCh - assert.Equal(t, err.Error(), "http: Server closed") -} - -var listenerNetworkTests = []struct { - test string - network string - address string -}{ - {"tcp ipv4 address", "tcp", "127.0.0.1:1323"}, - {"tcp ipv6 address", "tcp", "[::1]:1323"}, - {"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"}, - {"tcp6 ipv6 address", "tcp6", "[::1]:1323"}, -} - -func supportsIPv6() bool { - addrs, _ := net.InterfaceAddrs() - for _, addr := range addrs { - // Check if any interface has local IPv6 assigned - if strings.Contains(addr.String(), "::1") { - return true - } - } - return false -} - -func TestEchoListenerNetwork(t *testing.T) { - hasIPv6 := supportsIPv6() - for _, tt := range listenerNetworkTests { - if !hasIPv6 && strings.Contains(tt.address, "::") { - t.Skip("Skipping testing IPv6 for " + tt.address + ", not available") - continue - } - t.Run(tt.test, func(t *testing.T) { - e := New() - e.ListenerNetwork = tt.network - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") + e.Logger = slog.New(slog.DiscardHandler) + e.Any("/path", func(c *Context) error { + return tc.whenError }) - errCh := make(chan error) - - go func() { - errCh <- e.Start(tt.address) - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) - - if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil { - defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { - assert.Fail(t, err.Error()) - } - }(resp.Body) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - if body, err := io.ReadAll(resp.Body); err == nil { - assert.Equal(t, "OK", string(body)) - } else { - assert.Fail(t, err.Error()) - } - - } else { - assert.Fail(t, err.Error()) - } + e.HTTPErrorHandler = DefaultHTTPErrorHandler(tc.givenExposeError) - if err := e.Close(); err != nil { - t.Fatal(err) + method := http.MethodGet + if tc.whenMethod != "" { + method = tc.whenMethod } - }) - } -} - -func TestEchoListenerNetworkInvalid(t *testing.T) { - e := New() - e.ListenerNetwork = "unix" - - // HandlerFunc - e.GET("/ok", func(c Context) error { - return c.String(http.StatusOK, "OK") - }) - - assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323")) -} - -func TestEcho_OnAddRouteHandler(t *testing.T) { - type rr struct { - host string - route Route - handler HandlerFunc - middleware []MiddlewareFunc - } - dummyHandler := func(Context) error { return nil } - e := New() - - added := make([]rr, 0) - e.OnAddRouteHandler = func(host string, route Route, handler HandlerFunc, middleware []MiddlewareFunc) { - added = append(added, rr{ - host: host, - route: route, - handler: handler, - middleware: middleware, - }) - } - - e.GET("/static", dummyHandler) - e.Host("domain.site").GET("/static/*", dummyHandler, func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return next(c) - } - }) - - assert.Len(t, added, 2) - - assert.Equal(t, "", added[0].host) - assert.Equal(t, Route{Method: http.MethodGet, Path: "/static", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[0].route) - assert.Len(t, added[0].middleware, 0) - - assert.Equal(t, "domain.site", added[1].host) - assert.Equal(t, Route{Method: http.MethodGet, Path: "/static/*", Name: "github.com/labstack/echo/v4.TestEcho_OnAddRouteHandler.func1"}, added[1].route) - assert.Len(t, added[1].middleware, 1) -} - -func TestEchoReverse(t *testing.T) { - var testCases = []struct { - name string - whenRouteName string - whenParams []interface{} - expect string - }{ - { - name: "ok, not existing path returns empty url", - whenRouteName: "not-existing", - expect: "", - }, - { - name: "ok,static with no params", - whenRouteName: "/static", - expect: "/static", - }, - { - name: "ok,static with non existent param", - whenRouteName: "/static", - whenParams: []interface{}{"missing param"}, - expect: "/static", - }, - { - name: "ok, wildcard with no params", - whenRouteName: "/static/*", - expect: "/static/*", - }, - { - name: "ok, wildcard with params", - whenRouteName: "/static/*", - whenParams: []interface{}{"foo.txt"}, - expect: "/static/foo.txt", - }, - { - name: "ok, single param without param", - whenRouteName: "/params/:foo", - expect: "/params/:foo", - }, - { - name: "ok, single param with param", - whenRouteName: "/params/:foo", - whenParams: []interface{}{"one"}, - expect: "/params/one", - }, - { - name: "ok, multi param without params", - whenRouteName: "/params/:foo/bar/:qux", - expect: "/params/:foo/bar/:qux", - }, - { - name: "ok, multi param with one param", - whenRouteName: "/params/:foo/bar/:qux", - whenParams: []interface{}{"one"}, - expect: "/params/one/bar/:qux", - }, - { - name: "ok, multi param with all params", - whenRouteName: "/params/:foo/bar/:qux", - whenParams: []interface{}{"one", "two"}, - expect: "/params/one/bar/two", - }, - { - name: "ok, multi param + wildcard with all params", - whenRouteName: "/params/:foo/bar/:qux/*", - whenParams: []interface{}{"one", "two", "three"}, - expect: "/params/one/bar/two/three", - }, - { - name: "ok, backslash is not escaped", - whenRouteName: "/backslash", - whenParams: []interface{}{"test"}, - expect: `/a\b/test`, - }, - { - name: "ok, escaped colon verbs", - whenRouteName: "/params:customVerb", - whenParams: []interface{}{"PATCH"}, - expect: `/params:PATCH`, - }, - } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - dummyHandler := func(Context) error { return nil } - - e.GET("/static", dummyHandler).Name = "/static" - e.GET("/static/*", dummyHandler).Name = "/static/*" - e.GET("/params/:foo", dummyHandler).Name = "/params/:foo" - e.GET("/params/:foo/bar/:qux", dummyHandler).Name = "/params/:foo/bar/:qux" - e.GET("/params/:foo/bar/:qux/*", dummyHandler).Name = "/params/:foo/bar/:qux/*" - e.GET("/a\\b/:x", dummyHandler).Name = "/backslash" - e.GET("/params\\::customVerb", dummyHandler).Name = "/params:customVerb" + c, b := request(method, "/path", e) - assert.Equal(t, tc.expect, e.Reverse(tc.whenRouteName, tc.whenParams...)) + assert.Equal(t, tc.expectStatus, c) + assert.Equal(t, tc.expectBody, b) + assert.Equal(t, tc.expectLogged, buf.String()) }) } } -func TestEchoReverseHandleHostProperly(t *testing.T) { - dummyHandler := func(Context) error { return nil } - - e := New() - - // routes added to the default router are different form different hosts - e.GET("/static", dummyHandler).Name = "default-host /static" - e.GET("/static/*", dummyHandler).Name = "xxx" - - // different host - h := e.Host("the_host") - h.GET("/static", dummyHandler).Name = "host2 /static" - h.GET("/static/v2/*", dummyHandler).Name = "xxx" - - assert.Equal(t, "/static", e.Reverse("default-host /static")) - // when actual route does not have params and we provide some to Reverse we should get that route url back - assert.Equal(t, "/static", e.Reverse("default-host /static", "missing param")) - - host2Router := e.Routers()["the_host"] - assert.Equal(t, "/static", host2Router.Reverse("host2 /static")) - assert.Equal(t, "/static", host2Router.Reverse("host2 /static", "missing param")) - - assert.Equal(t, "/static/v2/*", host2Router.Reverse("xxx")) - assert.Equal(t, "/static/v2/foo.txt", host2Router.Reverse("xxx", "foo.txt")) - -} - -func TestEcho_ListenerAddr(t *testing.T) { - e := New() - - addr := e.ListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.Start(":0") - }() - - err := waitForServerStart(e, errCh, false) - assert.NoError(t, err) -} - -func TestEcho_TLSListenerAddr(t *testing.T) { - cert, err := os.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := os.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - +func TestDefaultHTTPErrorHandler_CommitedResponse(t *testing.T) { e := New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + resp := httptest.NewRecorder() + c := e.NewContext(req, resp) - addr := e.TLSListenerAddr() - assert.Nil(t, addr) - - errCh := make(chan error) - go func() { - errCh <- e.StartTLS(":0", cert, key) - }() - - err = waitForServerStart(e, errCh, true) - assert.NoError(t, err) -} - -func TestEcho_StartServer(t *testing.T) { - cert, err := os.ReadFile("_fixture/certs/cert.pem") - require.NoError(t, err) - key, err := os.ReadFile("_fixture/certs/key.pem") - require.NoError(t, err) - certs, err := tls.X509KeyPair(cert, key) - require.NoError(t, err) - - var testCases = []struct { - name string - addr string - TLSConfig *tls.Config - expectError string - }{ - { - name: "ok", - addr: ":0", - }, - { - name: "ok, start with TLS", - addr: ":0", - TLSConfig: &tls.Config{Certificates: []tls.Certificate{certs}}, - }, - { - name: "nok, invalid address", - addr: "nope", - expectError: "listen tcp: address nope: missing port in address", - }, - { - name: "nok, invalid tls address", - addr: "nope", - TLSConfig: &tls.Config{InsecureSkipVerify: true}, - expectError: "listen tcp: address nope: missing port in address", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Debug = true - - server := new(http.Server) - server.Addr = tc.addr - if tc.TLSConfig != nil { - server.TLSConfig = tc.TLSConfig - } - - errCh := make(chan error) - go func() { - errCh <- e.StartServer(server) - }() + c.orgResponse.Committed = true + errHandler := DefaultHTTPErrorHandler(false) - err := waitForServerStart(e, errCh, tc.TLSConfig != nil) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - assert.NoError(t, e.Close()) - }) - } + errHandler(c, errors.New("my_error")) + assert.Equal(t, http.StatusOK, resp.Code) } -func benchmarkEchoRoutes(b *testing.B, routes []*Route) { +func benchmarkEchoRoutes(b *testing.B, routes []testRoute) { e := New() - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) u := req.URL w := httptest.NewRecorder() @@ -1825,7 +1190,7 @@ func benchmarkEchoRoutes(b *testing.B, routes []*Route) { // Add routes for _, route := range routes { - e.Add(route.Method, route.Path, func(c Context) error { + e.Add(route.Method, route.Path, func(c *Context) error { return nil }) } diff --git a/echotest/context.go b/echotest/context.go new file mode 100644 index 000000000..2f665705d --- /dev/null +++ b/echotest/context.go @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echotest + +import ( + "bytes" + "io" + "mime/multipart" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/labstack/echo/v5" +) + +// ContextConfig is configuration for creating echo.Context for testing purposes. +type ContextConfig struct { + // Request will be used instead of default `httptest.NewRequest(http.MethodGet, "/", nil)` + Request *http.Request + + // Response will be used instead of default `httptest.NewRecorder()` + Response *httptest.ResponseRecorder + + // QueryValues wil be set as Request.URL.RawQuery value + QueryValues url.Values + + // Headers wil be set as Request.Header value + Headers http.Header + + // PathValues initializes context.PathValues with given value. + PathValues echo.PathValues + + // RouteInfo initializes context.RouteInfo() with given value + RouteInfo *echo.RouteInfo + + // FormValues creates form-urlencoded form out of given values. If there is no + // `content-type` header it will be set to `application/x-www-form-urlencoded` + // In case Request was not set the Request.Method is set to `POST` + // + // FormValues, MultipartForm and JSONBody are mutually exclusive. + FormValues url.Values + + // MultipartForm creates multipart form out of given value. If there is no + // `content-type` header it will be set to `multipart/form-data` + // In case Request was not set the Request.Method is set to `POST` + // + // FormValues, MultipartForm and JSONBody are mutually exclusive. + MultipartForm *MultipartForm + + // JSONBody creates JSON body out of given bytes. If there is no + // `content-type` header it will be set to `application/json` + // In case Request was not set the Request.Method is set to `POST` + // + // FormValues, MultipartForm and JSONBody are mutually exclusive. + JSONBody []byte +} + +// MultipartForm is used to create multipart form out of given value +type MultipartForm struct { + Fields map[string]string + Files []MultipartFormFile +} + +// MultipartFormFile is used to create file in multipart form out of given value +type MultipartFormFile struct { + Fieldname string + Filename string + Content []byte +} + +// ToContext converts ContextConfig to echo.Context +func (conf ContextConfig) ToContext(t *testing.T) *echo.Context { + c, _ := conf.ToContextRecorder(t) + return c +} + +// ToContextRecorder converts ContextConfig to echo.Context and httptest.ResponseRecorder +func (conf ContextConfig) ToContextRecorder(t *testing.T) (*echo.Context, *httptest.ResponseRecorder) { + if conf.Response == nil { + conf.Response = httptest.NewRecorder() + } + isDefaultRequest := false + if conf.Request == nil { + isDefaultRequest = true + conf.Request = httptest.NewRequest(http.MethodGet, "/", nil) + } + + if len(conf.QueryValues) > 0 { + conf.Request.URL.RawQuery = conf.QueryValues.Encode() + } + if len(conf.Headers) > 0 { + conf.Request.Header = conf.Headers + } + if len(conf.FormValues) > 0 { + body := strings.NewReader(url.Values(conf.FormValues).Encode()) + conf.Request.Body = io.NopCloser(body) + conf.Request.ContentLength = int64(body.Len()) + + if conf.Request.Header.Get(echo.HeaderContentType) == "" { + conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + } + if isDefaultRequest { + conf.Request.Method = http.MethodPost + } + } else if conf.MultipartForm != nil { + var body bytes.Buffer + mw := multipart.NewWriter(&body) + for field, value := range conf.MultipartForm.Fields { + if err := mw.WriteField(field, value); err != nil { + t.Fatal(err) + } + } + for _, file := range conf.MultipartForm.Files { + fw, err := mw.CreateFormFile(file.Fieldname, file.Filename) + if err != nil { + t.Fatal(err) + } + if _, err = fw.Write(file.Content); err != nil { + t.Fatal(err) + } + } + if err := mw.Close(); err != nil { + t.Fatal(err) + } + + conf.Request.Body = io.NopCloser(&body) + conf.Request.ContentLength = int64(body.Len()) + if conf.Request.Header.Get(echo.HeaderContentType) == "" { + conf.Request.Header.Set(echo.HeaderContentType, mw.FormDataContentType()) + } + if isDefaultRequest { + conf.Request.Method = http.MethodPost + } + } else if conf.JSONBody != nil { + body := bytes.NewReader(conf.JSONBody) + conf.Request.Body = io.NopCloser(body) + conf.Request.ContentLength = int64(body.Len()) + + if conf.Request.Header.Get(echo.HeaderContentType) == "" { + conf.Request.Header.Set(echo.HeaderContentType, echo.MIMEApplicationJSON) + } + if isDefaultRequest { + conf.Request.Method = http.MethodPost + } + } + + ec := echo.NewContext(conf.Request, conf.Response, echo.New()) + if conf.RouteInfo == nil { + conf.RouteInfo = &echo.RouteInfo{ + Name: "", + Method: conf.Request.Method, + Path: "/test", + Parameters: []string{}, + } + for _, p := range conf.PathValues { + conf.RouteInfo.Parameters = append(conf.RouteInfo.Parameters, p.Name) + } + } + ec.InitializeRoute(conf.RouteInfo, &conf.PathValues) + return ec, conf.Response +} + +// ServeWithHandler serves ContextConfig with given handler and returns httptest.ResponseRecorder for response checking +func (conf ContextConfig) ServeWithHandler(t *testing.T, handler echo.HandlerFunc, opts ...any) *httptest.ResponseRecorder { + c, rec := conf.ToContextRecorder(t) + + errHandler := echo.DefaultHTTPErrorHandler(false) + for _, opt := range opts { + switch o := opt.(type) { + case echo.HTTPErrorHandler: + errHandler = o + } + } + + err := handler(c) + if err != nil { + errHandler(c, err) + } + return rec +} diff --git a/echotest/context_external_test.go b/echotest/context_external_test.go new file mode 100644 index 000000000..d98257148 --- /dev/null +++ b/echotest/context_external_test.go @@ -0,0 +1,27 @@ +package echotest_test + +import ( + "net/http" + "testing" + + "github.com/labstack/echo/v5" + "github.com/labstack/echo/v5/echotest" + "github.com/stretchr/testify/assert" +) + +func TestToContext_JSONBody(t *testing.T) { + c := echotest.ContextConfig{ + JSONBody: echotest.LoadBytes(t, "testdata/test.json"), + }.ToContext(t) + + payload := struct { + Field string `json:"field"` + }{} + if err := c.Bind(&payload); err != nil { + t.Fatal(err) + } + + assert.Equal(t, "value", payload.Field) + assert.Equal(t, http.MethodPost, c.Request().Method) + assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType)) +} diff --git a/echotest/context_test.go b/echotest/context_test.go new file mode 100644 index 000000000..66815e4b0 --- /dev/null +++ b/echotest/context_test.go @@ -0,0 +1,157 @@ +package echotest + +import ( + "net/http" + "net/url" + "strings" + "testing" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/assert" +) + +func TestServeWithHandler(t *testing.T) { + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, c.QueryParam("key")) + } + testConf := ContextConfig{ + QueryValues: url.Values{"key": []string{"value"}}, + } + + resp := testConf.ServeWithHandler(t, handler) + + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "value", resp.Body.String()) +} + +func TestServeWithHandler_error(t *testing.T) { + handler := func(c *echo.Context) error { + return echo.NewHTTPError(http.StatusBadRequest, "something went wrong") + } + testConf := ContextConfig{ + QueryValues: url.Values{"key": []string{"value"}}, + } + + customErrHandler := echo.DefaultHTTPErrorHandler(true) + + resp := testConf.ServeWithHandler(t, handler, customErrHandler) + + assert.Equal(t, http.StatusBadRequest, resp.Code) + assert.Equal(t, `{"message":"something went wrong"}`+"\n", resp.Body.String()) +} + +func TestToContext_QueryValues(t *testing.T) { + testConf := ContextConfig{ + QueryValues: url.Values{"t": []string{"2006-01-02"}}, + } + c := testConf.ToContext(t) + + v, err := echo.QueryParam[string](c, "t") + + assert.NoError(t, err) + assert.Equal(t, "2006-01-02", v) +} + +func TestToContext_Headers(t *testing.T) { + testConf := ContextConfig{ + Headers: http.Header{echo.HeaderXRequestID: []string{"ABC"}}, + } + c := testConf.ToContext(t) + + id := c.Request().Header.Get(echo.HeaderXRequestID) + + assert.Equal(t, "ABC", id) +} + +func TestToContext_PathValues(t *testing.T) { + testConf := ContextConfig{ + PathValues: echo.PathValues{{ + Name: "key", + Value: "value", + }}, + } + c := testConf.ToContext(t) + + key := c.Param("key") + + assert.Equal(t, "value", key) +} + +func TestToContext_RouteInfo(t *testing.T) { + testConf := ContextConfig{ + RouteInfo: &echo.RouteInfo{ + Name: "my_route", + Method: http.MethodGet, + Path: "/:id", + Parameters: []string{"id"}, + }, + } + c := testConf.ToContext(t) + + ri := c.RouteInfo() + + assert.Equal(t, echo.RouteInfo{ + Name: "my_route", + Method: http.MethodGet, + Path: "/:id", + Parameters: []string{"id"}, + }, ri) +} + +func TestToContext_FormValues(t *testing.T) { + testConf := ContextConfig{ + FormValues: url.Values{"key": []string{"value"}}, + } + c := testConf.ToContext(t) + + assert.Equal(t, "value", c.FormValue("key")) + assert.Equal(t, http.MethodPost, c.Request().Method) + assert.Equal(t, echo.MIMEApplicationForm, c.Request().Header.Get(echo.HeaderContentType)) +} + +func TestToContext_MultipartForm(t *testing.T) { + testConf := ContextConfig{ + MultipartForm: &MultipartForm{ + Fields: map[string]string{ + "key": "value", + }, + Files: []MultipartFormFile{ + { + Fieldname: "file", + Filename: "test.json", + Content: LoadBytes(t, "testdata/test.json"), + }, + }, + }, + } + c := testConf.ToContext(t) + + assert.Equal(t, "value", c.FormValue("key")) + assert.Equal(t, http.MethodPost, c.Request().Method) + assert.Equal(t, true, strings.HasPrefix(c.Request().Header.Get(echo.HeaderContentType), "multipart/form-data; boundary=")) + + fv, err := c.FormFile("file") + if err != nil { + t.Fatal(err) + } + assert.Equal(t, "test.json", fv.Filename) + assert.Equal(t, int64(23), fv.Size) +} + +func TestToContext_JSONBody(t *testing.T) { + testConf := ContextConfig{ + JSONBody: LoadBytes(t, "testdata/test.json"), + } + c := testConf.ToContext(t) + + payload := struct { + Field string `json:"field"` + }{} + if err := c.Bind(&payload); err != nil { + t.Fatal(err) + } + + assert.Equal(t, "value", payload.Field) + assert.Equal(t, http.MethodPost, c.Request().Method) + assert.Equal(t, echo.MIMEApplicationJSON, c.Request().Header.Get(echo.HeaderContentType)) +} diff --git a/echotest/reader.go b/echotest/reader.go new file mode 100644 index 000000000..0caceca02 --- /dev/null +++ b/echotest/reader.go @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echotest + +import ( + "os" + "path/filepath" + "runtime" + "testing" +) + +type loadBytesOpts func([]byte) []byte + +// TrimNewlineEnd instructs LoadBytes to remove `\n` from the end of loaded file. +func TrimNewlineEnd(bytes []byte) []byte { + bLen := len(bytes) + if bLen > 1 && bytes[bLen-1] == '\n' { + bytes = bytes[:bLen-1] + } + return bytes +} + +// LoadBytes is helper to load file contents relative to current (where test file is) package +// directory. +func LoadBytes(t *testing.T, name string, opts ...loadBytesOpts) []byte { + bytes := loadBytes(t, name, 2) + + for _, f := range opts { + bytes = f(bytes) + } + + return bytes +} + +func loadBytes(t *testing.T, name string, callDepth int) []byte { + _, b, _, _ := runtime.Caller(callDepth) + basepath := filepath.Dir(b) + + path := filepath.Join(basepath, name) // relative path + bytes, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + return bytes[:] +} diff --git a/echotest/reader_external_test.go b/echotest/reader_external_test.go new file mode 100644 index 000000000..43fd57416 --- /dev/null +++ b/echotest/reader_external_test.go @@ -0,0 +1,25 @@ +package echotest_test + +import ( + "strings" + "testing" + + "github.com/labstack/echo/v5/echotest" + "github.com/stretchr/testify/assert" +) + +const testJSONContent = `{ + "field": "value" +}` + +func TestLoadBytesOK(t *testing.T) { + data := echotest.LoadBytes(t, "testdata/test.json") + assert.Equal(t, []byte(testJSONContent+"\n"), data) +} + +func TestLoadBytes_custom(t *testing.T) { + data := echotest.LoadBytes(t, "testdata/test.json", func(bytes []byte) []byte { + return []byte(strings.ToUpper(string(bytes))) + }) + assert.Equal(t, []byte(strings.ToUpper(testJSONContent)+"\n"), data) +} diff --git a/echotest/reader_test.go b/echotest/reader_test.go new file mode 100644 index 000000000..23b3c2dd2 --- /dev/null +++ b/echotest/reader_test.go @@ -0,0 +1,21 @@ +package echotest + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +const testJSONContent = `{ + "field": "value" +}` + +func TestLoadBytesOK(t *testing.T) { + data := LoadBytes(t, "testdata/test.json") + assert.Equal(t, []byte(testJSONContent+"\n"), data) +} + +func TestLoadBytesOK_TrimNewlineEnd(t *testing.T) { + data := LoadBytes(t, "testdata/test.json", TrimNewlineEnd) + assert.Equal(t, []byte(testJSONContent), data) +} diff --git a/echotest/testdata/test.json b/echotest/testdata/test.json new file mode 100644 index 000000000..94ae65f17 --- /dev/null +++ b/echotest/testdata/test.json @@ -0,0 +1,3 @@ +{ + "field": "value" +} diff --git a/go.mod b/go.mod index a1652a31e..abdbcace0 100644 --- a/go.mod +++ b/go.mod @@ -1,23 +1,16 @@ -module github.com/labstack/echo/v4 +module github.com/labstack/echo/v5 -go 1.24.0 +go 1.25.0 require ( - github.com/labstack/gommon v0.4.2 github.com/stretchr/testify v1.11.1 - github.com/valyala/fasttemplate v1.2.2 - golang.org/x/crypto v0.46.0 golang.org/x/net v0.48.0 golang.org/x/time v0.14.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect - github.com/mattn/go-colorable v0.1.14 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/valyala/bytebufferpool v1.0.0 // indirect - golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.32.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 405f8c8ee..6eb81abf9 100644 --- a/go.sum +++ b/go.sum @@ -1,26 +1,11 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= -github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= -github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= -github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= -github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= -github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= -github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= diff --git a/group.go b/group.go index cb37b123f..d81cd9163 100644 --- a/group.go +++ b/group.go @@ -4,6 +4,7 @@ package echo import ( + "io/fs" "net/http" ) @@ -11,119 +12,161 @@ import ( // routes that share a common middleware or functionality that should be separate // from the parent echo instance while still inheriting from it. type Group struct { - common - host string - prefix string echo *Echo + prefix string middleware []MiddlewareFunc } // Use implements `Echo#Use()` for sub-routes within the Group. +// Group middlewares are not executed on request when there is no matching route found. func (g *Group) Use(middleware ...MiddlewareFunc) { g.middleware = append(g.middleware, middleware...) - if len(g.middleware) == 0 { - return - } - // group level middlewares are different from Echo `Pre` and `Use` middlewares (those are global). Group level middlewares - // are only executed if they are added to the Router with route. - // So we register catch all route (404 is a safe way to emulate route match) for this group and now during routing the - // Router would find route to match our request path and therefore guarantee the middleware(s) will get executed. - g.RouteNotFound("", NotFoundHandler) - g.RouteNotFound("/*", NotFoundHandler) } -// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. -func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// CONNECT implements `Echo#CONNECT()` for sub-routes within the Group. Panics on error. +func (g *Group) CONNECT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodConnect, path, h, m...) } -// DELETE implements `Echo#DELETE()` for sub-routes within the Group. -func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// DELETE implements `Echo#DELETE()` for sub-routes within the Group. Panics on error. +func (g *Group) DELETE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodDelete, path, h, m...) } -// GET implements `Echo#GET()` for sub-routes within the Group. -func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// GET implements `Echo#GET()` for sub-routes within the Group. Panics on error. +func (g *Group) GET(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodGet, path, h, m...) } -// HEAD implements `Echo#HEAD()` for sub-routes within the Group. -func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// HEAD implements `Echo#HEAD()` for sub-routes within the Group. Panics on error. +func (g *Group) HEAD(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodHead, path, h, m...) } -// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. -func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// OPTIONS implements `Echo#OPTIONS()` for sub-routes within the Group. Panics on error. +func (g *Group) OPTIONS(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodOptions, path, h, m...) } -// PATCH implements `Echo#PATCH()` for sub-routes within the Group. -func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PATCH implements `Echo#PATCH()` for sub-routes within the Group. Panics on error. +func (g *Group) PATCH(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPatch, path, h, m...) } -// POST implements `Echo#POST()` for sub-routes within the Group. -func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// POST implements `Echo#POST()` for sub-routes within the Group. Panics on error. +func (g *Group) POST(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPost, path, h, m...) } -// PUT implements `Echo#PUT()` for sub-routes within the Group. -func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// PUT implements `Echo#PUT()` for sub-routes within the Group. Panics on error. +func (g *Group) PUT(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodPut, path, h, m...) } -// TRACE implements `Echo#TRACE()` for sub-routes within the Group. -func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// TRACE implements `Echo#TRACE()` for sub-routes within the Group. Panics on error. +func (g *Group) TRACE(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(http.MethodTrace, path, h, m...) } -// Any implements `Echo#Any()` for sub-routes within the Group. -func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) +// Any implements `Echo#Any()` for sub-routes within the Group. Panics on error. +func (g *Group) Any(path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + return g.Add(RouteAny, path, handler, middleware...) +} + +// Match implements `Echo#Match()` for sub-routes within the Group. Panics on error. +func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) Routes { + errs := make([]error, 0) + ris := make(Routes, 0) + for _, m := range methods { + ri, err := g.AddRoute(Route{ + Method: m, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + errs = append(errs, err) + continue + } + ris = append(ris, ri) } - return routes -} - -// Match implements `Echo#Match()` for sub-routes within the Group. -func (g *Group) Match(methods []string, path string, handler HandlerFunc, middleware ...MiddlewareFunc) []*Route { - routes := make([]*Route, len(methods)) - for i, m := range methods { - routes[i] = g.Add(m, path, handler, middleware...) + if len(errs) > 0 { + panic(errs) // this is how `v4` handles errors. `v5` has methods to have panic-free usage } - return routes + return ris } // Group creates a new sub-group with prefix and optional sub-group-level middleware. +// Important! Group middlewares are only executed in case there was exact route match and not +// for 404 (not found) or 405 (method not allowed) cases. If this kind of behaviour is needed then add +// a catch-all route `/*` for the group which handler returns always 404 func (g *Group) Group(prefix string, middleware ...MiddlewareFunc) (sg *Group) { m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) m = append(m, g.middleware...) m = append(m, middleware...) sg = g.echo.Group(g.prefix+prefix, m...) - sg.host = g.host return } -// File implements `Echo#File()` for sub-routes within the Group. -func (g *Group) File(path, file string) { - g.file(path, file, g.GET) +// Static implements `Echo#Static()` for sub-routes within the Group. +func (g *Group) Static(pathPrefix, fsRoot string, middleware ...MiddlewareFunc) RouteInfo { + subFs := MustSubFS(g.echo.Filesystem, fsRoot) + return g.StaticFS(pathPrefix, subFs, middleware...) +} + +// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. +// +// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary +// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths +// including `assets/images` as their prefix. +func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS, middleware ...MiddlewareFunc) RouteInfo { + return g.Add( + http.MethodGet, + pathPrefix+"*", + StaticDirectoryHandler(filesystem, false), + middleware..., + ) +} + +// FileFS implements `Echo#FileFS()` for sub-routes within the Group. +func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) RouteInfo { + return g.GET(path, StaticFileHandler(file, filesystem), m...) +} + +// File implements `Echo#File()` for sub-routes within the Group. Panics on error. +func (g *Group) File(path, file string, middleware ...MiddlewareFunc) RouteInfo { + handler := func(c *Context) error { + return c.File(file) + } + return g.Add(http.MethodGet, path, handler, middleware...) } // RouteNotFound implements `Echo#RouteNotFound()` for sub-routes within the Group. // -// Example: `g.RouteNotFound("/*", func(c echo.Context) error { return c.NoContent(http.StatusNotFound) })` -func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) *Route { +// Example: `g.RouteNotFound("/*", func(c *echo.Context) error { return c.NoContent(http.StatusNotFound) })` +func (g *Group) RouteNotFound(path string, h HandlerFunc, m ...MiddlewareFunc) RouteInfo { return g.Add(RouteNotFound, path, h, m...) } -// Add implements `Echo#Add()` for sub-routes within the Group. -func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) *Route { - // Combine into a new slice to avoid accidentally passing the same slice for +// Add implements `Echo#Add()` for sub-routes within the Group. Panics on error. +func (g *Group) Add(method, path string, handler HandlerFunc, middleware ...MiddlewareFunc) RouteInfo { + ri, err := g.AddRoute(Route{ + Method: method, + Path: path, + Handler: handler, + Middlewares: middleware, + }) + if err != nil { + panic(err) // this is how `v4` handles errors. `v5` has methods to have panic-free usage + } + return ri +} + +// AddRoute registers a new Routable with Router +func (g *Group) AddRoute(route Route) (RouteInfo, error) { + // Combine middleware into a new slice to avoid accidentally passing the same slice for // multiple routes, which would lead to later add() calls overwriting the // middleware from earlier calls. - m := make([]MiddlewareFunc, 0, len(g.middleware)+len(middleware)) - m = append(m, g.middleware...) - m = append(m, middleware...) - return g.echo.add(g.host, method, g.prefix+path, handler, m...) + groupRoute := route.WithPrefix(g.prefix, append([]MiddlewareFunc{}, g.middleware...)) + return g.echo.add(groupRoute) } diff --git a/group_fs.go b/group_fs.go deleted file mode 100644 index c1b7ec2d3..000000000 --- a/group_fs.go +++ /dev/null @@ -1,33 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "io/fs" - "net/http" -) - -// Static implements `Echo#Static()` for sub-routes within the Group. -func (g *Group) Static(pathPrefix, fsRoot string) { - subFs := MustSubFS(g.echo.Filesystem, fsRoot) - g.StaticFS(pathPrefix, subFs) -} - -// StaticFS implements `Echo#StaticFS()` for sub-routes within the Group. -// -// When dealing with `embed.FS` use `fs := echo.MustSubFS(fs, "rootDirectory") to create sub fs which uses necessary -// prefix for directory path. This is necessary as `//go:embed assets/images` embeds files with paths -// including `assets/images` as their prefix. -func (g *Group) StaticFS(pathPrefix string, filesystem fs.FS) { - g.Add( - http.MethodGet, - pathPrefix+"*", - StaticDirectoryHandler(filesystem, false), - ) -} - -// FileFS implements `Echo#FileFS()` for sub-routes within the Group. -func (g *Group) FileFS(path, file string, filesystem fs.FS, m ...MiddlewareFunc) *Route { - return g.GET(path, StaticFileHandler(file, filesystem), m...) -} diff --git a/group_fs_test.go b/group_fs_test.go deleted file mode 100644 index caa200940..000000000 --- a/group_fs_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "github.com/stretchr/testify/assert" - "io/fs" - "net/http" - "net/http/httptest" - "os" - "testing" -) - -func TestGroup_FileFS(t *testing.T) { - var testCases = []struct { - name string - whenPath string - whenFile string - whenFS fs.FS - givenURL string - expectCode int - expectStartsWith []byte - }{ - { - name: "ok", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle", - expectCode: http.StatusOK, - expectStartsWith: []byte{0x89, 0x50, 0x4e}, - }, - { - name: "nok, requesting invalid path", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "walle.png", - givenURL: "/assets/walle.png", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - { - name: "nok, serving not existent file from filesystem", - whenPath: "/walle", - whenFS: os.DirFS("_fixture/images"), - whenFile: "not-existent.png", - givenURL: "/assets/walle", - expectCode: http.StatusNotFound, - expectStartsWith: []byte(`{"message":"Not Found"}`), - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - g := e.Group("/assets") - g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) - - req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) - rec := httptest.NewRecorder() - - e.ServeHTTP(rec, req) - - assert.Equal(t, tc.expectCode, rec.Code) - - body := rec.Body.Bytes() - if len(body) > len(tc.expectStartsWith) { - body = body[:len(tc.expectStartsWith)] - } - assert.Equal(t, tc.expectStartsWith, body) - }) - } -} - -func TestGroup_StaticPanic(t *testing.T) { - var testCases = []struct { - name string - givenRoot string - }{ - { - name: "panics for ../", - givenRoot: "../images", - }, - { - name: "panics for /", - givenRoot: "/images", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := New() - e.Filesystem = os.DirFS("./") - - g := e.Group("/assets") - - assert.Panics(t, func() { - g.Static("/images", tc.givenRoot) - }) - }) - } -} diff --git a/group_test.go b/group_test.go index a97371418..819b6df97 100644 --- a/group_test.go +++ b/group_test.go @@ -4,31 +4,70 @@ package echo import ( + "io/fs" "net/http" "net/http/httptest" "os" + "strings" "testing" "github.com/stretchr/testify/assert" ) -// TODO: Fix me -func TestGroup(t *testing.T) { - g := New().Group("/group") - h := func(Context) error { return nil } - g.CONNECT("/", h) - g.DELETE("/", h) - g.GET("/", h) - g.HEAD("/", h) - g.OPTIONS("/", h) - g.PATCH("/", h) - g.POST("/", h) - g.PUT("/", h) - g.TRACE("/", h) - g.Any("/", h) - g.Match([]string{http.MethodGet, http.MethodPost}, "/", h) - g.Static("/static", "/tmp") - g.File("/walle", "_fixture/images//walle.png") +func TestGroup_withoutRouteWillNotExecuteMiddleware(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware it will not be executed when there are no routes under that group + _ = e.Group("/group", mw) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_withRoutesWillNotExecuteMiddlewareFor404(t *testing.T) { + e := New() + + called := false + mw := func(next HandlerFunc) HandlerFunc { + return func(c *Context) error { + called = true + return c.NoContent(http.StatusTeapot) + } + } + // even though group has middleware and routes when we have no match on some route the middlewares for that + // group will not be executed + g := e.Group("/group", mw) + g.GET("/yes", handlerFunc) + + status, body := request(http.MethodGet, "/group/nope", e) + assert.Equal(t, http.StatusNotFound, status) + assert.Equal(t, `{"message":"Not Found"}`+"\n", body) + + assert.False(t, called) +} + +func TestGroup_multiLevelGroup(t *testing.T) { + e := New() + + api := e.Group("/api") + users := api.Group("/users") + users.GET("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + status, body := request(http.MethodGet, "/api/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) } func TestGroupFile(t *testing.T) { @@ -48,29 +87,29 @@ func TestGroupRouteMiddleware(t *testing.T) { // Ensure middleware slices are not re-used e := New() g := e.Group("/group") - h := func(Context) error { return nil } + h := func(*Context) error { return nil } m1 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return next(c) } } m2 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return next(c) } } m3 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return next(c) } } m4 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return c.NoContent(404) } } m5 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return c.NoContent(405) } } @@ -89,17 +128,17 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { e := New() g := e.Group("/group") m1 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { return next(c) } } m2 := func(next HandlerFunc) HandlerFunc { - return func(c Context) error { - return c.String(http.StatusOK, c.Path()) + return func(c *Context) error { + return c.String(http.StatusOK, c.RouteInfo().Path) } } - h := func(c Context) error { - return c.String(http.StatusOK, c.Path()) + h := func(c *Context) error { + return c.String(http.StatusOK, c.RouteInfo().Path) } g.Use(m1) g.GET("/help", h, m2) @@ -123,11 +162,155 @@ func TestGroupRouteMiddlewareWithMatchAny(t *testing.T) { } +func TestGroup_CONNECT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.CONNECT("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodConnect, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodConnect+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodConnect, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_DELETE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.DELETE("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodDelete, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodDelete+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodDelete, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_HEAD(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.HEAD("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodHead, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodHead+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodHead, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_OPTIONS(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.OPTIONS("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodOptions, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodOptions+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodOptions, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PATCH(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PATCH("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPatch, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodPatch+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPatch, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_POST(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.POST("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPost, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodPost+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPost, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_PUT(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.PUT("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodPut, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodPut+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodPut, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + +func TestGroup_TRACE(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.TRACE("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + + assert.Equal(t, http.MethodTrace, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, http.MethodTrace+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodTrace, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) +} + func TestGroup_RouteNotFound(t *testing.T) { var testCases = []struct { + expectRoute any name string whenURL string - expectRoute interface{} expectCode int }{ { @@ -161,10 +344,10 @@ func TestGroup_RouteNotFound(t *testing.T) { e := New() g := e.Group("/group") - okHandler := func(c Context) error { + okHandler := func(c *Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } - notFoundHandler := func(c Context) error { + notFoundHandler := func(c *Context) error { return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) } @@ -188,44 +371,396 @@ func TestGroup_RouteNotFound(t *testing.T) { } } +func TestGroup_Any(t *testing.T) { + e := New() + + users := e.Group("/users") + ri := users.Any("/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK from ANY") + }) + + assert.Equal(t, RouteAny, ri.Method) + assert.Equal(t, "/users/activate", ri.Path) + assert.Equal(t, RouteAny+":/users/activate", ri.Name) + assert.Nil(t, ri.Parameters) + + status, body := request(http.MethodTrace, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK from ANY`, body) +} + +func TestGroup_Match(t *testing.T) { + e := New() + + myMethods := []string{http.MethodGet, http.MethodPost} + users := e.Group("/users") + ris := users.Match(myMethods, "/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + assert.Len(t, ris, 2) + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + assert.Equal(t, http.StatusTeapot, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_MatchWithErrors(t *testing.T) { + e := New() + + users := e.Group("/users") + users.GET("/activate", func(c *Context) error { + return c.String(http.StatusOK, "OK") + }) + myMethods := []string{http.MethodGet, http.MethodPost} + + errs := func() (errs []error) { + defer func() { + if r := recover(); r != nil { + if tmpErr, ok := r.([]error); ok { + errs = tmpErr + return + } + panic(r) + } + }() + + users.Match(myMethods, "/activate", func(c *Context) error { + return c.String(http.StatusTeapot, "OK") + }) + return nil + }() + assert.Len(t, errs, 1) + assert.EqualError(t, errs[0], "GET /users/activate: adding duplicate route (same method+path) is not allowed") + + for _, m := range myMethods { + status, body := request(m, "/users/activate", e) + + expect := http.StatusTeapot + if m == http.MethodGet { + expect = http.StatusOK + } + assert.Equal(t, expect, status) + assert.Equal(t, `OK`, body) + } +} + +func TestGroup_Static(t *testing.T) { + e := New() + + g := e.Group("/books") + ri := g.Static("/download", "_fixture") + assert.Equal(t, http.MethodGet, ri.Method) + assert.Equal(t, "/books/download*", ri.Path) + assert.Equal(t, "GET:/books/download*", ri.Name) + assert.Equal(t, []string{"*"}, ri.Parameters) + + req := httptest.NewRequest(http.MethodGet, "/books/download/index.html", nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + body := rec.Body.String() + assert.True(t, strings.HasPrefix(body, "")) +} + +func TestGroup_StaticMultiTest(t *testing.T) { + var testCases = []struct { + name string + givenPrefix string + givenRoot string + whenURL string + expectHeaderLocation string + expectBodyStartsWith string + expectStatus int + }{ + { + name: "ok", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/walle.png", + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "ok, without prefix", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/testwalle.png", // `/test` + `*` creates route `/test*` witch matches `/testwalle.png` + expectStatus: http.StatusOK, + expectBodyStartsWith: string([]byte{0x89, 0x50, 0x4e, 0x47}), + }, + { + name: "nok, without prefix does not serve dir index", + givenPrefix: "", + givenRoot: "_fixture/images", + whenURL: "/test/", // `/test` + `*` creates route `/test*` + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "No file", + givenPrefix: "/images", + givenRoot: "_fixture/scripts", + whenURL: "/test/images/bolt.png", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory", + givenPrefix: "/images", + givenRoot: "_fixture/images", + whenURL: "/test/images/", + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Directory Redirect", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory Redirect with non-root path", + givenPrefix: "/static", + givenRoot: "_fixture", + whenURL: "/test/static", + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/static/", + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory 404 (request URL without slash)", + givenPrefix: "/folder/", // trailing slash will intentionally not match "/folder" + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "Prefixed directory redirect (without slash redirect to slash)", + givenPrefix: "/folder", // no trailing slash shall match /folder and /folder/* + givenRoot: "_fixture", + whenURL: "/test/folder", // no trailing slash + expectStatus: http.StatusMovedPermanently, + expectHeaderLocation: "/test/folder/", + expectBodyStartsWith: "", + }, + { + name: "Directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending with slash)", + givenPrefix: "/assets/", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Prefixed directory with index.html (prefix ending without slash)", + givenPrefix: "/assets", + givenRoot: "_fixture", + whenURL: "/test/assets/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "Sub-directory with index.html", + givenPrefix: "/", + givenRoot: "_fixture", + whenURL: "/test/folder/", + expectStatus: http.StatusOK, + expectBodyStartsWith: "", + }, + { + name: "do not allow directory traversal (backslash - windows separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/..\\middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + { + name: "do not allow directory traversal (slash - unix separator)", + givenPrefix: "/", + givenRoot: "_fixture/", + whenURL: `/test/../middleware/basic_auth.go`, + expectStatus: http.StatusNotFound, + expectBodyStartsWith: "{\"message\":\"Not Found\"}\n", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + + g := e.Group("/test") + g.Static(tc.givenPrefix, tc.givenRoot) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectStatus, rec.Code) + body := rec.Body.String() + if tc.expectBodyStartsWith != "" { + assert.True(t, strings.HasPrefix(body, tc.expectBodyStartsWith)) + } else { + assert.Equal(t, "", body) + } + + if tc.expectHeaderLocation != "" { + assert.Equal(t, tc.expectHeaderLocation, rec.Result().Header["Location"][0]) + } else { + _, ok := rec.Result().Header["Location"] + assert.False(t, ok) + } + }) + } +} + +func TestGroup_FileFS(t *testing.T) { + var testCases = []struct { + whenFS fs.FS + name string + whenPath string + whenFile string + givenURL string + expectStartsWith []byte + expectCode int + }{ + { + name: "ok", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle", + expectCode: http.StatusOK, + expectStartsWith: []byte{0x89, 0x50, 0x4e}, + }, + { + name: "nok, requesting invalid path", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "walle.png", + givenURL: "/assets/walle.png", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + { + name: "nok, serving not existent file from filesystem", + whenPath: "/walle", + whenFS: os.DirFS("_fixture/images"), + whenFile: "not-existent.png", + givenURL: "/assets/walle", + expectCode: http.StatusNotFound, + expectStartsWith: []byte(`{"message":"Not Found"}`), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + g := e.Group("/assets") + g.FileFS(tc.whenPath, tc.whenFile, tc.whenFS) + + req := httptest.NewRequest(http.MethodGet, tc.givenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectCode, rec.Code) + + body := rec.Body.Bytes() + if len(body) > len(tc.expectStartsWith) { + body = body[:len(tc.expectStartsWith)] + } + assert.Equal(t, tc.expectStartsWith, body) + }) + } +} + +func TestGroup_StaticPanic(t *testing.T) { + var testCases = []struct { + name string + givenRoot string + }{ + { + name: "panics for ../", + givenRoot: "../images", + }, + { + name: "panics for /", + givenRoot: "/images", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := New() + e.Filesystem = os.DirFS("./") + + g := e.Group("/assets") + + assert.Panics(t, func() { + g.Static("/images", tc.givenRoot) + }) + }) + } +} + func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { var testCases = []struct { - name string - givenCustom404 bool - whenURL string - expectBody interface{} - expectCode int + expectBody any + name string + whenURL string + expectCode int + givenCustom404 bool + expectMiddlewareCalled bool }{ { - name: "ok, custom 404 handler is called with middleware", - givenCustom404: true, - whenURL: "/group/test3", - expectBody: "GET /group/*", - expectCode: http.StatusNotFound, + name: "ok, custom 404 handler is called with middleware", + givenCustom404: true, + whenURL: "/group/test3", + expectBody: "404 GET /group/*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: true, // because RouteNotFound is added after middleware is added }, { - name: "ok, default group 404 handler is called with middleware", - givenCustom404: false, - whenURL: "/group/test3", - expectBody: "{\"message\":\"Not Found\"}\n", - expectCode: http.StatusNotFound, + name: "ok, default group 404 handler is not called with middleware", + givenCustom404: false, + whenURL: "/group/test3", + expectBody: "404 GET /*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added }, { - name: "ok, (no slash) default group 404 handler is called with middleware", - givenCustom404: false, - whenURL: "/group", - expectBody: "{\"message\":\"Not Found\"}\n", - expectCode: http.StatusNotFound, + name: "ok, (no slash) default group 404 handler is called with middleware", + givenCustom404: false, + whenURL: "/group", + expectBody: "404 GET /*", + expectCode: http.StatusNotFound, + expectMiddlewareCalled: false, // because RouteNotFound is added before middleware is added }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - okHandler := func(c Context) error { + okHandler := func(c *Context) error { return c.String(http.StatusOK, c.Request().Method+" "+c.Path()) } - notFoundHandler := func(c Context) error { - return c.String(http.StatusNotFound, c.Request().Method+" "+c.Path()) + notFoundHandler := func(c *Context) error { + return c.String(http.StatusNotFound, "404 "+c.Request().Method+" "+c.Path()) } e := New() @@ -237,7 +772,7 @@ func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { middlewareCalled := false g.Use(func(next HandlerFunc) HandlerFunc { - return func(c Context) error { + return func(c *Context) error { middlewareCalled = true return next(c) } @@ -251,7 +786,7 @@ func TestGroup_RouteNotFoundWithMiddleware(t *testing.T) { e.ServeHTTP(rec, req) - assert.True(t, middlewareCalled) + assert.Equal(t, tc.expectMiddlewareCalled, middlewareCalled) assert.Equal(t, tc.expectCode, rec.Code) assert.Equal(t, tc.expectBody, rec.Body.String()) }) diff --git a/httperror.go b/httperror.go new file mode 100644 index 000000000..682cce2a0 --- /dev/null +++ b/httperror.go @@ -0,0 +1,107 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "errors" + "fmt" + "net/http" +) + +// HTTPStatusCoder is interface that errors can implement to produce status code for HTTP response +type HTTPStatusCoder interface { + StatusCode() int +} + +// Following errors can produce HTTP status code by implementing HTTPStatusCoder interface +var ( + ErrBadRequest = &httpError{http.StatusBadRequest} // 400 + ErrUnauthorized = &httpError{http.StatusUnauthorized} // 401 + ErrForbidden = &httpError{http.StatusForbidden} // 403 + ErrNotFound = &httpError{http.StatusNotFound} // 404 + ErrMethodNotAllowed = &httpError{http.StatusMethodNotAllowed} // 405 + ErrRequestTimeout = &httpError{http.StatusRequestTimeout} // 408 + ErrStatusRequestEntityTooLarge = &httpError{http.StatusRequestEntityTooLarge} // 413 + ErrUnsupportedMediaType = &httpError{http.StatusUnsupportedMediaType} // 415 + ErrTooManyRequests = &httpError{http.StatusTooManyRequests} // 429 + ErrInternalServerError = &httpError{http.StatusInternalServerError} // 500 + ErrBadGateway = &httpError{http.StatusBadGateway} // 502 + ErrServiceUnavailable = &httpError{http.StatusServiceUnavailable} // 503 +) + +// Following errors fall into 500 (InternalServerError) category +var ( + ErrValidatorNotRegistered = errors.New("validator not registered") + ErrRendererNotRegistered = errors.New("renderer not registered") + ErrInvalidRedirectCode = errors.New("invalid redirect status code") + ErrCookieNotFound = errors.New("cookie not found") + ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte") + ErrInvalidListenerNetwork = errors.New("invalid listener network") +) + +// NewHTTPError creates new instance of HTTPError +func NewHTTPError(code int, message string) *HTTPError { + return &HTTPError{ + Code: code, + Message: message, + } +} + +// HTTPError represents an error that occurred while handling a request. +type HTTPError struct { + // Code is status code for HTTP response + Code int `json:"-"` + Message string `json:"message"` + err error +} + +// StatusCode returns status code for HTTP response +func (he *HTTPError) StatusCode() int { + return he.Code +} + +// Error makes it compatible with `error` interface. +func (he *HTTPError) Error() string { + msg := he.Message + if msg == "" { + msg = http.StatusText(he.Code) + } + if he.err == nil { + return fmt.Sprintf("code=%d, message=%v", he.Code, msg) + } + return fmt.Sprintf("code=%d, message=%v, err=%v", he.Code, msg, he.err.Error()) +} + +// Wrap eturns new HTTPError with given errors wrapped inside +func (he HTTPError) Wrap(err error) error { + return &HTTPError{ + Code: he.Code, + Message: he.Message, + err: err, + } +} + +func (he *HTTPError) Unwrap() error { + return he.err +} + +type httpError struct { + code int +} + +func (he httpError) StatusCode() int { + return he.code +} + +func (he httpError) Error() string { + return http.StatusText(he.code) // does not include status code +} + +func (he httpError) Wrap(err error) error { + return &HTTPError{ + Code: he.code, + Message: http.StatusText(he.code), + err: err, + } +} diff --git a/httperror_external_test.go b/httperror_external_test.go new file mode 100644 index 000000000..91acdca25 --- /dev/null +++ b/httperror_external_test.go @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +// run tests as external package to get real feel for API +package echo_test + +import ( + "encoding/json" + "fmt" + "github.com/labstack/echo/v5" + "net/http" + "net/http/httptest" +) + +func ExampleDefaultHTTPErrorHandler() { + e := echo.New() + e.GET("/api/endpoint", func(c *echo.Context) error { + return &apiError{ + Code: http.StatusBadRequest, + Body: map[string]any{"message": "custom error"}, + } + }) + + req := httptest.NewRequest(http.MethodGet, "/api/endpoint?err=1", nil) + resp := httptest.NewRecorder() + + e.ServeHTTP(resp, req) + + fmt.Printf("%d %s", resp.Code, resp.Body.String()) + + // Output: 400 {"error":{"message":"custom error"}} +} + +type apiError struct { + Code int + Body any +} + +func (e *apiError) StatusCode() int { + return e.Code +} + +func (e *apiError) MarshalJSON() ([]byte, error) { + type body struct { + Error any `json:"error"` + } + return json.Marshal(body{Error: e.Body}) +} + +func (e *apiError) Error() string { + return http.StatusText(e.Code) +} diff --git a/httperror_test.go b/httperror_test.go new file mode 100644 index 000000000..9ae88abcb --- /dev/null +++ b/httperror_test.go @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors + +package echo + +import ( + "errors" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestHTTPError_StatusCode(t *testing.T) { + var err error = &HTTPError{Code: http.StatusBadRequest, Message: "my error message"} + + code := 0 + var sc HTTPStatusCoder + if errors.As(err, &sc) { + code = sc.StatusCode() + } + assert.Equal(t, http.StatusBadRequest, code) +} + +func TestHTTPError_Error(t *testing.T) { + var testCases = []struct { + name string + error error + expect string + }{ + { + name: "ok, without message", + error: &HTTPError{Code: http.StatusBadRequest}, + expect: "code=400, message=Bad Request", + }, + { + name: "ok, with message", + error: &HTTPError{Code: http.StatusBadRequest, Message: "my error message"}, + expect: "code=400, message=my error message", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expect, tc.error.Error()) + }) + } +} + +func TestHTTPError_WrapUnwrap(t *testing.T) { + err := &HTTPError{Code: http.StatusBadRequest, Message: "bad"} + wrapped := err.Wrap(errors.New("my_error")).(*HTTPError) + + err.Code = http.StatusOK + err.Message = "changed" + + assert.Equal(t, http.StatusBadRequest, wrapped.Code) + assert.Equal(t, "bad", wrapped.Message) + + assert.Equal(t, errors.New("my_error"), wrapped.Unwrap()) + assert.Equal(t, "code=400, message=bad, err=my_error", wrapped.Error()) +} + +func TestNewHTTPError(t *testing.T) { + err := NewHTTPError(http.StatusBadRequest, "bad") + err2 := &HTTPError{Code: http.StatusBadRequest, Message: "bad"} + + assert.Equal(t, err2, err) +} diff --git a/ip.go b/ip.go index dce51f55d..e2b287bfd 100644 --- a/ip.go +++ b/ip.go @@ -224,21 +224,15 @@ func extractIP(req *http.Request) string { func ExtractIPFromRealIPHeader(options ...TrustOption) IPExtractor { checker := newIPChecker(options) return func(req *http.Request) string { - directIP := extractIP(req) realIP := req.Header.Get(HeaderXRealIP) - if realIP == "" { - return directIP - } - - if checker.trust(net.ParseIP(directIP)) { + if realIP != "" { realIP = strings.TrimPrefix(realIP, "[") realIP = strings.TrimSuffix(realIP, "]") - if rIP := net.ParseIP(realIP); rIP != nil { + if ip := net.ParseIP(realIP); ip != nil && checker.trust(ip) { return realIP } } - - return directIP + return extractIP(req) } } diff --git a/ip_test.go b/ip_test.go index e850b78cb..29bf6afde 100644 --- a/ip_test.go +++ b/ip_test.go @@ -22,8 +22,8 @@ func mustParseCIDR(s string) *net.IPNet { func TestIPChecker_TrustOption(t *testing.T) { var testCases = []struct { name string - givenOptions []TrustOption whenIP string + givenOptions []TrustOption expect bool }{ { @@ -490,14 +490,14 @@ func TestExtractIPDirect(t *testing.T) { } func TestExtractIPFromRealIPHeader(t *testing.T) { - _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.0/24") + _, ipForRemoteAddrExternalRange, _ := net.ParseCIDR("203.0.113.199/24") _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { - name string - givenTrustOptions []TrustOption whenRequest http.Request + name string expectIP string + givenTrustOptions []TrustOption }{ { name: "request has no headers, extracts IP from request remote addr", @@ -518,42 +518,36 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, { name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", - givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" - TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" - }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXRealIP: []string{"203.0.113.199"}, // <-- this is untrusted }, - RemoteAddr: "8.8.8.8:8080", // <-- this is untrusted + RemoteAddr: "203.0.113.1:8080", }, - expectIP: "8.8.8.8", + expectIP: "203.0.113.1", }, { name: "request is from external IP has valid + UNTRUSTED external X-Real-Ip header, extract IP from remote addr", - givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" - TrustIPRange(ipv6ForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" - }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[bc01:1010::9090:1888]"}, + HeaderXRealIP: []string{"[2001:db8::113:199]"}, // <-- this is untrusted }, - RemoteAddr: "[fe64:aa10::1]:8080", // <-- this is untrusted + RemoteAddr: "[2001:db8::113:1]:8080", }, - expectIP: "fe64:aa10::1", + expectIP: "2001:db8::113:1", }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", givenTrustOptions: []TrustOption{ // case for "trust direct-facing proxy" - TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.0/24" + TrustIPRange(ipForRemoteAddrExternalRange), // we trust external IP range "203.0.113.199/24" }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"8.8.8.8"}, + HeaderXRealIP: []string{"203.0.113.199"}, }, RemoteAddr: "203.0.113.1:8080", }, - expectIP: "8.8.8.8", + expectIP: "203.0.113.199", }, { name: "request is from external IP has valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -562,11 +556,11 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[fe64:db8::113:199]"}, + HeaderXRealIP: []string{"[2001:db8::113:199]"}, }, RemoteAddr: "[2001:db8::113:1]:8080", }, - expectIP: "fe64:db8::113:199", + expectIP: "2001:db8::113:199", }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -575,12 +569,12 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"8.8.8.8"}, - HeaderXForwardedFor: []string{"1.1.1.1 ,8.8.8.8"}, // <-- should not affect anything + HeaderXRealIP: []string{"203.0.113.199"}, + HeaderXForwardedFor: []string{"203.0.113.198, 203.0.113.197"}, // <-- should not affect anything }, RemoteAddr: "203.0.113.1:8080", }, - expectIP: "8.8.8.8", + expectIP: "203.0.113.199", }, { name: "request is from external IP has XFF and valid + TRUSTED X-Real-Ip header, extract IP from X-Real-Ip header", @@ -589,12 +583,12 @@ func TestExtractIPFromRealIPHeader(t *testing.T) { }, whenRequest: http.Request{ Header: http.Header{ - HeaderXRealIP: []string{"[fe64:db8::113:199]"}, - HeaderXForwardedFor: []string{"[feab:cde9::113:198], [fe64:db8::113:199]"}, // <-- should not affect anything + HeaderXRealIP: []string{"[2001:db8::113:199]"}, + HeaderXForwardedFor: []string{"[2001:db8::113:198], [2001:db8::113:197]"}, // <-- should not affect anything }, RemoteAddr: "[2001:db8::113:1]:8080", }, - expectIP: "fe64:db8::113:199", + expectIP: "2001:db8::113:199", }, } @@ -611,10 +605,10 @@ func TestExtractIPFromXFFHeader(t *testing.T) { _, ipv6ForRemoteAddrExternalRange, _ := net.ParseCIDR("2001:db8::/64") var testCases = []struct { - name string - givenTrustOptions []TrustOption whenRequest http.Request + name string expectIP string + givenTrustOptions []TrustOption }{ { name: "request has no headers, extracts IP from request remote addr", diff --git a/json.go b/json.go index 6da0aaf97..a969ccb8c 100644 --- a/json.go +++ b/json.go @@ -5,8 +5,6 @@ package echo import ( "encoding/json" - "fmt" - "net/http" ) // DefaultJSONSerializer implements JSON encoding using encoding/json. @@ -14,21 +12,18 @@ type DefaultJSONSerializer struct{} // Serialize converts an interface into a json and writes it to the response. // You can optionally use the indent parameter to produce pretty JSONs. -func (d DefaultJSONSerializer) Serialize(c Context, i interface{}, indent string) error { +func (d DefaultJSONSerializer) Serialize(c *Context, target any, indent string) error { enc := json.NewEncoder(c.Response()) if indent != "" { enc.SetIndent("", indent) } - return enc.Encode(i) + return enc.Encode(target) } // Deserialize reads a JSON from a request body and converts it into an interface. -func (d DefaultJSONSerializer) Deserialize(c Context, i interface{}) error { - err := json.NewDecoder(c.Request().Body).Decode(i) - if ute, ok := err.(*json.UnmarshalTypeError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Unmarshal type error: expected=%v, got=%v, field=%v, offset=%v", ute.Type, ute.Value, ute.Field, ute.Offset)).SetInternal(err) - } else if se, ok := err.(*json.SyntaxError); ok { - return NewHTTPError(http.StatusBadRequest, fmt.Sprintf("Syntax error: offset=%v, error=%v", se.Offset, se.Error())).SetInternal(err) +func (d DefaultJSONSerializer) Deserialize(c *Context, target any) error { + if err := json.NewDecoder(c.Request().Body).Decode(target); err != nil { + return ErrBadRequest.Wrap(err) } - return err + return nil } diff --git a/json_test.go b/json_test.go index 0b15ed1a1..1804b3e82 100644 --- a/json_test.go +++ b/json_test.go @@ -17,7 +17,7 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", nil) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) // Echo assert.Equal(t, e, c.Echo()) @@ -34,15 +34,15 @@ func TestDefaultJSONCodec_Encode(t *testing.T) { enc := new(DefaultJSONSerializer) - err := enc.Serialize(c, user{1, "Jon Snow"}, "") + err := enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, "") if assert.NoError(t, err) { assert.Equal(t, userJSON+"\n", rec.Body.String()) } req = httptest.NewRequest(http.MethodPost, "/", nil) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) - err = enc.Serialize(c, user{1, "Jon Snow"}, " ") + c = e.NewContext(req, rec) + err = enc.Serialize(c, user{ID: 1, Name: "Jon Snow"}, " ") if assert.NoError(t, err) { assert.Equal(t, userJSONPretty+"\n", rec.Body.String()) } @@ -54,7 +54,7 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { e := New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec := httptest.NewRecorder() - c := e.NewContext(req, rec).(*context) + c := e.NewContext(req, rec) // Echo assert.Equal(t, e, c.Echo()) @@ -80,10 +80,10 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { var userUnmarshalSyntaxError = user{} req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(invalidContent)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec) err = enc.Deserialize(c, &userUnmarshalSyntaxError) assert.IsType(t, &HTTPError{}, err) - assert.EqualError(t, err, "code=400, message=Syntax error: offset=1, error=invalid character 'i' looking for beginning of value, internal=invalid character 'i' looking for beginning of value") + assert.EqualError(t, err, "code=400, message=Bad Request, err=invalid character 'i' looking for beginning of value") var userUnmarshalTypeError = struct { ID string `json:"id"` @@ -92,9 +92,9 @@ func TestDefaultJSONCodec_Decode(t *testing.T) { req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON)) rec = httptest.NewRecorder() - c = e.NewContext(req, rec).(*context) + c = e.NewContext(req, rec) err = enc.Deserialize(c, &userUnmarshalTypeError) assert.IsType(t, &HTTPError{}, err) - assert.EqualError(t, err, "code=400, message=Unmarshal type error: expected=string, got=number, field=id, offset=7, internal=json: cannot unmarshal number into Go struct field .id of type string") + assert.EqualError(t, err, "code=400, message=Bad Request, err=json: cannot unmarshal number into Go struct field .id of type string") } diff --git a/log.go b/log.go deleted file mode 100644 index 0acd9ff03..000000000 --- a/log.go +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package echo - -import ( - "github.com/labstack/gommon/log" - "io" -) - -// Logger defines the logging interface. -type Logger interface { - Output() io.Writer - SetOutput(w io.Writer) - Prefix() string - SetPrefix(p string) - Level() log.Lvl - SetLevel(v log.Lvl) - SetHeader(h string) - Print(i ...interface{}) - Printf(format string, args ...interface{}) - Printj(j log.JSON) - Debug(i ...interface{}) - Debugf(format string, args ...interface{}) - Debugj(j log.JSON) - Info(i ...interface{}) - Infof(format string, args ...interface{}) - Infoj(j log.JSON) - Warn(i ...interface{}) - Warnf(format string, args ...interface{}) - Warnj(j log.JSON) - Error(i ...interface{}) - Errorf(format string, args ...interface{}) - Errorj(j log.JSON) - Fatal(i ...interface{}) - Fatalj(j log.JSON) - Fatalf(format string, args ...interface{}) - Panic(i ...interface{}) - Panicj(j log.JSON) - Panicf(format string, args ...interface{}) -} diff --git a/middleware/DEVELOPMENT.md b/middleware/DEVELOPMENT.md new file mode 100644 index 000000000..77cb226dd --- /dev/null +++ b/middleware/DEVELOPMENT.md @@ -0,0 +1,11 @@ +# Development Guidelines for middlewares + +## Best practices: + +* Do not use `panic` in middleware creator functions in case of invalid configuration. +* In case of an error in middleware function handling request avoid using `c.Error()` and returning no error instead + because previous middlewares up in call chain could have logic for dealing with returned errors. +* Create middleware configuration structs that implement `MiddlewareConfigurator` interface so can decide if they + want to create middleware with panics or with returning errors on configuration errors. +* When adding `echo.Context` to function type or fields make it first parameter so all functions with Context looks same. + diff --git a/middleware/basic_auth.go b/middleware/basic_auth.go index 4a46098e3..e0a284c67 100644 --- a/middleware/basic_auth.go +++ b/middleware/basic_auth.go @@ -4,105 +4,153 @@ package middleware import ( + "bytes" + "cmp" "encoding/base64" - "net/http" + "errors" "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -// BasicAuthConfig defines the config for BasicAuth middleware. +// BasicAuthConfig defines the config for BasicAuthWithConfig middleware. +// +// SECURITY: The Validator function is responsible for securely comparing credentials. +// See BasicAuthValidator documentation for guidance on preventing timing attacks. type BasicAuthConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // Validator is a function to validate BasicAuth credentials. + // Validator is a function to validate BasicAuthWithConfig credentials. Note: if request contains multiple basic auth headers + // this function would be called once for each header until first valid result is returned // Required. Validator BasicAuthValidator - // Realm is a string to define realm attribute of BasicAuth. + // Realm is a string to define realm attribute of BasicAuthWithConfig. // Default value "Restricted". Realm string + + // AllowedCheckLimit set how many headers are allowed to be checked. This is useful + // environments like corporate test environments with application proxies restricting + // access to environment with their own auth scheme. + // Defaults to 1. + AllowedCheckLimit uint } -// BasicAuthValidator defines a function to validate BasicAuth credentials. -// The function should return a boolean indicating whether the credentials are valid, -// and an error if any error occurs during the validation process. -type BasicAuthValidator func(string, string, echo.Context) (bool, error) +// BasicAuthValidator defines a function to validate BasicAuthWithConfig credentials. +// +// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate +// valid usernames or passwords, validator implementations MUST use constant-time +// comparison for credential checking. Use crypto/subtle.ConstantTimeCompare instead +// of standard string equality (==) or switch statements. +// +// Example of SECURE implementation: +// +// import "crypto/subtle" +// +// validator := func(c *echo.Context, username, password string) (bool, error) { +// // Fetch expected credentials from database/config +// expectedUser := "admin" +// expectedPass := "secretpassword" +// +// // Use constant-time comparison to prevent timing attacks +// userMatch := subtle.ConstantTimeCompare([]byte(username), []byte(expectedUser)) == 1 +// passMatch := subtle.ConstantTimeCompare([]byte(password), []byte(expectedPass)) == 1 +// +// if userMatch && passMatch { +// return true, nil +// } +// return false, nil +// } +// +// Example of INSECURE implementation (DO NOT USE): +// +// // VULNERABLE TO TIMING ATTACKS - DO NOT USE +// validator := func(c *echo.Context, username, password string) (bool, error) { +// if username == "admin" && password == "secret" { // Timing leak! +// return true, nil +// } +// return false, nil +// } +type BasicAuthValidator func(c *echo.Context, user string, password string) (bool, error) const ( basic = "basic" defaultRealm = "Restricted" ) -// DefaultBasicAuthConfig is the default BasicAuth middleware config. -var DefaultBasicAuthConfig = BasicAuthConfig{ - Skipper: DefaultSkipper, - Realm: defaultRealm, -} - // BasicAuth returns an BasicAuth middleware. // // For valid credentials it calls the next handler. // For missing or invalid credentials, it sends "401 - Unauthorized" response. func BasicAuth(fn BasicAuthValidator) echo.MiddlewareFunc { - c := DefaultBasicAuthConfig - c.Validator = fn - return BasicAuthWithConfig(c) + return BasicAuthWithConfig(BasicAuthConfig{Validator: fn}) } -// BasicAuthWithConfig returns an BasicAuth middleware with config. -// See `BasicAuth()`. +// BasicAuthWithConfig returns an BasicAuthWithConfig middleware with config. func BasicAuthWithConfig(config BasicAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BasicAuthConfig to middleware or returns an error for invalid configuration +func (config BasicAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Validator == nil { - panic("echo: basic-auth middleware requires a validator function") + return nil, errors.New("echo basic-auth middleware requires a validator function") } if config.Skipper == nil { - config.Skipper = DefaultBasicAuthConfig.Skipper + config.Skipper = DefaultSkipper } - if config.Realm == "" { - config.Realm = defaultRealm + realm := defaultRealm + if config.Realm != "" { + realm = config.Realm } - - // Pre-compute the quoted realm for WWW-Authenticate header (RFC 7617) - quotedRealm := strconv.Quote(config.Realm) + realm = strconv.Quote(realm) + limit := cmp.Or(config.AllowedCheckLimit, 1) return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } - auth := c.Request().Header.Get(echo.HeaderAuthorization) + var lastError error l := len(basic) + i := uint(0) + for _, auth := range c.Request().Header[echo.HeaderAuthorization] { + if i >= limit { + break + } + if !(len(auth) > l+1 && strings.EqualFold(auth[:l], basic)) { + continue + } + i++ - if len(auth) > l+1 && strings.EqualFold(auth[:l], basic) { // Invalid base64 shouldn't be treated as error // instead should be treated as invalid client input - b, err := base64.StdEncoding.DecodeString(auth[l+1:]) - if err != nil { - return echo.NewHTTPError(http.StatusBadRequest).SetInternal(err) + b, errDecode := base64.StdEncoding.DecodeString(auth[l+1:]) + if errDecode != nil { + lastError = echo.ErrBadRequest.Wrap(errDecode) + continue } - - cred := string(b) - user, pass, ok := strings.Cut(cred, ":") - if ok { - // Verify credentials - valid, err := config.Validator(user, pass, c) - if err != nil { - return err + idx := bytes.IndexByte(b, ':') + if idx >= 0 { + valid, errValidate := config.Validator(c, string(b[:idx]), string(b[idx+1:])) + if errValidate != nil { + lastError = errValidate } else if valid { return next(c) } } } + if lastError != nil { + return lastError + } + // Need to return `401` for browsers to pop-up login box. - // Realm is case-insensitive, so we can use "basic" directly. See RFC 7617. - c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+quotedRealm) + c.Response().Header().Set(echo.HeaderWWWAuthenticate, basic+" realm="+realm) return echo.ErrUnauthorized } - } + }, nil } diff --git a/middleware/basic_auth_test.go b/middleware/basic_auth_test.go index 2d3192615..42386354f 100644 --- a/middleware/basic_auth_test.go +++ b/middleware/basic_auth_test.go @@ -4,6 +4,7 @@ package middleware import ( + "crypto/subtle" "encoding/base64" "errors" "net/http" @@ -11,116 +12,177 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestBasicAuth(t *testing.T) { - e := echo.New() + validatorFunc := func(c *echo.Context, u, p string) (bool, error) { + // Use constant-time comparison to prevent timing attacks + userMatch := subtle.ConstantTimeCompare([]byte(u), []byte("joe")) == 1 + passMatch := subtle.ConstantTimeCompare([]byte(p), []byte("secret")) == 1 - mockValidator := func(u, p string, c echo.Context) (bool, error) { - if u == "joe" && p == "secret" { + if userMatch && passMatch { return true, nil } + + // Special case for testing error handling + if u == "error" { + return false, errors.New(p) + } + return false, nil } + defaultConfig := BasicAuthConfig{Validator: validatorFunc} - tests := []struct { - name string - authHeader string - expectedCode int - expectedAuth string - skipperResult bool - expectedErr bool - expectedErrMsg string + var testCases = []struct { + name string + givenConfig BasicAuthConfig + whenAuth []string + expectHeader string + expectErr string }{ { - name: "Valid credentials", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), - expectedCode: http.StatusOK, + name: "ok", + givenConfig: defaultConfig, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "ok, multiple", + givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 2}, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), + basic + " NOT_BASE64", + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + }, }, { - name: "Case-insensitive header scheme", - authHeader: strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), - expectedCode: http.StatusOK, + name: "nok, multiple, valid out of limit", + givenConfig: BasicAuthConfig{Validator: validatorFunc, AllowedCheckLimit: 1}, + whenAuth: []string{ + "Bearer " + base64.StdEncoding.EncodeToString([]byte("token")), + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid_password")), + // limit only check first and should not check auth below + basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret")), + }, + expectHeader: basic + ` realm="Restricted"`, + expectErr: "Unauthorized", }, { - name: "Invalid credentials", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:invalid-password")), - expectedCode: http.StatusUnauthorized, - expectedAuth: basic + ` realm="someRealm"`, - expectedErr: true, - expectedErrMsg: "Unauthorized", + name: "nok, invalid Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="Restricted"`, + expectErr: "Unauthorized", }, { - name: "Invalid base64 string", - authHeader: basic + " invalidString", - expectedCode: http.StatusBadRequest, - expectedErr: true, - expectedErrMsg: "Bad Request", + name: "nok, not base64 Authorization header", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " NOT_BASE64"}, + expectErr: "code=400, message=Bad Request, err=illegal base64 data at input byte 3", }, { - name: "Missing Authorization header", - expectedCode: http.StatusUnauthorized, - expectedErr: true, - expectedErrMsg: "Unauthorized", + name: "nok, missing Authorization header", + givenConfig: defaultConfig, + expectHeader: basic + ` realm="Restricted"`, + expectErr: "Unauthorized", }, { - name: "Invalid Authorization header", - authHeader: base64.StdEncoding.EncodeToString([]byte("invalid")), - expectedCode: http.StatusUnauthorized, - expectedErr: true, - expectedErrMsg: "Unauthorized", + name: "ok, realm", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, }, { - name: "Skipped Request", - authHeader: basic + " " + base64.StdEncoding.EncodeToString([]byte("joe:skip")), - expectedCode: http.StatusOK, - skipperResult: true, + name: "ok, realm, case-insensitive header scheme", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("joe:secret"))}, + }, + { + name: "nok, realm, invalid Authorization header", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Realm: "someRealm"}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, + expectHeader: basic + ` realm="someRealm"`, + expectErr: "Unauthorized", + }, + { + name: "nok, validator func returns an error", + givenConfig: defaultConfig, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("error:my_error"))}, + expectErr: "my_error", + }, + { + name: "ok, skipped", + givenConfig: BasicAuthConfig{Validator: validatorFunc, Skipper: func(c *echo.Context) bool { + return true + }}, + whenAuth: []string{strings.ToUpper(basic) + " " + base64.StdEncoding.EncodeToString([]byte("invalid"))}, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) res := httptest.NewRecorder() c := e.NewContext(req, res) - if tt.authHeader != "" { - req.Header.Set(echo.HeaderAuthorization, tt.authHeader) - } + config := tc.givenConfig - h := BasicAuthWithConfig(BasicAuthConfig{ - Validator: mockValidator, - Realm: "someRealm", - Skipper: func(c echo.Context) bool { - return tt.skipperResult - }, - })(func(c echo.Context) error { - return c.String(http.StatusOK, "test") - }) + mw, err := config.ToMiddleware() + assert.NoError(t, err) - err := h(c) + h := mw(func(c *echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) - if tt.expectedErr { - var he *echo.HTTPError - errors.As(err, &he) - assert.Equal(t, tt.expectedCode, he.Code) - if tt.expectedAuth != "" { - assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) + if len(tc.whenAuth) != 0 { + for _, a := range tc.whenAuth { + req.Header.Add(echo.HeaderAuthorization, a) } + } + err = h(c) + + if tc.expectErr != "" { + assert.Equal(t, http.StatusOK, res.Code) + assert.EqualError(t, err, tc.expectErr) } else { + assert.Equal(t, http.StatusTeapot, res.Code) assert.NoError(t, err) - assert.Equal(t, tt.expectedCode, res.Code) + } + if tc.expectHeader != "" { + assert.Equal(t, tc.expectHeader, res.Header().Get(echo.HeaderWWWAuthenticate)) } }) } } +func TestBasicAuth_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuth(nil) + assert.NotNil(t, mw) + }) + + mw := BasicAuth(func(c *echo.Context, user string, password string) (bool, error) { + return true, nil + }) + assert.NotNil(t, mw) +} + +func TestBasicAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: nil}) + assert.NotNil(t, mw) + }) + + mw := BasicAuthWithConfig(BasicAuthConfig{Validator: func(c *echo.Context, user string, password string) (bool, error) { + return true, nil + }}) + assert.NotNil(t, mw) +} + func TestBasicAuthRealm(t *testing.T) { e := echo.New() - mockValidator := func(u, p string, c echo.Context) (bool, error) { + mockValidator := func(c *echo.Context, u, p string) (bool, error) { return false, nil // Always fail to trigger WWW-Authenticate header } @@ -165,15 +227,13 @@ func TestBasicAuthRealm(t *testing.T) { h := BasicAuthWithConfig(BasicAuthConfig{ Validator: mockValidator, Realm: tt.realm, - })(func(c echo.Context) error { + })(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) err := h(c) - var he *echo.HTTPError - errors.As(err, &he) - assert.Equal(t, http.StatusUnauthorized, he.Code) + assert.Equal(t, echo.ErrUnauthorized, err) assert.Equal(t, tt.expectedAuth, res.Header().Get(echo.HeaderWWWAuthenticate)) }) } diff --git a/middleware/body_dump.go b/middleware/body_dump.go index add778d67..d5c823c9b 100644 --- a/middleware/body_dump.go +++ b/middleware/body_dump.go @@ -10,8 +10,9 @@ import ( "io" "net" "net/http" + "sync" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // BodyDumpConfig defines the config for BodyDump middleware. @@ -19,78 +20,127 @@ type BodyDumpConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // Handler receives request and response payload. + // Handler receives request, response payloads and handler error if there are any. // Required. Handler BodyDumpHandler + + // MaxRequestBytes limits how much of the request body to dump. + // If the request body exceeds this limit, only the first MaxRequestBytes + // are dumped. The handler callback receives truncated data. + // Default: 5 * MB (5,242,880 bytes) + // Set to -1 to disable limits (not recommended in production). + MaxRequestBytes int64 + + // MaxResponseBytes limits how much of the response body to dump. + // If the response body exceeds this limit, only the first MaxResponseBytes + // are dumped. The handler callback receives truncated data. + // Default: 5 * MB (5,242,880 bytes) + // Set to -1 to disable limits (not recommended in production). + MaxResponseBytes int64 } // BodyDumpHandler receives the request and response payload. -type BodyDumpHandler func(echo.Context, []byte, []byte) +type BodyDumpHandler func(c *echo.Context, reqBody []byte, resBody []byte, err error) type bodyDumpResponseWriter struct { io.Writer http.ResponseWriter } -// DefaultBodyDumpConfig is the default BodyDump middleware config. -var DefaultBodyDumpConfig = BodyDumpConfig{ - Skipper: DefaultSkipper, -} - // BodyDump returns a BodyDump middleware. // // BodyDump middleware captures the request and response payload and calls the // registered handler. +// +// SECURITY: By default, this limits dumped bodies to 5MB to prevent memory exhaustion +// attacks. To customize limits, use BodyDumpWithConfig. To disable limits (not recommended +// in production), explicitly set MaxRequestBytes and MaxResponseBytes to -1. func BodyDump(handler BodyDumpHandler) echo.MiddlewareFunc { - c := DefaultBodyDumpConfig - c.Handler = handler - return BodyDumpWithConfig(c) + return BodyDumpWithConfig(BodyDumpConfig{Handler: handler}) } // BodyDumpWithConfig returns a BodyDump middleware with config. // See: `BodyDump()`. +// +// SECURITY: If MaxRequestBytes and MaxResponseBytes are not set (zero values), they default +// to 5MB each to prevent DoS attacks via large payloads. Set them explicitly to -1 to disable +// limits if needed for your use case. func BodyDumpWithConfig(config BodyDumpConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyDumpConfig to middleware or returns an error for invalid configuration +func (config BodyDumpConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Handler == nil { - panic("echo: body-dump middleware requires a handler function") + return nil, errors.New("echo body-dump middleware requires a handler function") } if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.MaxRequestBytes == 0 { + config.MaxRequestBytes = 5 * MB + } + if config.MaxResponseBytes == 0 { + config.MaxResponseBytes = 5 * MB } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } - // Request - reqBody := []byte{} - if c.Request().Body != nil { - var readErr error - reqBody, readErr = io.ReadAll(c.Request().Body) - if readErr != nil { - return readErr - } - } - c.Request().Body = io.NopCloser(bytes.NewBuffer(reqBody)) // Reset + reqBuf := bodyDumpBufferPool.Get().(*bytes.Buffer) + reqBuf.Reset() + defer bodyDumpBufferPool.Put(reqBuf) - // Response - resBody := new(bytes.Buffer) - mw := io.MultiWriter(c.Response().Writer, resBody) - writer := &bodyDumpResponseWriter{Writer: mw, ResponseWriter: c.Response().Writer} - c.Response().Writer = writer + var bodyReader io.Reader = c.Request().Body + if config.MaxRequestBytes > 0 { + bodyReader = io.LimitReader(c.Request().Body, config.MaxRequestBytes) + } + _, readErr := io.Copy(reqBuf, bodyReader) + if readErr != nil && readErr != io.EOF { + return readErr + } + if config.MaxRequestBytes > 0 { + // Drain any remaining body data to prevent connection issues + _, _ = io.Copy(io.Discard, c.Request().Body) + _ = c.Request().Body.Close() + } - if err = next(c); err != nil { - c.Error(err) + reqBody := make([]byte, reqBuf.Len()) + copy(reqBody, reqBuf.Bytes()) + c.Request().Body = io.NopCloser(bytes.NewReader(reqBody)) + + // response part + resBuf := bodyDumpBufferPool.Get().(*bytes.Buffer) + resBuf.Reset() + defer bodyDumpBufferPool.Put(resBuf) + + var respWriter io.Writer + if config.MaxResponseBytes > 0 { + respWriter = &limitedWriter{ + response: c.Response(), + dumpBuf: resBuf, + limit: config.MaxResponseBytes, + } + } else { + respWriter = io.MultiWriter(c.Response(), resBuf) } + writer := &bodyDumpResponseWriter{ + Writer: respWriter, + ResponseWriter: c.Response(), + } + c.SetResponse(writer) + + err := next(c) // Callback - config.Handler(c, reqBody, resBody.Bytes()) + config.Handler(c, reqBody, resBuf.Bytes(), err) - return + return err } - } + }, nil } func (w *bodyDumpResponseWriter) WriteHeader(code int) { @@ -115,3 +165,37 @@ func (w *bodyDumpResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (w *bodyDumpResponseWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter } + +var bodyDumpBufferPool = sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, +} + +type limitedWriter struct { + response http.ResponseWriter + dumpBuf *bytes.Buffer + dumped int64 + limit int64 +} + +func (w *limitedWriter) Write(b []byte) (n int, err error) { + // Always write full data to actual response (don't truncate client response) + n, err = w.response.Write(b) + if err != nil { + return n, err + } + + // Write to dump buffer only up to limit + if w.dumped < w.limit { + remaining := w.limit - w.dumped + toDump := int64(n) + if toDump > remaining { + toDump = remaining + } + w.dumpBuf.Write(b[:toDump]) + w.dumped += toDump + } + + return n, nil +} diff --git a/middleware/body_dump_test.go b/middleware/body_dump_test.go index 7a7dee3d9..f493e75c8 100644 --- a/middleware/body_dump_test.go +++ b/middleware/body_dump_test.go @@ -11,7 +11,7 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -21,7 +21,7 @@ func TestBodyDump(t *testing.T) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c echo.Context) error { + h := func(c *echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err @@ -31,10 +31,11 @@ func TestBodyDump(t *testing.T) { requestBody := "" responseBody := "" - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { + mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { requestBody = string(reqBody) responseBody = string(resBody) - }) + }}.ToMiddleware() + assert.NoError(t, err) if assert.NoError(t, mw(h)(c)) { assert.Equal(t, requestBody, hw) @@ -43,51 +44,76 @@ func TestBodyDump(t *testing.T) { assert.Equal(t, hw, rec.Body.String()) } - // Must set default skipper - BodyDumpWithConfig(BodyDumpConfig{ - Skipper: nil, - Handler: func(c echo.Context, reqBody, resBody []byte) { - requestBody = string(reqBody) - responseBody = string(resBody) +} + +func TestBodyDump_skipper(t *testing.T) { + e := echo.New() + + isCalled := false + mw, err := BodyDumpConfig{ + Skipper: func(c *echo.Context) bool { + return true }, - }) + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + isCalled = true + }, + }.ToMiddleware() + assert.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("{}")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c *echo.Context) error { + return errors.New("some error") + } + + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.False(t, isCalled) } -func TestBodyDumpFails(t *testing.T) { +func TestBodyDump_fails(t *testing.T) { e := echo.New() hw := "Hello, World!" req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c echo.Context) error { + h := func(c *echo.Context) error { return errors.New("some error") } - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) {}) + mw, err := BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}.ToMiddleware() + assert.NoError(t, err) - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } + err = mw(h)(c) + assert.EqualError(t, err, "some error") + assert.Equal(t, http.StatusOK, rec.Code) +} + +func TestBodyDumpWithConfig_panic(t *testing.T) { assert.Panics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ + mw := BodyDumpWithConfig(BodyDumpConfig{ Skipper: nil, Handler: nil, }) + assert.NotNil(t, mw) }) assert.NotPanics(t, func() { - mw = BodyDumpWithConfig(BodyDumpConfig{ - Skipper: func(c echo.Context) bool { - return true - }, - Handler: func(c echo.Context, reqBody, resBody []byte) { - }, - }) + mw := BodyDumpWithConfig(BodyDumpConfig{Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}}) + assert.NotNil(t, mw) + }) +} - if !assert.Error(t, mw(h)(c)) { - t.FailNow() - } +func TestBodyDump_panic(t *testing.T) { + assert.Panics(t, func() { + mw := BodyDump(nil) + assert.NotNil(t, mw) + }) + + assert.NotPanics(t, func() { + BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) {}) }) } @@ -95,7 +121,6 @@ func TestBodyDumpResponseWriter_CanNotFlush(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: new(testResponseWriterNoFlushHijack), // this RW does not support flush } - assert.PanicsWithError(t, "response writer flushing is not supported", func() { bdrw.Flush() }) @@ -106,7 +131,6 @@ func TestBodyDumpResponseWriter_CanFlush(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, } - bdrw.Flush() assert.Equal(t, 1, trwu.unwrapCalled) } @@ -116,7 +140,6 @@ func TestBodyDumpResponseWriter_CanUnwrap(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: trwu, } - result := bdrw.Unwrap() assert.Equal(t, trwu, result) } @@ -126,7 +149,6 @@ func TestBodyDumpResponseWriter_CanHijack(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } - _, _, err := bdrw.Hijack() assert.EqualError(t, err, "can hijack") } @@ -136,7 +158,6 @@ func TestBodyDumpResponseWriter_CanNotHijack(t *testing.T) { bdrw := bodyDumpResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } - _, _, err := bdrw.Hijack() assert.EqualError(t, err, "feature not supported") } @@ -155,14 +176,14 @@ func TestBodyDump_ReadError(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c echo.Context) error { + h := func(c *echo.Context) error { // This handler should not be reached if body read fails body, _ := io.ReadAll(c.Request().Body) return c.String(http.StatusOK, string(body)) } requestBodyReceived := "" - mw := BodyDump(func(c echo.Context, reqBody, resBody []byte) { + mw := BodyDump(func(c *echo.Context, reqBody, resBody []byte, err error) { requestBodyReceived = string(reqBody) }) @@ -202,3 +223,359 @@ func (f *failingReadCloser) Read(p []byte) (n int, err error) { func (f *failingReadCloser) Close() error { return nil } + +func TestBodyDump_RequestWithinLimit(t *testing.T) { + e := echo.New() + requestData := "Hello, World!" + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + }, + MaxRequestBytes: 1 * MB, // 1MB limit + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, requestData, requestBodyDumped, "Small request should be fully dumped") + assert.Equal(t, requestData, rec.Body.String(), "Handler should receive full request") +} + +func TestBodyDump_RequestExceedsLimit(t *testing.T) { + e := echo.New() + // Create 2KB of data but limit to 1KB + largeData := strings.Repeat("A", 2*1024) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + limit := int64(1024) // 1KB limit + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + }, + MaxRequestBytes: limit, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, int(limit), len(requestBodyDumped), "Dumped request should be truncated to limit") + assert.Equal(t, strings.Repeat("A", 1024), requestBodyDumped, "Dumped data should match first N bytes") + // Handler should receive truncated data (what was dumped) + assert.Equal(t, strings.Repeat("A", 1024), rec.Body.String()) +} + +func TestBodyDump_RequestAtExactLimit(t *testing.T) { + e := echo.New() + exactData := strings.Repeat("B", 1024) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(exactData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + limit := int64(1024) + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + }, + MaxRequestBytes: limit, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, int(limit), len(requestBodyDumped), "Exact limit should dump full data") + assert.Equal(t, exactData, requestBodyDumped) +} + +func TestBodyDump_ResponseWithinLimit(t *testing.T) { + e := echo.New() + responseData := "Response data" + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + return c.String(http.StatusOK, responseData) + } + + responseBodyDumped := "" + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, responseData, responseBodyDumped, "Small response should be fully dumped") + assert.Equal(t, responseData, rec.Body.String(), "Client should receive full response") +} + +func TestBodyDump_ResponseExceedsLimit(t *testing.T) { + e := echo.New() + largeResponse := strings.Repeat("X", 2*1024) // 2KB + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + return c.String(http.StatusOK, largeResponse) + } + + responseBodyDumped := "" + limit := int64(1024) // 1KB limit + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: limit, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + // Dump should be truncated + assert.Equal(t, int(limit), len(responseBodyDumped), "Dumped response should be truncated to limit") + assert.Equal(t, strings.Repeat("X", 1024), responseBodyDumped) + // Client should still receive full response! + assert.Equal(t, largeResponse, rec.Body.String(), "Client must receive full response despite dump truncation") +} + +func TestBodyDump_ClientGetsFullResponse(t *testing.T) { + e := echo.New() + // This is critical - even when dump is limited, client gets everything + largeResponse := strings.Repeat("DATA", 500) // 2KB + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + // Write response in chunks to test incremental writes + for i := 0; i < 4; i++ { + c.Response().Write([]byte(strings.Repeat("DATA", 125))) + } + return nil + } + + responseBodyDumped := "" + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: 512, // Very small limit + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, 512, len(responseBodyDumped), "Dump should be limited") + assert.Equal(t, largeResponse, rec.Body.String(), "Client must get full response") +} + +func TestBodyDump_BothLimitsSimultaneous(t *testing.T) { + e := echo.New() + largeRequest := strings.Repeat("Q", 2*1024) + largeResponse := strings.Repeat("R", 2*1024) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeRequest)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + io.ReadAll(c.Request().Body) // Consume request + return c.String(http.StatusOK, largeResponse) + } + + requestBodyDumped := "" + responseBodyDumped := "" + limit := int64(1024) + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: limit, + MaxResponseBytes: limit, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, int(limit), len(requestBodyDumped), "Request dump should be limited") + assert.Equal(t, int(limit), len(responseBodyDumped), "Response dump should be limited") +} + +func TestBodyDump_DefaultConfig(t *testing.T) { + e := echo.New() + smallData := "test" + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(smallData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + // Use default config which should have 1MB limits + config := BodyDumpConfig{} + config.Handler = func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + } + mw, err := config.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, smallData, requestBodyDumped) +} + +func TestBodyDump_LargeRequestDosPrevention(t *testing.T) { + e := echo.New() + // Simulate a very large request (10MB) that could cause OOM + largeSize := 10 * 1024 * 1024 // 10MB + largeData := strings.Repeat("M", largeSize) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(largeData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + body, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(body)) + } + + requestBodyDumped := "" + limit := int64(1 * MB) // Only dump 1MB max + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + requestBodyDumped = string(reqBody) + }, + MaxRequestBytes: limit, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + // Verify only 1MB was dumped, not 10MB + assert.Equal(t, int(limit), len(requestBodyDumped), "Should only dump up to limit, preventing DoS") + assert.Less(t, len(requestBodyDumped), largeSize, "Dump should be much smaller than full request") +} + +func TestBodyDump_LargeResponseDosPrevention(t *testing.T) { + e := echo.New() + // Simulate a very large response (10MB) + largeSize := 10 * 1024 * 1024 // 10MB + largeResponse := strings.Repeat("R", largeSize) + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h := func(c *echo.Context) error { + return c.String(http.StatusOK, largeResponse) + } + + responseBodyDumped := "" + limit := int64(1 * MB) // Only dump 1MB max + mw, err := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + responseBodyDumped = string(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: limit, + }.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + assert.NoError(t, err) + // Verify only 1MB was dumped, not 10MB + assert.Equal(t, int(limit), len(responseBodyDumped), "Should only dump up to limit, preventing DoS") + assert.Less(t, len(responseBodyDumped), largeSize, "Dump should be much smaller than full response") + // Client still gets full response + assert.Equal(t, largeSize, rec.Body.Len(), "Client must receive full response") +} + +func BenchmarkBodyDump_WithLimit(b *testing.B) { + e := echo.New() + requestData := strings.Repeat("data", 256) // 1KB + responseData := strings.Repeat("resp", 256) // 1KB + + h := func(c *echo.Context) error { + io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, responseData) + } + + mw, _ := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) { + // Simulate logging + _ = len(reqBody) + len(resBody) + }, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(h)(c) + } +} + +func BenchmarkBodyDump_BufferPooling(b *testing.B) { + e := echo.New() + requestData := strings.Repeat("x", 1024) + responseData := "response" + + h := func(c *echo.Context) error { + io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, responseData) + } + + mw, _ := BodyDumpConfig{ + Handler: func(c *echo.Context, reqBody, resBody []byte, err error) {}, + MaxRequestBytes: 1 * MB, + MaxResponseBytes: 1 * MB, + }.ToMiddleware() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(requestData)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + mw(h)(c) + } +} diff --git a/middleware/body_limit.go b/middleware/body_limit.go index d13ad2c4e..4f1963e18 100644 --- a/middleware/body_limit.go +++ b/middleware/body_limit.go @@ -4,24 +4,20 @@ package middleware import ( - "fmt" "io" "net/http" "sync" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/bytes" + "github.com/labstack/echo/v5" ) -// BodyLimitConfig defines the config for BodyLimit middleware. +// BodyLimitConfig defines the config for BodyLimitWithConfig middleware. type BodyLimitConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // Maximum allowed size for a request body, it can be specified - // as `4x` or `4xB`, where x is one of the multiple from K, M, G, T or P. - Limit string `yaml:"limit"` - limit int64 + // LimitBytes is maximum allowed size in bytes for a request body + LimitBytes int64 } type limitedReader struct { @@ -30,50 +26,43 @@ type limitedReader struct { read int64 } -// DefaultBodyLimitConfig is the default BodyLimit middleware config. -var DefaultBodyLimitConfig = BodyLimitConfig{ - Skipper: DefaultSkipper, -} - // BodyLimit returns a BodyLimit middleware. // -// BodyLimit middleware sets the maximum allowed size for a request body, if the -// size exceeds the configured limit, it sends "413 - Request Entity Too Large" -// response. The BodyLimit is determined based on both `Content-Length` request +// BodyLimit middleware sets the maximum allowed size for a request body, if the size exceeds the configured limit, it +// sends "413 - Request Entity Too Large" response. The BodyLimit is determined based on both `Content-Length` request // header and actual content read, which makes it super secure. -// Limit can be specified as `4x` or `4xB`, where x is one of the multiple from K, M, -// G, T or P. -func BodyLimit(limit string) echo.MiddlewareFunc { - c := DefaultBodyLimitConfig - c.Limit = limit - return BodyLimitWithConfig(c) +func BodyLimit(limitBytes int64) echo.MiddlewareFunc { + return BodyLimitWithConfig(BodyLimitConfig{LimitBytes: limitBytes}) } -// BodyLimitWithConfig returns a BodyLimit middleware with config. -// See: `BodyLimit()`. +// BodyLimitWithConfig returns a BodyLimitWithConfig middleware. Middleware sets the maximum allowed size in bytes for +// a request body, if the size exceeds the configured limit, it sends "413 - Request Entity Too Large" response. +// The BodyLimitWithConfig is determined based on both `Content-Length` request header and actual content read, which +// makes it super secure. func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts BodyLimitConfig to middleware or returns an error for invalid configuration +func (config BodyLimitConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyLimitConfig.Skipper + config.Skipper = DefaultSkipper } - - limit, err := bytes.Parse(config.Limit) - if err != nil { - panic(fmt.Errorf("echo: invalid body-limit=%s", config.Limit)) + pool := sync.Pool{ + New: func() any { + return &limitedReader{BodyLimitConfig: config} + }, } - config.limit = limit - pool := limitedReaderPool(config) return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } - req := c.Request() // Based on content length - if req.ContentLength > config.limit { + if req.ContentLength > config.LimitBytes { return echo.ErrStatusRequestEntityTooLarge } @@ -88,13 +77,13 @@ func BodyLimitWithConfig(config BodyLimitConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } func (r *limitedReader) Read(b []byte) (n int, err error) { n, err = r.reader.Read(b) r.read += int64(n) - if r.read > r.limit { + if r.read > r.LimitBytes { return n, echo.ErrStatusRequestEntityTooLarge } return @@ -108,11 +97,3 @@ func (r *limitedReader) Reset(reader io.ReadCloser) { r.reader = reader r.read = 0 } - -func limitedReaderPool(c BodyLimitConfig) sync.Pool { - return sync.Pool{ - New: func() interface{} { - return &limitedReader{BodyLimitConfig: c} - }, - } -} diff --git a/middleware/body_limit_test.go b/middleware/body_limit_test.go index d14c2b649..5529f5d84 100644 --- a/middleware/body_limit_test.go +++ b/middleware/body_limit_test.go @@ -10,17 +10,17 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestBodyLimit(t *testing.T) { +func TestBodyLimitConfig_ToMiddleware(t *testing.T) { e := echo.New() hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := func(c echo.Context) error { + h := func(c *echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err @@ -29,41 +29,51 @@ func TestBodyLimit(t *testing.T) { } // Based on content length (within limit) - if assert.NoError(t, BodyLimit("2M")(h)(c)) { + mw, err := BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = mw(h)(c) + if assert.NoError(t, err) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.Bytes()) } - // Based on content length (overlimit) - he := BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) + // Based on content read (overlimit) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he := mw(h)(c).(echo.HTTPStatusCoder) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) // Based on content read (within limit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - if assert.NoError(t, BodyLimit("2M")(h)(c)) { - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "Hello, World!", rec.Body.String()) - } + + mw, err = BodyLimitConfig{LimitBytes: 2 * MB}.ToMiddleware() + assert.NoError(t, err) + err = mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "Hello, World!", rec.Body.String()) // Based on content read (overlimit) req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) req.ContentLength = -1 rec = httptest.NewRecorder() c = e.NewContext(req, rec) - he = BodyLimit("2B")(h)(c).(*echo.HTTPError) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) + mw, err = BodyLimitConfig{LimitBytes: 2}.ToMiddleware() + assert.NoError(t, err) + he = mw(h)(c).(echo.HTTPStatusCoder) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) } func TestBodyLimitReader(t *testing.T) { hw := []byte("Hello, World!") config := BodyLimitConfig{ - Skipper: DefaultSkipper, - Limit: "2B", - limit: 2, + Skipper: DefaultSkipper, + LimitBytes: 2, } reader := &limitedReader{ BodyLimitConfig: config, @@ -72,8 +82,8 @@ func TestBodyLimitReader(t *testing.T) { // read all should return ErrStatusRequestEntityTooLarge _, err := io.ReadAll(reader) - he := err.(*echo.HTTPError) - assert.Equal(t, http.StatusRequestEntityTooLarge, he.Code) + he := err.(echo.HTTPStatusCoder) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) // reset reader and read two bytes must succeed bt := make([]byte, 2) @@ -83,91 +93,74 @@ func TestBodyLimitReader(t *testing.T) { assert.Equal(t, nil, err) } -func TestBodyLimitWithConfig_Skipper(t *testing.T) { +func TestBodyLimit_skipper(t *testing.T) { e := echo.New() - h := func(c echo.Context) error { + h := func(c *echo.Context) error { body, err := io.ReadAll(c.Request().Body) if err != nil { return err } return c.String(http.StatusOK, string(body)) } - mw := BodyLimitWithConfig(BodyLimitConfig{ - Skipper: func(c echo.Context) bool { + mw, err := BodyLimitConfig{ + Skipper: func(c *echo.Context) bool { return true }, - Limit: "2B", // if not skipped this limit would make request to fail limit check - }) + LimitBytes: 2, + }.ToMiddleware() + assert.NoError(t, err) hw := []byte("Hello, World!") req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - err := mw(h)(c) + err = mw(h)(c) assert.NoError(t, err) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, hw, rec.Body.Bytes()) } func TestBodyLimitWithConfig(t *testing.T) { - var testCases = []struct { - name string - givenLimit string - whenBody []byte - expectBody []byte - expectError string - }{ - { - name: "ok, body is less than limit", - givenLimit: "10B", - whenBody: []byte("123456789"), - expectBody: []byte("123456789"), - expectError: "", - }, - { - name: "nok, body is more than limit", - givenLimit: "9B", - whenBody: []byte("1234567890"), - expectBody: []byte(nil), - expectError: "code=413, message=Request Entity Too Large", - }, + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c *echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) } - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - h := func(c echo.Context) error { - body, err := io.ReadAll(c.Request().Body) - if err != nil { - return err - } - return c.String(http.StatusOK, string(body)) - } - mw := BodyLimitWithConfig(BodyLimitConfig{ - Limit: tc.givenLimit, - }) - - req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(tc.whenBody)) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - err := mw(h)(c) - if tc.expectError != "" { - assert.EqualError(t, err, tc.expectError) - } else { - assert.NoError(t, err) - } - // not testing status as middlewares return error instead of committing it and OK cases are anyway 200 - assert.Equal(t, tc.expectBody, rec.Body.Bytes()) - }) - } + mw := BodyLimitWithConfig(BodyLimitConfig{LimitBytes: 2 * MB}) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) } -func TestBodyLimit_panicOnInvalidLimit(t *testing.T) { - assert.PanicsWithError( - t, - "echo: invalid body-limit=", - func() { BodyLimit("") }, - ) +func TestBodyLimit(t *testing.T) { + e := echo.New() + hw := []byte("Hello, World!") + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(hw)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := func(c *echo.Context) error { + body, err := io.ReadAll(c.Request().Body) + if err != nil { + return err + } + return c.String(http.StatusOK, string(body)) + } + + mw := BodyLimit(2 * MB) + + err := mw(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, hw, rec.Body.Bytes()) } diff --git a/middleware/compress.go b/middleware/compress.go index 48ccc9856..7754d5db8 100644 --- a/middleware/compress.go +++ b/middleware/compress.go @@ -7,13 +7,18 @@ import ( "bufio" "bytes" "compress/gzip" + "errors" "io" "net" "net/http" "strings" "sync" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" +) + +const ( + gzipScheme = "gzip" ) // GzipConfig defines the config for Gzip middleware. @@ -23,7 +28,7 @@ type GzipConfig struct { // Gzip compression level. // Optional. Default value -1. - Level int `yaml:"level"` + Level int // Length threshold before gzip compression is applied. // Optional. Default value 0. @@ -50,42 +55,36 @@ type gzipResponseWriter struct { code int } -const ( - gzipScheme = "gzip" -) - -// DefaultGzipConfig is the default Gzip middleware config. -var DefaultGzipConfig = GzipConfig{ - Skipper: DefaultSkipper, - Level: -1, - MinLength: 0, -} - -// Gzip returns a middleware which compresses HTTP response using gzip compression -// scheme. +// Gzip returns a middleware which compresses HTTP response using gzip compression scheme. func Gzip() echo.MiddlewareFunc { - return GzipWithConfig(DefaultGzipConfig) + return GzipWithConfig(GzipConfig{}) } -// GzipWithConfig return Gzip middleware with config. -// See: `Gzip()`. +// GzipWithConfig returns a middleware which compresses HTTP response using gzip compression scheme. func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts GzipConfig to middleware or returns an error for invalid configuration +func (config GzipConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Level < -2 || config.Level > 9 { // these are consts: gzip.HuffmanOnly and gzip.BestCompression + return nil, errors.New("invalid gzip level") } if config.Level == 0 { - config.Level = DefaultGzipConfig.Level + config.Level = -1 } if config.MinLength < 0 { - config.MinLength = DefaultGzipConfig.MinLength + config.MinLength = 0 } pool := gzipCompressPool(config) bpool := bufferPool() return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -98,13 +97,18 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { if !ok { return echo.NewHTTPError(http.StatusInternalServerError, "invalid pool object") } - rw := res.Writer + rw := res w.Reset(rw) - buf := bpool.Get().(*bytes.Buffer) buf.Reset() - grw := &gzipResponseWriter{Writer: w, ResponseWriter: rw, minLength: config.MinLength, buffer: buf} + grw := &gzipResponseWriter{ + Writer: w, + ResponseWriter: rw, + minLength: config.MinLength, + buffer: buf, + } + c.SetResponse(grw) defer func() { // There are different reasons for cases when we have not yet written response to the client and now need to do so. // a) handler response had only response code and no response body (ala 404 or redirects etc). Response code need to be written now. @@ -119,26 +123,25 @@ func GzipWithConfig(config GzipConfig) echo.MiddlewareFunc { // We have to reset response to it's pristine state when // nothing is written to body or error is returned. // See issue #424, #407. - res.Writer = rw + c.SetResponse(rw) w.Reset(io.Discard) } else if !grw.minLengthExceeded { // Write uncompressed response - res.Writer = rw + c.SetResponse(rw) if grw.wroteHeader { grw.ResponseWriter.WriteHeader(grw.code) } - grw.buffer.WriteTo(rw) + _, _ = grw.buffer.WriteTo(rw) w.Reset(io.Discard) } - w.Close() + _ = w.Close() bpool.Put(buf) pool.Put(w) }() - res.Writer = grw } return next(c) } - } + }, nil } func (w *gzipResponseWriter) WriteHeader(code int) { @@ -186,7 +189,7 @@ func (w *gzipResponseWriter) Flush() { w.ResponseWriter.WriteHeader(w.code) } - w.Writer.Write(w.buffer.Bytes()) + _, _ = w.Writer.Write(w.buffer.Bytes()) } if gw, ok := w.Writer.(*gzip.Writer); ok { @@ -195,14 +198,14 @@ func (w *gzipResponseWriter) Flush() { _ = http.NewResponseController(w.ResponseWriter).Flush() } -func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { - return w.ResponseWriter -} - func (w *gzipResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { return http.NewResponseController(w.ResponseWriter).Hijack() } +func (w *gzipResponseWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} + func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { if p, ok := w.ResponseWriter.(http.Pusher); ok { return p.Push(target, opts) @@ -212,7 +215,7 @@ func (w *gzipResponseWriter) Push(target string, opts *http.PushOptions) error { func gzipCompressPool(config GzipConfig) sync.Pool { return sync.Pool{ - New: func() interface{} { + New: func() any { w, err := gzip.NewWriterLevel(io.Discard, config.Level) if err != nil { return err @@ -224,7 +227,7 @@ func gzipCompressPool(config GzipConfig) sync.Pool { func bufferPool() sync.Pool { return sync.Pool{ - New: func() interface{} { + New: func() any { b := &bytes.Buffer{} return b }, diff --git a/middleware/compress_test.go b/middleware/compress_test.go index c9083ee28..084ffc9c7 100644 --- a/middleware/compress_test.go +++ b/middleware/compress_test.go @@ -11,91 +11,216 @@ import ( "net/http/httptest" "os" "testing" + "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func TestGzip(t *testing.T) { +func TestGzip_NoAcceptEncodingHeader(t *testing.T) { + // Skip if no Accept-Encoding header + h := Gzip()(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - // Skip if no Accept-Encoding header - h := Gzip()(func(c echo.Context) error { + err := h(c) + assert.NoError(t, err) + + assert.Equal(t, "test", rec.Body.String()) +} + +func TestMustGzipWithConfig_panics(t *testing.T) { + assert.Panics(t, func() { + GzipWithConfig(GzipConfig{Level: 999}) + }) +} + +func TestGzip_AcceptEncodingHeader(t *testing.T) { + h := Gzip()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - - assert.Equal(t, "test", rec.Body.String()) - // Gzip - req = httptest.NewRequest(http.MethodGet, "/", nil) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) + + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) assert.Contains(t, rec.Header().Get(echo.HeaderContentType), echo.MIMETextPlain) - r, err := gzip.NewReader(rec.Body) - if assert.NoError(t, err) { - buf := new(bytes.Buffer) - defer r.Close() - buf.ReadFrom(r) - assert.Equal(t, "test", buf.String()) - } - chunkBuf := make([]byte, 5) + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) + buf := new(bytes.Buffer) + defer r.Close() + buf.ReadFrom(r) + assert.Equal(t, "test", buf.String()) +} - // Gzip chunked - req = httptest.NewRequest(http.MethodGet, "/", nil) +func TestGzip_chunked(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec = httptest.NewRecorder() + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - c = e.NewContext(req, rec) - Gzip()(func(c echo.Context) error { + chunkChan := make(chan struct{}) + waitChan := make(chan struct{}) + h := Gzip()(func(c *echo.Context) error { + rc := http.NewResponseController(c.Response()) c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data - c.Response().Write([]byte("test\n")) - c.Response().Flush() - - // Read the first part of the data - assert.True(t, rec.Flushed) - assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - r.Reset(rec.Body) + c.Response().Write([]byte("first\n")) + rc.Flush() - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(t, err) - assert.Equal(t, "test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write and flush the second part of the data - c.Response().Write([]byte("test\n")) - c.Response().Flush() + c.Response().Write([]byte("second\n")) + rc.Flush() - _, err = io.ReadFull(r, chunkBuf) - assert.NoError(t, err) - assert.Equal(t, "test\n", string(chunkBuf)) + chunkChan <- struct{}{} + <-waitChan // Write the final part of the data and return - c.Response().Write([]byte("test")) + c.Response().Write([]byte("third")) + + chunkChan <- struct{}{} return nil - })(c) + }) + + go func() { + err := h(c) + chunkChan <- struct{}{} + assert.NoError(t, err) + }() + <-chunkChan // wait for first write + waitChan <- struct{}{} + + <-chunkChan // wait for second write + waitChan <- struct{}{} + + <-chunkChan // wait for final write in handler + <-chunkChan // wait for return from handler + time.Sleep(5 * time.Millisecond) // to have time for flushing + + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + + r, err := gzip.NewReader(rec.Body) + assert.NoError(t, err) buf := new(bytes.Buffer) - defer r.Close() buf.ReadFrom(r) - assert.Equal(t, "test", buf.String()) + assert.Equal(t, "first\nsecond\nthird", buf.String()) +} + +func TestGzip_NoContent(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Gzip()(func(c *echo.Context) error { + return c.NoContent(http.StatusNoContent) + }) + if assert.NoError(t, h(c)) { + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) + assert.Equal(t, 0, len(rec.Body.Bytes())) + } +} + +func TestGzip_Empty(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + h := Gzip()(func(c *echo.Context) error { + return c.String(http.StatusOK, "") + }) + if assert.NoError(t, h(c)) { + assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) + assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + var buf bytes.Buffer + buf.ReadFrom(r) + assert.Equal(t, "", buf.String()) + } + } +} + +func TestGzip_ErrorReturned(t *testing.T) { + e := echo.New() + e.Use(Gzip()) + e.GET("/", func(c *echo.Context) error { + return echo.ErrNotFound + }) + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) + assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) +} + +func TestGzipWithConfig_invalidLevel(t *testing.T) { + mw, err := GzipConfig{Level: 12}.ToMiddleware() + assert.EqualError(t, err, "invalid gzip level") + assert.Nil(t, mw) +} + +// Issue #806 +func TestGzipWithStatic(t *testing.T) { + e := echo.New() + e.Filesystem = os.DirFS("../") + + e.Use(Gzip()) + e.Static("/test", "_fixture/images") + req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) + req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + // Data is written out in chunks when Content-Length == "", so only + // validate the content length if it's not set. + if cl := rec.Header().Get("Content-Length"); cl != "" { + assert.Equal(t, cl, rec.Body.Len()) + } + r, err := gzip.NewReader(rec.Body) + if assert.NoError(t, err) { + defer r.Close() + want, err := os.ReadFile("../_fixture/images/walle.png") + if assert.NoError(t, err) { + buf := new(bytes.Buffer) + buf.ReadFrom(r) + assert.Equal(t, want, buf.Bytes()) + } + } } func TestGzipWithMinLength(t *testing.T) { e := echo.New() // Minimal response length e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { c.Response().Write([]byte("foobarfoobar")) return nil }) @@ -118,7 +243,7 @@ func TestGzipWithMinLengthTooShort(t *testing.T) { e := echo.New() // Minimal response length e.Use(GzipWithConfig(GzipConfig{MinLength: 10})) - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { c.Response().Write([]byte("test")) return nil }) @@ -134,7 +259,7 @@ func TestGzipWithResponseWithoutBody(t *testing.T) { e := echo.New() e.Use(Gzip()) - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { return c.Redirect(http.StatusMovedPermanently, "http://localhost") }) @@ -161,13 +286,14 @@ func TestGzipWithMinLengthChunked(t *testing.T) { var r *gzip.Reader = nil c := e.NewContext(req, rec) - GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { + next := func(c *echo.Context) error { + rc := http.NewResponseController(c.Response()) c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Transfer-Encoding", "chunked") // Write and flush the first part of the data c.Response().Write([]byte("test\n")) - c.Response().Flush() + rc.Flush() // Read the first part of the data assert.True(t, rec.Flushed) @@ -183,7 +309,7 @@ func TestGzipWithMinLengthChunked(t *testing.T) { // Write and flush the second part of the data c.Response().Write([]byte("test\n")) - c.Response().Flush() + rc.Flush() _, err = io.ReadFull(r, chunkBuf) assert.NoError(t, err) @@ -192,8 +318,10 @@ func TestGzipWithMinLengthChunked(t *testing.T) { // Write the final part of the data and return c.Response().Write([]byte("test")) return nil - })(c) + } + err := GzipWithConfig(GzipConfig{MinLength: 10})(next)(c) + assert.NoError(t, err) assert.NotNil(t, r) buf := new(bytes.Buffer) @@ -210,7 +338,7 @@ func TestGzipWithMinLengthNoContent(t *testing.T) { req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c echo.Context) error { + h := GzipWithConfig(GzipConfig{MinLength: 10})(func(c *echo.Context) error { return c.NoContent(http.StatusNoContent) }) if assert.NoError(t, h(c)) { @@ -220,106 +348,11 @@ func TestGzipWithMinLengthNoContent(t *testing.T) { } } -func TestGzipNoContent(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := Gzip()(func(c echo.Context) error { - return c.NoContent(http.StatusNoContent) - }) - if assert.NoError(t, h(c)) { - assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) - assert.Equal(t, 0, len(rec.Body.Bytes())) - } -} - -func TestGzipEmpty(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - h := Gzip()(func(c echo.Context) error { - return c.String(http.StatusOK, "") - }) - if assert.NoError(t, h(c)) { - assert.Equal(t, gzipScheme, rec.Header().Get(echo.HeaderContentEncoding)) - assert.Equal(t, "text/plain; charset=UTF-8", rec.Header().Get(echo.HeaderContentType)) - r, err := gzip.NewReader(rec.Body) - if assert.NoError(t, err) { - var buf bytes.Buffer - buf.ReadFrom(r) - assert.Equal(t, "", buf.String()) - } - } -} - -func TestGzipErrorReturned(t *testing.T) { - e := echo.New() - e.Use(Gzip()) - e.GET("/", func(c echo.Context) error { - return echo.ErrNotFound - }) - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusNotFound, rec.Code) - assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) -} - -func TestGzipErrorReturnedInvalidConfig(t *testing.T) { - e := echo.New() - // Invalid level - e.Use(GzipWithConfig(GzipConfig{Level: 12})) - e.GET("/", func(c echo.Context) error { - c.Response().Write([]byte("test")) - return nil - }) - req := httptest.NewRequest(http.MethodGet, "/", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, rec.Body.String(), `{"message":"invalid pool object"}`) -} - -// Issue #806 -func TestGzipWithStatic(t *testing.T) { - e := echo.New() - e.Use(Gzip()) - e.Static("/test", "../_fixture/images") - req := httptest.NewRequest(http.MethodGet, "/test/walle.png", nil) - req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - // Data is written out in chunks when Content-Length == "", so only - // validate the content length if it's not set. - if cl := rec.Header().Get("Content-Length"); cl != "" { - assert.Equal(t, cl, rec.Body.Len()) - } - r, err := gzip.NewReader(rec.Body) - if assert.NoError(t, err) { - defer r.Close() - want, err := os.ReadFile("../_fixture/images/walle.png") - if assert.NoError(t, err) { - buf := new(bytes.Buffer) - buf.ReadFrom(r) - assert.Equal(t, want, buf.Bytes()) - } - } -} - func TestGzipResponseWriter_CanUnwrap(t *testing.T) { trwu := &testResponseWriterUnwrapper{rw: httptest.NewRecorder()} bdrw := gzipResponseWriter{ ResponseWriter: trwu, } - result := bdrw.Unwrap() assert.Equal(t, trwu, result) } @@ -329,7 +362,6 @@ func TestGzipResponseWriter_CanHijack(t *testing.T) { bdrw := gzipResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } - _, _, err := bdrw.Hijack() assert.EqualError(t, err, "can hijack") } @@ -339,7 +371,6 @@ func TestGzipResponseWriter_CanNotHijack(t *testing.T) { bdrw := gzipResponseWriter{ ResponseWriter: &trwu, // this RW supports hijacking through unwrapping } - _, _, err := bdrw.Hijack() assert.EqualError(t, err, "feature not supported") } @@ -350,7 +381,7 @@ func BenchmarkGzip(b *testing.B) { req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderAcceptEncoding, gzipScheme) - h := Gzip()(func(c echo.Context) error { + h := Gzip()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) diff --git a/middleware/context_timeout.go b/middleware/context_timeout.go index 5d9ae9755..68465199a 100644 --- a/middleware/context_timeout.go +++ b/middleware/context_timeout.go @@ -8,51 +8,18 @@ import ( "errors" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -// ContextTimeout Middleware -// -// ContextTimeout provides request timeout functionality using Go's context mechanism. -// It is the recommended replacement for the deprecated Timeout middleware. -// -// -// Basic Usage: -// -// e.Use(middleware.ContextTimeout(30 * time.Second)) -// -// With Configuration: -// -// e.Use(middleware.ContextTimeoutWithConfig(middleware.ContextTimeoutConfig{ -// Timeout: 30 * time.Second, -// Skipper: middleware.DefaultSkipper, -// })) -// -// Handler Example: -// -// e.GET("/task", func(c echo.Context) error { -// ctx := c.Request().Context() -// -// result, err := performTaskWithContext(ctx) -// if err != nil { -// if errors.Is(err, context.DeadlineExceeded) { -// return echo.NewHTTPError(http.StatusServiceUnavailable, "timeout") -// } -// return err -// } -// -// return c.JSON(http.StatusOK, result) -// }) - // ContextTimeoutConfig defines the config for ContextTimeout middleware. type ContextTimeoutConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // ErrorHandler is a function when error arises in middleware execution. - ErrorHandler func(err error, c echo.Context) error + // ErrorHandler is a function when error arises in middeware execution. + ErrorHandler func(c *echo.Context, err error) error - // Timeout configures a timeout for the middleware, defaults to 0 for no timeout + // Timeout configures a timeout for the middleware Timeout time.Duration } @@ -64,11 +31,7 @@ func ContextTimeout(timeout time.Duration) echo.MiddlewareFunc { // ContextTimeoutWithConfig returns a Timeout middleware with config. func ContextTimeoutWithConfig(config ContextTimeoutConfig) echo.MiddlewareFunc { - mw, err := config.ToMiddleware() - if err != nil { - panic(err) - } - return mw + return toMiddlewareOrPanic(config) } // ToMiddleware converts Config to middleware. @@ -80,16 +43,16 @@ func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.Skipper = DefaultSkipper } if config.ErrorHandler == nil { - config.ErrorHandler = func(err error, c echo.Context) error { + config.ErrorHandler = func(c *echo.Context, err error) error { if err != nil && errors.Is(err, context.DeadlineExceeded) { - return echo.ErrServiceUnavailable.WithInternal(err) + return echo.ErrServiceUnavailable.Wrap(err) } return err } } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -100,7 +63,7 @@ func (config ContextTimeoutConfig) ToMiddleware() (echo.MiddlewareFunc, error) { c.SetRequest(c.Request().WithContext(timeoutContext)) if err := next(c); err != nil { - return config.ErrorHandler(err, c) + return config.ErrorHandler(c, err) } return nil } diff --git a/middleware/context_timeout_test.go b/middleware/context_timeout_test.go index e69bcd268..c7ba76beb 100644 --- a/middleware/context_timeout_test.go +++ b/middleware/context_timeout_test.go @@ -6,6 +6,7 @@ package middleware import ( "context" "errors" + "github.com/labstack/echo/v5" "net/http" "net/http/httptest" "net/url" @@ -13,14 +14,13 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" "github.com/stretchr/testify/assert" ) func TestContextTimeoutSkipper(t *testing.T) { t.Parallel() m := ContextTimeoutWithConfig(ContextTimeoutConfig{ - Skipper: func(context echo.Context) bool { + Skipper: func(context *echo.Context) bool { return true }, Timeout: 10 * time.Millisecond, @@ -32,7 +32,7 @@ func TestContextTimeoutSkipper(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { + err := m(func(c *echo.Context) error { if err := sleepWithContext(c.Request().Context(), time.Duration(20*time.Millisecond)); err != nil { return err } @@ -65,7 +65,7 @@ func TestContextTimeoutErrorOutInHandler(t *testing.T) { c := e.NewContext(req, rec) rec.Code = 1 // we want to be sure that even 200 will not be sent - err := m(func(c echo.Context) error { + err := m(func(c *echo.Context) error { // this error must not be written to the client response. Middlewares upstream of timeout middleware must be able // to handle returned error and this can be done only then handler has not yet committed (written status code) // the response. @@ -91,7 +91,7 @@ func TestContextTimeoutSuccessfulRequest(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { + err := m(func(c *echo.Context) error { return c.JSON(http.StatusCreated, map[string]string{"data": "ok"}) })(c) @@ -115,7 +115,7 @@ func TestContextTimeoutTestRequestClone(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { + err := m(func(c *echo.Context) error { // Cookie test cookie, err := c.Request().Cookie("cookie") if assert.NoError(t, err) { @@ -150,23 +150,24 @@ func TestContextTimeoutWithDefaultErrorMessage(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { + err := m(func(c *echo.Context) error { if err := sleepWithContext(c.Request().Context(), time.Duration(80*time.Millisecond)); err != nil { return err } return c.String(http.StatusOK, "Hello, World!") })(c) - assert.IsType(t, &echo.HTTPError{}, err) assert.Error(t, err) - assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) - assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) + if assert.IsType(t, &echo.HTTPError{}, err) { + assert.Equal(t, http.StatusServiceUnavailable, err.(*echo.HTTPError).Code) + assert.Equal(t, "Service Unavailable", err.(*echo.HTTPError).Message) + } } func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { t.Parallel() - timeoutErrorHandler := func(err error, c echo.Context) error { + timeoutErrorHandler := func(c *echo.Context, err error) error { if err != nil { if errors.Is(err, context.DeadlineExceeded) { return &echo.HTTPError{ @@ -191,7 +192,7 @@ func TestContextTimeoutCanHandleContextDeadlineOnNextHandler(t *testing.T) { e := echo.New() c := e.NewContext(req, rec) - err := m(func(c echo.Context) error { + err := m(func(c *echo.Context) error { // NOTE: Very short periods are not reliable for tests due to Go routine scheduling and the unpredictable order // for 1) request and 2) time goroutine. For most OS this works as expected, but MacOS seems most flaky. diff --git a/middleware/cors.go b/middleware/cors.go index a1f445321..96ed16985 100644 --- a/middleware/cors.go +++ b/middleware/cors.go @@ -4,12 +4,13 @@ package middleware import ( + "errors" + "fmt" "net/http" - "regexp" "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // CORSConfig defines the config for CORS middleware. @@ -19,29 +20,41 @@ type CORSConfig struct { // AllowOrigins determines the value of the Access-Control-Allow-Origin // response header. This header defines a list of origins that may access the - // resource. The wildcard characters '*' and '?' are supported and are - // converted to regex fragments '.*' and '.' accordingly. + // resource. + // + // Origin consist of following parts: `scheme + "://" + host + optional ":" + port` + // Wildcard can be used, but has to be set explicitly []string{"*"} + // Example: `https://example.com`, `http://example.com:8080`, `*` // // Security: use extreme caution when handling the origin, and carefully // validate any logic. Remember that attackers may register hostile domain names. // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html - // - // Optional. Default value []string{"*"}. - // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Origin - AllowOrigins []string `yaml:"allow_origins"` - - // AllowOriginFunc is a custom function to validate the origin. It takes the - // origin as an argument and returns true if allowed or false otherwise. If - // an error is returned, it is returned by the handler. If this option is - // set, AllowOrigins is ignored. + // + // Mandatory. + AllowOrigins []string + + // UnsafeAllowOriginFunc is an optional custom function to validate the origin. It takes the + // origin as an argument and returns + // - string, allowed origin + // - bool, true if allowed or false otherwise. + // - error, if an error is returned, it is returned immediately by the handler. + // If this option is set, AllowOrigins is ignored. // // Security: use extreme caution when handling the origin, and carefully - // validate any logic. Remember that attackers may register hostile domain names. + // validate any logic. Remember that attackers may register hostile (sub)domain names. // See https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // + // Sub-domain checks example: + // UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { + // if strings.HasSuffix(origin, ".example.com") { + // return origin, true, nil + // } + // return "", false, nil + // }, + // // Optional. - AllowOriginFunc func(origin string) (bool, error) `yaml:"-"` + UnsafeAllowOriginFunc func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) // AllowMethods determines the value of the Access-Control-Allow-Methods // response header. This header specified the list of methods allowed when @@ -53,16 +66,16 @@ type CORSConfig struct { // from `Allow` header that echo.Router set into context. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods - AllowMethods []string `yaml:"allow_methods"` + AllowMethods []string // AllowHeaders determines the value of the Access-Control-Allow-Headers // response header. This header is used in response to a preflight request to // indicate which HTTP headers can be used when making the actual request. // - // Optional. Default value []string{}. + // Optional. Defaults to empty list. No domains allowed for CORS. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers - AllowHeaders []string `yaml:"allow_headers"` + AllowHeaders []string // AllowCredentials determines the value of the // Access-Control-Allow-Credentials response header. This header indicates @@ -79,16 +92,7 @@ type CORSConfig struct { // https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Credentials - AllowCredentials bool `yaml:"allow_credentials"` - - // UnsafeWildcardOriginWithAllowCredentials UNSAFE/INSECURE: allows wildcard '*' origin to be used with AllowCredentials - // flag. In that case we consider any origin allowed and send it back to the client with `Access-Control-Allow-Origin` header. - // - // This is INSECURE and potentially leads to [cross-origin](https://portswigger.net/research/exploiting-cors-misconfigurations-for-bitcoins-and-bounties) - // attacks. See: https://github.com/labstack/echo/issues/2400 for discussion on the subject. - // - // Optional. Default value is false. - UnsafeWildcardOriginWithAllowCredentials bool `yaml:"unsafe_wildcard_origin_with_allow_credentials"` + AllowCredentials bool // ExposeHeaders determines the value of Access-Control-Expose-Headers, which // defines a list of headers that clients are allowed to access. @@ -96,7 +100,7 @@ type CORSConfig struct { // Optional. Default value []string{}, in which case the header is not set. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Expose-Header - ExposeHeaders []string `yaml:"expose_headers"` + ExposeHeaders []string // MaxAge determines the value of the Access-Control-Max-Age response header. // This header indicates how long (in seconds) the results of a preflight @@ -106,19 +110,16 @@ type CORSConfig struct { // Optional. Default value 0 - meaning header is not sent. // // See also: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Max-Age - MaxAge int `yaml:"max_age"` -} - -// DefaultCORSConfig is the default CORS middleware config. -var DefaultCORSConfig = CORSConfig{ - Skipper: DefaultSkipper, - AllowOrigins: []string{"*"}, - AllowMethods: []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete}, + MaxAge int } // CORS returns a Cross-Origin Resource Sharing (CORS) middleware. // See also [MDN: Cross-Origin Resource Sharing (CORS)]. // +// Origin consist of following parts: `scheme + "://" + host + optional ":" + port` +// Wildcard `*` can be used, but has to be set explicitly. +// Example: `https://example.com`, `http://example.com:8080`, `*` +// // Security: Poorly configured CORS can compromise security because it allows // relaxation of the browser's Same-Origin policy. See [Exploiting CORS // misconfigurations for Bitcoins and bounties] and [Portswigger: Cross-origin @@ -127,45 +128,29 @@ var DefaultCORSConfig = CORSConfig{ // [MDN: Cross-Origin Resource Sharing (CORS)]: https://developer.mozilla.org/en/docs/Web/HTTP/Access_control_CORS // [Exploiting CORS misconfigurations for Bitcoins and bounties]: https://blog.portswigger.net/2016/10/exploiting-cors-misconfigurations-for.html // [Portswigger: Cross-origin resource sharing (CORS)]: https://portswigger.net/web-security/cors -func CORS() echo.MiddlewareFunc { - return CORSWithConfig(DefaultCORSConfig) +func CORS(allowOrigins ...string) echo.MiddlewareFunc { + c := CORSConfig{ + AllowOrigins: allowOrigins, + } + return CORSWithConfig(c) } -// CORSWithConfig returns a CORS middleware with config. +// CORSWithConfig returns a CORS middleware with config or panics on invalid configuration. // See: [CORS]. func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CORSConfig to middleware or returns an error for invalid configuration +func (config CORSConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { - config.Skipper = DefaultCORSConfig.Skipper - } - if len(config.AllowOrigins) == 0 { - config.AllowOrigins = DefaultCORSConfig.AllowOrigins + config.Skipper = DefaultSkipper } hasCustomAllowMethods := true if len(config.AllowMethods) == 0 { hasCustomAllowMethods = false - config.AllowMethods = DefaultCORSConfig.AllowMethods - } - - allowOriginPatterns := make([]*regexp.Regexp, 0, len(config.AllowOrigins)) - for _, origin := range config.AllowOrigins { - if origin == "*" { - continue // "*" is handled differently and does not need regexp - } - pattern := regexp.QuoteMeta(origin) - pattern = strings.ReplaceAll(pattern, "\\*", ".*") - pattern = strings.ReplaceAll(pattern, "\\?", ".") - pattern = "^" + pattern + "$" - - re, err := regexp.Compile(pattern) - if err != nil { - // this is to preserve previous behaviour - invalid patterns were just ignored. - // If we would turn this to panic, users with invalid patterns - // would have applications crashing in production due unrecovered panic. - // TODO: this should be turned to error/panic in `v5` - continue - } - allowOriginPatterns = append(allowOriginPatterns, re) + config.AllowMethods = []string{http.MethodGet, http.MethodHead, http.MethodPut, http.MethodPatch, http.MethodPost, http.MethodDelete} } allowMethods := strings.Join(config.AllowMethods, ",") @@ -177,8 +162,29 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { maxAge = strconv.Itoa(config.MaxAge) } + allowOriginFunc := config.UnsafeAllowOriginFunc + if config.UnsafeAllowOriginFunc == nil { + if len(config.AllowOrigins) == 0 { + return nil, errors.New("at least one AllowOrigins is required or UnsafeAllowOriginFunc must be provided") + } + allowOriginFunc = config.defaultAllowOriginFunc + for _, origin := range config.AllowOrigins { + if origin == "*" { + if config.AllowCredentials { + return nil, fmt.Errorf("* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc") + } + allowOriginFunc = config.starAllowOriginFunc + break + } + if err := validateOrigin(origin, "allow origin"); err != nil { + return nil, err + } + } + config.AllowOrigins = append([]string(nil), config.AllowOrigins...) + } + return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -186,7 +192,6 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { req := c.Request() res := c.Response() origin := req.Header.Get(echo.HeaderOrigin) - allowOrigin := "" res.Header().Add(echo.HeaderVary, echo.HeaderOrigin) @@ -211,76 +216,51 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { // No Origin provided. This is (probably) not request from actual browser - proceed executing middleware chain if origin == "" { - if !preflight { - return next(c) + if preflight { // req.Method=OPTIONS + return c.NoContent(http.StatusNoContent) } - return c.NoContent(http.StatusNoContent) + return next(c) // let non-browser calls through } - if config.AllowOriginFunc != nil { - allowed, err := config.AllowOriginFunc(origin) - if err != nil { - return err - } - if allowed { - allowOrigin = origin - } - } else { - // Check allowed origins - for _, o := range config.AllowOrigins { - if o == "*" && config.AllowCredentials && config.UnsafeWildcardOriginWithAllowCredentials { - allowOrigin = origin - break - } - if o == "*" || o == origin { - allowOrigin = o - break - } - if matchSubdomain(origin, o) { - allowOrigin = origin - break - } - } - - checkPatterns := false - if allowOrigin == "" { - // to avoid regex cost by invalid (long) domains (253 is domain name max limit) - if len(origin) <= (253+3+5) && strings.Contains(origin, "://") { - checkPatterns = true - } - } - if checkPatterns { - for _, re := range allowOriginPatterns { - if match := re.MatchString(origin); match { - allowOrigin = origin - break - } - } - } + allowedOrigin, allowed, err := allowOriginFunc(c, origin) + if err != nil { + return err } - - // Origin not allowed - if allowOrigin == "" { - if !preflight { - return next(c) + if !allowed { + // Origin existed and was NOT allowed + if preflight { + // From: https://github.com/labstack/echo/issues/2767 + // If the request's origin isn't allowed by the CORS configuration, + // the middleware should simply omit the relevant CORS headers from the response + // and let the browser fail the CORS check (if any). + return c.NoContent(http.StatusNoContent) } - return c.NoContent(http.StatusNoContent) + // From: https://github.com/labstack/echo/issues/2767 + // no CORS middleware should block non-preflight requests; + // such requests should be let through. One reason is that not all requests that + // carry an Origin header participate in the CORS protocol. + return next(c) } - res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowOrigin) + // Origin existed and was allowed + + res.Header().Set(echo.HeaderAccessControlAllowOrigin, allowedOrigin) if config.AllowCredentials { res.Header().Set(echo.HeaderAccessControlAllowCredentials, "true") } - // Simple request + // Simple request will be let though if !preflight { if exposeHeaders != "" { res.Header().Set(echo.HeaderAccessControlExposeHeaders, exposeHeaders) } return next(c) } - - // Preflight request + // Below code is for Preflight (OPTIONS) request + // + // Preflight will end with c.NoContent(http.StatusNoContent) as we do not know if + // at the end of handler chain is actual OPTIONS route or 404/405 route which + // response code will confuse browsers res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestMethod) res.Header().Add(echo.HeaderVary, echo.HeaderAccessControlRequestHeaders) @@ -303,5 +283,18 @@ func CORSWithConfig(config CORSConfig) echo.MiddlewareFunc { } return c.NoContent(http.StatusNoContent) } + }, nil +} + +func (config CORSConfig) starAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) { + return "*", true, nil +} + +func (config CORSConfig) defaultAllowOriginFunc(c *echo.Context, origin string) (string, bool, error) { + for _, allowedOrigin := range config.AllowOrigins { + if strings.EqualFold(allowedOrigin, origin) { + return allowedOrigin, true, nil + } } + return "", false, nil } diff --git a/middleware/cors_test.go b/middleware/cors_test.go index 5461e9362..5de4ca063 100644 --- a/middleware/cors_test.go +++ b/middleware/cors_test.go @@ -4,72 +4,87 @@ package middleware import ( + "cmp" "errors" "net/http" "net/http/httptest" + "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestCORS(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodOptions, "/", nil) // Preflight request + req.Header.Set(echo.HeaderOrigin, "http://example.com") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + mw := CORS("*") + handler := mw(func(c *echo.Context) error { + return nil + }) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusNoContent, rec.Code) + assert.Equal(t, "*", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) +} + +func TestCORSConfig(t *testing.T) { var testCases = []struct { name string - givenMW echo.MiddlewareFunc + givenConfig *CORSConfig whenMethod string whenHeaders map[string]string expectHeaders map[string]string notExpectHeaders map[string]string + expectErr string }{ { - name: "ok, wildcard origin", + name: "ok, wildcard origin", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + }, whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "*"}, }, { - name: "ok, wildcard AllowedOrigin with no Origin header in request", + name: "ok, wildcard AllowedOrigin with no Origin header in request", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + }, notExpectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: ""}, }, - { - name: "ok, invalid pattern is ignored", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{ - "\xff", // Invalid UTF-8 makes regexp.Compile to error - "*.example.com", - }, - }), - whenMethod: http.MethodOptions, - whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, - expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, - }, { name: "ok, specific AllowOrigins and AllowCredentials", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost", "http://localhost:8080"}, AllowCredentials: true, MaxAge: 3600, - }), - whenHeaders: map[string]string{echo.HeaderOrigin: "localhost"}, + }, + whenHeaders: map[string]string{echo.HeaderOrigin: "http://localhost"}, expectHeaders: map[string]string{ - echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowOrigin: "http://localhost", echo.HeaderAccessControlAllowCredentials: "true", }, }, { name: "ok, preflight request with matching origin for `AllowOrigins`", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, MaxAge: 3600, - }), + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ - echo.HeaderOrigin: "localhost", + echo.HeaderOrigin: "http://localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ - echo.HeaderAccessControlAllowOrigin: "localhost", + echo.HeaderAccessControlAllowOrigin: "http://localhost", echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", echo.HeaderAccessControlAllowCredentials: "true", echo.HeaderAccessControlMaxAge: "3600", @@ -77,14 +92,14 @@ func TestCORS(t *testing.T) { }, { name: "ok, preflight request when `Access-Control-Max-Age` is set", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, MaxAge: 1, - }), + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ - echo.HeaderOrigin: "localhost", + echo.HeaderOrigin: "http://localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ @@ -93,14 +108,14 @@ func TestCORS(t *testing.T) { }, { name: "ok, preflight request when `Access-Control-Max-Age` is set to 0 - not to cache response", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, MaxAge: -1, // forces `Access-Control-Max-Age: 0` - }), + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ - echo.HeaderOrigin: "localhost", + echo.HeaderOrigin: "http://localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, expectHeaders: map[string]string{ @@ -109,16 +124,16 @@ func TestCORS(t *testing.T) { }, { name: "ok, CORS check are skipped", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"localhost"}, + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://localhost"}, AllowCredentials: true, - Skipper: func(c echo.Context) bool { + Skipper: func(c *echo.Context) bool { return true }, - }), + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ - echo.HeaderOrigin: "localhost", + echo.HeaderOrigin: "http://localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, notExpectHeaders: map[string]string{ @@ -129,31 +144,33 @@ func TestCORS(t *testing.T) { }, }, { - name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", - givenMW: CORSWithConfig(CORSConfig{ + name: "nok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", + givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, AllowCredentials: true, MaxAge: 3600, - }), + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, - expectHeaders: map[string]string{ - echo.HeaderAccessControlAllowOrigin: "*", // Note: browsers will ignore and complain about responses having `*` - echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", - echo.HeaderAccessControlAllowCredentials: "true", - echo.HeaderAccessControlMaxAge: "3600", + expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`, + }, + { + name: "nok, preflight request with invalid `AllowOrigins` value", + givenConfig: &CORSConfig{ + AllowOrigins: []string{"http://server", "missing-scheme"}, }, + expectErr: `allow origin is missing scheme or host: missing-scheme`, }, { name: "ok, preflight request with wildcard `AllowOrigins` and `AllowCredentials` false", - givenMW: CORSWithConfig(CORSConfig{ + givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, AllowCredentials: false, // important for this testcase MaxAge: 3600, - }), + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", @@ -170,29 +187,23 @@ func TestCORS(t *testing.T) { }, { name: "ok, INSECURE preflight request with wildcard `AllowOrigins` and `AllowCredentials` true", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"*"}, - AllowCredentials: true, - UnsafeWildcardOriginWithAllowCredentials: true, // important for this testcase - MaxAge: 3600, - }), + givenConfig: &CORSConfig{ + AllowOrigins: []string{"*"}, + AllowCredentials: true, + MaxAge: 3600, + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", echo.HeaderContentType: echo.MIMEApplicationJSON, }, - expectHeaders: map[string]string{ - echo.HeaderAccessControlAllowOrigin: "localhost", // This could end up as cross-origin attack - echo.HeaderAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", - echo.HeaderAccessControlAllowCredentials: "true", - echo.HeaderAccessControlMaxAge: "3600", - }, + expectErr: `* as allowed origin and AllowCredentials=true is insecure and not allowed. Use custom UnsafeAllowOriginFunc`, }, { name: "ok, preflight request with Access-Control-Request-Headers", - givenMW: CORSWithConfig(CORSConfig{ + givenConfig: &CORSConfig{ AllowOrigins: []string{"*"}, - }), + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{ echo.HeaderOrigin: "localhost", @@ -207,18 +218,28 @@ func TestCORS(t *testing.T) { }, { name: "ok, preflight request with `AllowOrigins` which allow all subdomains aaa with *", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"http://*.example.com"}, - }), + givenConfig: &CORSConfig{ + UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (allowedOrigin string, allowed bool, err error) { + if strings.HasSuffix(origin, ".example.com") { + allowed = true + } + return origin, allowed, nil + }, + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{echo.HeaderOrigin: "http://aaa.example.com"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://aaa.example.com"}, }, { name: "ok, preflight request with `AllowOrigins` which allow all subdomains bbb with *", - givenMW: CORSWithConfig(CORSConfig{ - AllowOrigins: []string{"http://*.example.com"}, - }), + givenConfig: &CORSConfig{ + UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { + if strings.HasSuffix(origin, ".example.com") { + return origin, true, nil + } + return "", false, nil + }, + }, whenMethod: http.MethodOptions, whenHeaders: map[string]string{echo.HeaderOrigin: "http://bbb.example.com"}, expectHeaders: map[string]string{echo.HeaderAccessControlAllowOrigin: "http://bbb.example.com"}, @@ -228,18 +249,26 @@ func TestCORS(t *testing.T) { t.Run(tc.name, func(t *testing.T) { e := echo.New() - mw := CORS() - if tc.givenMW != nil { - mw = tc.givenMW + var mw echo.MiddlewareFunc + var err error + if tc.givenConfig != nil { + mw, err = tc.givenConfig.ToMiddleware() + } else { + mw, err = CORSConfig{}.ToMiddleware() + } + if err != nil { + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + return + } + t.Fatal(err) } - h := mw(func(c echo.Context) error { + + h := mw(func(c *echo.Context) error { return nil }) - method := http.MethodGet - if tc.whenMethod != "" { - method = tc.whenMethod - } + method := cmp.Or(tc.whenMethod, http.MethodGet) req := httptest.NewRequest(method, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) @@ -247,7 +276,7 @@ func TestCORS(t *testing.T) { req.Header.Set(k, v) } - err := h(c) + err = h(c) assert.NoError(t, err) header := rec.Header() @@ -301,98 +330,7 @@ func Test_allowOriginScheme(t *testing.T) { cors := CORSWithConfig(CORSConfig{ AllowOrigins: []string{tt.pattern}, }) - h := cors(echo.NotFoundHandler) - h(c) - - if tt.expected { - assert.Equal(t, tt.domain, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) - } else { - assert.NotContains(t, rec.Header(), echo.HeaderAccessControlAllowOrigin) - } - } -} - -func Test_allowOriginSubdomain(t *testing.T) { - tests := []struct { - domain, pattern string - expected bool - }{ - { - domain: "http://aaa.example.com", - pattern: "http://*.example.com", - expected: true, - }, - { - domain: "http://bbb.aaa.example.com", - pattern: "http://*.example.com", - expected: true, - }, - { - domain: "http://bbb.aaa.example.com", - pattern: "http://*.aaa.example.com", - expected: true, - }, - { - domain: "http://aaa.example.com:8080", - pattern: "http://*.example.com:8080", - expected: true, - }, - - { - domain: "http://fuga.hoge.com", - pattern: "http://*.example.com", - expected: false, - }, - { - domain: "http://ccc.bbb.example.com", - pattern: "http://*.aaa.example.com", - expected: false, - }, - { - domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ - .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ - .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\ - .1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, - pattern: "http://*.example.com", - expected: false, - }, - { - domain: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`, - pattern: "http://*.example.com", - expected: false, - }, - { - domain: "http://ccc.bbb.example.com", - pattern: "http://example.com", - expected: false, - }, - { - domain: "https://prod-preview--aaa.bbb.com", - pattern: "https://*--aaa.bbb.com", - expected: true, - }, - { - domain: "http://ccc.bbb.example.com", - pattern: "http://*.example.com", - expected: true, - }, - { - domain: "http://ccc.bbb.example.com", - pattern: "http://foo.[a-z]*.example.com", - expected: false, - }, - } - - e := echo.New() - for _, tt := range tests { - req := httptest.NewRequest(http.MethodOptions, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - req.Header.Set(echo.HeaderOrigin, tt.domain) - cors := CORSWithConfig(CORSConfig{ - AllowOrigins: []string{tt.pattern}, - }) - h := cors(echo.NotFoundHandler) + h := cors(func(c *echo.Context) error { return echo.ErrNotFound }) h(c) if tt.expected { @@ -405,50 +343,53 @@ func Test_allowOriginSubdomain(t *testing.T) { func TestCORSWithConfig_AllowMethods(t *testing.T) { var testCases = []struct { - name string - allowOrigins []string - allowContextKey string - - whenOrigin string - whenAllowMethods []string - + name string + givenAllowOrigins []string + givenAllowMethods []string + whenAllowContextKey string + whenOrigin string expectAllow string expectAccessControlAllowMethods string }{ { - name: "custom AllowMethods, preflight, no origin, sets only allow header from context key", - allowContextKey: "OPTIONS, GET", - whenAllowMethods: []string{http.MethodGet, http.MethodHead}, - whenOrigin: "", - expectAllow: "OPTIONS, GET", + name: "custom AllowMethods, preflight, no origin, sets only allow header from context key", + givenAllowOrigins: []string{"*"}, + givenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenAllowContextKey: "OPTIONS, GET", + whenOrigin: "", + expectAllow: "OPTIONS, GET", }, { - name: "default AllowMethods, preflight, no origin, no allow header in context key and in response", - allowContextKey: "", - whenAllowMethods: nil, - whenOrigin: "", - expectAllow: "", + name: "default AllowMethods, preflight, no origin, no allow header in context key and in response", + givenAllowOrigins: []string{"*"}, + givenAllowMethods: nil, + whenAllowContextKey: "", + whenOrigin: "", + expectAllow: "", }, { name: "custom AllowMethods, preflight, existing origin, sets both headers different values", - allowContextKey: "OPTIONS, GET", - whenAllowMethods: []string{http.MethodGet, http.MethodHead}, + givenAllowOrigins: []string{"*"}, + givenAllowMethods: []string{http.MethodGet, http.MethodHead}, + whenAllowContextKey: "OPTIONS, GET", whenOrigin: "http://google.com", expectAllow: "OPTIONS, GET", expectAccessControlAllowMethods: "GET,HEAD", }, { name: "default AllowMethods, preflight, existing origin, sets both headers", - allowContextKey: "OPTIONS, GET", - whenAllowMethods: nil, + givenAllowOrigins: []string{"*"}, + givenAllowMethods: nil, + whenAllowContextKey: "OPTIONS, GET", whenOrigin: "http://google.com", expectAllow: "OPTIONS, GET", expectAccessControlAllowMethods: "OPTIONS, GET", }, { name: "default AllowMethods, preflight, existing origin, no allows, sets only CORS allow methods", - allowContextKey: "", - whenAllowMethods: nil, + givenAllowOrigins: []string{"*"}, + givenAllowMethods: nil, + whenAllowContextKey: "", whenOrigin: "http://google.com", expectAllow: "", expectAccessControlAllowMethods: "GET,HEAD,PUT,PATCH,POST,DELETE", @@ -458,13 +399,13 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { e := echo.New() - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusOK, "OK") }) cors := CORSWithConfig(CORSConfig{ - AllowOrigins: tc.allowOrigins, - AllowMethods: tc.whenAllowMethods, + AllowOrigins: tc.givenAllowOrigins, + AllowMethods: tc.givenAllowMethods, }) req := httptest.NewRequest(http.MethodOptions, "/test", nil) @@ -472,11 +413,13 @@ func TestCORSWithConfig_AllowMethods(t *testing.T) { c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) - if tc.allowContextKey != "" { - c.Set(echo.ContextKeyHeaderAllow, tc.allowContextKey) + if tc.whenAllowContextKey != "" { + c.Set(echo.ContextKeyHeaderAllow, tc.whenAllowContextKey) } - h := cors(echo.NotFoundHandler) + h := cors(func(c *echo.Context) error { + return c.String(http.StatusOK, "OK") + }) h(c) assert.Equal(t, tc.expectAllow, rec.Header().Get(echo.HeaderAllow)) @@ -592,10 +535,10 @@ func TestCorsHeaders(t *testing.T) { //MaxAge: 3600, })) - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { return c.String(http.StatusOK, "OK") }) - e.POST("/", func(c echo.Context) error { + e.POST("/", func(c *echo.Context) error { return c.String(http.StatusCreated, "OK") }) @@ -639,17 +582,17 @@ func TestCorsHeaders(t *testing.T) { } func Test_allowOriginFunc(t *testing.T) { - returnTrue := func(origin string) (bool, error) { - return true, nil + returnTrue := func(c *echo.Context, origin string) (string, bool, error) { + return origin, true, nil } - returnFalse := func(origin string) (bool, error) { - return false, nil + returnFalse := func(c *echo.Context, origin string) (string, bool, error) { + return origin, false, nil } - returnError := func(origin string) (bool, error) { - return true, errors.New("this is a test error") + returnError := func(c *echo.Context, origin string) (string, bool, error) { + return origin, true, errors.New("this is a test error") } - allowOriginFuncs := []func(origin string) (bool, error){ + allowOriginFuncs := []func(c *echo.Context, origin string) (string, bool, error){ returnTrue, returnFalse, returnError, @@ -663,21 +606,21 @@ func Test_allowOriginFunc(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) req.Header.Set(echo.HeaderOrigin, origin) - cors := CORSWithConfig(CORSConfig{ - AllowOriginFunc: allowOriginFunc, - }) - h := cors(echo.NotFoundHandler) - err := h(c) + cors, err := CORSConfig{UnsafeAllowOriginFunc: allowOriginFunc}.ToMiddleware() + assert.NoError(t, err) + + h := cors(func(c *echo.Context) error { return echo.ErrNotFound }) + err = h(c) - expected, expectedErr := allowOriginFunc(origin) + allowedOrigin, allowed, expectedErr := allowOriginFunc(c, origin) if expectedErr != nil { assert.Equal(t, expectedErr, err) assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) continue } - if expected { - assert.Equal(t, origin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) + if allowed { + assert.Equal(t, allowedOrigin, rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } else { assert.Equal(t, "", rec.Header().Get(echo.HeaderAccessControlAllowOrigin)) } diff --git a/middleware/csrf.go b/middleware/csrf.go index f9d3293b0..33757b760 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -10,14 +10,13 @@ import ( "strings" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // CSRFConfig defines the config for CSRF middleware. type CSRFConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper - // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header // exactly matches the specified value. // Values should be formated as Origin header "scheme://host[:port]". @@ -32,10 +31,10 @@ type CSRFConfig struct { // - `same-site` same registrable domain (subdomain and/or different port) // - `cross-site` request originates from different site // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers - AllowSecFetchSiteFunc func(c echo.Context) (bool, error) + AllowSecFetchSiteFunc func(c *echo.Context) (bool, error) // TokenLength is the length of the generated token. - TokenLength uint8 `yaml:"token_length"` + TokenLength uint8 // Optional. Default value 32. // TokenLookup is a string in the form of ":" or ":,:" that is used @@ -49,47 +48,48 @@ type CSRFConfig struct { // - "header:X-CSRF-Token,query:csrf" TokenLookup string `yaml:"token_lookup"` + // Generator defines a function to generate token. + // Optional. Defaults tp randomString(TokenLength). + Generator func() string + // Context key to store generated CSRF token into context. // Optional. Default value "csrf". - ContextKey string `yaml:"context_key"` + ContextKey string // Name of the CSRF cookie. This cookie will store CSRF token. // Optional. Default value "csrf". - CookieName string `yaml:"cookie_name"` + CookieName string // Domain of the CSRF cookie. // Optional. Default value none. - CookieDomain string `yaml:"cookie_domain"` + CookieDomain string // Path of the CSRF cookie. // Optional. Default value none. - CookiePath string `yaml:"cookie_path"` + CookiePath string // Max age (in seconds) of the CSRF cookie. // Optional. Default value 86400 (24hr). - CookieMaxAge int `yaml:"cookie_max_age"` + CookieMaxAge int // Indicates if CSRF cookie is secure. // Optional. Default value false. - CookieSecure bool `yaml:"cookie_secure"` + CookieSecure bool // Indicates if CSRF cookie is HTTP only. // Optional. Default value false. - CookieHTTPOnly bool `yaml:"cookie_http_only"` + CookieHTTPOnly bool // Indicates SameSite mode of the CSRF cookie. // Optional. Default value SameSiteDefaultMode. - CookieSameSite http.SameSite `yaml:"cookie_same_site"` + CookieSameSite http.SameSite // ErrorHandler defines a function which is executed for returning custom errors. - ErrorHandler CSRFErrorHandler + ErrorHandler func(c *echo.Context, err error) error } -// CSRFErrorHandler is a function which is executed for creating custom errors. -type CSRFErrorHandler func(err error, c echo.Context) error - // ErrCSRFInvalid is returned when CSRF check fails -var ErrCSRFInvalid = echo.NewHTTPError(http.StatusForbidden, "invalid csrf token") +var ErrCSRFInvalid = &echo.HTTPError{Code: http.StatusForbidden, Message: "invalid csrf token"} // DefaultCSRFConfig is the default CSRF middleware config. var DefaultCSRFConfig = CSRFConfig{ @@ -105,25 +105,26 @@ var DefaultCSRFConfig = CSRFConfig{ // CSRF returns a Cross-Site Request Forgery (CSRF) middleware. // See: https://en.wikipedia.org/wiki/Cross-site_request_forgery func CSRF() echo.MiddlewareFunc { - c := DefaultCSRFConfig - return CSRFWithConfig(c) + return CSRFWithConfig(DefaultCSRFConfig) } -// CSRFWithConfig returns a CSRF middleware with config. -// See `CSRF()`. +// CSRFWithConfig returns a CSRF middleware with config or panics on invalid configuration. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return toMiddlewareOrPanic(config) } // ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { + // Defaults if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper } if config.TokenLength == 0 { config.TokenLength = DefaultCSRFConfig.TokenLength } - + if config.Generator == nil { + config.Generator = createRandomStringGenerator(config.TokenLength) + } if config.TokenLookup == "" { config.TokenLookup = DefaultCSRFConfig.TokenLookup } @@ -140,19 +141,19 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { config.CookieSecure = true } if len(config.TrustedOrigins) > 0 { - if vErr := validateOrigins(config.TrustedOrigins, "trusted origin"); vErr != nil { - return nil, vErr + if err := validateOrigins(config.TrustedOrigins, "trusted origin"); err != nil { + return nil, err } config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...) } - extractors, cErr := CreateExtractors(config.TokenLookup) + extractors, cErr := createExtractors(config.TokenLookup, 1) if cErr != nil { return nil, cErr } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -170,7 +171,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { token := "" if k, err := c.Cookie(config.CookieName); err != nil { - token = randomString(config.TokenLength) + token = config.Generator() // Generate token } else { token = k.Value // Reuse token } @@ -183,7 +184,7 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { var lastTokenErr error outer: for _, extractor := range extractors { - clientTokens, err := extractor(c) + clientTokens, _, err := extractor(c) if err != nil { lastExtractorErr = err continue @@ -202,22 +203,11 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if lastTokenErr != nil { finalErr = lastTokenErr } else if lastExtractorErr != nil { - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the query string") - } else if lastExtractorErr == errFormExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in the form parameter") - } else if lastExtractorErr == errHeaderExtractorValueMissing { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, "missing csrf token in request header") - } else { - lastExtractorErr = echo.NewHTTPError(http.StatusBadRequest, lastExtractorErr.Error()) - } - finalErr = lastExtractorErr + finalErr = echo.ErrBadRequest.Wrap(lastExtractorErr) } - if finalErr != nil { if config.ErrorHandler != nil { - return config.ErrorHandler(finalErr, c) + return config.ErrorHandler(c, finalErr) } return finalErr } @@ -258,7 +248,7 @@ func validateCSRFToken(token, clientToken string) bool { var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace} -func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error) { +func (config CSRFConfig) checkSecFetchSiteRequest(c *echo.Context) (bool, error) { // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers // Sec-Fetch-Site values are: // - `same-origin` exact origin match - allow always @@ -291,13 +281,13 @@ func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error) } // we are here when request is state-changing and `cross-site` or `same-site` - // Note: if you want to block `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc` + // Note: if you want to allow `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc` if config.AllowSecFetchSiteFunc != nil { return config.AllowSecFetchSiteFunc(c) } if secFetchSite == "same-site" { - return false, nil // fall back to legacy token + return false, echo.NewHTTPError(http.StatusForbidden, "same-site request blocked by CSRF") } return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF") } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 85b7f1077..ddecc10e3 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -11,7 +11,7 @@ import ( "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -57,6 +57,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenFormTokens: map[string][]string{ "csrf": {"invalid", "token"}, }, + expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from POST form", @@ -74,7 +75,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenFormTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in the form parameter", + expectError: "code=400, message=Bad Request, err=missing value in the form", }, { name: "ok, token from POST header", @@ -86,13 +87,14 @@ func TestCSRF_tokenExtractors(t *testing.T) { }, }, { - name: "ok, token from POST header, second token passes", + name: "nok, token from POST header, tokens limited to 1, second token would pass", whenTokenLookup: "header:" + echo.HeaderXCSRFToken, givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{ echo.HeaderXCSRFToken: {"invalid", "token"}, }, + expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from POST header", @@ -110,7 +112,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPost, givenHeaderTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in request header", + expectError: "code=400, message=Bad Request, err=missing value in request header", }, { name: "ok, token from PUT query param", @@ -122,13 +124,14 @@ func TestCSRF_tokenExtractors(t *testing.T) { }, }, { - name: "ok, token from PUT query form, second token passes", + name: "nok, token from PUT query form, second token would pass", whenTokenLookup: "query:csrf", givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{ "csrf": {"invalid", "token"}, }, + expectError: "code=403, message=invalid csrf token", }, { name: "nok, invalid token from PUT query form", @@ -146,7 +149,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenCSRFCookie: "token", givenMethod: http.MethodPut, givenQueryTokens: map[string][]string{}, - expectError: "code=400, message=missing csrf token in the query string", + expectError: "code=400, message=Bad Request, err=missing value in the query string", }, { name: "nok, invalid TokenLookup", @@ -210,7 +213,7 @@ func TestCSRF_tokenExtractors(t *testing.T) { assert.NoError(t, err) } - h := csrf(func(c echo.Context) error { + h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -255,7 +258,7 @@ func TestCSRFWithConfig(t *testing.T) { name: "nok, POST without token", whenMethod: http.MethodPost, expectEmptyBody: true, - expectErr: `code=400, message=missing csrf token in request header`, + expectErr: `code=400, message=Bad Request, err=missing value in request header`, }, { name: "nok, POST empty token", @@ -319,7 +322,7 @@ func TestCSRFWithConfig(t *testing.T) { } assert.NoError(t, err) - h := mw(func(c echo.Context) error { + h := mw(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -349,7 +352,7 @@ func TestCSRF(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) csrf := CSRF() - h := csrf(func(c echo.Context) error { + h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -369,7 +372,7 @@ func TestCSRFSetSameSiteMode(t *testing.T) { CookieSameSite: http.SameSiteStrictMode, }) - h := csrf(func(c echo.Context) error { + h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -386,7 +389,7 @@ func TestCSRFWithoutSameSiteMode(t *testing.T) { csrf := CSRFWithConfig(CSRFConfig{}) - h := csrf(func(c echo.Context) error { + h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -405,7 +408,7 @@ func TestCSRFWithSameSiteDefaultMode(t *testing.T) { CookieSameSite: http.SameSiteDefaultMode, }) - h := csrf(func(c echo.Context) error { + h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -425,7 +428,7 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { }.ToMiddleware() assert.NoError(t, err) - h := csrf(func(c echo.Context) error { + h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -461,12 +464,12 @@ func TestCSRFConfig_skipper(t *testing.T) { c := e.NewContext(req, rec) csrf := CSRFWithConfig(CSRFConfig{ - Skipper: func(c echo.Context) bool { + Skipper: func(c *echo.Context) bool { return tc.whenSkip }, }) - h := csrf(func(c echo.Context) error { + h := csrf(func(c *echo.Context) error { return c.String(http.StatusOK, "test") }) @@ -480,13 +483,13 @@ func TestCSRFConfig_skipper(t *testing.T) { func TestCSRFErrorHandling(t *testing.T) { cfg := CSRFConfig{ - ErrorHandler: func(err error, c echo.Context) error { + ErrorHandler: func(c *echo.Context, err error) error { return echo.NewHTTPError(http.StatusTeapot, "error_handler_executed") }, } e := echo.New() - e.POST("/", func(c echo.Context) error { + e.POST("/", func(c *echo.Context) error { return c.String(http.StatusNotImplemented, "should not end up here") }) @@ -559,7 +562,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodPost, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: ``, + expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "ok, unsafe POST + same-origin passes", @@ -617,7 +620,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodPut, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: ``, + expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "nok, unsafe DELETE + cross-site is blocked", @@ -633,7 +636,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { whenMethod: http.MethodDelete, whenSecFetchSite: "same-site", expectAllow: false, - expectErr: ``, + expectErr: `code=403, message=same-site request blocked by CSRF`, }, { name: "nok, unsafe PATCH + cross-site is blocked", @@ -746,7 +749,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { { name: "ok, unsafe POST + same-site + custom func allows", givenConfig: CSRFConfig{ - AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return true, nil }, }, @@ -757,7 +760,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { { name: "ok, unsafe POST + cross-site + custom func allows", givenConfig: CSRFConfig{ - AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return true, nil }, }, @@ -768,7 +771,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { { name: "nok, unsafe POST + same-site + custom func returns custom error", givenConfig: CSRFConfig{ - AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func") }, }, @@ -780,7 +783,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { { name: "nok, unsafe POST + cross-site + custom func returns false with nil error", givenConfig: CSRFConfig{ - AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return false, nil }, }, @@ -801,7 +804,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, - AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "should not be called") }, }, @@ -814,7 +817,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks", givenConfig: CSRFConfig{ TrustedOrigins: []string{"https://trusted.example.com"}, - AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + AllowSecFetchSiteFunc: func(c *echo.Context) (bool, error) { return false, echo.NewHTTPError(http.StatusTeapot, "custom block") }, }, @@ -836,8 +839,7 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { } res := httptest.NewRecorder() - e := echo.New() - c := e.NewContext(req, res) + c := echo.NewContext(req, res) allow, err := tc.givenConfig.checkSecFetchSiteRequest(c) diff --git a/middleware/decompress.go b/middleware/decompress.go index 0c56176ee..a384af2ea 100644 --- a/middleware/decompress.go +++ b/middleware/decompress.go @@ -9,7 +9,7 @@ import ( "net/http" "sync" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // DecompressConfig defines the config for Decompress middleware. @@ -19,6 +19,13 @@ type DecompressConfig struct { // GzipDecompressPool defines an interface to provide the sync.Pool used to create/store Gzip readers GzipDecompressPool Decompressor + + // MaxDecompressedSize limits the maximum size of decompressed request body in bytes. + // If the decompressed body exceeds this limit, the middleware returns HTTP 413 error. + // This prevents zip bomb attacks where small compressed payloads decompress to huge sizes. + // Default: 100 * MB (104,857,600 bytes) + // Set to -1 to disable limits (not recommended in production). + MaxDecompressedSize int64 } // GZIPEncoding content-encoding header if set to "gzip", decompress body contents. @@ -29,39 +36,48 @@ type Decompressor interface { gzipDecompressPool() sync.Pool } -// DefaultDecompressConfig defines the config for decompress middleware -var DefaultDecompressConfig = DecompressConfig{ - Skipper: DefaultSkipper, - GzipDecompressPool: &DefaultGzipDecompressPool{}, -} - // DefaultGzipDecompressPool is the default implementation of Decompressor interface type DefaultGzipDecompressPool struct { } func (d *DefaultGzipDecompressPool) gzipDecompressPool() sync.Pool { - return sync.Pool{New: func() interface{} { return new(gzip.Reader) }} + return sync.Pool{New: func() any { return new(gzip.Reader) }} } // Decompress decompresses request body based if content encoding type is set to "gzip" with default config +// +// SECURITY: By default, this limits decompressed data to 100MB to prevent zip bomb attacks. +// To customize the limit, use DecompressWithConfig. To disable limits (not recommended in production), +// set MaxDecompressedSize to -1. func Decompress() echo.MiddlewareFunc { - return DecompressWithConfig(DefaultDecompressConfig) + return DecompressWithConfig(DecompressConfig{}) } -// DecompressWithConfig decompresses request body based if content encoding type is set to "gzip" with config +// DecompressWithConfig returns a decompress middleware with config or panics on invalid configuration. +// +// SECURITY: If MaxDecompressedSize is not set (zero value), it defaults to 100MB to prevent +// DoS attacks via zip bombs. Set to -1 to explicitly disable limits if needed for your use case. func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts DecompressConfig to middleware or returns an error for invalid configuration +func (config DecompressConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultGzipConfig.Skipper + config.Skipper = DefaultSkipper } if config.GzipDecompressPool == nil { - config.GzipDecompressPool = DefaultDecompressConfig.GzipDecompressPool + config.GzipDecompressPool = &DefaultGzipDecompressPool{} + } + // Apply secure default for decompression limit + if config.MaxDecompressedSize == 0 { + config.MaxDecompressedSize = 100 * MB } return func(next echo.HandlerFunc) echo.HandlerFunc { pool := config.GzipDecompressPool.gzipDecompressPool() - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -73,7 +89,10 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { i := pool.Get() gr, ok := i.(*gzip.Reader) if !ok || gr == nil { - return echo.NewHTTPError(http.StatusInternalServerError, i.(error).Error()) + if err, isErr := i.(error); isErr { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return echo.NewHTTPError(http.StatusInternalServerError, "unexpected type from gzip decompression pool") } defer pool.Put(gr) @@ -90,9 +109,47 @@ func DecompressWithConfig(config DecompressConfig) echo.MiddlewareFunc { // only Close gzip reader if it was set to a proper gzip source otherwise it will panic on close. defer gr.Close() - c.Request().Body = gr + // Apply decompression size limit to prevent zip bombs + if config.MaxDecompressedSize > 0 { + c.Request().Body = &limitedGzipReader{ + Reader: gr, + remaining: config.MaxDecompressedSize, + limit: config.MaxDecompressedSize, + } + } else { + // -1 means explicitly unlimited (not recommended) + c.Request().Body = gr + } return next(c) } + }, nil +} + +// limitedGzipReader wraps a gzip reader with size limiting to prevent zip bombs +type limitedGzipReader struct { + *gzip.Reader + remaining int64 + limit int64 +} + +func (r *limitedGzipReader) Read(p []byte) (n int, err error) { + if r.remaining <= 0 { + // Limit exceeded - return 413 error + return 0, echo.ErrStatusRequestEntityTooLarge + } + + // Limit the read to remaining bytes + if int64(len(p)) > r.remaining { + p = p[:r.remaining] } + + n, err = r.Reader.Read(p) + r.remaining -= int64(n) + + return n, err +} + +func (r *limitedGzipReader) Close() error { + return r.Reader.Close() } diff --git a/middleware/decompress_test.go b/middleware/decompress_test.go index 63b1a68f5..1823e94bb 100644 --- a/middleware/decompress_test.go +++ b/middleware/decompress_test.go @@ -14,61 +14,91 @@ import ( "sync" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestDecompress(t *testing.T) { e := echo.New() - req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - // Skip if no Content-Encoding header - h := Decompress()(func(c echo.Context) error { + h := Decompress()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) - - assert.Equal(t, "test", rec.Body.String()) - // Decompress + // Decompress request body body := `{"name": "echo"}` gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) assert.Equal(t, body, string(b)) } -func TestDecompressDefaultConfig(t *testing.T) { +func TestDecompress_skippedIfNoHeader(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := DecompressWithConfig(DecompressConfig{})(func(c echo.Context) error { + // Skip if no Content-Encoding header + h := Decompress()(func(c *echo.Context) error { c.Response().Write([]byte("test")) // For Content-Type sniffing return nil }) - h(c) + + err := h(c) + assert.NoError(t, err) + assert.Equal(t, "test", rec.Body.String()) + +} + +func TestDecompressWithConfig_DefaultConfig_noDecode(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("test")) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + })(c) + assert.NoError(t, err) assert.Equal(t, "test", rec.Body.String()) +} + +func TestDecompressWithConfig_DefaultConfig(t *testing.T) { + e := echo.New() + + h := Decompress()(func(c *echo.Context) error { + c.Response().Write([]byte("test")) // For Content-Type sniffing + return nil + }) + // Decompress body := `{"name": "echo"}` gz, _ := gzipString(body) - req = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - h(c) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := h(c) + assert.NoError(t, err) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) @@ -83,7 +113,9 @@ func TestCompressRequestWithoutDecompressMiddleware(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) b, err := io.ReadAll(req.Body) assert.NoError(t, err) @@ -97,10 +129,13 @@ func TestDecompressNoContent(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Decompress()(func(c echo.Context) error { + h := Decompress()(func(c *echo.Context) error { return c.NoContent(http.StatusNoContent) }) - if assert.NoError(t, h(c)) { + + err := h(c) + + if assert.NoError(t, err) { assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) assert.Empty(t, rec.Header().Get(echo.HeaderContentType)) assert.Equal(t, 0, len(rec.Body.Bytes())) @@ -110,13 +145,15 @@ func TestDecompressNoContent(t *testing.T) { func TestDecompressErrorReturned(t *testing.T) { e := echo.New() e.Use(Decompress()) - e.GET("/", func(c echo.Context) error { + e.GET("/", func(c *echo.Context) error { return echo.ErrNotFound }) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + assert.Equal(t, http.StatusNotFound, rec.Code) assert.Empty(t, rec.Header().Get(echo.HeaderContentEncoding)) } @@ -124,7 +161,7 @@ func TestDecompressErrorReturned(t *testing.T) { func TestDecompressSkipper(t *testing.T) { e := echo.New() e.Use(DecompressWithConfig(DecompressConfig{ - Skipper: func(c echo.Context) bool { + Skipper: func(c *echo.Context) bool { return c.Request().URL.Path == "/skip" }, })) @@ -133,7 +170,9 @@ func TestDecompressSkipper(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, rec.Header().Get(echo.HeaderContentType), echo.MIMEApplicationJSON) reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) @@ -145,7 +184,7 @@ type TestDecompressPoolWithError struct { func (d *TestDecompressPoolWithError) gzipDecompressPool() sync.Pool { return sync.Pool{ - New: func() interface{} { + New: func() any { return errors.New("pool error") }, } @@ -162,7 +201,9 @@ func TestDecompressPoolError(t *testing.T) { req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) rec := httptest.NewRecorder() c := e.NewContext(req, rec) + e.ServeHTTP(rec, req) + assert.Equal(t, GZIPEncoding, req.Header.Get(echo.HeaderContentEncoding)) reqBody, err := io.ReadAll(c.Request().Body) assert.NoError(t, err) @@ -177,7 +218,7 @@ func BenchmarkDecompress(b *testing.B) { req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(string(gz))) req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) - h := Decompress()(func(c echo.Context) error { + h := Decompress()(func(c *echo.Context) error { c.Response().Write([]byte(body)) // For Content-Type sniffing return nil }) @@ -208,3 +249,260 @@ func gzipString(body string) ([]byte, error) { return buf.Bytes(), nil } + +func TestDecompress_WithinLimit(t *testing.T) { + e := echo.New() + body := strings.Repeat("test data ", 100) // Small payload ~1KB + gz, _ := gzipString(body) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, body, rec.Body.String()) +} + +func TestDecompress_ExceedsLimit(t *testing.T) { + e := echo.New() + // Create 2KB of data but limit to 1KB + largeBody := strings.Repeat("A", 2*1024) + gz, _ := gzipString(largeBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + _, readErr := io.ReadAll(c.Request().Body) + return readErr + })(c) + + // Should return 413 error + assert.Error(t, err) + he, ok := err.(echo.HTTPStatusCoder) + assert.True(t, ok) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) +} + +func TestDecompress_AtExactLimit(t *testing.T) { + e := echo.New() + exactBody := strings.Repeat("B", 1024) // Exactly 1KB + gz, _ := gzipString(exactBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, exactBody, rec.Body.String()) +} + +func TestDecompress_ZipBomb(t *testing.T) { + e := echo.New() + // Create highly compressed data that expands to 2MB + // but limit is 1MB + largeBody := bytes.Repeat([]byte("A"), 2*1024*1024) // 2MB + var buf bytes.Buffer + gzWriter := gzip.NewWriter(&buf) + gzWriter.Write(largeBody) + gzWriter.Close() + + req := httptest.NewRequest(http.MethodPost, "/", &buf) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + _, readErr := io.ReadAll(c.Request().Body) + return readErr + })(c) + + // Should return 413 error + assert.Error(t, err) + he, ok := err.(echo.HTTPStatusCoder) + assert.True(t, ok) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) +} + +func TestDecompress_UnlimitedExplicit(t *testing.T) { + e := echo.New() + largeBody := strings.Repeat("X", 10*1024) // 10KB + gz, _ := gzipString(largeBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: -1}.ToMiddleware() // Unlimited + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, largeBody, rec.Body.String()) +} + +func TestDecompress_DefaultLimit(t *testing.T) { + e := echo.New() + smallBody := "test" + gz, _ := gzipString(smallBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Use zero value which should apply 100MB default + h, err := DecompressConfig{}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, smallBody, rec.Body.String()) +} + +func TestDecompress_SmallCustomLimit(t *testing.T) { + e := echo.New() + body := strings.Repeat("D", 512) // 512 bytes + gz, _ := gzipString(body) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + b, _ := io.ReadAll(c.Request().Body) + return c.String(http.StatusOK, string(b)) + })(c) + + assert.NoError(t, err) + assert.Equal(t, body, rec.Body.String()) +} + +func TestDecompress_MultipleReads(t *testing.T) { + e := echo.New() + // Test that limit is enforced across multiple Read() calls + largeBody := strings.Repeat("M", 2*1024) // 2KB + gz, _ := gzipString(largeBody) + + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1024}.ToMiddleware() // 1KB limit + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + // Read in small chunks + buf := make([]byte, 256) + total := 0 + for { + n, readErr := c.Request().Body.Read(buf) + total += n + if readErr != nil { + if readErr == io.EOF { + return nil + } + return readErr + } + } + })(c) + + // Should return 413 error from cumulative reads + assert.Error(t, err) + he, ok := err.(echo.HTTPStatusCoder) + assert.True(t, ok) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) +} + +func TestDecompress_LargePayloadDosPrevention(t *testing.T) { + e := echo.New() + // Simulate a DoS attack with highly compressed large payload + largeSize := 10 * 1024 * 1024 // 10MB decompressed + largeBody := bytes.Repeat([]byte("Z"), largeSize) + var buf bytes.Buffer + gzWriter := gzip.NewWriter(&buf) + gzWriter.Write(largeBody) + gzWriter.Close() + + req := httptest.NewRequest(http.MethodPost, "/", &buf) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h, err := DecompressConfig{MaxDecompressedSize: 1 * MB}.ToMiddleware() + assert.NoError(t, err) + + err = h(func(c *echo.Context) error { + _, readErr := io.ReadAll(c.Request().Body) + return readErr + })(c) + + // Should prevent DoS by returning 413 + assert.Error(t, err) + he, ok := err.(echo.HTTPStatusCoder) + assert.True(t, ok) + assert.Equal(t, http.StatusRequestEntityTooLarge, he.StatusCode()) +} + +func BenchmarkDecompress_WithLimit(b *testing.B) { + e := echo.New() + body := strings.Repeat("benchmark data ", 1000) // ~15KB + gz, _ := gzipString(body) + + h, _ := DecompressConfig{MaxDecompressedSize: 100 * MB}.ToMiddleware() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader(gz)) + req.Header.Set(echo.HeaderContentEncoding, GZIPEncoding) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + h(func(c *echo.Context) error { + io.ReadAll(c.Request().Body) + return nil + })(c) + } +} diff --git a/middleware/extractor.go b/middleware/extractor.go index 3f2741407..abb603186 100644 --- a/middleware/extractor.go +++ b/middleware/extractor.go @@ -4,11 +4,11 @@ package middleware import ( - "errors" "fmt" - "github.com/labstack/echo/v4" "net/textproto" "strings" + + "github.com/labstack/echo/v5" ) const ( @@ -17,18 +17,44 @@ const ( extractorLimit = 20 ) -var errHeaderExtractorValueMissing = errors.New("missing value in request header") -var errHeaderExtractorValueInvalid = errors.New("invalid value in request header") -var errQueryExtractorValueMissing = errors.New("missing value in the query string") -var errParamExtractorValueMissing = errors.New("missing value in path params") -var errCookieExtractorValueMissing = errors.New("missing value in cookies") -var errFormExtractorValueMissing = errors.New("missing value in the form") +// ExtractorSource is type to indicate source for extracted value +type ExtractorSource string + +const ( + // ExtractorSourceHeader means value was extracted from request header + ExtractorSourceHeader ExtractorSource = "header" + // ExtractorSourceQuery means value was extracted from request query parameters + ExtractorSourceQuery ExtractorSource = "query" + // ExtractorSourcePathParam means value was extracted from route path parameters + ExtractorSourcePathParam ExtractorSource = "param" + // ExtractorSourceCookie means value was extracted from request cookies + ExtractorSourceCookie ExtractorSource = "cookie" + // ExtractorSourceForm means value was extracted from request form values + ExtractorSourceForm ExtractorSource = "form" +) + +// ValueExtractorError is error type when middleware extractor is unable to extract value from lookups +type ValueExtractorError struct { + message string +} + +// Error returns errors text +func (e *ValueExtractorError) Error() string { + return e.message +} + +var errHeaderExtractorValueMissing = &ValueExtractorError{message: "missing value in request header"} +var errHeaderExtractorValueInvalid = &ValueExtractorError{message: "invalid value in request header"} +var errQueryExtractorValueMissing = &ValueExtractorError{message: "missing value in the query string"} +var errParamExtractorValueMissing = &ValueExtractorError{message: "missing value in path params"} +var errCookieExtractorValueMissing = &ValueExtractorError{message: "missing value in cookies"} +var errFormExtractorValueMissing = &ValueExtractorError{message: "missing value in the form"} // ValuesExtractor defines a function for extracting values (keys/tokens) from the given context. -type ValuesExtractor func(c echo.Context) ([]string, error) +type ValuesExtractor func(c *echo.Context) ([]string, ExtractorSource, error) // CreateExtractors creates ValuesExtractors from given lookups. -// Lookups is a string in the form of ":" or ":,:" that is used +// lookups is a string in the form of ":" or ":,:" that is used // to extract key from the request. // Possible values: // - "header:" or "header::" @@ -43,14 +69,22 @@ type ValuesExtractor func(c echo.Context) ([]string, error) // // Multiple sources example: // - "header:Authorization,header:X-Api-Key" -func CreateExtractors(lookups string) ([]ValuesExtractor, error) { - return createExtractors(lookups, "") +// +// limit sets the maximum amount how many lookups can be returned. +func CreateExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { + return createExtractors(lookups, limit) } -func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, error) { +func createExtractors(lookups string, limit uint) ([]ValuesExtractor, error) { if lookups == "" { return nil, nil } + if limit == 0 { + limit = 1 + } else if limit > extractorLimit { + limit = extractorLimit + } + sources := strings.Split(lookups, ",") var extractors = make([]ValuesExtractor, 0) for _, source := range sources { @@ -61,28 +95,19 @@ func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, err switch parts[0] { case "query": - extractors = append(extractors, valuesFromQuery(parts[1])) + extractors = append(extractors, valuesFromQuery(parts[1], limit)) case "param": - extractors = append(extractors, valuesFromParam(parts[1])) + extractors = append(extractors, valuesFromParam(parts[1], limit)) case "cookie": - extractors = append(extractors, valuesFromCookie(parts[1])) + extractors = append(extractors, valuesFromCookie(parts[1], limit)) case "form": - extractors = append(extractors, valuesFromForm(parts[1])) + extractors = append(extractors, valuesFromForm(parts[1], limit)) case "header": prefix := "" if len(parts) > 2 { prefix = parts[2] - } else if authScheme != "" && parts[1] == echo.HeaderAuthorization { - // backwards compatibility for JWT and KeyAuth: - // * we only apply this fix to Authorization as header we use and uses prefixes like "Bearer " etc - // * previously header extractor assumed that auth-scheme/prefix had a space as suffix we need to retain that - // behaviour for default values and Authorization header. - prefix = authScheme - if !strings.HasSuffix(prefix, " ") { - prefix += " " - } } - extractors = append(extractors, valuesFromHeader(parts[1], prefix)) + extractors = append(extractors, valuesFromHeader(parts[1], prefix, limit)) } } return extractors, nil @@ -94,28 +119,32 @@ func createExtractors(lookups string, authScheme string) ([]ValuesExtractor, err // note the space at the end. In case of basic authentication `Authorization: Basic ` prefix we want to remove // is `Basic `. In case of JWT tokens `Authorization: Bearer ` prefix is `Bearer `. // If prefix is left empty the whole value is returned. -func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { +func valuesFromHeader(header string, valuePrefix string, limit uint) ValuesExtractor { prefixLen := len(valuePrefix) // standard library parses http.Request header keys in canonical form but we may provide something else so fix this header = textproto.CanonicalMIMEHeaderKey(header) - return func(c echo.Context) ([]string, error) { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { values := c.Request().Header.Values(header) if len(values) == 0 { - return nil, errHeaderExtractorValueMissing + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } + i := uint(0) result := make([]string, 0) - for i, value := range values { + for _, value := range values { if prefixLen == 0 { result = append(result, value) - if i >= extractorLimit-1 { + i++ + if i >= limit { break } - continue - } - if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { + } else if len(value) > prefixLen && strings.EqualFold(value[:prefixLen], valuePrefix) { result = append(result, value[prefixLen:]) - if i >= extractorLimit-1 { + i++ + if i >= limit { break } } @@ -123,85 +152,102 @@ func valuesFromHeader(header string, valuePrefix string) ValuesExtractor { if len(result) == 0 { if prefixLen > 0 { - return nil, errHeaderExtractorValueInvalid + return nil, ExtractorSourceHeader, errHeaderExtractorValueInvalid } - return nil, errHeaderExtractorValueMissing + return nil, ExtractorSourceHeader, errHeaderExtractorValueMissing } - return result, nil + return result, ExtractorSourceHeader, nil } } // valuesFromQuery returns a function that extracts values from the query string. -func valuesFromQuery(param string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { +func valuesFromQuery(param string, limit uint) ValuesExtractor { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { result := c.QueryParams()[param] if len(result) == 0 { - return nil, errQueryExtractorValueMissing - } else if len(result) > extractorLimit-1 { - result = result[:extractorLimit] + return nil, ExtractorSourceQuery, errQueryExtractorValueMissing + } else if len(result) > int(limit)-1 { + result = result[:limit] } - return result, nil + return result, ExtractorSourceQuery, nil } } // valuesFromParam returns a function that extracts values from the url param string. -func valuesFromParam(param string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { +func valuesFromParam(param string, limit uint) ValuesExtractor { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { result := make([]string, 0) - paramVales := c.ParamValues() - for i, p := range c.ParamNames() { - if param == p { - result = append(result, paramVales[i]) - if i >= extractorLimit-1 { - break - } + i := uint(0) + for _, p := range c.PathValues() { + if param != p.Name { + continue + } + result = append(result, p.Value) + i++ + if i >= limit { + break } } if len(result) == 0 { - return nil, errParamExtractorValueMissing + return nil, ExtractorSourcePathParam, errParamExtractorValueMissing } - return result, nil + return result, ExtractorSourcePathParam, nil } } // valuesFromCookie returns a function that extracts values from the named cookie. -func valuesFromCookie(name string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { +func valuesFromCookie(name string, limit uint) ValuesExtractor { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { cookies := c.Cookies() if len(cookies) == 0 { - return nil, errCookieExtractorValueMissing + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } + i := uint(0) result := make([]string, 0) - for i, cookie := range cookies { - if name == cookie.Name { - result = append(result, cookie.Value) - if i >= extractorLimit-1 { - break - } + for _, cookie := range cookies { + if name != cookie.Name { + continue + } + result = append(result, cookie.Value) + i++ + if i >= limit { + break } } if len(result) == 0 { - return nil, errCookieExtractorValueMissing + return nil, ExtractorSourceCookie, errCookieExtractorValueMissing } - return result, nil + return result, ExtractorSourceCookie, nil } } // valuesFromForm returns a function that extracts values from the form field. -func valuesFromForm(name string) ValuesExtractor { - return func(c echo.Context) ([]string, error) { +func valuesFromForm(name string, limit uint) ValuesExtractor { + if limit == 0 { + limit = 1 + } + return func(c *echo.Context) ([]string, ExtractorSource, error) { if c.Request().Form == nil { - _ = c.Request().ParseMultipartForm(32 << 20) // same what `c.Request().FormValue(name)` does + _, _ = c.MultipartForm() // we want to trigger c.request.ParseMultipartForm(c.formParseMaxMemory) } values := c.Request().Form[name] if len(values) == 0 { - return nil, errFormExtractorValueMissing + return nil, ExtractorSourceForm, errFormExtractorValueMissing } - if len(values) > extractorLimit-1 { - values = values[:extractorLimit] + if len(values) > int(limit)-1 { + values = values[:limit] } result := append([]string{}, values...) - return result, nil + return result, ExtractorSourceForm, nil } } diff --git a/middleware/extractor_test.go b/middleware/extractor_test.go index 42cbcfeab..04cc7b829 100644 --- a/middleware/extractor_test.go +++ b/middleware/extractor_test.go @@ -6,39 +6,26 @@ package middleware import ( "bytes" "fmt" - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" "mime/multipart" "net/http" "net/http/httptest" "net/url" "strings" "testing" -) - -type pathParam struct { - name string - value string -} -func setPathParams(c echo.Context, params []pathParam) { - names := make([]string, 0, len(params)) - values := make([]string, 0, len(params)) - for _, pp := range params { - names = append(names, pp.name) - values = append(values, pp.value) - } - c.SetParamNames(names...) - c.SetParamValues(values...) -} + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/assert" +) func TestCreateExtractors(t *testing.T) { var testCases = []struct { name string givenRequest func() *http.Request - givenPathParams []pathParam - whenLoopups string + givenPathValues echo.PathValues + whenLookups string + whenLimit uint expectValues []string + expectSource ExtractorSource expectCreateError string expectError string }{ @@ -49,8 +36,9 @@ func TestCreateExtractors(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer token") return req }, - whenLoopups: "header:Authorization:Bearer ", + whenLookups: "header:Authorization:Bearer ", expectValues: []string{"token"}, + expectSource: ExtractorSourceHeader, }, { name: "ok, form", @@ -62,8 +50,9 @@ func TestCreateExtractors(t *testing.T) { req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) return req }, - whenLoopups: "form:name", + whenLookups: "form:name", expectValues: []string{"Jon Snow"}, + expectSource: ExtractorSourceForm, }, { name: "ok, cookie", @@ -72,16 +61,18 @@ func TestCreateExtractors(t *testing.T) { req.Header.Set(echo.HeaderCookie, "_csrf=token") return req }, - whenLoopups: "cookie:_csrf", + whenLookups: "cookie:_csrf", expectValues: []string{"token"}, + expectSource: ExtractorSourceCookie, }, { name: "ok, param", - givenPathParams: []pathParam{ - {name: "id", value: "123"}, + givenPathValues: echo.PathValues{ + {Name: "id", Value: "123"}, }, - whenLoopups: "param:id", + whenLookups: "param:id", expectValues: []string{"123"}, + expectSource: ExtractorSourcePathParam, }, { name: "ok, query", @@ -89,12 +80,13 @@ func TestCreateExtractors(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/?id=999", nil) return req }, - whenLoopups: "query:id", + whenLookups: "query:id", expectValues: []string{"999"}, + expectSource: ExtractorSourceQuery, }, { name: "nok, invalid lookup", - whenLoopups: "query", + whenLookups: "query", expectCreateError: "extractor source for lookup could not be split into needed parts: query", }, } @@ -109,11 +101,11 @@ func TestCreateExtractors(t *testing.T) { } rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) + if tc.givenPathValues != nil { + c.SetPathValues(tc.givenPathValues) } - extractors, err := CreateExtractors(tc.whenLoopups) + extractors, err := CreateExtractors(tc.whenLookups, tc.whenLimit) if tc.expectCreateError != "" { assert.EqualError(t, err, tc.expectCreateError) return @@ -121,8 +113,9 @@ func TestCreateExtractors(t *testing.T) { assert.NoError(t, err) for _, e := range extractors { - values, eErr := e(c) + values, source, eErr := e(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, tc.expectSource, source) if tc.expectError != "" { assert.EqualError(t, eErr, tc.expectError) return @@ -143,6 +136,7 @@ func TestValuesFromHeader(t *testing.T) { givenRequest func(req *http.Request) whenName string whenValuePrefix string + whenLimit uint expectValues []string expectError string }{ @@ -168,6 +162,7 @@ func TestValuesFromHeader(t *testing.T) { }, whenName: echo.HeaderAuthorization, whenValuePrefix: "basic ", + whenLimit: 2, expectValues: []string{"dXNlcjpwYXNzd29yZA==", "dGVzdDp0ZXN0"}, }, { @@ -213,6 +208,7 @@ func TestValuesFromHeader(t *testing.T) { }, whenName: echo.HeaderAuthorization, whenValuePrefix: "basic ", + whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -227,6 +223,7 @@ func TestValuesFromHeader(t *testing.T) { }, whenName: echo.HeaderAuthorization, whenValuePrefix: "", + whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -245,10 +242,11 @@ func TestValuesFromHeader(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix) + extractor := valuesFromHeader(tc.whenName, tc.whenValuePrefix, tc.whenLimit) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceHeader, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -263,6 +261,7 @@ func TestValuesFromQuery(t *testing.T) { name string givenQueryPart string whenName string + whenLimit uint expectValues []string expectError string }{ @@ -276,6 +275,7 @@ func TestValuesFromQuery(t *testing.T) { name: "ok, multiple value", givenQueryPart: "?id=123&id=456&name=test", whenName: "id", + whenLimit: 2, expectValues: []string{"123", "456"}, }, { @@ -290,7 +290,8 @@ func TestValuesFromQuery(t *testing.T) { "&id=1&id=2&id=3&id=4&id=5&id=6&id=7&id=8&id=9&id=10" + "&id=11&id=12&id=13&id=14&id=15&id=16&id=17&id=18&id=19&id=20" + "&id=21&id=22&id=23&id=24&id=25", - whenName: "id", + whenName: "id", + whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -306,10 +307,11 @@ func TestValuesFromQuery(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - extractor := valuesFromQuery(tc.whenName) + extractor := valuesFromQuery(tc.whenName, tc.whenLimit) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceQuery, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -320,53 +322,56 @@ func TestValuesFromQuery(t *testing.T) { } func TestValuesFromParam(t *testing.T) { - examplePathParams := []pathParam{ - {name: "id", value: "123"}, - {name: "gid", value: "456"}, - {name: "gid", value: "789"}, + examplePathValues := echo.PathValues{ + {Name: "id", Value: "123"}, + {Name: "gid", Value: "456"}, + {Name: "gid", Value: "789"}, } - examplePathParams20 := make([]pathParam, 0) + examplePathValues20 := make(echo.PathValues, 0) for i := 1; i < 25; i++ { - examplePathParams20 = append(examplePathParams20, pathParam{name: "id", value: fmt.Sprintf("%v", i)}) + examplePathValues20 = append(examplePathValues20, echo.PathValue{Name: "id", Value: fmt.Sprintf("%v", i)}) } var testCases = []struct { name string - givenPathParams []pathParam + givenPathValues echo.PathValues whenName string + whenLimit uint expectValues []string expectError string }{ { name: "ok, single value", - givenPathParams: examplePathParams, + givenPathValues: examplePathValues, whenName: "id", expectValues: []string{"123"}, }, { name: "ok, multiple value", - givenPathParams: examplePathParams, + givenPathValues: examplePathValues, whenName: "gid", + whenLimit: 2, expectValues: []string{"456", "789"}, }, { name: "nok, no values", - givenPathParams: nil, + givenPathValues: nil, whenName: "nope", expectValues: nil, expectError: errParamExtractorValueMissing.Error(), }, { name: "nok, no matching value", - givenPathParams: examplePathParams, + givenPathValues: examplePathValues, whenName: "nope", expectValues: nil, expectError: errParamExtractorValueMissing.Error(), }, { name: "ok, cut values over extractorLimit", - givenPathParams: examplePathParams20, + givenPathValues: examplePathValues20, whenName: "id", + whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -381,14 +386,15 @@ func TestValuesFromParam(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - if tc.givenPathParams != nil { - setPathParams(c, tc.givenPathParams) + if tc.givenPathValues != nil { + c.SetPathValues(tc.givenPathValues) } - extractor := valuesFromParam(tc.whenName) + extractor := valuesFromParam(tc.whenName, tc.whenLimit) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourcePathParam, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -407,6 +413,7 @@ func TestValuesFromCookie(t *testing.T) { name string givenRequest func(req *http.Request) whenName string + whenLimit uint expectValues []string expectError string }{ @@ -423,6 +430,7 @@ func TestValuesFromCookie(t *testing.T) { req.Header.Add(echo.HeaderCookie, "_csrf=token2") }, whenName: "_csrf", + whenLimit: 2, expectValues: []string{"token", "token2"}, }, { @@ -446,7 +454,8 @@ func TestValuesFromCookie(t *testing.T) { req.Header.Add(echo.HeaderCookie, fmt.Sprintf("_csrf=%v", i)) } }, - whenName: "_csrf", + whenName: "_csrf", + whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -465,10 +474,11 @@ func TestValuesFromCookie(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - extractor := valuesFromCookie(tc.whenName) + extractor := valuesFromCookie(tc.whenName, tc.whenLimit) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceCookie, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -527,6 +537,7 @@ func TestValuesFromForm(t *testing.T) { name string givenRequest *http.Request whenName string + whenLimit uint expectValues []string expectError string }{ @@ -542,6 +553,7 @@ func TestValuesFromForm(t *testing.T) { v.Add("emails[]", "snow@labstack.com") }), whenName: "emails[]", + whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { @@ -550,6 +562,7 @@ func TestValuesFromForm(t *testing.T) { w.WriteField("emails[]", "snow@labstack.com") }), whenName: "emails[]", + whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { @@ -564,6 +577,7 @@ func TestValuesFromForm(t *testing.T) { v.Add("emails[]", "snow@labstack.com") }), whenName: "emails[]", + whenLimit: 2, expectValues: []string{"jon@labstack.com", "snow@labstack.com"}, }, { @@ -579,7 +593,8 @@ func TestValuesFromForm(t *testing.T) { v.Add("id[]", fmt.Sprintf("%v", i)) } }), - whenName: "id[]", + whenName: "id[]", + whenLimit: extractorLimit, expectValues: []string{ "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", @@ -595,10 +610,11 @@ func TestValuesFromForm(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - extractor := valuesFromForm(tc.whenName) + extractor := valuesFromForm(tc.whenName, tc.whenLimit) - values, err := extractor(c) + values, source, err := extractor(c) assert.Equal(t, tc.expectValues, values) + assert.Equal(t, ExtractorSourceForm, source) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { diff --git a/middleware/key_auth.go b/middleware/key_auth.go index 79bee207c..e14bd9e2e 100644 --- a/middleware/key_auth.go +++ b/middleware/key_auth.go @@ -4,12 +4,18 @@ package middleware import ( + "cmp" "errors" - "github.com/labstack/echo/v4" + "fmt" "net/http" + + "github.com/labstack/echo/v5" ) // KeyAuthConfig defines the config for KeyAuth middleware. +// +// SECURITY: The Validator function is responsible for securely comparing API keys. +// See KeyAuthValidator documentation for guidance on preventing timing attacks. type KeyAuthConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper @@ -30,16 +36,22 @@ type KeyAuthConfig struct { // - "header:Authorization,header:X-Api-Key" KeyLookup string - // AuthScheme to be used in the Authorization header. - // Optional. Default value "Bearer". - AuthScheme string + // AllowedCheckLimit set how many KeyLookup values are allowed to be checked. This is + // useful environments like corporate test environments with application proxies restricting + // access to environment with their own auth scheme. + AllowedCheckLimit uint // Validator is a function to validate key. // Required. Validator KeyAuthValidator - // ErrorHandler defines a function which is executed for an invalid key. + // ErrorHandler defines a function which is executed when all lookups have been done and none of them passed Validator + // function. ErrorHandler is executed with last missing (ErrExtractionValueMissing) or an invalid key. // It may be used to define a custom error. + // + // Note: when error handler swallows the error (returns nil) middleware continues handler chain execution towards handler. + // This is useful in cases when portion of your site/api is publicly accessible and has extra features for authorized users + // In that case you can use ErrorHandler to set default public auth value to request and continue with handler chain. ErrorHandler KeyAuthErrorHandler // ContinueOnIgnoredError allows the next middleware/handler to be called when ErrorHandler decides to @@ -51,31 +63,55 @@ type KeyAuthConfig struct { } // KeyAuthValidator defines a function to validate KeyAuth credentials. -type KeyAuthValidator func(auth string, c echo.Context) (bool, error) +// +// SECURITY WARNING: To prevent timing attacks that could allow attackers to enumerate +// valid API keys, validator implementations MUST use constant-time comparison. +// Use crypto/subtle.ConstantTimeCompare instead of standard string equality (==) +// or switch statements. +// +// Example of SECURE implementation: +// +// import "crypto/subtle" +// +// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) { +// // Fetch valid keys from database/config +// validKeys := []string{"key1", "key2", "key3"} +// +// for _, validKey := range validKeys { +// // Use constant-time comparison to prevent timing attacks +// if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { +// return true, nil +// } +// } +// return false, nil +// } +// +// Example of INSECURE implementation (DO NOT USE): +// +// // VULNERABLE TO TIMING ATTACKS - DO NOT USE +// validator := func(c *echo.Context, key string, source ExtractorSource) (bool, error) { +// switch key { // Timing leak! +// case "valid-key": +// return true, nil +// default: +// return false, nil +// } +// } +type KeyAuthValidator func(c *echo.Context, key string, source ExtractorSource) (bool, error) // KeyAuthErrorHandler defines a function which is executed for an invalid key. -type KeyAuthErrorHandler func(err error, c echo.Context) error +type KeyAuthErrorHandler func(c *echo.Context, err error) error -// ErrKeyAuthMissing is error type when KeyAuth middleware is unable to extract value from lookups -type ErrKeyAuthMissing struct { - Err error -} +// ErrKeyMissing denotes an error raised when key value could not be extracted from request +var ErrKeyMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing key") + +// ErrInvalidKey denotes an error raised when key value is invalid by validator +var ErrInvalidKey = echo.NewHTTPError(http.StatusUnauthorized, "invalid key") // DefaultKeyAuthConfig is the default KeyAuth middleware config. var DefaultKeyAuthConfig = KeyAuthConfig{ - Skipper: DefaultSkipper, - KeyLookup: "header:" + echo.HeaderAuthorization, - AuthScheme: "Bearer", -} - -// Error returns errors text -func (e *ErrKeyAuthMissing) Error() string { - return e.Err.Error() -} - -// Unwrap unwraps error -func (e *ErrKeyAuthMissing) Unwrap() error { - return e.Err + Skipper: DefaultSkipper, + KeyLookup: "header:" + echo.HeaderAuthorization + ":Bearer ", } // KeyAuth returns an KeyAuth middleware. @@ -89,31 +125,39 @@ func KeyAuth(fn KeyAuthValidator) echo.MiddlewareFunc { return KeyAuthWithConfig(c) } -// KeyAuthWithConfig returns an KeyAuth middleware with config. -// See `KeyAuth()`. +// KeyAuthWithConfig returns an KeyAuth middleware or panics if configuration is invalid. +// +// For first valid key it calls the next handler. +// For invalid key, it sends "401 - Unauthorized" response. +// For missing key, it sends "400 - Bad Request" response. func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts KeyAuthConfig to middleware or returns an error for invalid configuration +func (config KeyAuthConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultKeyAuthConfig.Skipper } - // Defaults - if config.AuthScheme == "" { - config.AuthScheme = DefaultKeyAuthConfig.AuthScheme - } if config.KeyLookup == "" { config.KeyLookup = DefaultKeyAuthConfig.KeyLookup } if config.Validator == nil { - panic("echo: key-auth middleware requires a validator function") + return nil, errors.New("echo key-auth middleware requires a validator function") } - extractors, cErr := createExtractors(config.KeyLookup, config.AuthScheme) + limit := cmp.Or(config.AllowedCheckLimit, 1) + + extractors, cErr := createExtractors(config.KeyLookup, limit) if cErr != nil { - panic(cErr) + return nil, fmt.Errorf("echo key-auth middleware could not create key extractor: %w", cErr) + } + if len(extractors) == 0 { + return nil, errors.New("echo key-auth middleware could not create extractors from KeyLookup string") } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -121,59 +165,41 @@ func KeyAuthWithConfig(config KeyAuthConfig) echo.MiddlewareFunc { var lastExtractorErr error var lastValidatorErr error for _, extractor := range extractors { - keys, err := extractor(c) - if err != nil { - lastExtractorErr = err + keys, source, extrErr := extractor(c) + if extrErr != nil { + lastExtractorErr = extrErr continue } for _, key := range keys { - valid, err := config.Validator(key, c) + valid, err := config.Validator(c, key, source) if err != nil { lastValidatorErr = err continue } - if valid { - return next(c) + if !valid { + lastValidatorErr = ErrInvalidKey + continue } - lastValidatorErr = errors.New("invalid key") + return next(c) } } - // we are here only when we did not successfully extract and validate any of keys + // prioritize validator errors over extracting errors err := lastValidatorErr - if err == nil { // prioritize validator errors over extracting errors - // ugly part to preserve backwards compatible errors. someone could rely on them - if lastExtractorErr == errQueryExtractorValueMissing { - err = errors.New("missing key in the query string") - } else if lastExtractorErr == errCookieExtractorValueMissing { - err = errors.New("missing key in cookies") - } else if lastExtractorErr == errFormExtractorValueMissing { - err = errors.New("missing key in the form") - } else if lastExtractorErr == errHeaderExtractorValueMissing { - err = errors.New("missing key in request header") - } else if lastExtractorErr == errHeaderExtractorValueInvalid { - err = errors.New("invalid key in the request header") - } else { - err = lastExtractorErr - } - err = &ErrKeyAuthMissing{Err: err} + if err == nil { + err = lastExtractorErr } - if config.ErrorHandler != nil { - tmpErr := config.ErrorHandler(err, c) + tmpErr := config.ErrorHandler(c, err) if config.ContinueOnIgnoredError && tmpErr == nil { return next(c) } return tmpErr } - if lastValidatorErr != nil { // prioritize validator errors over extracting errors - return &echo.HTTPError{ - Code: http.StatusUnauthorized, - Message: "Unauthorized", - Internal: lastValidatorErr, - } + if lastValidatorErr == nil { + return ErrKeyMissing.Wrap(err) } - return echo.NewHTTPError(http.StatusBadRequest, err.Error()) + return echo.ErrUnauthorized.Wrap(err) } - } + }, nil } diff --git a/middleware/key_auth_test.go b/middleware/key_auth_test.go index 447f0bee8..49a917ed3 100644 --- a/middleware/key_auth_test.go +++ b/middleware/key_auth_test.go @@ -4,30 +4,34 @@ package middleware import ( + "crypto/subtle" "errors" "net/http" "net/http/httptest" "strings" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) -func testKeyValidator(key string, c echo.Context) (bool, error) { - switch key { - case "valid-key": +func testKeyValidator(c *echo.Context, key string, source ExtractorSource) (bool, error) { + // Use constant-time comparison to prevent timing attacks + if subtle.ConstantTimeCompare([]byte(key), []byte("valid-key")) == 1 { return true, nil - case "error-key": + } + + // Special case for testing error handling + if key == "error-key" { // Error path doesn't need constant-time return false, errors.New("some user defined error") - default: - return false, nil } + + return false, nil } func TestKeyAuth(t *testing.T) { handlerCalled := false - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { handlerCalled = true return c.String(http.StatusOK, "test") } @@ -67,7 +71,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { - conf.Skipper = func(context echo.Context) bool { + conf.Skipper = func(context *echo.Context) bool { return true } }, @@ -79,7 +83,7 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer invalid-key") }, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized, internal=invalid key", + expectError: "code=401, message=Unauthorized, err=code=401, message=invalid key", }, { name: "nok, defaults, invalid scheme in header", @@ -87,24 +91,13 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bear valid-key") }, expectHandlerCalled: false, - expectError: "code=400, message=invalid key in the request header", + expectError: "code=401, message=missing key, err=invalid value in request header", }, { name: "nok, defaults, missing header", givenRequest: func(req *http.Request) {}, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", - }, - { - name: "ok, custom key lookup from multiple places, query and header", - givenRequest: func(req *http.Request) { - req.URL.RawQuery = "key=invalid-key" - req.Header.Set("API-Key", "valid-key") - }, - whenConfig: func(conf *KeyAuthConfig) { - conf.KeyLookup = "query:key,header:API-Key" - }, - expectHandlerCalled: true, + expectError: "code=401, message=missing key, err=missing value in request header", }, { name: "ok, custom key lookup, header", @@ -124,7 +117,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "header:API-Key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in request header", + expectError: "code=401, message=missing key, err=missing value in request header", }, { name: "ok, custom key lookup, query", @@ -144,7 +137,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "query:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the query string", + expectError: "code=401, message=missing key, err=missing value in the query string", }, { name: "ok, custom key lookup, form", @@ -169,7 +162,7 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "form:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in the form", + expectError: "code=401, message=missing key, err=missing value in the form", }, { name: "ok, custom key lookup, cookie", @@ -193,20 +186,18 @@ func TestKeyAuthWithConfig(t *testing.T) { conf.KeyLookup = "cookie:key" }, expectHandlerCalled: false, - expectError: "code=400, message=missing key in cookies", + expectError: "code=401, message=missing key, err=missing value in cookies", }, { name: "nok, custom errorHandler, error from extractor", whenConfig: func(conf *KeyAuthConfig) { conf.KeyLookup = "header:token" - conf.ErrorHandler = func(err error, context echo.Context) error { - httpError := echo.NewHTTPError(http.StatusTeapot, "custom") - httpError.Internal = err - return httpError + conf.ErrorHandler = func(c *echo.Context, err error) error { + return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err) } }, expectHandlerCalled: false, - expectError: "code=418, message=custom, internal=missing key in request header", + expectError: "code=418, message=custom, err=missing value in request header", }, { name: "nok, custom errorHandler, error from validator", @@ -214,14 +205,12 @@ func TestKeyAuthWithConfig(t *testing.T) { req.Header.Set(echo.HeaderAuthorization, "Bearer error-key") }, whenConfig: func(conf *KeyAuthConfig) { - conf.ErrorHandler = func(err error, context echo.Context) error { - httpError := echo.NewHTTPError(http.StatusTeapot, "custom") - httpError.Internal = err - return httpError + conf.ErrorHandler = func(c *echo.Context, err error) error { + return echo.NewHTTPError(http.StatusTeapot, "custom").Wrap(err) } }, expectHandlerCalled: false, - expectError: "code=418, message=custom, internal=some user defined error", + expectError: "code=418, message=custom, err=some user defined error", }, { name: "nok, defaults, error from validator", @@ -230,14 +219,33 @@ func TestKeyAuthWithConfig(t *testing.T) { }, whenConfig: func(conf *KeyAuthConfig) {}, expectHandlerCalled: false, - expectError: "code=401, message=Unauthorized, internal=some user defined error", + expectError: "code=401, message=Unauthorized, err=some user defined error", + }, + { + name: "ok, custom validator checks source", + givenRequest: func(req *http.Request) { + q := req.URL.Query() + q.Add("key", "valid-key") + req.URL.RawQuery = q.Encode() + }, + whenConfig: func(conf *KeyAuthConfig) { + conf.KeyLookup = "query:key" + conf.Validator = func(c *echo.Context, key string, source ExtractorSource) (bool, error) { + if source == ExtractorSourceQuery { + return true, nil + } + return false, errors.New("invalid source") + } + + }, + expectHandlerCalled: true, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { handlerCalled := false - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { handlerCalled = true return c.String(http.StatusOK, "test") } @@ -272,108 +280,96 @@ func TestKeyAuthWithConfig(t *testing.T) { } } -func TestKeyAuthWithConfig_panicsOnInvalidLookup(t *testing.T) { - assert.PanicsWithError( - t, - "extractor source for lookup could not be split into needed parts: a", - func() { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - KeyLookup: "a", - })(handler) - }, - ) -} - -func TestKeyAuthWithConfig_panicsOnEmptyValidator(t *testing.T) { - assert.PanicsWithValue( - t, - "echo: key-auth middleware requires a validator function", - func() { - handler := func(c echo.Context) error { - return c.String(http.StatusOK, "test") - } - KeyAuthWithConfig(KeyAuthConfig{ - Validator: nil, - })(handler) - }, - ) -} - -func TestKeyAuthWithConfig_ContinueOnIgnoredError(t *testing.T) { +func TestKeyAuthWithConfig_errors(t *testing.T) { var testCases = []struct { - name string - whenContinueOnIgnoredError bool - givenKey string - expectStatus int - expectBody string + name string + whenConfig KeyAuthConfig + expectError string }{ { - name: "no error handler is called", - whenContinueOnIgnoredError: true, - givenKey: "valid-key", - expectStatus: http.StatusTeapot, - expectBody: "", + name: "ok, no error", + whenConfig: KeyAuthConfig{ + Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, }, { - name: "ContinueOnIgnoredError is false and error handler is called for missing token", - whenContinueOnIgnoredError: false, - givenKey: "", - // empty response with 200. This emulates previous behaviour when error handler swallowed the error - expectStatus: http.StatusOK, - expectBody: "", + name: "ok, missing validator func", + whenConfig: KeyAuthConfig{ + Validator: nil, + }, + expectError: "echo key-auth middleware requires a validator function", }, { - name: "error handler is called for missing token", - whenContinueOnIgnoredError: true, - givenKey: "", - expectStatus: http.StatusTeapot, - expectBody: "public-auth", + name: "ok, extractor source can not be split", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope", + Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create key extractor: extractor source for lookup could not be split into needed parts: nope", }, { - name: "error handler is called for invalid token", - whenContinueOnIgnoredError: true, - givenKey: "x.x.x", - expectStatus: http.StatusUnauthorized, - expectBody: "{\"message\":\"Unauthorized\"}\n", + name: "ok, no extractors", + whenConfig: KeyAuthConfig{ + KeyLookup: "nope:nope", + Validator: func(c *echo.Context, key string, source ExtractorSource) (bool, error) { + return false, nil + }, + }, + expectError: "echo key-auth middleware could not create extractors from KeyLookup string", }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e := echo.New() + mw, err := tc.whenConfig.ToMiddleware() + if tc.expectError != "" { + assert.Nil(t, mw) + assert.EqualError(t, err, tc.expectError) + } else { + assert.NotNil(t, mw) + assert.NoError(t, err) + } + }) + } +} - e.GET("/", func(c echo.Context) error { - testValue, _ := c.Get("test").(string) - return c.String(http.StatusTeapot, testValue) - }) +func TestMustKeyAuthWithConfig_panic(t *testing.T) { + assert.Panics(t, func() { + KeyAuthWithConfig(KeyAuthConfig{}) + }) +} - e.Use(KeyAuthWithConfig(KeyAuthConfig{ - Validator: testKeyValidator, - ErrorHandler: func(err error, c echo.Context) error { - if _, ok := err.(*ErrKeyAuthMissing); ok { - c.Set("test", "public-auth") - return nil - } - return echo.ErrUnauthorized - }, - KeyLookup: "header:X-API-Key", - ContinueOnIgnoredError: tc.whenContinueOnIgnoredError, - })) +func TestKeyAuth_errorHandlerSwallowsError(t *testing.T) { + handlerCalled := false + var authValue string + handler := func(c *echo.Context) error { + handlerCalled = true + authValue = c.Get("auth").(string) + return c.String(http.StatusOK, "test") + } + middlewareChain := KeyAuthWithConfig(KeyAuthConfig{ + Validator: testKeyValidator, + ErrorHandler: func(c *echo.Context, err error) error { + // could check error to decide if we can swallow the error + c.Set("auth", "public") + return nil + }, + ContinueOnIgnoredError: true, + })(handler) - req := httptest.NewRequest(http.MethodGet, "/", nil) - if tc.givenKey != "" { - req.Header.Set("X-API-Key", tc.givenKey) - } - res := httptest.NewRecorder() + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // no auth header this time + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) - e.ServeHTTP(res, req) + err := middlewareChain(c) - assert.Equal(t, tc.expectStatus, res.Code) - assert.Equal(t, tc.expectBody, res.Body.String()) - }) - } + assert.NoError(t, err) + assert.True(t, handlerCalled) + assert.Equal(t, "public", authValue) } diff --git a/middleware/logger.go b/middleware/logger.go deleted file mode 100644 index 59020955b..000000000 --- a/middleware/logger.go +++ /dev/null @@ -1,420 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package middleware - -import ( - "bytes" - "io" - "strconv" - "strings" - "sync" - "time" - - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/color" - "github.com/valyala/fasttemplate" -) - -// LoggerConfig defines the config for Logger middleware. -// -// # Configuration Examples -// -// ## Basic Usage with Default Settings -// -// e.Use(middleware.Logger()) -// -// This uses the default JSON format that logs all common request/response details. -// -// ## Custom Simple Format -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Format: "${time_rfc3339_nano} ${status} ${method} ${uri} ${latency_human}\n", -// })) -// -// ## JSON Format with Custom Fields -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Format: `{"timestamp":"${time_rfc3339_nano}","level":"info","remote_ip":"${remote_ip}",` + -// `"method":"${method}","uri":"${uri}","status":${status},"latency":"${latency_human}",` + -// `"user_agent":"${user_agent}","error":"${error}"}` + "\n", -// })) -// -// ## Custom Time Format -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Format: "${time_custom} ${method} ${uri} ${status}\n", -// CustomTimeFormat: "2006-01-02 15:04:05", -// })) -// -// ## Logging Headers and Parameters -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Format: `{"time":"${time_rfc3339_nano}","method":"${method}","uri":"${uri}",` + -// `"status":${status},"auth":"${header:Authorization}","user":"${query:user}",` + -// `"form_data":"${form:action}","session":"${cookie:session_id}"}` + "\n", -// })) -// -// ## Custom Output (File Logging) -// -// file, err := os.OpenFile("app.log", os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) -// if err != nil { -// log.Fatal(err) -// } -// defer file.Close() -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Output: file, -// })) -// -// ## Custom Tag Function -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Format: `{"time":"${time_rfc3339_nano}","user_id":"${custom}","method":"${method}"}` + "\n", -// CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { -// userID := getUserIDFromContext(c) // Your custom logic -// return buf.WriteString(strconv.Itoa(userID)) -// }, -// })) -// -// ## Conditional Logging (Skip Certain Requests) -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Skipper: func(c echo.Context) bool { -// // Skip logging for health check endpoints -// return c.Request().URL.Path == "/health" || c.Request().URL.Path == "/metrics" -// }, -// })) -// -// ## Integration with External Logging Service -// -// logBuffer := &SyncBuffer{} // Thread-safe buffer for external service -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Format: `{"timestamp":"${time_rfc3339_nano}","service":"my-api","level":"info",` + -// `"method":"${method}","uri":"${uri}","status":${status},"latency_ms":${latency},` + -// `"remote_ip":"${remote_ip}","user_agent":"${user_agent}","error":"${error}"}` + "\n", -// Output: logBuffer, -// })) -// -// # Available Tags -// -// ## Time Tags -// - time_unix: Unix timestamp (seconds) -// - time_unix_milli: Unix timestamp (milliseconds) -// - time_unix_micro: Unix timestamp (microseconds) -// - time_unix_nano: Unix timestamp (nanoseconds) -// - time_rfc3339: RFC3339 format (2006-01-02T15:04:05Z07:00) -// - time_rfc3339_nano: RFC3339 with nanoseconds -// - time_custom: Uses CustomTimeFormat field -// -// ## Request Information -// - id: Request ID from X-Request-ID header -// - remote_ip: Client IP address (respects proxy headers) -// - uri: Full request URI with query parameters -// - host: Host header value -// - method: HTTP method (GET, POST, etc.) -// - path: URL path without query parameters -// - route: Echo route pattern (e.g., /users/:id) -// - protocol: HTTP protocol version -// - referer: Referer header value -// - user_agent: User-Agent header value -// -// ## Response Information -// - status: HTTP status code -// - error: Error message if request failed -// - latency: Request processing time in nanoseconds -// - latency_human: Human-readable processing time -// - bytes_in: Request body size in bytes -// - bytes_out: Response body size in bytes -// -// ## Dynamic Tags -// - header:: Value of specific header (e.g., header:Authorization) -// - query:: Value of specific query parameter (e.g., query:user_id) -// - form:: Value of specific form field (e.g., form:username) -// - cookie:: Value of specific cookie (e.g., cookie:session_id) -// - custom: Output from CustomTagFunc -// -// # Troubleshooting -// -// ## Common Issues -// -// 1. **Missing logs**: Check if Skipper function is filtering out requests -// 2. **Invalid JSON**: Ensure CustomTagFunc outputs valid JSON content -// 3. **Performance issues**: Consider using a buffered writer for high-traffic applications -// 4. **File permission errors**: Ensure write permissions when logging to files -// -// ## Performance Tips -// -// - Use time_unix formats for better performance than time_rfc3339 -// - Minimize the number of dynamic tags (header:, query:, form:, cookie:) -// - Use Skipper to exclude high-frequency, low-value requests (health checks, etc.) -// - Consider async logging for very high-traffic applications -type LoggerConfig struct { - // Skipper defines a function to skip middleware. - // Use this to exclude certain requests from logging (e.g., health checks). - // - // Example: - // Skipper: func(c echo.Context) bool { - // return c.Request().URL.Path == "/health" - // }, - Skipper Skipper - - // Format defines the logging format using template tags. - // Tags are enclosed in ${} and replaced with actual values. - // See the detailed tag documentation above for all available options. - // - // Default: JSON format with common fields - // Example: "${time_rfc3339_nano} ${status} ${method} ${uri} ${latency_human}\n" - Format string `yaml:"format"` - - // CustomTimeFormat specifies the time format used by ${time_custom} tag. - // Uses Go's reference time: Mon Jan 2 15:04:05 MST 2006 - // - // Default: "2006-01-02 15:04:05.00000" - // Example: "2006-01-02 15:04:05" or "15:04:05.000" - CustomTimeFormat string `yaml:"custom_time_format"` - - // CustomTagFunc is called when ${custom} tag is encountered. - // Use this to add application-specific information to logs. - // The function should write valid content for your log format. - // - // Example: - // CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { - // userID := getUserFromContext(c) - // return buf.WriteString(`"user_id":"` + userID + `"`) - // }, - CustomTagFunc func(c echo.Context, buf *bytes.Buffer) (int, error) - - // Output specifies where logs are written. - // Can be any io.Writer: files, buffers, network connections, etc. - // - // Default: os.Stdout - // Example: Custom file, syslog, or external logging service - Output io.Writer - - template *fasttemplate.Template - colorer *color.Color - pool *sync.Pool - timeNow func() time.Time -} - -// DefaultLoggerConfig is the default Logger middleware config. -var DefaultLoggerConfig = LoggerConfig{ - Skipper: DefaultSkipper, - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}",` + - `"host":"${host}","method":"${method}","uri":"${uri}","user_agent":"${user_agent}",` + - `"status":${status},"error":"${error}","latency":${latency},"latency_human":"${latency_human}"` + - `,"bytes_in":${bytes_in},"bytes_out":${bytes_out}}` + "\n", - CustomTimeFormat: "2006-01-02 15:04:05.00000", - colorer: color.New(), - timeNow: time.Now, -} - -// Logger returns a middleware that logs HTTP requests using the default configuration. -// -// The default format logs requests as JSON with the following fields: -// - time: RFC3339 nano timestamp -// - id: Request ID from X-Request-ID header -// - remote_ip: Client IP address -// - host: Host header -// - method: HTTP method -// - uri: Request URI -// - user_agent: User-Agent header -// - status: HTTP status code -// - error: Error message (if any) -// - latency: Processing time in nanoseconds -// - latency_human: Human-readable processing time -// - bytes_in: Request body size -// - bytes_out: Response body size -// -// Example output: -// -// {"time":"2023-01-15T10:30:45.123456789Z","id":"","remote_ip":"127.0.0.1", -// "host":"localhost:8080","method":"GET","uri":"/users/123","user_agent":"curl/7.81.0", -// "status":200,"error":"","latency":1234567,"latency_human":"1.234567ms", -// "bytes_in":0,"bytes_out":42} -// -// For custom configurations, use LoggerWithConfig instead. -// -// Deprecated: please use middleware.RequestLogger or middleware.RequestLoggerWithConfig instead. -func Logger() echo.MiddlewareFunc { - return LoggerWithConfig(DefaultLoggerConfig) -} - -// LoggerWithConfig returns a Logger middleware with custom configuration. -// -// This function allows you to customize all aspects of request logging including: -// - Log format and fields -// - Output destination -// - Time formatting -// - Custom tags and logic -// - Request filtering -// -// See LoggerConfig documentation for detailed configuration examples and options. -// -// Example: -// -// e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ -// Format: "${time_rfc3339} ${status} ${method} ${uri} ${latency_human}\n", -// Output: customLogWriter, -// Skipper: func(c echo.Context) bool { -// return c.Request().URL.Path == "/health" -// }, -// })) -// -// Deprecated: please use middleware.RequestLoggerWithConfig instead. -func LoggerWithConfig(config LoggerConfig) echo.MiddlewareFunc { - // Defaults - if config.Skipper == nil { - config.Skipper = DefaultLoggerConfig.Skipper - } - if config.Format == "" { - config.Format = DefaultLoggerConfig.Format - } - writeString := func(buf *bytes.Buffer, in string) (int, error) { return buf.WriteString(in) } - if config.Format[0] == '{' { // format looks like JSON, so we need to escape invalid characters - writeString = writeJSONSafeString - } - - if config.Output == nil { - config.Output = DefaultLoggerConfig.Output - } - timeNow := DefaultLoggerConfig.timeNow - if config.timeNow != nil { - timeNow = config.timeNow - } - - config.template = fasttemplate.New(config.Format, "${", "}") - config.colorer = color.New() - config.colorer.SetOutput(config.Output) - config.pool = &sync.Pool{ - New: func() interface{} { - return bytes.NewBuffer(make([]byte, 256)) - }, - } - - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { - if config.Skipper(c) { - return next(c) - } - - req := c.Request() - res := c.Response() - start := time.Now() - if err = next(c); err != nil { - c.Error(err) - } - stop := time.Now() - buf := config.pool.Get().(*bytes.Buffer) - buf.Reset() - defer config.pool.Put(buf) - - if _, err = config.template.ExecuteFunc(buf, func(w io.Writer, tag string) (int, error) { - switch tag { - case "custom": - if config.CustomTagFunc == nil { - return 0, nil - } - return config.CustomTagFunc(c, buf) - case "time_unix": - return buf.WriteString(strconv.FormatInt(timeNow().Unix(), 10)) - case "time_unix_milli": - return buf.WriteString(strconv.FormatInt(timeNow().UnixMilli(), 10)) - case "time_unix_micro": - return buf.WriteString(strconv.FormatInt(timeNow().UnixMicro(), 10)) - case "time_unix_nano": - return buf.WriteString(strconv.FormatInt(timeNow().UnixNano(), 10)) - case "time_rfc3339": - return buf.WriteString(timeNow().Format(time.RFC3339)) - case "time_rfc3339_nano": - return buf.WriteString(timeNow().Format(time.RFC3339Nano)) - case "time_custom": - return buf.WriteString(timeNow().Format(config.CustomTimeFormat)) - case "id": - id := req.Header.Get(echo.HeaderXRequestID) - if id == "" { - id = res.Header().Get(echo.HeaderXRequestID) - } - return writeString(buf, id) - case "remote_ip": - return writeString(buf, c.RealIP()) - case "host": - return writeString(buf, req.Host) - case "uri": - return writeString(buf, req.RequestURI) - case "method": - return writeString(buf, req.Method) - case "path": - p := req.URL.Path - if p == "" { - p = "/" - } - return writeString(buf, p) - case "route": - return writeString(buf, c.Path()) - case "protocol": - return writeString(buf, req.Proto) - case "referer": - return writeString(buf, req.Referer()) - case "user_agent": - return writeString(buf, req.UserAgent()) - case "status": - n := res.Status - s := config.colorer.Green(n) - switch { - case n >= 500: - s = config.colorer.Red(n) - case n >= 400: - s = config.colorer.Yellow(n) - case n >= 300: - s = config.colorer.Cyan(n) - } - return buf.WriteString(s) - case "error": - if err != nil { - return writeJSONSafeString(buf, err.Error()) - } - case "latency": - l := stop.Sub(start) - return buf.WriteString(strconv.FormatInt(int64(l), 10)) - case "latency_human": - return buf.WriteString(stop.Sub(start).String()) - case "bytes_in": - cl := req.Header.Get(echo.HeaderContentLength) - if cl == "" { - cl = "0" - } - return writeString(buf, cl) - case "bytes_out": - return buf.WriteString(strconv.FormatInt(res.Size, 10)) - default: - switch { - case strings.HasPrefix(tag, "header:"): - return writeString(buf, c.Request().Header.Get(tag[7:])) - case strings.HasPrefix(tag, "query:"): - return writeString(buf, c.QueryParam(tag[6:])) - case strings.HasPrefix(tag, "form:"): - return writeString(buf, c.FormValue(tag[5:])) - case strings.HasPrefix(tag, "cookie:"): - cookie, err := c.Cookie(tag[7:]) - if err == nil { - return buf.Write([]byte(cookie.Value)) - } - } - } - return 0, nil - }); err != nil { - return - } - - if config.Output == nil { - _, err = c.Logger().Output().Write(buf.Bytes()) - return - } - _, err = config.Output.Write(buf.Bytes()) - return - } - } -} diff --git a/middleware/logger_strings.go b/middleware/logger_strings.go deleted file mode 100644 index 8476cb046..000000000 --- a/middleware/logger_strings.go +++ /dev/null @@ -1,242 +0,0 @@ -// SPDX-License-Identifier: BSD-3-Clause -// SPDX-FileCopyrightText: Copyright 2010 The Go Authors -// -// Copyright 2010 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -// -// -// Go LICENSE https://raw.githubusercontent.com/golang/go/36bca3166e18db52687a4d91ead3f98ffe6d00b8/LICENSE -/** -Copyright 2009 The Go Authors. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - * Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following disclaimer -in the documentation and/or other materials provided with the -distribution. - * Neither the name of Google LLC nor the names of its -contributors may be used to endorse or promote products derived from -this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -*/ - -package middleware - -import ( - "bytes" - "unicode/utf8" -) - -// This function is modified copy from Go standard library encoding/json/encode.go `appendString` function -// Source: https://github.com/golang/go/blob/36bca3166e18db52687a4d91ead3f98ffe6d00b8/src/encoding/json/encode.go#L999 -func writeJSONSafeString(buf *bytes.Buffer, src string) (int, error) { - const hex = "0123456789abcdef" - - written := 0 - start := 0 - for i := 0; i < len(src); { - if b := src[i]; b < utf8.RuneSelf { - if safeSet[b] { - i++ - continue - } - - n, err := buf.Write([]byte(src[start:i])) - written += n - if err != nil { - return written, err - } - switch b { - case '\\', '"': - n, err := buf.Write([]byte{'\\', b}) - written += n - if err != nil { - return written, err - } - case '\b': - n, err := buf.Write([]byte{'\\', 'b'}) - written += n - if err != nil { - return n, err - } - case '\f': - n, err := buf.Write([]byte{'\\', 'f'}) - written += n - if err != nil { - return written, err - } - case '\n': - n, err := buf.Write([]byte{'\\', 'n'}) - written += n - if err != nil { - return written, err - } - case '\r': - n, err := buf.Write([]byte{'\\', 'r'}) - written += n - if err != nil { - return written, err - } - case '\t': - n, err := buf.Write([]byte{'\\', 't'}) - written += n - if err != nil { - return written, err - } - default: - // This encodes bytes < 0x20 except for \b, \f, \n, \r and \t. - n, err := buf.Write([]byte{'\\', 'u', '0', '0', hex[b>>4], hex[b&0xF]}) - written += n - if err != nil { - return written, err - } - } - i++ - start = i - continue - } - srcN := min(len(src)-i, utf8.UTFMax) - c, size := utf8.DecodeRuneInString(src[i : i+srcN]) - if c == utf8.RuneError && size == 1 { - n, err := buf.Write([]byte(src[start:i])) - written += n - if err != nil { - return written, err - } - n, err = buf.Write([]byte(`\ufffd`)) - written += n - if err != nil { - return written, err - } - i += size - start = i - continue - } - i += size - } - n, err := buf.Write([]byte(src[start:])) - written += n - return written, err -} - -// safeSet holds the value true if the ASCII character with the given array -// position can be represented inside a JSON string without any further -// escaping. -// -// All values are true except for the ASCII control characters (0-31), the -// double quote ("), and the backslash character ("\"). -var safeSet = [utf8.RuneSelf]bool{ - ' ': true, - '!': true, - '"': false, - '#': true, - '$': true, - '%': true, - '&': true, - '\'': true, - '(': true, - ')': true, - '*': true, - '+': true, - ',': true, - '-': true, - '.': true, - '/': true, - '0': true, - '1': true, - '2': true, - '3': true, - '4': true, - '5': true, - '6': true, - '7': true, - '8': true, - '9': true, - ':': true, - ';': true, - '<': true, - '=': true, - '>': true, - '?': true, - '@': true, - 'A': true, - 'B': true, - 'C': true, - 'D': true, - 'E': true, - 'F': true, - 'G': true, - 'H': true, - 'I': true, - 'J': true, - 'K': true, - 'L': true, - 'M': true, - 'N': true, - 'O': true, - 'P': true, - 'Q': true, - 'R': true, - 'S': true, - 'T': true, - 'U': true, - 'V': true, - 'W': true, - 'X': true, - 'Y': true, - 'Z': true, - '[': true, - '\\': false, - ']': true, - '^': true, - '_': true, - '`': true, - 'a': true, - 'b': true, - 'c': true, - 'd': true, - 'e': true, - 'f': true, - 'g': true, - 'h': true, - 'i': true, - 'j': true, - 'k': true, - 'l': true, - 'm': true, - 'n': true, - 'o': true, - 'p': true, - 'q': true, - 'r': true, - 's': true, - 't': true, - 'u': true, - 'v': true, - 'w': true, - 'x': true, - 'y': true, - 'z': true, - '{': true, - '|': true, - '}': true, - '~': true, - '\u007f': true, -} diff --git a/middleware/logger_strings_test.go b/middleware/logger_strings_test.go deleted file mode 100644 index 3d66404c5..000000000 --- a/middleware/logger_strings_test.go +++ /dev/null @@ -1,288 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package middleware - -import ( - "bytes" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestWriteJSONSafeString(t *testing.T) { - testCases := []struct { - name string - whenInput string - expect string - expectN int - }{ - // Basic cases - { - name: "empty string", - whenInput: "", - expect: "", - expectN: 0, - }, - { - name: "simple ASCII without special chars", - whenInput: "hello", - expect: "hello", - expectN: 5, - }, - { - name: "single character", - whenInput: "a", - expect: "a", - expectN: 1, - }, - { - name: "alphanumeric", - whenInput: "Hello123World", - expect: "Hello123World", - expectN: 13, - }, - - // Special character escaping - { - name: "backslash", - whenInput: `path\to\file`, - expect: `path\\to\\file`, - expectN: 14, - }, - { - name: "double quote", - whenInput: `say "hello"`, - expect: `say \"hello\"`, - expectN: 13, - }, - { - name: "backslash and quote combined", - whenInput: `a\b"c`, - expect: `a\\b\"c`, - expectN: 7, - }, - { - name: "single backslash", - whenInput: `\`, - expect: `\\`, - expectN: 2, - }, - { - name: "single quote", - whenInput: `"`, - expect: `\"`, - expectN: 2, - }, - - // Control character escaping - { - name: "backspace", - whenInput: "hello\bworld", - expect: `hello\bworld`, - expectN: 12, - }, - { - name: "form feed", - whenInput: "hello\fworld", - expect: `hello\fworld`, - expectN: 12, - }, - { - name: "newline", - whenInput: "hello\nworld", - expect: `hello\nworld`, - expectN: 12, - }, - { - name: "carriage return", - whenInput: "hello\rworld", - expect: `hello\rworld`, - expectN: 12, - }, - { - name: "tab", - whenInput: "hello\tworld", - expect: `hello\tworld`, - expectN: 12, - }, - { - name: "multiple newlines", - whenInput: "line1\nline2\nline3", - expect: `line1\nline2\nline3`, - expectN: 19, - }, - - // Low control characters (< 0x20) - { - name: "null byte", - whenInput: "hello\x00world", - expect: `hello\u0000world`, - expectN: 16, - }, - { - name: "control character 0x01", - whenInput: "test\x01value", - expect: `test\u0001value`, - expectN: 15, - }, - { - name: "control character 0x0e", - whenInput: "test\x0evalue", - expect: `test\u000evalue`, - expectN: 15, - }, - { - name: "control character 0x1f", - whenInput: "test\x1fvalue", - expect: `test\u001fvalue`, - expectN: 15, - }, - { - name: "multiple control characters", - whenInput: "\x00\x01\x02", - expect: `\u0000\u0001\u0002`, - expectN: 18, - }, - - // UTF-8 handling - { - name: "valid UTF-8 Chinese", - whenInput: "hello 世界", - expect: "hello 世界", - expectN: 12, - }, - { - name: "valid UTF-8 emoji", - whenInput: "party 🎉 time", - expect: "party 🎉 time", - expectN: 15, - }, - { - name: "mixed ASCII and UTF-8", - whenInput: "Hello世界123", - expect: "Hello世界123", - expectN: 14, - }, - { - name: "UTF-8 with special chars", - whenInput: "世界\n\"test\"", - expect: `世界\n\"test\"`, - expectN: 16, - }, - - // Invalid UTF-8 - { - name: "invalid UTF-8 sequence", - whenInput: "hello\xff\xfeworld", - expect: `hello\ufffd\ufffdworld`, - expectN: 22, - }, - { - name: "incomplete UTF-8 sequence", - whenInput: "test\xc3value", - expect: `test\ufffdvalue`, - expectN: 15, - }, - - // Complex mixed cases - { - name: "all common escapes", - whenInput: "tab\there\nquote\"backslash\\", - expect: `tab\there\nquote\"backslash\\`, - expectN: 29, - }, - { - name: "mixed controls and UTF-8", - whenInput: "hello\t世界\ntest\"", - expect: `hello\t世界\ntest\"`, - expectN: 21, - }, - { - name: "all control characters", - whenInput: "\b\f\n\r\t", - expect: `\b\f\n\r\t`, - expectN: 10, - }, - { - name: "control and low ASCII", - whenInput: "a\nb\x00c", - expect: `a\nb\u0000c`, - expectN: 11, - }, - - // Edge cases - { - name: "starts with special char", - whenInput: "\\start", - expect: `\\start`, - expectN: 7, - }, - { - name: "ends with special char", - whenInput: "end\"", - expect: `end\"`, - expectN: 5, - }, - { - name: "consecutive special chars", - whenInput: "\\\\\"\"", - expect: `\\\\\"\"`, - expectN: 8, - }, - { - name: "only special characters", - whenInput: "\"\\\n\t", - expect: `\"\\\n\t`, - expectN: 8, - }, - { - name: "spaces and punctuation", - whenInput: "Hello, World! How are you?", - expect: "Hello, World! How are you?", - expectN: 26, - }, - { - name: "JSON-like string", - whenInput: "{\"key\":\"value\"}", - expect: `{\"key\":\"value\"}`, - expectN: 19, - }, - } - - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - buf := &bytes.Buffer{} - n, err := writeJSONSafeString(buf, tt.whenInput) - - assert.NoError(t, err) - assert.Equal(t, tt.expect, buf.String()) - assert.Equal(t, tt.expectN, n) - }) - } -} - -func BenchmarkWriteJSONSafeString(b *testing.B) { - testCases := []struct { - name string - input string - }{ - {"simple", "hello world"}, - {"with escapes", "tab\there\nquote\"backslash\\"}, - {"utf8", "hello 世界 🎉"}, - {"mixed", "Hello\t世界\ntest\"value\\path"}, - {"long simple", "abcdefghijklmnopqrstuvwxyz0123456789abcdefghijklmnopqrstuvwxyz0123456789"}, - {"long complex", "line1\nline2\tline3\"quote\\slash\x00null世界🎉"}, - } - - for _, tc := range testCases { - b.Run(tc.name, func(b *testing.B) { - buf := &bytes.Buffer{} - b.ResetTimer() - for i := 0; i < b.N; i++ { - buf.Reset() - writeJSONSafeString(buf, tc.input) - } - }) - } -} diff --git a/middleware/logger_test.go b/middleware/logger_test.go deleted file mode 100644 index e4b783db5..000000000 --- a/middleware/logger_test.go +++ /dev/null @@ -1,540 +0,0 @@ -// SPDX-License-Identifier: MIT -// SPDX-FileCopyrightText: © 2015 LabStack LLC and Echo contributors - -package middleware - -import ( - "bytes" - "cmp" - "encoding/json" - "errors" - "net/http" - "net/http/httptest" - "net/url" - "regexp" - "strings" - "testing" - "time" - "unsafe" - - "github.com/labstack/echo/v4" - "github.com/stretchr/testify/assert" -) - -func TestLoggerDefaultMW(t *testing.T) { - var testCases = []struct { - name string - whenHeader map[string]string - whenStatusCode int - whenResponse string - whenError error - expect string - }{ - { - name: "ok, status 200", - whenStatusCode: http.StatusOK, - whenResponse: "test", - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", - }, - { - name: "ok, status 300", - whenStatusCode: http.StatusTemporaryRedirect, - whenResponse: "test", - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":307,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", - }, - { - name: "ok, handler error = status 500", - whenError: errors.New("error"), - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":36}` + "\n", - }, - { - name: "error with invalid UTF-8 sequences", - whenError: errors.New("invalid data: \xFF\xFE"), - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"invalid data: \ufffd\ufffd","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":36}` + "\n", - }, - { - name: "error with JSON special characters (quotes and backslashes)", - whenError: errors.New(`error with "quotes" and \backslash`), - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error with \"quotes\" and \\backslash","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":36}` + "\n", - }, - { - name: "error with control characters (newlines and tabs)", - whenError: errors.New("error\nwith\nnewlines\tand\ttabs"), - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":500,"error":"error\nwith\nnewlines\tand\ttabs","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":36}` + "\n", - }, - { - name: "ok, remote_ip from X-Real-Ip header", - whenHeader: map[string]string{echo.HeaderXRealIP: "127.0.0.1"}, - whenStatusCode: http.StatusOK, - whenResponse: "test", - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"127.0.0.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", - }, - { - name: "ok, remote_ip from X-Forwarded-For header", - whenHeader: map[string]string{echo.HeaderXForwardedFor: "127.0.0.1"}, - whenStatusCode: http.StatusOK, - whenResponse: "test", - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"127.0.0.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - req := httptest.NewRequest(http.MethodGet, "/", nil) - if len(tc.whenHeader) > 0 { - for k, v := range tc.whenHeader { - req.Header.Add(k, v) - } - } - - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - DefaultLoggerConfig.timeNow = func() time.Time { return time.Unix(1588037200, 0).UTC() } - h := Logger()(func(c echo.Context) error { - if tc.whenError != nil { - return tc.whenError - } - return c.String(tc.whenStatusCode, tc.whenResponse) - }) - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - - err := h(c) - assert.NoError(t, err) - - result := buf.String() - // handle everchanging latency numbers - result = regexp.MustCompile(`"latency":\d+,`).ReplaceAllString(result, `"latency":1,`) - result = regexp.MustCompile(`"latency_human":"[^"]+"`).ReplaceAllString(result, `"latency_human":"1µs"`) - - assert.Equal(t, tc.expect, result) - }) - } -} - -func TestLoggerWithLoggerConfig(t *testing.T) { - // to handle everchanging latency numbers - jsonLatency := map[string]*regexp.Regexp{ - `"latency":1,`: regexp.MustCompile(`"latency":\d+,`), - `"latency_human":"1µs"`: regexp.MustCompile(`"latency_human":"[^"]+"`), - } - - form := make(url.Values) - form.Set("csrf", "token") - form.Add("multiple", "1") - form.Add("multiple", "2") - - var testCases = []struct { - name string - givenConfig LoggerConfig - whenURI string - whenMethod string - whenHost string - whenPath string - whenRoute string - whenProto string - whenRequestURI string - whenHeader map[string]string - whenFormValues url.Values - whenStatusCode int - whenResponse string - whenError error - whenReplacers map[string]*regexp.Regexp - expect string - }{ - { - name: "ok, skipper", - givenConfig: LoggerConfig{ - Skipper: func(c echo.Context) bool { return true }, - }, - expect: ``, - }, - { // this is an example how format that does not seem to be JSON is not currently escaped - name: "ok, NON json string is not escaped: method", - givenConfig: LoggerConfig{Format: `method:"${method}"`}, - whenMethod: `","method":":D"`, - expect: `method:"","method":":D""`, - }, - { - name: "ok, json string escape: method", - givenConfig: LoggerConfig{Format: `{"method":"${method}"}`}, - whenMethod: `","method":":D"`, - expect: `{"method":"\",\"method\":\":D\""}`, - }, - { - name: "ok, json string escape: id", - givenConfig: LoggerConfig{Format: `{"id":"${id}"}`}, - whenHeader: map[string]string{echo.HeaderXRequestID: `\"127.0.0.1\"`}, - expect: `{"id":"\\\"127.0.0.1\\\""}`, - }, - { - name: "ok, json string escape: remote_ip", - givenConfig: LoggerConfig{Format: `{"remote_ip":"${remote_ip}"}`}, - whenHeader: map[string]string{echo.HeaderXForwardedFor: `\"127.0.0.1\"`}, - expect: `{"remote_ip":"\\\"127.0.0.1\\\""}`, - }, - { - name: "ok, json string escape: host", - givenConfig: LoggerConfig{Format: `{"host":"${host}"}`}, - whenHost: `\"127.0.0.1\"`, - expect: `{"host":"\\\"127.0.0.1\\\""}`, - }, - { - name: "ok, json string escape: path", - givenConfig: LoggerConfig{Format: `{"path":"${path}"}`}, - whenPath: `\","` + "\n", - expect: `{"path":"\\\",\"\n"}`, - }, - { - name: "ok, json string escape: route", - givenConfig: LoggerConfig{Format: `{"route":"${route}"}`}, - whenRoute: `\","` + "\n", - expect: `{"route":"\\\",\"\n"}`, - }, - { - name: "ok, json string escape: proto", - givenConfig: LoggerConfig{Format: `{"protocol":"${protocol}"}`}, - whenProto: `\","` + "\n", - expect: `{"protocol":"\\\",\"\n"}`, - }, - { - name: "ok, json string escape: referer", - givenConfig: LoggerConfig{Format: `{"referer":"${referer}"}`}, - whenHeader: map[string]string{"Referer": `\","` + "\n"}, - expect: `{"referer":"\\\",\"\n"}`, - }, - { - name: "ok, json string escape: user_agent", - givenConfig: LoggerConfig{Format: `{"user_agent":"${user_agent}"}`}, - whenHeader: map[string]string{"User-Agent": `\","` + "\n"}, - expect: `{"user_agent":"\\\",\"\n"}`, - }, - { - name: "ok, json string escape: bytes_in", - givenConfig: LoggerConfig{Format: `{"bytes_in":"${bytes_in}"}`}, - whenHeader: map[string]string{echo.HeaderContentLength: `\","` + "\n"}, - expect: `{"bytes_in":"\\\",\"\n"}`, - }, - { - name: "ok, json string escape: query param", - givenConfig: LoggerConfig{Format: `{"query":"${query:test}"}`}, - whenURI: `/?test=1","`, - expect: `{"query":"1\",\""}`, - }, - { - name: "ok, json string escape: header", - givenConfig: LoggerConfig{Format: `{"header":"${header:referer}"}`}, - whenHeader: map[string]string{"referer": `\","` + "\n"}, - expect: `{"header":"\\\",\"\n"}`, - }, - { - name: "ok, json string escape: form", - givenConfig: LoggerConfig{Format: `{"csrf":"${form:csrf}"}`}, - whenMethod: http.MethodPost, - whenFormValues: url.Values{"csrf": {`token","`}}, - expect: `{"csrf":"token\",\""}`, - }, - { - name: "nok, json string escape: cookie - will not accept invalid chars", - // net/cookie.go: validCookieValueByte function allows these byte in cookie value - // only `0x20 <= b && b < 0x7f && b != '"' && b != ';' && b != '\\'` - givenConfig: LoggerConfig{Format: `{"cookie":"${cookie:session}"}`}, - whenHeader: map[string]string{"Cookie": `_ga=GA1.2.000000000.0000000000; session=test\n`}, - expect: `{"cookie":""}`, - }, - { - name: "ok, format time_unix", - givenConfig: LoggerConfig{Format: `${time_unix}`}, - whenStatusCode: http.StatusOK, - whenResponse: "test", - expect: `1588037200`, - }, - { - name: "ok, format time_unix_milli", - givenConfig: LoggerConfig{Format: `${time_unix_milli}`}, - whenStatusCode: http.StatusOK, - whenResponse: "test", - expect: `1588037200000`, - }, - { - name: "ok, format time_unix_micro", - givenConfig: LoggerConfig{Format: `${time_unix_micro}`}, - whenStatusCode: http.StatusOK, - whenResponse: "test", - expect: `1588037200000000`, - }, - { - name: "ok, format time_unix_nano", - givenConfig: LoggerConfig{Format: `${time_unix_nano}`}, - whenStatusCode: http.StatusOK, - whenResponse: "test", - expect: `1588037200000000000`, - }, - { - name: "ok, format time_rfc3339", - givenConfig: LoggerConfig{Format: `${time_rfc3339}`}, - whenStatusCode: http.StatusOK, - whenResponse: "test", - expect: `2020-04-28T01:26:40Z`, - }, - { - name: "ok, status 200", - whenStatusCode: http.StatusOK, - whenResponse: "test", - whenReplacers: jsonLatency, - expect: `{"time":"2020-04-28T01:26:40Z","id":"","remote_ip":"192.0.2.1","host":"example.com","method":"GET","uri":"/","user_agent":"","status":200,"error":"","latency":1,"latency_human":"1µs","bytes_in":0,"bytes_out":4}` + "\n", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - e := echo.New() - - req := httptest.NewRequest(http.MethodGet, cmp.Or(tc.whenURI, "/"), nil) - if tc.whenFormValues != nil { - req = httptest.NewRequest(http.MethodGet, cmp.Or(tc.whenURI, "/"), strings.NewReader(tc.whenFormValues.Encode())) - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - } - - for k, v := range tc.whenHeader { - req.Header.Add(k, v) - } - if tc.whenHost != "" { - req.Host = tc.whenHost - } - if tc.whenMethod != "" { - req.Method = tc.whenMethod - } - if tc.whenProto != "" { - req.Proto = tc.whenProto - } - if tc.whenRequestURI != "" { - req.RequestURI = tc.whenRequestURI - } - if tc.whenPath != "" { - req.URL.Path = tc.whenPath - } - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - if tc.whenFormValues != nil { - c.FormValue("to trigger form parsing") - } - if tc.whenRoute != "" { - c.SetPath(tc.whenRoute) - } - - config := tc.givenConfig - if config.timeNow == nil { - config.timeNow = func() time.Time { return time.Unix(1588037200, 0).UTC() } - } - buf := new(bytes.Buffer) - if config.Output == nil { - e.Logger.SetOutput(buf) - } - - h := LoggerWithConfig(config)(func(c echo.Context) error { - if tc.whenError != nil { - return tc.whenError - } - return c.String(cmp.Or(tc.whenStatusCode, http.StatusOK), cmp.Or(tc.whenResponse, "test")) - }) - - err := h(c) - assert.NoError(t, err) - - result := buf.String() - - for replaceTo, replacer := range tc.whenReplacers { - result = replacer.ReplaceAllString(result, replaceTo) - } - - assert.Equal(t, tc.expect, result) - }) - } -} - -func TestLoggerTemplate(t *testing.T) { - buf := new(bytes.Buffer) - - e := echo.New() - e.Use(LoggerWithConfig(LoggerConfig{ - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + - `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "route":"${route}", "referer":"${referer}",` + - `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + - `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", - Output: buf, - })) - - e.GET("/users/:id", func(c echo.Context) error { - return c.String(http.StatusOK, "Header Logged") - }) - - req := httptest.NewRequest(http.MethodGet, "/users/1?username=apagano-param&password=secret", nil) - req.RequestURI = "/" - req.Header.Add(echo.HeaderXRealIP, "127.0.0.1") - req.Header.Add("Referer", "google.com") - req.Header.Add("User-Agent", "echo-tests-agent") - req.Header.Add("X-Custom-Header", "AAA-CUSTOM-VALUE") - req.Header.Add("X-Request-ID", "6ba7b810-9dad-11d1-80b4-00c04fd430c8") - req.Header.Add("Cookie", "_ga=GA1.2.000000000.0000000000; session=ac08034cd216a647fc2eb62f2bcf7b810") - req.Form = url.Values{ - "username": []string{"apagano-form"}, - "password": []string{"secret-form"}, - } - - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - cases := map[string]bool{ - "apagano-param": true, - "apagano-form": true, - "AAA-CUSTOM-VALUE": true, - "BBB-CUSTOM-VALUE": false, - "secret-form": false, - "hexvalue": false, - "GET": true, - "127.0.0.1": true, - "\"path\":\"/users/1\"": true, - "\"route\":\"/users/:id\"": true, - "\"uri\":\"/\"": true, - "\"status\":200": true, - "\"bytes_in\":0": true, - "google.com": true, - "echo-tests-agent": true, - "6ba7b810-9dad-11d1-80b4-00c04fd430c8": true, - "ac08034cd216a647fc2eb62f2bcf7b810": true, - } - - for token, present := range cases { - assert.True(t, strings.Contains(buf.String(), token) == present, "Case: "+token) - } -} - -func TestLoggerCustomTimestamp(t *testing.T) { - buf := new(bytes.Buffer) - customTimeFormat := "2006-01-02 15:04:05.00000" - e := echo.New() - e.Use(LoggerWithConfig(LoggerConfig{ - Format: `{"time":"${time_custom}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + - `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + - `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}",` + - `"us":"${query:username}", "cf":"${form:username}", "session":"${cookie:session}"}` + "\n", - CustomTimeFormat: customTimeFormat, - Output: buf, - })) - - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusOK, "custom time stamp test") - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - var objs map[string]*json.RawMessage - if err := json.Unmarshal(buf.Bytes(), &objs); err != nil { - panic(err) - } - loggedTime := *(*string)(unsafe.Pointer(objs["time"])) - _, err := time.Parse(customTimeFormat, loggedTime) - assert.Error(t, err) -} - -func TestLoggerCustomTagFunc(t *testing.T) { - e := echo.New() - buf := new(bytes.Buffer) - e.Use(LoggerWithConfig(LoggerConfig{ - Format: `{"method":"${method}",${custom}}` + "\n", - CustomTagFunc: func(c echo.Context, buf *bytes.Buffer) (int, error) { - return buf.WriteString(`"tag":"my-value"`) - }, - Output: buf, - })) - - e.GET("/", func(c echo.Context) error { - return c.String(http.StatusOK, "custom time stamp test") - }) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - - assert.Equal(t, `{"method":"GET","tag":"my-value"}`+"\n", buf.String()) -} - -func BenchmarkLoggerWithConfig_withoutMapFields(b *testing.B) { - e := echo.New() - - buf := new(bytes.Buffer) - mw := LoggerWithConfig(LoggerConfig{ - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + - `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + - `"bytes_out":${bytes_out}, "protocol":"${protocol}"}` + "\n", - Output: buf, - })(func(c echo.Context) error { - c.Request().Header.Set(echo.HeaderXRequestID, "123") - c.FormValue("to force parse form") - return c.String(http.StatusTeapot, "OK") - }) - - f := make(url.Values) - f.Set("csrf", "token") - f.Add("multiple", "1") - f.Add("multiple", "2") - req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) - req.Header.Set("Referer", "https://echo.labstack.com/") - req.Header.Set("User-Agent", "curl/7.68.0") - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - mw(c) - buf.Reset() - } -} - -func BenchmarkLoggerWithConfig_withMapFields(b *testing.B) { - e := echo.New() - - buf := new(bytes.Buffer) - mw := LoggerWithConfig(LoggerConfig{ - Format: `{"time":"${time_rfc3339_nano}","id":"${id}","remote_ip":"${remote_ip}","host":"${host}","user_agent":"${user_agent}",` + - `"method":"${method}","uri":"${uri}","status":${status}, "latency":${latency},` + - `"latency_human":"${latency_human}","bytes_in":${bytes_in}, "path":"${path}", "referer":"${referer}",` + - `"bytes_out":${bytes_out},"ch":"${header:X-Custom-Header}", "protocol":"${protocol}"` + - `"us":"${query:username}", "cf":"${form:csrf}", "Referer2":"${header:Referer}"}` + "\n", - Output: buf, - })(func(c echo.Context) error { - c.Request().Header.Set(echo.HeaderXRequestID, "123") - c.FormValue("to force parse form") - return c.String(http.StatusTeapot, "OK") - }) - - f := make(url.Values) - f.Set("csrf", "token") - f.Add("multiple", "1") - f.Add("multiple", "2") - req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", strings.NewReader(f.Encode())) - req.Header.Set("Referer", "https://echo.labstack.com/") - req.Header.Set("User-Agent", "curl/7.68.0") - req.Header.Add(echo.HeaderContentType, echo.MIMEApplicationForm) - - b.ReportAllocs() - b.ResetTimer() - - for i := 0; i < b.N; i++ { - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - mw(c) - buf.Reset() - } -} diff --git a/middleware/method_override.go b/middleware/method_override.go index 3991e1029..25ec1f935 100644 --- a/middleware/method_override.go +++ b/middleware/method_override.go @@ -6,7 +6,7 @@ package middleware import ( "net/http" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // MethodOverrideConfig defines the config for MethodOverride middleware. @@ -20,7 +20,7 @@ type MethodOverrideConfig struct { } // MethodOverrideGetter is a function that gets overridden method from the request -type MethodOverrideGetter func(echo.Context) string +type MethodOverrideGetter func(c *echo.Context) string // DefaultMethodOverrideConfig is the default MethodOverride middleware config. var DefaultMethodOverrideConfig = MethodOverrideConfig{ @@ -37,9 +37,13 @@ func MethodOverride() echo.MiddlewareFunc { return MethodOverrideWithConfig(DefaultMethodOverrideConfig) } -// MethodOverrideWithConfig returns a MethodOverride middleware with config. -// See: `MethodOverride()`. +// MethodOverrideWithConfig returns a Method Override middleware with config or panics on invalid configuration. func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts MethodOverrideConfig to middleware or returns an error for invalid configuration +func (config MethodOverrideConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultMethodOverrideConfig.Skipper @@ -49,7 +53,7 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -63,13 +67,13 @@ func MethodOverrideWithConfig(config MethodOverrideConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } // MethodFromHeader is a `MethodOverrideGetter` that gets overridden method from // the request header. func MethodFromHeader(header string) MethodOverrideGetter { - return func(c echo.Context) string { + return func(c *echo.Context) string { return c.Request().Header.Get(header) } } @@ -77,7 +81,7 @@ func MethodFromHeader(header string) MethodOverrideGetter { // MethodFromForm is a `MethodOverrideGetter` that gets overridden method from the // form parameter. func MethodFromForm(param string) MethodOverrideGetter { - return func(c echo.Context) string { + return func(c *echo.Context) string { return c.FormValue(param) } } @@ -85,7 +89,7 @@ func MethodFromForm(param string) MethodOverrideGetter { // MethodFromQuery is a `MethodOverrideGetter` that gets overridden method from // the query parameter. func MethodFromQuery(param string) MethodOverrideGetter { - return func(c echo.Context) string { + return func(c *echo.Context) string { return c.QueryParam(param) } } diff --git a/middleware/method_override_test.go b/middleware/method_override_test.go index 0000d1d80..525ad10ba 100644 --- a/middleware/method_override_test.go +++ b/middleware/method_override_test.go @@ -9,14 +9,14 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestMethodOverride(t *testing.T) { e := echo.New() m := MethodOverride() - h := func(c echo.Context) error { + h := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } @@ -25,28 +25,68 @@ func TestMethodOverride(t *testing.T) { rec := httptest.NewRecorder() req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) c := e.NewContext(req, rec) - m(h)(c) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_formParam(t *testing.T) { + e := echo.New() + h := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + // Override with form parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromForm("_method")}) - req = httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) - rec = httptest.NewRecorder() + m, err := MethodOverrideConfig{Getter: MethodFromForm("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/", bytes.NewReader([]byte("_method="+http.MethodDelete))) + rec := httptest.NewRecorder() req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) - c = e.NewContext(req, rec) - m(h)(c) + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_queryParam(t *testing.T) { + e := echo.New() + h := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } // Override with query parameter - m = MethodOverrideWithConfig(MethodOverrideConfig{Getter: MethodFromQuery("_method")}) - req = httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - m(h)(c) + m, err := MethodOverrideConfig{Getter: MethodFromQuery("_method")}.ToMiddleware() + assert.NoError(t, err) + req := httptest.NewRequest(http.MethodPost, "/?_method="+http.MethodDelete, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err = m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodDelete, req.Method) +} + +func TestMethodOverride_ignoreGet(t *testing.T) { + e := echo.New() + m := MethodOverride() + h := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } // Ignore `GET` - req = httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(echo.HeaderXHTTPMethodOverride, http.MethodDelete) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := m(h)(c) + assert.NoError(t, err) + assert.Equal(t, http.MethodGet, req.Method) } diff --git a/middleware/middleware.go b/middleware/middleware.go index 164e52b4c..4562d03b5 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -9,15 +9,14 @@ import ( "strconv" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) -// Skipper defines a function to skip middleware. Returning true skips processing -// the middleware. -type Skipper func(c echo.Context) bool +// Skipper defines a function to skip middleware. Returning true skips processing the middleware. +type Skipper func(c *echo.Context) bool // BeforeFunc defines a function which is executed just before the middleware. -type BeforeFunc func(c echo.Context) +type BeforeFunc func(c *echo.Context) func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { groups := pattern.FindAllStringSubmatch(input, -1) @@ -54,7 +53,7 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error return nil } - // Depending on how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. + // Depending how HTTP request is sent RequestURI could contain Scheme://Host/path or be just /path. // We only want to use path part for rewriting and therefore trim prefix if it exists rawURI := req.RequestURI if rawURI != "" && rawURI[0] != '/' { @@ -85,13 +84,11 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error } // DefaultSkipper returns false which processes the middleware. -func DefaultSkipper(echo.Context) bool { +func DefaultSkipper(c *echo.Context) bool { return false } -func toMiddlewareOrPanic(config interface { - ToMiddleware() (echo.MiddlewareFunc, error) -}) echo.MiddlewareFunc { +func toMiddlewareOrPanic(config echo.MiddlewareConfigurator) echo.MiddlewareFunc { mw, err := config.ToMiddleware() if err != nil { panic(err) diff --git a/middleware/middleware_test.go b/middleware/middleware_test.go index 7f3dc3866..28407ed5c 100644 --- a/middleware/middleware_test.go +++ b/middleware/middleware_test.go @@ -102,11 +102,9 @@ type testResponseWriterNoFlushHijack struct { func (w *testResponseWriterNoFlushHijack) WriteHeader(statusCode int) { } - func (w *testResponseWriterNoFlushHijack) Write([]byte) (int, error) { return 0, nil } - func (w *testResponseWriterNoFlushHijack) Header() http.Header { return nil } @@ -118,15 +116,12 @@ type testResponseWriterUnwrapper struct { func (w *testResponseWriterUnwrapper) WriteHeader(statusCode int) { } - func (w *testResponseWriterUnwrapper) Write([]byte) (int, error) { return 0, nil } - func (w *testResponseWriterUnwrapper) Header() http.Header { return nil } - func (w *testResponseWriterUnwrapper) Unwrap() http.ResponseWriter { w.unwrapCalled++ return w.rw diff --git a/middleware/proxy.go b/middleware/proxy.go index f26870077..1996032f7 100644 --- a/middleware/proxy.go +++ b/middleware/proxy.go @@ -6,6 +6,7 @@ package middleware import ( "context" "crypto/tls" + "errors" "fmt" "io" "math/rand" @@ -18,7 +19,7 @@ import ( "sync" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // TODO: Handle TLS proxy @@ -41,14 +42,14 @@ type ProxyConfig struct { // of previous retries is less than RetryCount. If the function returns true, the // request will be retried. The provided error indicates the reason for the request // failure. When the ProxyTarget is unavailable, the error will be an instance of - // echo.HTTPError with a Code of http.StatusBadGateway. In all other cases, the error + // echo.HTTPError with a code of http.StatusBadGateway. In all other cases, the error // will indicate an internal error in the Proxy middleware. When a RetryFilter is not // specified, all requests that fail with http.StatusBadGateway will be retried. A custom // RetryFilter can be provided to only retry specific requests. Note that RetryFilter is // only called when the request to the target fails, or an internal error in the Proxy // middleware has occurred. Successful requests that return a non-200 response code cannot // be retried. - RetryFilter func(c echo.Context, e error) bool + RetryFilter func(c *echo.Context, e error) bool // ErrorHandler defines a function which can be used to return custom errors from // the Proxy middleware. ErrorHandler is only invoked when there has been @@ -57,7 +58,7 @@ type ProxyConfig struct { // when a ProxyTarget returns a non-200 response. In these cases, the response // is already written so errors cannot be modified. ErrorHandler is only // invoked after all retry attempts have been exhausted. - ErrorHandler func(c echo.Context, err error) error + ErrorHandler func(c *echo.Context, err error) error // Rewrite defines URL path rewrite rules. The values captured in asterisk can be // retrieved by index e.g. $1, $2 and so on. @@ -91,20 +92,14 @@ type ProxyConfig struct { type ProxyTarget struct { Name string URL *url.URL - Meta echo.Map + Meta map[string]any } // ProxyBalancer defines an interface to implement a load balancing technique. type ProxyBalancer interface { - AddTarget(*ProxyTarget) bool - RemoveTarget(string) bool - Next(echo.Context) *ProxyTarget -} - -// TargetProvider defines an interface that gives the opportunity for balancer -// to return custom errors when selecting target. -type TargetProvider interface { - NextTarget(echo.Context) (*ProxyTarget, error) + AddTarget(target *ProxyTarget) bool + RemoveTarget(targetName string) bool + Next(c *echo.Context) (*ProxyTarget, error) } type commonBalancer struct { @@ -131,7 +126,7 @@ var DefaultProxyConfig = ProxyConfig{ ContextKey: "target", } -func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { +func proxyRaw(c *echo.Context, t *ProxyTarget, config ProxyConfig) http.Handler { var dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) if transport, ok := config.Transport.(*http.Transport); ok { if transport.TLSClientConfig != nil { @@ -147,12 +142,13 @@ func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - in, _, err := c.Response().Hijack() + in, _, err := http.NewResponseController(w).Hijack() if err != nil { c.Set("_error", fmt.Errorf("proxy raw, hijack error=%w, url=%s", err, t.URL)) return } defer in.Close() + out, err := dialFunc(c.Request().Context(), "tcp", t.URL.Host) if err != nil { c.Set("_error", echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw, dial error=%v, url=%s", err, t.URL))) @@ -192,7 +188,9 @@ func proxyRaw(t *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { func NewRandomBalancer(targets []*ProxyTarget) ProxyBalancer { b := randomBalancer{} b.targets = targets - b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + // G404 (CWE-338): Use of weak random number generator (math/rand or math/rand/v2 instead of crypto/rand) + // this random is used to select next target. I can not think of reason this must be cryptographically safe. If you can - please open PR. + b.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // #nosec G404 return &b } @@ -236,15 +234,15 @@ func (b *commonBalancer) RemoveTarget(name string) bool { // Next randomly returns an upstream target. // // Note: `nil` is returned in case upstream target list is empty. -func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { +func (b *randomBalancer) Next(c *echo.Context) (*ProxyTarget, error) { b.mutex.Lock() defer b.mutex.Unlock() if len(b.targets) == 0 { - return nil + return nil, nil } else if len(b.targets) == 1 { - return b.targets[0] + return b.targets[0], nil } - return b.targets[b.random.Intn(len(b.targets))] + return b.targets[b.random.Intn(len(b.targets))], nil } // Next returns an upstream target using round-robin technique. In the case @@ -255,13 +253,13 @@ func (b *randomBalancer) Next(c echo.Context) *ProxyTarget { // return the original failed target. // // Note: `nil` is returned in case upstream target list is empty. -func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { +func (b *roundRobinBalancer) Next(c *echo.Context) (*ProxyTarget, error) { b.mutex.Lock() defer b.mutex.Unlock() if len(b.targets) == 0 { - return nil + return nil, nil } else if len(b.targets) == 1 { - return b.targets[0] + return b.targets[0], nil } var i int @@ -283,9 +281,8 @@ func (b *roundRobinBalancer) Next(c echo.Context) *ProxyTarget { i = b.i b.i++ } - c.Set(lastIdxKey, i) - return b.targets[i] + return b.targets[i], nil } // Proxy returns a Proxy middleware. @@ -297,18 +294,26 @@ func Proxy(balancer ProxyBalancer) echo.MiddlewareFunc { return ProxyWithConfig(c) } -// ProxyWithConfig returns a Proxy middleware with config. -// See: `Proxy()` +// ProxyWithConfig returns a Proxy middleware or panics if configuration is invalid. +// +// Proxy middleware forwards the request to upstream server using a configured load balancing technique. func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { - if config.Balancer == nil { - panic("echo: proxy middleware requires balancer") - } - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts ProxyConfig to middleware or returns an error for invalid configuration +func (config ProxyConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultProxyConfig.Skipper } + if config.ContextKey == "" { + config.ContextKey = DefaultProxyConfig.ContextKey + } + if config.Balancer == nil { + return nil, errors.New("echo proxy middleware requires balancer") + } if config.RetryFilter == nil { - config.RetryFilter = func(c echo.Context, e error) bool { + config.RetryFilter = func(c *echo.Context, e error) bool { if httpErr, ok := e.(*echo.HTTPError); ok { return httpErr.Code == http.StatusBadGateway } @@ -316,10 +321,11 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } if config.ErrorHandler == nil { - config.ErrorHandler = func(c echo.Context, err error) error { + config.ErrorHandler = func(c *echo.Context, err error) error { return err } } + if config.Rewrite != nil { if config.RegexRewrite == nil { config.RegexRewrite = make(map[*regexp.Regexp]string) @@ -329,10 +335,8 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { } } - provider, isTargetProvider := config.Balancer.(TargetProvider) - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } @@ -358,15 +362,9 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { retries := config.RetryCount for { - var tgt *ProxyTarget - var err error - if isTargetProvider { - tgt, err = provider.NextTarget(c) - if err != nil { - return config.ErrorHandler(c, err) - } - } else { - tgt = config.Balancer.Next(c) + tgt, err := config.Balancer.Next(c) + if err != nil { + return config.ErrorHandler(c, err) } c.Set(config.ContextKey, tgt) @@ -385,9 +383,9 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // Proxy switch { case c.IsWebSocket(): - proxyRaw(tgt, c, config).ServeHTTP(res, req) + proxyRaw(c, tgt, config).ServeHTTP(res, req) default: // even SSE requests - proxyHTTP(tgt, c, config).ServeHTTP(res, req) + proxyHTTP(c, tgt, config).ServeHTTP(res, req) } err, hasError := c.Get("_error").(error) @@ -403,7 +401,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { retries-- } } - } + }, nil } // StatusCodeContextCanceled is a custom HTTP status code for situations @@ -413,7 +411,7 @@ func ProxyWithConfig(config ProxyConfig) echo.MiddlewareFunc { // 499 too instead of the more problematic 5xx, which does not allow to detect this situation const StatusCodeContextCanceled = 499 -func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handler { +func proxyHTTP(c *echo.Context, tgt *ProxyTarget, config ProxyConfig) http.Handler { proxy := httputil.NewSingleHostReverseProxy(tgt.URL) proxy.ErrorHandler = func(resp http.ResponseWriter, req *http.Request, err error) { desc := tgt.URL.String() @@ -423,15 +421,17 @@ func proxyHTTP(tgt *ProxyTarget, c echo.Context, config ProxyConfig) http.Handle // If the client canceled the request (usually by closing the connection), we can report a // client error (4xx) instead of a server error (5xx) to correctly identify the situation. // The Go standard library (at of late 2020) wraps the exported, standard - // context.Canceled error with unexported garbage value requiring a substring check, see + // context. Canceled error with unexported garbage value requiring a substring check, see // https://github.com/golang/go/blob/6965b01ea248cabb70c3749fd218b36089a21efb/src/net/net.go#L416-L430 - if err == context.Canceled || strings.Contains(err.Error(), "operation was canceled") { - httpError := echo.NewHTTPError(StatusCodeContextCanceled, fmt.Sprintf("client closed connection: %v", err)) - httpError.Internal = err + // From Caddy https://github.com/caddyserver/caddy/blob/afa778ae05503f563af0d1015cdf7e5e78b1eeec/modules/caddyhttp/reverseproxy/reverseproxy.go#L1352 + if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "operation was canceled") { + httpError := echo.NewHTTPError(StatusCodeContextCanceled, "client closed connection").Wrap(err) c.Set("_error", httpError) } else { - httpError := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("remote %s unreachable, could not forward: %v", desc, err)) - httpError.Internal = err + httpError := echo.NewHTTPError( + http.StatusBadGateway, + "remote server unreachable, could not proxy request", + ).Wrap(fmt.Errorf("server: %s, err: %w", desc, err)) c.Set("_error", httpError) } } diff --git a/middleware/proxy_test.go b/middleware/proxy_test.go index dbf07648b..420be3240 100644 --- a/middleware/proxy_test.go +++ b/middleware/proxy_test.go @@ -19,7 +19,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "golang.org/x/net/websocket" ) @@ -37,6 +37,7 @@ func TestProxy(t *testing.T) { })) defer t2.Close() url2, _ := url.Parse(t2.URL) + targets := []*ProxyTarget{ { Name: "target 1", @@ -60,7 +61,7 @@ func TestProxy(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -82,7 +83,7 @@ func TestProxy(t *testing.T) { // Round-robin rrb := NewRoundRobinBalancer(targets) e = echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) @@ -112,68 +113,24 @@ func TestProxy(t *testing.T) { // ProxyTarget is set in context contextObserver := func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c *echo.Context) (err error) { next(c) assert.Contains(t, targets, c.Get("target"), "target is not set in context") return nil } } - rrb1 := NewRoundRobinBalancer(targets) e = echo.New() e.Use(contextObserver) - e.Use(Proxy(rrb1)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: NewRoundRobinBalancer(targets)})) rec = httptest.NewRecorder() e.ServeHTTP(rec, req) } -type testProvider struct { - commonBalancer - target *ProxyTarget - err error -} - -func (p *testProvider) Next(c echo.Context) *ProxyTarget { - return &ProxyTarget{} -} - -func (p *testProvider) NextTarget(c echo.Context) (*ProxyTarget, error) { - return p.target, p.err -} - -func TestTargetProvider(t *testing.T) { - t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "target 1") - })) - defer t1.Close() - url1, _ := url.Parse(t1.URL) - - e := echo.New() - tp := &testProvider{} - tp.target = &ProxyTarget{Name: "target 1", URL: url1} - e.Use(Proxy(tp)) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - e.ServeHTTP(rec, req) - body := rec.Body.String() - assert.Equal(t, "target 1", body) -} - -func TestFailNextTarget(t *testing.T) { - url1, err := url.Parse("http://dummy:8080") - assert.Nil(t, err) - - e := echo.New() - tp := &testProvider{} - tp.target = &ProxyTarget{Name: "target 1", URL: url1} - tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") - - e.Use(Proxy(tp)) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - e.ServeHTTP(rec, req) - body := rec.Body.String() - assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) +func TestMustProxyWithConfig_emptyBalancerPanics(t *testing.T) { + assert.Panics(t, func() { + ProxyWithConfig(ProxyConfig{Balancer: nil}) + }) } func TestProxyRealIPHeader(t *testing.T) { @@ -183,7 +140,7 @@ func TestProxyRealIPHeader(t *testing.T) { url, _ := url.Parse(upstream.URL) rrb := NewRoundRobinBalancer([]*ProxyTarget{{Name: "upstream", URL: url}}) e := echo.New() - e.Use(Proxy(rrb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rrb})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() @@ -388,7 +345,7 @@ func TestProxyError(t *testing.T) { // Random e := echo.New() - e.Use(Proxy(rb)) + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) req := httptest.NewRequest(http.MethodGet, "/", nil) // Remote unreachable @@ -399,8 +356,108 @@ func TestProxyError(t *testing.T) { assert.Equal(t, http.StatusBadGateway, rec.Code) } -func TestProxyRetries(t *testing.T) { +func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { + var timeoutStop sync.WaitGroup + timeoutStop.Add(1) + HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + timeoutStop.Wait() // wait until we have canceled the request + w.WriteHeader(http.StatusOK) + })) + defer HTTPTarget.Close() + targetURL, _ := url.Parse(HTTPTarget.URL) + target := &ProxyTarget{ + Name: "target", + URL: targetURL, + } + rb := NewRandomBalancer(nil) + assert.True(t, rb.AddTarget(target)) + e := echo.New() + e.Use(ProxyWithConfig(ProxyConfig{Balancer: rb})) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + ctx, cancel := context.WithCancel(req.Context()) + req = req.WithContext(ctx) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + e.ServeHTTP(rec, req) + timeoutStop.Done() + assert.Equal(t, 499, rec.Code) +} + +type testProvider struct { + commonBalancer + target *ProxyTarget + err error +} + +func (p *testProvider) Next(c *echo.Context) (*ProxyTarget, error) { + return p.target, p.err +} + +func TestTargetProvider(t *testing.T) { + t1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "target 1") + })) + defer t1.Close() + url1, _ := url.Parse(t1.URL) + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "target 1", body) +} + +func TestFailNextTarget(t *testing.T) { + url1, err := url.Parse("http://dummy:8080") + assert.Nil(t, err) + + e := echo.New() + tp := &testProvider{} + tp.target = &ProxyTarget{Name: "target 1", URL: url1} + tp.err = echo.NewHTTPError(http.StatusInternalServerError, "method could not select target") + + e.Use(Proxy(tp)) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/", nil) + e.ServeHTTP(rec, req) + body := rec.Body.String() + assert.Equal(t, "{\"message\":\"method could not select target\"}\n", body) +} + +func TestRandomBalancerWithNoTargets(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Assert balancer with empty targets does return `nil` on `Next()` + rb := NewRandomBalancer(nil) + target, err := rb.Next(c) + assert.Nil(t, target) + assert.NoError(t, err) +} +func TestRoundRobinBalancerWithNoTargets(t *testing.T) { + // Assert balancer with empty targets does return `nil` on `Next()` + rrb := NewRoundRobinBalancer([]*ProxyTarget{}) + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/?id=1&name=Jon+Snow", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + target, err := rrb.Next(c) + assert.Nil(t, target) + assert.NoError(t, err) +} + +func TestProxyRetries(t *testing.T) { newServer := func(res int) (*url.URL, *httptest.Server) { server := httptest.NewServer( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -431,13 +488,13 @@ func TestProxyRetries(t *testing.T) { URL: targetURL, } - alwaysRetryFilter := func(c echo.Context, e error) bool { return true } - neverRetryFilter := func(c echo.Context, e error) bool { return false } + alwaysRetryFilter := func(c *echo.Context, e error) bool { return true } + neverRetryFilter := func(c *echo.Context, e error) bool { return false } testCases := []struct { name string retryCount int - retryFilters []func(c echo.Context, e error) bool + retryFilters []func(c *echo.Context, e error) bool targets []*ProxyTarget expectedResponse int }{ @@ -460,7 +517,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 1 does retry on handler return true", retryCount: 1, - retryFilters: []func(c echo.Context, e error) bool{ + retryFilters: []func(c *echo.Context, e error) bool{ alwaysRetryFilter, }, targets: []*ProxyTarget{ @@ -472,7 +529,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 1 does not retry on handler return false", retryCount: 1, - retryFilters: []func(c echo.Context, e error) bool{ + retryFilters: []func(c *echo.Context, e error) bool{ neverRetryFilter, }, targets: []*ProxyTarget{ @@ -484,7 +541,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 2 returns error when no more retries left", retryCount: 2, - retryFilters: []func(c echo.Context, e error) bool{ + retryFilters: []func(c *echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, }, @@ -499,7 +556,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 2 returns error when retries left but handler returns false", retryCount: 3, - retryFilters: []func(c echo.Context, e error) bool{ + retryFilters: []func(c *echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, neverRetryFilter, @@ -515,7 +572,7 @@ func TestProxyRetries(t *testing.T) { { name: "retry count 3 succeeds", retryCount: 3, - retryFilters: []func(c echo.Context, e error) bool{ + retryFilters: []func(c *echo.Context, e error) bool{ alwaysRetryFilter, alwaysRetryFilter, alwaysRetryFilter, @@ -543,7 +600,7 @@ func TestProxyRetries(t *testing.T) { t.Run(tc.name, func(t *testing.T) { retryFilterCall := 0 - retryFilter := func(c echo.Context, e error) bool { + retryFilter := func(c *echo.Context, e error) bool { if len(tc.retryFilters) == 0 { assert.FailNow(t, fmt.Sprintf("unexpected calls, %d, to retry handler", retryFilterCall)) } @@ -658,13 +715,13 @@ func TestProxyErrorHandler(t *testing.T) { testCases := []struct { name string target *ProxyTarget - errorHandler func(c echo.Context, e error) error + errorHandler func(c *echo.Context, e error) error expectFinalError func(t *testing.T, err error) }{ { name: "Error handler not invoked when request success", target: goodTarget, - errorHandler: func(c echo.Context, e error) error { + errorHandler: func(c *echo.Context, e error) error { assert.FailNow(t, "error handler should not be invoked") return e }, @@ -672,7 +729,7 @@ func TestProxyErrorHandler(t *testing.T) { { name: "Error handler invoked when request fails", target: badTarget, - errorHandler: func(c echo.Context, e error) error { + errorHandler: func(c *echo.Context, e error) error { httpErr, ok := e.(*echo.HTTPError) assert.True(t, ok, "expected http error to be passed to handler") assert.Equal(t, http.StatusBadGateway, httpErr.Code, "expected http bad gateway error to be passed to handler") @@ -695,10 +752,11 @@ func TestProxyErrorHandler(t *testing.T) { )) errorHandlerCalled := false - e.HTTPErrorHandler = func(err error, c echo.Context) { + dheh := echo.DefaultHTTPErrorHandler(false) + e.HTTPErrorHandler = func(c *echo.Context, err error) { errorHandlerCalled = true tc.expectFinalError(t, err) - e.DefaultHTTPErrorHandler(err, c) + dheh(c, err) } req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -714,47 +772,7 @@ func TestProxyErrorHandler(t *testing.T) { } } -func TestClientCancelConnectionResultsHTTPCode499(t *testing.T) { - var timeoutStop sync.WaitGroup - timeoutStop.Add(1) - HTTPTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - timeoutStop.Wait() // wait until we have canceled the request - w.WriteHeader(http.StatusOK) - })) - defer HTTPTarget.Close() - targetURL, _ := url.Parse(HTTPTarget.URL) - target := &ProxyTarget{ - Name: "target", - URL: targetURL, - } - rb := NewRandomBalancer(nil) - assert.True(t, rb.AddTarget(target)) - e := echo.New() - e.Use(Proxy(rb)) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodGet, "/", nil) - ctx, cancel := context.WithCancel(req.Context()) - req = req.WithContext(ctx) - go func() { - time.Sleep(10 * time.Millisecond) - cancel() - }() - e.ServeHTTP(rec, req) - timeoutStop.Done() - assert.Equal(t, 499, rec.Code) -} - -// Assert balancer with empty targets does return `nil` on `Next()` -func TestProxyBalancerWithNoTargets(t *testing.T) { - rb := NewRandomBalancer(nil) - assert.Nil(t, rb.Next(nil)) - - rrb := NewRoundRobinBalancer([]*ProxyTarget{}) - assert.Nil(t, rrb.Next(nil)) -} - type testContextKey string - type customBalancer struct { target *ProxyTarget } @@ -762,15 +780,14 @@ type customBalancer struct { func (b *customBalancer) AddTarget(target *ProxyTarget) bool { return false } - func (b *customBalancer) RemoveTarget(name string) bool { return false } -func (b *customBalancer) Next(c echo.Context) *ProxyTarget { +func (b *customBalancer) Next(c *echo.Context) (*ProxyTarget, error) { ctx := context.WithValue(c.Request().Context(), testContextKey("FROM_BALANCER"), "CUSTOM_BALANCER") c.SetRequest(c.Request().WithContext(ctx)) - return b.target + return b.target, nil } func TestModifyResponseUseContext(t *testing.T) { @@ -781,7 +798,6 @@ func TestModifyResponseUseContext(t *testing.T) { }), ) defer server.Close() - targetURL, _ := url.Parse(server.URL) e := echo.New() e.Use(ProxyWithConfig( @@ -802,12 +818,9 @@ func TestModifyResponseUseContext(t *testing.T) { }, }, )) - req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() - e.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "OK", rec.Body.String()) assert.Equal(t, "CUSTOM_BALANCER", rec.Header().Get("FROM_BALANCER")) diff --git a/middleware/rate_limiter.go b/middleware/rate_limiter.go index 2746a3de1..bdf933e87 100644 --- a/middleware/rate_limiter.go +++ b/middleware/rate_limiter.go @@ -4,18 +4,18 @@ package middleware import ( + "errors" "math" "net/http" "sync" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "golang.org/x/time/rate" ) // RateLimiterStore is the interface to be implemented by custom stores. type RateLimiterStore interface { - // Stores for the rate limiter have to implement the Allow method Allow(identifier string) (bool, error) } @@ -23,18 +23,18 @@ type RateLimiterStore interface { type RateLimiterConfig struct { Skipper Skipper BeforeFunc BeforeFunc - // IdentifierExtractor uses echo.Context to extract the identifier for a visitor + // IdentifierExtractor uses *echo.Context to extract the identifier for a visitor IdentifierExtractor Extractor // Store defines a store for the rate limiter Store RateLimiterStore // ErrorHandler provides a handler to be called when IdentifierExtractor returns an error - ErrorHandler func(context echo.Context, err error) error + ErrorHandler func(c *echo.Context, err error) error // DenyHandler provides a handler to be called when RateLimiter denies access - DenyHandler func(context echo.Context, identifier string, err error) error + DenyHandler func(c *echo.Context, identifier string, err error) error } -// Extractor is used to extract data from echo.Context -type Extractor func(context echo.Context) (string, error) +// Extractor is used to extract data from *echo.Context +type Extractor func(c *echo.Context) (string, error) // ErrRateLimitExceeded denotes an error raised when rate limit is exceeded var ErrRateLimitExceeded = echo.NewHTTPError(http.StatusTooManyRequests, "rate limit exceeded") @@ -45,23 +45,15 @@ var ErrExtractorError = echo.NewHTTPError(http.StatusForbidden, "error while ext // DefaultRateLimiterConfig defines default values for RateLimiterConfig var DefaultRateLimiterConfig = RateLimiterConfig{ Skipper: DefaultSkipper, - IdentifierExtractor: func(ctx echo.Context) (string, error) { + IdentifierExtractor: func(ctx *echo.Context) (string, error) { id := ctx.RealIP() return id, nil }, - ErrorHandler: func(context echo.Context, err error) error { - return &echo.HTTPError{ - Code: ErrExtractorError.Code, - Message: ErrExtractorError.Message, - Internal: err, - } + ErrorHandler: func(c *echo.Context, err error) error { + return ErrExtractorError.Wrap(err) }, - DenyHandler: func(context echo.Context, identifier string, err error) error { - return &echo.HTTPError{ - Code: ErrRateLimitExceeded.Code, - Message: ErrRateLimitExceeded.Message, - Internal: err, - } + DenyHandler: func(c *echo.Context, identifier string, err error) error { + return ErrRateLimitExceeded.Wrap(err) }, } @@ -72,7 +64,7 @@ RateLimiter returns a rate limiting middleware limiterStore := middleware.NewRateLimiterMemoryStore(20) - e.GET("/rate-limited", func(c echo.Context) error { + e.GET("/rate-limited", func(c *echo.Context) error { return c.String(http.StatusOK, "test") }, RateLimiter(limiterStore)) */ @@ -93,23 +85,28 @@ RateLimiterWithConfig returns a rate limiting middleware Store: middleware.NewRateLimiterMemoryStore( middleware.RateLimiterMemoryStoreConfig{Rate: 10, Burst: 30, ExpiresIn: 3 * time.Minute} ) - IdentifierExtractor: func(ctx echo.Context) (string, error) { + IdentifierExtractor: func(ctx *echo.Context) (string, error) { id := ctx.RealIP() return id, nil }, - ErrorHandler: func(context echo.Context, err error) error { + ErrorHandler: func(context *echo.Context, err error) error { return context.JSON(http.StatusTooManyRequests, nil) }, - DenyHandler: func(context echo.Context, identifier string) error { + DenyHandler: func(context *echo.Context, identifier string) error { return context.JSON(http.StatusForbidden, nil) }, } - e.GET("/rate-limited", func(c echo.Context) error { + e.GET("/rate-limited", func(c *echo.Context) error { return c.String(http.StatusOK, "test") }, middleware.RateLimiterWithConfig(config)) */ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RateLimiterConfig to middleware or returns an error for invalid configuration +func (config RateLimiterConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultRateLimiterConfig.Skipper } @@ -123,10 +120,10 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { config.DenyHandler = DefaultRateLimiterConfig.DenyHandler } if config.Store == nil { - panic("Store configuration must be provided") + return nil, errors.New("echo rate limiter store configuration must be provided") } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -136,25 +133,22 @@ func RateLimiterWithConfig(config RateLimiterConfig) echo.MiddlewareFunc { identifier, err := config.IdentifierExtractor(c) if err != nil { - c.Error(config.ErrorHandler(c, err)) - return nil + return config.ErrorHandler(c, err) } - if allow, err := config.Store.Allow(identifier); !allow { - c.Error(config.DenyHandler(c, identifier, err)) - return nil + if allow, allowErr := config.Store.Allow(identifier); !allow { + return config.DenyHandler(c, identifier, allowErr) } return next(c) } - } + }, nil } // RateLimiterMemoryStore is the built-in store implementation for RateLimiter type RateLimiterMemoryStore struct { - visitors map[string]*Visitor - mutex sync.Mutex - rate rate.Limit // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. - + visitors map[string]*Visitor + mutex sync.Mutex + rate float64 // for more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit burst int expiresIn time.Duration lastCleanup time.Time @@ -181,9 +175,9 @@ Example (with 20 requests/sec): limiterStore := middleware.NewRateLimiterMemoryStore(20) */ -func NewRateLimiterMemoryStore(rate rate.Limit) (store *RateLimiterMemoryStore) { +func NewRateLimiterMemoryStore(rateLimit float64) (store *RateLimiterMemoryStore) { return NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{ - Rate: rate, + Rate: rateLimit, }) } @@ -226,7 +220,7 @@ func NewRateLimiterMemoryStoreWithConfig(config RateLimiterMemoryStoreConfig) (s // RateLimiterMemoryStoreConfig represents configuration for RateLimiterMemoryStore type RateLimiterMemoryStoreConfig struct { - Rate rate.Limit // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. + Rate float64 // Rate of requests allowed to pass as req/s. For more info check out Limiter docs - https://pkg.go.dev/golang.org/x/time/rate#Limit. Burst int // Burst is maximum number of requests to pass at the same moment. It additionally allows a number of requests to pass when rate limit is reached. ExpiresIn time.Duration // ExpiresIn is the duration after that a rate limiter is cleaned up } @@ -242,13 +236,13 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { limiter, exists := store.visitors[identifier] if !exists { limiter = new(Visitor) - limiter.Limiter = rate.NewLimiter(store.rate, store.burst) + limiter.Limiter = rate.NewLimiter(rate.Limit(store.rate), store.burst) store.visitors[identifier] = limiter } now := store.timeNow() limiter.lastSeen = now if now.Sub(store.lastCleanup) > store.expiresIn { - store.cleanupStaleVisitors() + store.cleanupStaleVisitors(now) } allowed := limiter.AllowN(now, 1) store.mutex.Unlock() @@ -259,11 +253,11 @@ func (store *RateLimiterMemoryStore) Allow(identifier string) (bool, error) { cleanupStaleVisitors helps manage the size of the visitors map by removing stale records of users who haven't visited again after the configured expiry time has elapsed */ -func (store *RateLimiterMemoryStore) cleanupStaleVisitors() { +func (store *RateLimiterMemoryStore) cleanupStaleVisitors(now time.Time) { for id, visitor := range store.visitors { - if store.timeNow().Sub(visitor.lastSeen) > store.expiresIn { + if now.Sub(visitor.lastSeen) > store.expiresIn { delete(store.visitors, id) } } - store.lastCleanup = store.timeNow() + store.lastCleanup = now } diff --git a/middleware/rate_limiter_test.go b/middleware/rate_limiter_test.go index 655d4731d..c591d2b19 100644 --- a/middleware/rate_limiter_test.go +++ b/middleware/rate_limiter_test.go @@ -13,7 +13,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" "golang.org/x/time/rate" ) @@ -21,25 +21,25 @@ import ( func TestRateLimiter(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) - mw := RateLimiter(inMemoryStore) + mw := RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -49,20 +49,25 @@ func TestRateLimiter(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - _ = mw(handler)(c) - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } -func TestRateLimiter_panicBehaviour(t *testing.T) { +func TestMustRateLimiterWithConfig_panicBehaviour(t *testing.T) { var inMemoryStore = NewRateLimiterMemoryStoreWithConfig(RateLimiterMemoryStoreConfig{Rate: 1, Burst: 3}) assert.Panics(t, func() { - RateLimiter(nil) + RateLimiterWithConfig(RateLimiterConfig{}) }) assert.NotPanics(t, func() { - RateLimiter(inMemoryStore) + RateLimiterWithConfig(RateLimiterConfig{Store: inMemoryStore}) }) } @@ -71,26 +76,27 @@ func TestRateLimiterWithConfig(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ - IdentifierExtractor: func(c echo.Context) (string, error) { + mw, err := RateLimiterConfig{ + IdentifierExtractor: func(c *echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { return "", errors.New("invalid identifier") } return id, nil }, - DenyHandler: func(ctx echo.Context, identifier string, err error) error { + DenyHandler: func(ctx *echo.Context, identifier string, err error) error { return ctx.JSON(http.StatusForbidden, nil) }, - ErrorHandler: func(ctx echo.Context, err error) error { + ErrorHandler: func(ctx *echo.Context, err error) error { return ctx.JSON(http.StatusBadRequest, nil) }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { id string @@ -113,8 +119,9 @@ func TestRateLimiterWithConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) + err := mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, tc.code, rec.Code) } } @@ -124,12 +131,12 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ - IdentifierExtractor: func(c echo.Context) (string, error) { + mw, err := RateLimiterConfig{ + IdentifierExtractor: func(c *echo.Context) (string, error) { id := c.Request().Header.Get(echo.HeaderXRealIP) if id == "" { return "", errors.New("invalid identifier") @@ -137,19 +144,20 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { return id, nil }, Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"", http.StatusForbidden}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {expectErr: "code=403, message=error while extracting identifier, err=invalid identifier"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -160,9 +168,13 @@ func TestRateLimiterWithConfig_defaultDenyHandler(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } @@ -172,25 +184,26 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } - mw := RateLimiterWithConfig(RateLimiterConfig{ + mw, err := RateLimiterConfig{ Store: inMemoryStore, - }) + }.ToMiddleware() + assert.NoError(t, err) testCases := []struct { - id string - code int + id string + expectErr string }{ - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusOK}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, - {"127.0.0.1", http.StatusTooManyRequests}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, + {id: "127.0.0.1", expectErr: "code=429, message=rate limit exceeded"}, } for _, tc := range testCases { @@ -201,9 +214,13 @@ func TestRateLimiterWithConfig_defaultConfig(t *testing.T) { c := e.NewContext(req, rec) - _ = mw(handler)(c) - - assert.Equal(t, tc.code, rec.Code) + err := mw(handler)(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + assert.Equal(t, http.StatusOK, rec.Code) } } } @@ -212,7 +229,7 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { e := echo.New() var beforeFuncRan bool - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStore(5) @@ -224,21 +241,23 @@ func TestRateLimiterWithConfig_skipper(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ - Skipper: func(c echo.Context) bool { + mw, err := RateLimiterConfig{ + Skipper: func(c *echo.Context) bool { return true }, - BeforeFunc: func(c echo.Context) { + BeforeFunc: func(c *echo.Context) { beforeFuncRan = true }, Store: inMemoryStore, - IdentifierExtractor: func(ctx echo.Context) (string, error) { + IdentifierExtractor: func(ctx *echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, false, beforeFuncRan) } @@ -246,7 +265,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { e := echo.New() var beforeFuncRan bool - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } var inMemoryStore = NewRateLimiterMemoryStore(5) @@ -258,18 +277,19 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ - Skipper: func(c echo.Context) bool { + mw, err := RateLimiterConfig{ + Skipper: func(c *echo.Context) bool { return false }, - BeforeFunc: func(c echo.Context) { + BeforeFunc: func(c *echo.Context) { beforeFuncRan = true }, Store: inMemoryStore, - IdentifierExtractor: func(ctx echo.Context) (string, error) { + IdentifierExtractor: func(ctx *echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) _ = mw(handler)(c) @@ -279,7 +299,7 @@ func TestRateLimiterWithConfig_skipperNoSkip(t *testing.T) { func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { e := echo.New() - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } @@ -293,18 +313,20 @@ func TestRateLimiterWithConfig_beforeFunc(t *testing.T) { c := e.NewContext(req, rec) - mw := RateLimiterWithConfig(RateLimiterConfig{ - BeforeFunc: func(c echo.Context) { + mw, err := RateLimiterConfig{ + BeforeFunc: func(c *echo.Context) { beforeRan = true }, Store: inMemoryStore, - IdentifierExtractor: func(ctx echo.Context) (string, error) { + IdentifierExtractor: func(ctx *echo.Context) (string, error) { return "127.0.0.1", nil }, - }) + }.ToMiddleware() + assert.NoError(t, err) - _ = mw(handler)(c) + err = mw(handler)(c) + assert.NoError(t, err) assert.Equal(t, true, beforeRan) } @@ -372,7 +394,7 @@ func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { } inMemoryStore.Allow("D") - inMemoryStore.cleanupStaleVisitors() + inMemoryStore.cleanupStaleVisitors(time.Now()) var exists bool @@ -391,7 +413,7 @@ func TestRateLimiterMemoryStore_cleanupStaleVisitors(t *testing.T) { func TestNewRateLimiterMemoryStore(t *testing.T) { testCases := []struct { - rate rate.Limit + rate float64 burst int expiresIn time.Duration expectedExpiresIn time.Duration diff --git a/middleware/recover.go b/middleware/recover.go index e6a5940e4..c18032847 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -8,13 +8,9 @@ import ( "net/http" "runtime" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" + "github.com/labstack/echo/v5" ) -// LogErrorFunc defines a function for custom logging in the middleware. -type LogErrorFunc func(c echo.Context, err error, stack []byte) error - // RecoverConfig defines the config for Recover middleware. type RecoverConfig struct { // Skipper defines a function to skip middleware. @@ -22,41 +18,24 @@ type RecoverConfig struct { // Size of the stack to be printed. // Optional. Default value 4KB. - StackSize int `yaml:"stack_size"` + StackSize int // DisableStackAll disables formatting stack traces of all other goroutines // into buffer after the trace for the current goroutine. // Optional. Default value false. - DisableStackAll bool `yaml:"disable_stack_all"` + DisableStackAll bool // DisablePrintStack disables printing stack trace. // Optional. Default value as false. - DisablePrintStack bool `yaml:"disable_print_stack"` - - // LogLevel is log level to printing stack trace. - // Optional. Default value 0 (Print). - LogLevel log.Lvl - - // LogErrorFunc defines a function for custom logging in the middleware. - // If it's set you don't need to provide LogLevel for config. - // If this function returns nil, the centralized HTTPErrorHandler will not be called. - LogErrorFunc LogErrorFunc - - // DisableErrorHandler disables the call to centralized HTTPErrorHandler. - // The recovered error is then passed back to upstream middleware, instead of swallowing the error. - // Optional. Default value false. - DisableErrorHandler bool `yaml:"disable_error_handler"` + DisablePrintStack bool } // DefaultRecoverConfig is the default Recover middleware config. var DefaultRecoverConfig = RecoverConfig{ - Skipper: DefaultSkipper, - StackSize: 4 << 10, // 4 KB - DisableStackAll: false, - DisablePrintStack: false, - LogLevel: 0, - LogErrorFunc: nil, - DisableErrorHandler: false, + Skipper: DefaultSkipper, + StackSize: 4 << 10, // 4 KB + DisableStackAll: false, + DisablePrintStack: false, } // Recover returns a middleware which recovers from panics anywhere in the chain @@ -65,9 +44,13 @@ func Recover() echo.MiddlewareFunc { return RecoverWithConfig(DefaultRecoverConfig) } -// RecoverWithConfig returns a Recover middleware with config. -// See: `Recover()`. +// RecoverWithConfig returns a Recovery middleware with config or panics on invalid configuration. func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RecoverConfig to middleware or returns an error for invalid configuration +func (config RecoverConfig) ToMiddleware() (echo.MiddlewareFunc, error) { // Defaults if config.Skipper == nil { config.Skipper = DefaultRecoverConfig.Skipper @@ -77,7 +60,7 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (returnErr error) { + return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } @@ -87,47 +70,19 @@ func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc { if r == http.ErrAbortHandler { panic(r) } - err, ok := r.(error) + tmpErr, ok := r.(error) if !ok { - err = fmt.Errorf("%v", r) + tmpErr = fmt.Errorf("%v", r) } - var stack []byte - var length int - if !config.DisablePrintStack { - stack = make([]byte, config.StackSize) - length = runtime.Stack(stack, !config.DisableStackAll) - stack = stack[:length] - } - - if config.LogErrorFunc != nil { - err = config.LogErrorFunc(c, err, stack) - } else if !config.DisablePrintStack { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length]) - switch config.LogLevel { - case log.DEBUG: - c.Logger().Debug(msg) - case log.INFO: - c.Logger().Info(msg) - case log.WARN: - c.Logger().Warn(msg) - case log.ERROR: - c.Logger().Error(msg) - case log.OFF: - // None. - default: - c.Logger().Print(msg) - } - } - - if err != nil && !config.DisableErrorHandler { - c.Error(err) - } else { - returnErr = err + stack := make([]byte, config.StackSize) + length := runtime.Stack(stack, !config.DisableStackAll) + tmpErr = fmt.Errorf("[PANIC RECOVER] %w %s", tmpErr, stack[:length]) } + err = tmpErr } }() return next(c) } - } + }, nil } diff --git a/middleware/recover_test.go b/middleware/recover_test.go index 8fa34fa5c..bf0d16531 100644 --- a/middleware/recover_test.go +++ b/middleware/recover_test.go @@ -5,43 +5,64 @@ package middleware import ( "bytes" - "errors" - "fmt" + "log/slog" "net/http" "net/http/httptest" "testing" - "github.com/labstack/echo/v4" - "github.com/labstack/gommon/log" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) func TestRecover(t *testing.T) { e := echo.New() buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + e.Logger = slog.New(&discardHandler{}) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + h := Recover()(func(c *echo.Context) error { panic("test") - })) + }) err := h(c) + assert.Contains(t, err.Error(), "[PANIC RECOVER] test goroutine") + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain + assert.Contains(t, buf.String(), "") // nothing is logged +} + +func TestRecover_skipper(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + config := RecoverConfig{ + Skipper: func(c *echo.Context) bool { + return true + }, + } + h := RecoverWithConfig(config)(func(c *echo.Context) error { + panic("testPANIC") + }) + + var err error + assert.Panics(t, func() { + err = h(c) + }) + assert.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.Contains(t, buf.String(), "PANIC RECOVER") + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain } func TestRecoverErrAbortHandler(t *testing.T) { e := echo.New() - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - h := Recover()(echo.HandlerFunc(func(c echo.Context) error { + h := Recover()(func(c *echo.Context) error { panic(http.ErrAbortHandler) - })) + }) defer func() { r := recover() if r == nil { @@ -55,135 +76,66 @@ func TestRecoverErrAbortHandler(t *testing.T) { } }() - h(c) + hErr := h(c) assert.Equal(t, http.StatusInternalServerError, rec.Code) - assert.NotContains(t, buf.String(), "PANIC RECOVER") + assert.NotContains(t, hErr.Error(), "PANIC RECOVER") } -func TestRecoverWithConfig_LogLevel(t *testing.T) { - tests := []struct { - logLevel log.Lvl - levelName string - }{{ - logLevel: log.DEBUG, - levelName: "DEBUG", - }, { - logLevel: log.INFO, - levelName: "INFO", - }, { - logLevel: log.WARN, - levelName: "WARN", - }, { - logLevel: log.ERROR, - levelName: "ERROR", - }, { - logLevel: log.OFF, - levelName: "OFF", - }} - - for _, tt := range tests { - tt := tt - t.Run(tt.levelName, func(t *testing.T) { - e := echo.New() - e.Logger.SetLevel(log.DEBUG) +func TestRecoverWithConfig(t *testing.T) { + var testCases = []struct { + name string + givenNoPanic bool + whenConfig RecoverConfig + expectErrContain string + expectErr string + }{ + { + name: "ok, default config", + whenConfig: DefaultRecoverConfig, + expectErrContain: "[PANIC RECOVER] testPANIC goroutine", + }, + { + name: "ok, no panic", + givenNoPanic: true, + whenConfig: DefaultRecoverConfig, + expectErrContain: "", + }, + { + name: "ok, DisablePrintStack", + whenConfig: RecoverConfig{ + DisablePrintStack: true, + }, + expectErr: "testPANIC", + }, + } - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - config := DefaultRecoverConfig - config.LogLevel = tt.logLevel - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("test") - })) - - h(c) + config := tc.whenConfig + h := RecoverWithConfig(config)(func(c *echo.Context) error { + if tc.givenNoPanic { + return nil + } + panic("testPANIC") + }) - assert.Equal(t, http.StatusInternalServerError, rec.Code) + err := h(c) - output := buf.String() - if tt.logLevel == log.OFF { - assert.Empty(t, output) + if tc.expectErrContain != "" { + assert.Contains(t, err.Error(), tc.expectErrContain) + } else if tc.expectErr != "" { + assert.Contains(t, err.Error(), tc.expectErr) } else { - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, fmt.Sprintf(`"level":"%s"`, tt.levelName)) + assert.NoError(t, err) } + assert.Equal(t, http.StatusOK, rec.Code) // status is still untouched. err is returned from middleware chain }) } } - -func TestRecoverWithConfig_LogErrorFunc(t *testing.T) { - e := echo.New() - e.Logger.SetLevel(log.DEBUG) - - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - testError := errors.New("test") - config := DefaultRecoverConfig - config.LogErrorFunc = func(c echo.Context, err error, stack []byte) error { - msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack) - if errors.Is(err, testError) { - c.Logger().Debug(msg) - } else { - c.Logger().Error(msg) - } - return err - } - - t.Run("first branch case for LogErrorFunc", func(t *testing.T) { - buf.Reset() - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic(testError) - })) - - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"DEBUG"`) - }) - - t.Run("else branch case for LogErrorFunc", func(t *testing.T) { - buf.Reset() - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("other") - })) - - h(c) - assert.Equal(t, http.StatusInternalServerError, rec.Code) - - output := buf.String() - assert.Contains(t, output, "PANIC RECOVER") - assert.Contains(t, output, `"level":"ERROR"`) - }) -} - -func TestRecoverWithDisabled_ErrorHandler(t *testing.T) { - e := echo.New() - buf := new(bytes.Buffer) - e.Logger.SetOutput(buf) - req := httptest.NewRequest(http.MethodGet, "/", nil) - rec := httptest.NewRecorder() - c := e.NewContext(req, rec) - - config := DefaultRecoverConfig - config.DisableErrorHandler = true - h := RecoverWithConfig(config)(echo.HandlerFunc(func(c echo.Context) error { - panic("test") - })) - err := h(c) - - assert.Equal(t, http.StatusOK, rec.Code) - assert.Contains(t, buf.String(), "PANIC RECOVER") - assert.EqualError(t, err, "test") -} diff --git a/middleware/redirect.go b/middleware/redirect.go index b772ac131..bb7045cfe 100644 --- a/middleware/redirect.go +++ b/middleware/redirect.go @@ -4,10 +4,11 @@ package middleware import ( + "errors" "net/http" "strings" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // RedirectConfig defines the config for Redirect middleware. @@ -17,7 +18,9 @@ type RedirectConfig struct { // Status code to be used when redirecting the request. // Optional. Default value http.StatusMovedPermanently. - Code int `yaml:"code"` + Code int + + redirect redirectLogic } // redirectLogic represents a function that given a scheme, host and uri @@ -27,29 +30,33 @@ type redirectLogic func(scheme, host, uri string) (ok bool, url string) const www = "www." -// DefaultRedirectConfig is the default Redirect middleware config. -var DefaultRedirectConfig = RedirectConfig{ - Skipper: DefaultSkipper, - Code: http.StatusMovedPermanently, -} +// RedirectHTTPSConfig is the HTTPS Redirect middleware config. +var RedirectHTTPSConfig = RedirectConfig{redirect: redirectHTTPS} + +// RedirectHTTPSWWWConfig is the HTTPS WWW Redirect middleware config. +var RedirectHTTPSWWWConfig = RedirectConfig{redirect: redirectHTTPSWWW} + +// RedirectNonHTTPSWWWConfig is the non HTTPS WWW Redirect middleware config. +var RedirectNonHTTPSWWWConfig = RedirectConfig{redirect: redirectNonHTTPSWWW} + +// RedirectWWWConfig is the WWW Redirect middleware config. +var RedirectWWWConfig = RedirectConfig{redirect: redirectWWW} + +// RedirectNonWWWConfig is the non WWW Redirect middleware config. +var RedirectNonWWWConfig = RedirectConfig{redirect: redirectNonWWW} // HTTPSRedirect redirects http requests to https. // For example, http://labstack.com will be redirect to https://labstack.com. // // Usage `Echo#Pre(HTTPSRedirect())` func HTTPSRedirect() echo.MiddlewareFunc { - return HTTPSRedirectWithConfig(DefaultRedirectConfig) + return HTTPSRedirectWithConfig(RedirectHTTPSConfig) } -// HTTPSRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSRedirect()`. +// HTTPSRedirectWithConfig returns a HTTPS redirect middleware with config or panics on invalid configuration. func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" { - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPS + return toMiddlewareOrPanic(config) } // HTTPSWWWRedirect redirects http requests to https www. @@ -57,18 +64,13 @@ func HTTPSRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSWWWRedirect())` func HTTPSWWWRedirect() echo.MiddlewareFunc { - return HTTPSWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSWWWRedirectWithConfig(RedirectHTTPSWWWConfig) } -// HTTPSWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSWWWRedirect()`. +// HTTPSWWWRedirectWithConfig returns a HTTPS WWW redirect middleware with config or panics on invalid configuration. func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if scheme != "https" && !strings.HasPrefix(host, www) { - return true, "https://www." + host + uri - } - return false, "" - }) + config.redirect = redirectHTTPSWWW + return toMiddlewareOrPanic(config) } // HTTPSNonWWWRedirect redirects http requests to https non www. @@ -76,19 +78,13 @@ func HTTPSWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(HTTPSNonWWWRedirect())` func HTTPSNonWWWRedirect() echo.MiddlewareFunc { - return HTTPSNonWWWRedirectWithConfig(DefaultRedirectConfig) + return HTTPSNonWWWRedirectWithConfig(RedirectNonHTTPSWWWConfig) } -// HTTPSNonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `HTTPSNonWWWRedirect()`. +// HTTPSNonWWWRedirectWithConfig returns a HTTPS Non-WWW redirect middleware with config or panics on invalid configuration. func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (ok bool, url string) { - if scheme != "https" { - host = strings.TrimPrefix(host, www) - return true, "https://" + host + uri - } - return false, "" - }) + config.redirect = redirectNonHTTPSWWW + return toMiddlewareOrPanic(config) } // WWWRedirect redirects non www requests to www. @@ -96,18 +92,13 @@ func HTTPSNonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(WWWRedirect())` func WWWRedirect() echo.MiddlewareFunc { - return WWWRedirectWithConfig(DefaultRedirectConfig) + return WWWRedirectWithConfig(RedirectWWWConfig) } -// WWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `WWWRedirect()`. +// WWWRedirectWithConfig returns a WWW redirect middleware with config or panics on invalid configuration. func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if !strings.HasPrefix(host, www) { - return true, scheme + "://www." + host + uri - } - return false, "" - }) + config.redirect = redirectWWW + return toMiddlewareOrPanic(config) } // NonWWWRedirect redirects www requests to non www. @@ -115,41 +106,79 @@ func WWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { // // Usage `Echo#Pre(NonWWWRedirect())` func NonWWWRedirect() echo.MiddlewareFunc { - return NonWWWRedirectWithConfig(DefaultRedirectConfig) + return NonWWWRedirectWithConfig(RedirectNonWWWConfig) } -// NonWWWRedirectWithConfig returns an HTTPSRedirect middleware with config. -// See `NonWWWRedirect()`. +// NonWWWRedirectWithConfig returns a Non-WWW redirect middleware with config or panics on invalid configuration. func NonWWWRedirectWithConfig(config RedirectConfig) echo.MiddlewareFunc { - return redirect(config, func(scheme, host, uri string) (bool, string) { - if strings.HasPrefix(host, www) { - return true, scheme + "://" + host[4:] + uri - } - return false, "" - }) + config.redirect = redirectNonWWW + return toMiddlewareOrPanic(config) } -func redirect(config RedirectConfig, cb redirectLogic) echo.MiddlewareFunc { +// ToMiddleware converts RedirectConfig to middleware or returns an error for invalid configuration +func (config RedirectConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRedirectConfig.Skipper + config.Skipper = DefaultSkipper } if config.Code == 0 { - config.Code = DefaultRedirectConfig.Code + config.Code = http.StatusMovedPermanently + } + if config.redirect == nil { + return nil, errors.New("redirectConfig is missing redirect function") } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } req, scheme := c.Request(), c.Scheme() host := req.Host - if ok, url := cb(scheme, host, req.RequestURI); ok { + if ok, url := config.redirect(scheme, host, req.RequestURI); ok { return c.Redirect(config.Code, url) } return next(c) } + }, nil +} + +var redirectHTTPS = func(scheme, host, uri string) (bool, string) { + if scheme != "https" { + return true, "https://" + host + uri + } + return false, "" +} + +var redirectHTTPSWWW = func(scheme, host, uri string) (bool, string) { + // Redirect if not HTTPS OR missing www prefix (needs either fix) + if scheme != "https" || !strings.HasPrefix(host, www) { + host = strings.TrimPrefix(host, www) // Remove www if present to avoid duplication + return true, "https://www." + host + uri + } + return false, "" +} + +var redirectNonHTTPSWWW = func(scheme, host, uri string) (ok bool, url string) { + // Redirect if not HTTPS OR has www prefix (needs either fix) + if scheme != "https" || strings.HasPrefix(host, www) { + host = strings.TrimPrefix(host, www) + return true, "https://" + host + uri + } + return false, "" +} + +var redirectWWW = func(scheme, host, uri string) (bool, string) { + if !strings.HasPrefix(host, www) { + return true, scheme + "://www." + host + uri + } + return false, "" +} + +var redirectNonWWW = func(scheme, host, uri string) (bool, string) { + if strings.HasPrefix(host, www) { + return true, scheme + "://" + host[4:] + uri } + return false, "" } diff --git a/middleware/redirect_test.go b/middleware/redirect_test.go index 88068ea2e..a127ca40c 100644 --- a/middleware/redirect_test.go +++ b/middleware/redirect_test.go @@ -8,7 +8,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -58,8 +58,8 @@ func TestRedirectHTTPSWWWRedirect(t *testing.T) { }, { whenHost: "www.labstack.com", - expectLocation: "", - expectStatusCode: http.StatusOK, + expectLocation: "https://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, }, { whenHost: "a.com", @@ -74,6 +74,12 @@ func TestRedirectHTTPSWWWRedirect(t *testing.T) { { whenHost: "labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://www.labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "www.labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "", expectStatusCode: http.StatusOK, }, @@ -114,6 +120,12 @@ func TestRedirectHTTPSNonWWWRedirect(t *testing.T) { { whenHost: "www.labstack.com", whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, + expectLocation: "https://labstack.com/", + expectStatusCode: http.StatusMovedPermanently, + }, + { + whenHost: "labstack.com", + whenHeader: map[string][]string{echo.HeaderXForwardedProto: {"https"}}, expectLocation: "", expectStatusCode: http.StatusOK, }, @@ -218,7 +230,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) { var testCases = []struct { name string givenCode int - givenSkipFunc func(c echo.Context) bool + givenSkipFunc func(c *echo.Context) bool whenHost string whenHeader http.Header expectLocation string @@ -232,7 +244,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) { }, { name: "redirect is skipped", - givenSkipFunc: func(c echo.Context) bool { + givenSkipFunc: func(c *echo.Context) bool { return true // skip always }, whenHost: "www.labstack.com", @@ -266,7 +278,7 @@ func TestNonWWWRedirectWithConfig(t *testing.T) { func redirectTest(fn middlewareGenerator, host string, header http.Header) *httptest.ResponseRecorder { e := echo.New() - next := func(c echo.Context) (err error) { + next := func(c *echo.Context) (err error) { return c.NoContent(http.StatusOK) } req := httptest.NewRequest(http.MethodGet, "/", nil) diff --git a/middleware/request_id.go b/middleware/request_id.go index 14bd4fd15..b3de40d19 100644 --- a/middleware/request_id.go +++ b/middleware/request_id.go @@ -4,7 +4,7 @@ package middleware import ( - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // RequestIDConfig defines the config for RequestID middleware. @@ -13,43 +13,45 @@ type RequestIDConfig struct { Skipper Skipper // Generator defines a function to generate an ID. - // Optional. Defaults to generator for random string of length 32. + // Optional. Default value random.String(32). Generator func() string // RequestIDHandler defines a function which is executed for a request id. - RequestIDHandler func(echo.Context, string) + RequestIDHandler func(c *echo.Context, requestID string) - // TargetHeader defines what header to look for to populate the id + // TargetHeader defines what header to look for to populate the id. + // Optional. Default value is `X-Request-Id` TargetHeader string } -// DefaultRequestIDConfig is the default RequestID middleware config. -var DefaultRequestIDConfig = RequestIDConfig{ - Skipper: DefaultSkipper, - Generator: generator, - TargetHeader: echo.HeaderXRequestID, -} - -// RequestID returns a X-Request-ID middleware. +// RequestID returns a middleware that reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when +// the header value is empty, generates that value and sets request ID to response +// as RequestIDConfig.TargetHeader (`X-Request-Id`) value. func RequestID() echo.MiddlewareFunc { - return RequestIDWithConfig(DefaultRequestIDConfig) + return RequestIDWithConfig(RequestIDConfig{}) } -// RequestIDWithConfig returns a X-Request-ID middleware with config. +// RequestIDWithConfig returns a middleware with given valid config or panics on invalid configuration. +// The middleware reads RequestIDConfig.TargetHeader (`X-Request-ID`) header value or when the header value is empty, +// generates that value and sets request ID to response as RequestIDConfig.TargetHeader (`X-Request-Id`) value. func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts RequestIDConfig to middleware or returns an error for invalid configuration +func (config RequestIDConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultRequestIDConfig.Skipper + config.Skipper = DefaultSkipper } if config.Generator == nil { - config.Generator = generator + config.Generator = createRandomStringGenerator(32) } if config.TargetHeader == "" { config.TargetHeader = echo.HeaderXRequestID } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -67,9 +69,5 @@ func RequestIDWithConfig(config RequestIDConfig) echo.MiddlewareFunc { return next(c) } - } -} - -func generator() string { - return randomString(32) + }, nil } diff --git a/middleware/request_id_test.go b/middleware/request_id_test.go index 4e68b126a..465e6fc42 100644 --- a/middleware/request_id_test.go +++ b/middleware/request_id_test.go @@ -8,7 +8,7 @@ import ( "net/http/httptest" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -17,29 +17,108 @@ func TestRequestID(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } - rid := RequestIDWithConfig(RequestIDConfig{}) + rid := RequestID() + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) +} + +func TestMustRequestIDWithConfig_skipper(t *testing.T) { + e := echo.New() + e.GET("/", func(c *echo.Context) error { + return c.String(http.StatusTeapot, "test") + }) + + generatorCalled := false + e.Use(RequestIDWithConfig(RequestIDConfig{ + Skipper: func(c *echo.Context) bool { + return true + }, + Generator: func() string { + generatorCalled = true + return "customGenerator" + }, + })) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + res := httptest.NewRecorder() + e.ServeHTTP(res, req) + + assert.Equal(t, http.StatusTeapot, res.Code) + assert.Equal(t, "test", res.Body.String()) + + assert.Equal(t, res.Header().Get(echo.HeaderXRequestID), "") + assert.False(t, generatorCalled) +} + +func TestMustRequestIDWithConfig_customGenerator(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") +} + +func TestMustRequestIDWithConfig_RequestIDHandler(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + called := false + rid := RequestIDWithConfig(RequestIDConfig{ + Generator: func() string { return "customGenerator" }, + RequestIDHandler: func(c *echo.Context, s string) { + called = true + }, + }) + h := rid(handler) + err := h(c) + assert.NoError(t, err) + assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") + assert.True(t, called) +} + +func TestRequestIDWithConfig(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := func(c *echo.Context) error { + return c.String(http.StatusOK, "test") + } + + rid, err := RequestIDConfig{}.ToMiddleware() + assert.NoError(t, err) h := rid(handler) h(c) assert.Len(t, rec.Header().Get(echo.HeaderXRequestID), 32) - // Custom generator and handler - customID := "customGenerator" - calledHandler := false + // Custom generator rid = RequestIDWithConfig(RequestIDConfig{ - Generator: func() string { return customID }, - RequestIDHandler: func(_ echo.Context, id string) { - calledHandler = true - assert.Equal(t, customID, id) - }, + Generator: func() string { return "customGenerator" }, }) h = rid(handler) h(c) assert.Equal(t, rec.Header().Get(echo.HeaderXRequestID), "customGenerator") - assert.True(t, calledHandler) } func TestRequestID_IDNotAltered(t *testing.T) { @@ -49,7 +128,7 @@ func TestRequestID_IDNotAltered(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } @@ -64,7 +143,7 @@ func TestRequestIDConfigDifferentHeader(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) - handler := func(c echo.Context) error { + handler := func(c *echo.Context) error { return c.String(http.StatusOK, "test") } @@ -79,7 +158,7 @@ func TestRequestIDConfigDifferentHeader(t *testing.T) { rid = RequestIDWithConfig(RequestIDConfig{ Generator: func() string { return customID }, TargetHeader: echo.HeaderXCorrelationID, - RequestIDHandler: func(_ echo.Context, id string) { + RequestIDHandler: func(_ *echo.Context, id string) { calledHandler = true assert.Equal(t, customID, id) }, diff --git a/middleware/request_logger.go b/middleware/request_logger.go index 211abf464..76903c62a 100644 --- a/middleware/request_logger.go +++ b/middleware/request_logger.go @@ -10,7 +10,7 @@ import ( "net/http" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // Example for `slog` https://pkg.go.dev/log/slog @@ -18,9 +18,8 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogStatus: true, // LogURI: true, -// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", // slog.String("uri", v.URI), @@ -41,9 +40,8 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogStatus: true, // LogURI: true, -// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // fmt.Printf("REQUEST: uri: %v, status: %v\n", v.URI, v.Status) // } else { @@ -58,9 +56,8 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, -// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.Info(). // Str("URI", v.URI). @@ -82,9 +79,8 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, -// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // logger.Info("request", // zap.String("URI", v.URI), @@ -106,9 +102,8 @@ import ( // e.Use(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ // LogURI: true, // LogStatus: true, -// LogError: true, // HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code -// LogValuesFunc: func(c echo.Context, v middleware.RequestLoggerValues) error { +// LogValuesFunc: func(c *echo.Context, v middleware.RequestLoggerValues) error { // if v.Error == nil { // log.WithFields(logrus.Fields{ // "URI": v.URI, @@ -131,10 +126,10 @@ type RequestLoggerConfig struct { Skipper Skipper // BeforeNextFunc defines a function that is called before next middleware or handler is called in chain. - BeforeNextFunc func(c echo.Context) + BeforeNextFunc func(c *echo.Context) // LogValuesFunc defines a function that is called with values extracted by logger from request/response. // Mandatory. - LogValuesFunc func(c echo.Context, v RequestLoggerValues) error + LogValuesFunc func(c *echo.Context, v RequestLoggerValues) error // HandleError instructs logger to call global error handler when next middleware/handler returns an error. // This is useful when you have custom error handler that can decide to use different status codes. @@ -168,8 +163,6 @@ type RequestLoggerConfig struct { // LogStatus instructs logger to extract response status code. If handler chain returns an echo.HTTPError, // the status code is extracted from the echo.HTTPError returned LogStatus bool - // LogError instructs logger to extract error returned from executed handler chain. - LogError bool // LogContentLength instructs logger to extract content length header value. Note: this value could be different from // actual request body size as it could be spoofed etc. LogContentLength bool @@ -228,15 +221,15 @@ type RequestLoggerValues struct { // ResponseSize is response content length value. Note: when used with Gzip middleware this value may not be always correct. ResponseSize int64 // Headers are list of headers from request. Note: request can contain more than one header with same value so slice - // of values is been logger for each given header. + // of values is what will be returned/logged for each given header. // Note: header values are converted to canonical form with http.CanonicalHeaderKey as this how request parser converts header // names to. For example, the canonical key for "accept-encoding" is "Accept-Encoding". Headers map[string][]string // QueryParams are list of query parameters from request URI. Note: request can contain more than one query parameter - // with same name so slice of values is been logger for each given query param name. + // with same name so slice of values is what will be returned/logged for each given query param name. QueryParams map[string][]string // FormValues are list of form values from request body+URI. Note: request can contain more than one form value with - // same name so slice of values is been logger for each given form value name. + // same name so slice of values is what will be returned/logged for each given form value name. FormValues map[string][]string } @@ -249,72 +242,6 @@ func RequestLoggerWithConfig(config RequestLoggerConfig) echo.MiddlewareFunc { return mw } -// RequestLogger returns a RequestLogger middleware with default configuration which -// uses default slog.slog logger. -// -// To customize slog output format replace slog default logger: -// For JSON format: `slog.SetDefault(slog.New(slog.NewJSONHandler(os.Stdout, nil)))` -func RequestLogger() echo.MiddlewareFunc { - config := RequestLoggerConfig{ - LogLatency: true, - LogProtocol: false, - LogRemoteIP: true, - LogHost: true, - LogMethod: true, - LogURI: true, - LogURIPath: false, - LogRoutePath: false, - LogRequestID: true, - LogReferer: false, - LogUserAgent: true, - LogStatus: true, - LogError: true, - LogContentLength: true, - LogResponseSize: true, - LogHeaders: nil, - LogQueryParams: nil, - LogFormValues: nil, - HandleError: true, // forwards error to the global error handler, so it can decide appropriate status code - LogValuesFunc: func(c echo.Context, v RequestLoggerValues) error { - if v.Error == nil { - slog.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", - slog.String("method", v.Method), - slog.String("uri", v.URI), - slog.Int("status", v.Status), - slog.Duration("latency", v.Latency), - slog.String("host", v.Host), - slog.String("bytes_in", v.ContentLength), - slog.Int64("bytes_out", v.ResponseSize), - slog.String("user_agent", v.UserAgent), - slog.String("remote_ip", v.RemoteIP), - slog.String("request_id", v.RequestID), - ) - } else { - slog.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", - slog.String("method", v.Method), - slog.String("uri", v.URI), - slog.Int("status", v.Status), - slog.Duration("latency", v.Latency), - slog.String("host", v.Host), - slog.String("bytes_in", v.ContentLength), - slog.Int64("bytes_out", v.ResponseSize), - slog.String("user_agent", v.UserAgent), - slog.String("remote_ip", v.RemoteIP), - slog.String("request_id", v.RequestID), - - slog.String("error", v.Error.Error()), - ) - } - return nil - }, - } - mw, err := config.ToMiddleware() - if err != nil { - panic(err) - } - return mw -} - // ToMiddleware converts RequestLoggerConfig into middleware or returns an error for invalid configuration. func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { @@ -339,7 +266,7 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { logFormValues := len(config.LogFormValues) > 0 return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { + return func(c *echo.Context) error { if config.Skipper(c) { return next(c) } @@ -353,7 +280,9 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { } err := next(c) if err != nil && config.HandleError { - c.Error(err) + // When global error handler writes the error to the client the Response gets "committed". This state can be + // checked with `c.Response().Committed` field. + c.Echo().HTTPErrorHandler(c, err) } v := RequestLoggerValues{ @@ -400,25 +329,41 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.LogUserAgent { v.UserAgent = req.UserAgent() } + + var resp *echo.Response + if config.LogStatus || config.LogResponseSize { + if r, err := echo.UnwrapResponse(res); err != nil { + c.Logger().Error("can not determine response status and/or size. ResponseWriter in context does not implement unwrapper interface") + } else { + resp = r + } + } + if config.LogStatus { - v.Status = res.Status + v.Status = -1 + if resp != nil { + v.Status = resp.Status + } if err != nil && !config.HandleError { // this block should not be executed in case of HandleError=true as the global error handler will decide // the status code. In that case status code could be different from what err contains. - var httpErr *echo.HTTPError - if errors.As(err, &httpErr) { - v.Status = httpErr.Code + var hsc echo.HTTPStatusCoder + if errors.As(err, &hsc) { + v.Status = hsc.StatusCode() } } } - if config.LogError && err != nil { + if err != nil { v.Error = err } if config.LogContentLength { v.ContentLength = req.Header.Get(echo.HeaderContentLength) } if config.LogResponseSize { - v.ResponseSize = res.Size + v.ResponseSize = -1 + if resp != nil { + v.ResponseSize = resp.Size + } } if logHeaders { v.Headers = map[string][]string{} @@ -449,11 +394,69 @@ func (config RequestLoggerConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if errOnLog := config.LogValuesFunc(c, v); errOnLog != nil { return errOnLog } - // in case of HandleError=true we are returning the error that we already have handled with global error handler // this is deliberate as this error could be useful for upstream middlewares and default global error handler // will ignore that error when it bubbles up in middleware chain. + // Committed response can be checked in custom error handler with following logic + // + // if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed { + // return + // } return err } }, nil } + +// RequestLogger creates Request Logger middleware with Echo default settings that uses Context.Logger() as logger. +func RequestLogger() echo.MiddlewareFunc { + return RequestLoggerWithConfig(RequestLoggerConfig{ + LogLatency: true, + LogRemoteIP: true, + LogHost: true, + LogMethod: true, + LogURI: true, + LogRequestID: true, + LogUserAgent: true, + LogStatus: true, + LogContentLength: true, + LogResponseSize: true, + // forwards error to the global error handler, so it can decide appropriate status code. + // NB: side-effect of that is - request is now "commited" written to the client. Middlewares up in chain can not + // change Response status code or response body. + HandleError: true, + LogValuesFunc: func(c *echo.Context, v RequestLoggerValues) error { + logger := c.Logger() + if v.Error == nil { + logger.LogAttrs(context.Background(), slog.LevelInfo, "REQUEST", + slog.String("method", v.Method), + slog.String("uri", v.URI), + slog.Int("status", v.Status), + slog.Duration("latency", v.Latency), + slog.String("host", v.Host), + slog.String("bytes_in", v.ContentLength), + slog.Int64("bytes_out", v.ResponseSize), + slog.String("user_agent", v.UserAgent), + slog.String("remote_ip", v.RemoteIP), + slog.String("request_id", v.RequestID), + ) + return nil + } + + logger.LogAttrs(context.Background(), slog.LevelError, "REQUEST_ERROR", + slog.String("method", v.Method), + slog.String("uri", v.URI), + slog.Int("status", v.Status), + slog.Duration("latency", v.Latency), + slog.String("host", v.Host), + slog.String("bytes_in", v.ContentLength), + slog.Int64("bytes_out", v.ResponseSize), + slog.String("user_agent", v.UserAgent), + slog.String("remote_ip", v.RemoteIP), + slog.String("request_id", v.RequestID), + + slog.String("error", v.Error.Error()), + ) + return nil + }, + }) +} diff --git a/middleware/request_logger_test.go b/middleware/request_logger_test.go index 510d34edd..af39eb32a 100644 --- a/middleware/request_logger_test.go +++ b/middleware/request_logger_test.go @@ -16,7 +16,7 @@ import ( "testing" "time" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -26,13 +26,12 @@ func TestRequestLoggerOK(t *testing.T) { slog.SetDefault(old) }) - buf := new(bytes.Buffer) - slog.SetDefault(slog.New(slog.NewJSONHandler(buf, nil))) - e := echo.New() + buf := new(bytes.Buffer) + e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) e.Use(RequestLogger()) - e.POST("/test", func(c echo.Context) error { + e.POST("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -76,13 +75,12 @@ func TestRequestLoggerError(t *testing.T) { slog.SetDefault(old) }) - buf := new(bytes.Buffer) - slog.SetDefault(slog.New(slog.NewJSONHandler(buf, nil))) - e := echo.New() + buf := new(bytes.Buffer) + e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) e.Use(RequestLogger()) - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { return errors.New("nope") }) req := httptest.NewRequest(http.MethodGet, "/test", nil) @@ -121,13 +119,13 @@ func TestRequestLoggerWithConfig(t *testing.T) { e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogRoutePath: true, LogURI: true, - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return nil }, })) - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -153,16 +151,16 @@ func TestRequestLogger_skipper(t *testing.T) { loggerCalled := false e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ - Skipper: func(c echo.Context) bool { + Skipper: func(c *echo.Context) bool { return true }, - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { loggerCalled = true return nil }, })) - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -180,16 +178,16 @@ func TestRequestLogger_beforeNextFunc(t *testing.T) { var myLoggerInstance int e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ - BeforeNextFunc: func(c echo.Context) { + BeforeNextFunc: func(c *echo.Context) { c.Set("myLoggerInstance", 42) }, - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { myLoggerInstance = c.Get("myLoggerInstance").(int) return nil }, })) - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -207,15 +205,14 @@ func TestRequestLogger_logError(t *testing.T) { var actual RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ - LogError: true, LogStatus: true, - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { actual = values return nil }, })) - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { return echo.NewHTTPError(http.StatusNotAcceptable, "nope") }) @@ -238,23 +235,22 @@ func TestRequestLogger_HandleError(t *testing.T) { return time.Unix(1631045377, 0).UTC() }, HandleError: true, - LogError: true, LogStatus: true, - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { actual = values return nil }, })) // to see if "HandleError" works we create custom error handler that uses its own status codes - e.HTTPErrorHandler = func(err error, c echo.Context) { - if c.Response().Committed { + e.HTTPErrorHandler = func(c *echo.Context, err error) { + if r, _ := echo.UnwrapResponse(c.Response()); r != nil && r.Committed { return } c.JSON(http.StatusTeapot, "custom error handler") } - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { return echo.NewHTTPError(http.StatusForbidden, "nope") }) @@ -278,15 +274,14 @@ func TestRequestLogger_LogValuesFuncError(t *testing.T) { var expect RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ - LogError: true, LogStatus: true, - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return echo.NewHTTPError(http.StatusNotAcceptable, "LogValuesFuncError") }, })) - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { return c.String(http.StatusTeapot, "OK") }) @@ -327,13 +322,13 @@ func TestRequestLogger_ID(t *testing.T) { var expect RequestLoggerValues e.Use(RequestLoggerWithConfig(RequestLoggerConfig{ LogRequestID: true, - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return nil }, })) - e.GET("/test", func(c echo.Context) error { + e.GET("/test", func(c *echo.Context) error { c.Response().Header().Set(echo.HeaderXRequestID, "321") return c.String(http.StatusTeapot, "OK") }) @@ -357,12 +352,12 @@ func TestRequestLogger_headerIsCaseInsensitive(t *testing.T) { var expect RequestLoggerValues mw := RequestLoggerWithConfig(RequestLoggerConfig{ - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return nil }, LogHeaders: []string{"referer", "User-Agent"}, - })(func(c echo.Context) error { + })(func(c *echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") @@ -387,7 +382,7 @@ func TestRequestLogger_allFields(t *testing.T) { isFirstNowCall := true var expect RequestLoggerValues mw := RequestLoggerWithConfig(RequestLoggerConfig{ - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { expect = values return nil }, @@ -403,7 +398,6 @@ func TestRequestLogger_allFields(t *testing.T) { LogReferer: true, LogUserAgent: true, LogStatus: true, - LogError: true, LogContentLength: true, LogResponseSize: true, LogHeaders: []string{"accept-encoding", "User-Agent"}, @@ -416,7 +410,7 @@ func TestRequestLogger_allFields(t *testing.T) { } return time.Unix(1631045377+10, 0) }, - })(func(c echo.Context) error { + })(func(c *echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") @@ -471,12 +465,86 @@ func TestRequestLogger_allFields(t *testing.T) { assert.Equal(t, []string{"1", "2"}, expect.FormValues["multiple"]) } +func TestTestRequestLogger(t *testing.T) { + var testCases = []struct { + name string + whenStatus int + whenError error + expectStatus string + expectError string + }{ + { + name: "ok", + whenStatus: http.StatusTeapot, + expectStatus: "418", + }, + { + name: "error", + whenError: echo.NewHTTPError(http.StatusBadGateway, "bad gw"), + expectStatus: "502", + expectError: `"error":"code=502, message=bad gw"`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + buf := new(bytes.Buffer) + e.Logger = slog.New(slog.NewJSONHandler(buf, nil)) + + e.Use(RequestLogger()) + e.POST("/test", func(c *echo.Context) error { + if tc.whenError != nil { + return tc.whenError + } + return c.String(tc.whenStatus, "OK") + }) + + f := make(url.Values) + f.Set("csrf", "token") + f.Set("multiple", "1") + f.Add("multiple", "2") + reader := strings.NewReader(f.Encode()) + req := httptest.NewRequest(http.MethodPost, "/test?lang=en&checked=1&checked=2", reader) + req.Header.Set("Referer", "https://echo.labstack.com/") + req.Header.Set("User-Agent", "curl/7.68.0") + req.Header.Set(echo.HeaderContentLength, strconv.Itoa(int(reader.Size()))) + req.Header.Set(echo.HeaderContentType, echo.MIMEApplicationForm) + req.Header.Set(echo.HeaderXRealIP, "8.8.8.8") + req.Header.Set(echo.HeaderXRequestID, "MY_ID") + + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + rawlog := buf.Bytes() + if tc.expectError != "" { + assert.Contains(t, string(rawlog), `"level":"ERROR"`) + assert.Contains(t, string(rawlog), `"msg":"REQUEST_ERROR"`) + assert.Contains(t, string(rawlog), tc.expectError) + } else { + assert.Contains(t, string(rawlog), `"level":"INFO"`) + assert.Contains(t, string(rawlog), `"msg":"REQUEST"`) + } + assert.Contains(t, string(rawlog), `"status":`+tc.expectStatus) + assert.Contains(t, string(rawlog), `"method":"POST"`) + assert.Contains(t, string(rawlog), `"uri":"/test?lang=en&checked=1&checked=2"`) + assert.Contains(t, string(rawlog), `"latency":`) // this value varies + assert.Contains(t, string(rawlog), `"request_id":"MY_ID"`) + assert.Contains(t, string(rawlog), `"remote_ip":"8.8.8.8"`) + assert.Contains(t, string(rawlog), `"host":"example.com"`) + assert.Contains(t, string(rawlog), `"user_agent":"curl/7.68.0"`) + assert.Contains(t, string(rawlog), `"bytes_in":"32"`) + assert.Contains(t, string(rawlog), `"bytes_out":2`) + }) + } +} + func BenchmarkRequestLogger_withoutMapFields(b *testing.B) { e := echo.New() mw := RequestLoggerWithConfig(RequestLoggerConfig{ Skipper: nil, - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { return nil }, LogLatency: true, @@ -491,10 +559,9 @@ func BenchmarkRequestLogger_withoutMapFields(b *testing.B) { LogReferer: true, LogUserAgent: true, LogStatus: true, - LogError: true, LogContentLength: true, LogResponseSize: true, - })(func(c echo.Context) error { + })(func(c *echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") return c.String(http.StatusTeapot, "OK") }) @@ -517,7 +584,7 @@ func BenchmarkRequestLogger_withMapFields(b *testing.B) { e := echo.New() mw := RequestLoggerWithConfig(RequestLoggerConfig{ - LogValuesFunc: func(c echo.Context, values RequestLoggerValues) error { + LogValuesFunc: func(c *echo.Context, values RequestLoggerValues) error { return nil }, LogLatency: true, @@ -532,13 +599,12 @@ func BenchmarkRequestLogger_withMapFields(b *testing.B) { LogReferer: true, LogUserAgent: true, LogStatus: true, - LogError: true, LogContentLength: true, LogResponseSize: true, LogHeaders: []string{"accept-encoding", "User-Agent"}, LogQueryParams: []string{"lang", "checked"}, LogFormValues: []string{"csrf", "multiple"}, - })(func(c echo.Context) error { + })(func(c *echo.Context) error { c.Request().Header.Set(echo.HeaderXRequestID, "123") c.FormValue("to force parse form") return c.String(http.StatusTeapot, "OK") diff --git a/middleware/rewrite.go b/middleware/rewrite.go index 4c19cc1cc..ea58091b0 100644 --- a/middleware/rewrite.go +++ b/middleware/rewrite.go @@ -4,9 +4,10 @@ package middleware import ( + "errors" "regexp" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // RewriteConfig defines the config for Rewrite middleware. @@ -22,40 +23,39 @@ type RewriteConfig struct { // "/js/*": "/public/javascripts/$1", // "/users/*/orders/*": "/user/$1/order/$2", // Required. - Rules map[string]string `yaml:"rules"` + Rules map[string]string // RegexRules defines the URL path rewrite rules using regexp.Rexexp with captures // Every capture group in the values can be retrieved by index e.g. $1, $2 and so on. // Example: // "^/old/[0.9]+/": "/new", // "^/api/.+?/(.*)": "/v2/$1", - RegexRules map[*regexp.Regexp]string `yaml:"-"` -} - -// DefaultRewriteConfig is the default Rewrite middleware config. -var DefaultRewriteConfig = RewriteConfig{ - Skipper: DefaultSkipper, + RegexRules map[*regexp.Regexp]string } // Rewrite returns a Rewrite middleware. // // Rewrite middleware rewrites the URL path based on the provided rules. func Rewrite(rules map[string]string) echo.MiddlewareFunc { - c := DefaultRewriteConfig + c := RewriteConfig{} c.Rules = rules return RewriteWithConfig(c) } -// RewriteWithConfig returns a Rewrite middleware with config. -// See: `Rewrite()`. +// RewriteWithConfig returns a Rewrite middleware or panics on invalid configuration. +// +// Rewrite middleware rewrites the URL path based on the provided rules. func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { - // Defaults - if config.Rules == nil && config.RegexRules == nil { - panic("echo: rewrite middleware requires url path rewrite rules or regex rules") - } + return toMiddlewareOrPanic(config) +} +// ToMiddleware converts RewriteConfig to middleware or returns an error for invalid configuration +func (config RewriteConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { - config.Skipper = DefaultBodyDumpConfig.Skipper + config.Skipper = DefaultSkipper + } + if config.Rules == nil && config.RegexRules == nil { + return nil, errors.New("echo rewrite middleware requires url path rewrite rules or regex rules") } if config.RegexRules == nil { @@ -66,7 +66,7 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) (err error) { + return func(c *echo.Context) (err error) { if config.Skipper(c) { return next(c) } @@ -76,5 +76,5 @@ func RewriteWithConfig(config RewriteConfig) echo.MiddlewareFunc { } return next(c) } - } + }, nil } diff --git a/middleware/rewrite_test.go b/middleware/rewrite_test.go index d137b2d13..f45b8d98a 100644 --- a/middleware/rewrite_test.go +++ b/middleware/rewrite_test.go @@ -11,7 +11,7 @@ import ( "regexp" "testing" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" "github.com/stretchr/testify/assert" ) @@ -26,10 +26,10 @@ func TestRewriteAfterRouting(t *testing.T) { "/users/*/orders/*": "/user/$1/order/$2", }, })) - e.GET("/public/*", func(c echo.Context) error { + e.GET("/public/*", func(c *echo.Context) error { return c.String(http.StatusOK, c.Param("*")) }) - e.GET("/*", func(c echo.Context) error { + e.GET("/*", func(c *echo.Context) error { return c.String(http.StatusOK, c.Param("*")) }) @@ -93,20 +93,74 @@ func TestRewriteAfterRouting(t *testing.T) { } } +func TestMustRewriteWithConfig_emptyRulesPanics(t *testing.T) { + assert.Panics(t, func() { + RewriteWithConfig(RewriteConfig{}) + }) +} + +func TestMustRewriteWithConfig_skipper(t *testing.T) { + var testCases = []struct { + name string + givenSkipper func(c *echo.Context) bool + whenURL string + expectURL string + expectStatus int + }{ + { + name: "not skipped", + whenURL: "/old", + expectURL: "/new", + expectStatus: http.StatusOK, + }, + { + name: "skipped", + givenSkipper: func(c *echo.Context) bool { + return true + }, + whenURL: "/old", + expectURL: "/old", + expectStatus: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + e.Pre(RewriteWithConfig( + RewriteConfig{ + Skipper: tc.givenSkipper, + Rules: map[string]string{"/old": "/new"}}, + )) + + e.GET("/new", func(c *echo.Context) error { + return c.NoContent(http.StatusOK) + }) + + req := httptest.NewRequest(http.MethodGet, tc.whenURL, nil) + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, tc.expectURL, req.URL.EscapedPath()) + assert.Equal(t, tc.expectStatus, rec.Code) + }) + } +} + // Issue #1086 func TestEchoRewritePreMiddleware(t *testing.T) { e := echo.New() - r := e.Router() // Rewrite old url to new one // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches - e.Pre(Rewrite(map[string]string{ - "/old": "/new", - }, - )) + e.Pre(RewriteWithConfig(RewriteConfig{ + Rules: map[string]string{"/old": "/new"}}), + ) // Route - r.Add(http.MethodGet, "/new", func(c echo.Context) error { + e.Add(http.MethodGet, "/new", func(c *echo.Context) error { return c.NoContent(http.StatusOK) }) @@ -120,7 +174,6 @@ func TestEchoRewritePreMiddleware(t *testing.T) { // Issue #1143 func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { e := echo.New() - r := e.Router() // middlewares added with `Pre()` are executed before routing is done and therefore change which handler matches e.Pre(RewriteWithConfig(RewriteConfig{ @@ -130,10 +183,10 @@ func TestRewriteWithConfigPreMiddleware_Issue1143(t *testing.T) { }, })) - r.Add(http.MethodGet, "/api/:version/hosts/:name", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/hosts/:name", func(c *echo.Context) error { return c.String(http.StatusOK, "hosts") }) - r.Add(http.MethodGet, "/api/:version/eng", func(c echo.Context) error { + e.Add(http.MethodGet, "/api/:version/eng", func(c *echo.Context) error { return c.String(http.StatusOK, "eng") }) diff --git a/middleware/secure.go b/middleware/secure.go index c904abf1a..bd389f7ae 100644 --- a/middleware/secure.go +++ b/middleware/secure.go @@ -6,7 +6,7 @@ package middleware import ( "fmt" - "github.com/labstack/echo/v4" + "github.com/labstack/echo/v5" ) // SecureConfig defines the config for Secure middleware. @@ -17,12 +17,12 @@ type SecureConfig struct { // XSSProtection provides protection against cross-site scripting attack (XSS) // by setting the `X-XSS-Protection` header. // Optional. Default value "1; mode=block". - XSSProtection string `yaml:"xss_protection"` + XSSProtection string // ContentTypeNosniff provides protection against overriding Content-Type // header by setting the `X-Content-Type-Options` header. // Optional. Default value "nosniff". - ContentTypeNosniff string `yaml:"content_type_nosniff"` + ContentTypeNosniff string // XFrameOptions can be used to indicate whether or not a browser should // be allowed to render a page in a ,