Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Move the common middlewares to go-mod-bootstrap #567

Merged
merged 3 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions bootstrap/handlers/common_middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//
// Copyright (C) 2023 IOTech Ltd
//
// SPDX-License-Identifier: Apache-2.0

package handlers

import (
"context"
"net/http"
"net/url"
"time"

"github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger"
"github.com/edgexfoundry/go-mod-core-contracts/v3/common"
"github.com/edgexfoundry/go-mod-core-contracts/v3/models"

"github.com/google/uuid"
"github.com/labstack/echo/v4"
)

func ManageHeader(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
r := c.Request()
correlationID := r.Header.Get(common.CorrelationHeader)
if correlationID == "" {
correlationID = uuid.New().String()
}
// lint:ignore SA1029 legacy
// nolint:staticcheck // See golangci-lint #741
ctx := context.WithValue(r.Context(), common.CorrelationHeader, correlationID)

contentType := r.Header.Get(common.ContentType)
// lint:ignore SA1029 legacy
// nolint:staticcheck // See golangci-lint #741
ctx = context.WithValue(ctx, common.ContentType, contentType)

c.SetRequest(r.WithContext(ctx))

return next(c)
}
}

func LoggingMiddleware(lc logger.LoggingClient) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if lc.LogLevel() == models.TraceLog {
r := c.Request()
begin := time.Now()
correlationId := FromContext(r.Context())
lc.Trace("Begin request", common.CorrelationHeader, correlationId, "path", r.URL.Path)
err := next(c)
if err != nil {
lc.Errorf("failed to add the middleware: %v", err)
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
}
lc.Trace("Response complete", common.CorrelationHeader, correlationId, "duration", time.Since(begin).String())
return nil
}
return next(c)
}
}
}

// UrlDecodeMiddleware decode the path variables
// After invoking the router.UseEncodedPath() func, the path variables needs to decode before passing to the controller
func UrlDecodeMiddleware(lc logger.LoggingClient) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
var unescapedParams []string
// Retrieve all the url path param names
paramNames := c.ParamNames()

// Retrieve all the url path param values and decode
for k, v := range c.ParamValues() {
unescape, err := url.PathUnescape(v)
if err != nil {
lc.Debugf("failed to decode the %s from the value %s", paramNames[k], v)
return err
}
unescapedParams = append(unescapedParams, unescape)
}
c.SetParamValues(unescapedParams...)
return next(c)
}
}
}

func FromContext(ctx context.Context) string {
hdr, ok := ctx.Value(common.CorrelationHeader).(string)
if !ok {
hdr = ""
}
return hdr
}
88 changes: 88 additions & 0 deletions bootstrap/handlers/common_middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//
// Copyright (C) 2023 IOTech Ltd
//
// SPDX-License-Identifier: Apache-2.0

package handlers

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger"
"github.com/edgexfoundry/go-mod-core-contracts/v3/clients/logger/mocks"
"github.com/edgexfoundry/go-mod-core-contracts/v3/common"

"github.com/labstack/echo/v4"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

var expectedCorrelationId = "927e91d3-864c-4c26-852d-b68c39492d14"

var handler = func(c echo.Context) error {
return c.String(http.StatusOK, "OK")
}

func TestManageHeader(t *testing.T) {
e := echo.New()
e.GET("/", func(c echo.Context) error {
c.Response().Header().Set(common.CorrelationHeader, c.Request().Context().Value(common.CorrelationHeader).(string))
c.Response().Header().Set(common.ContentType, c.Request().Context().Value(common.ContentType).(string))
c.Response().WriteHeader(http.StatusOK)
return nil
})
e.Use(ManageHeader)

req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(common.CorrelationHeader, expectedCorrelationId)
expectedContentType := common.ContentTypeJSON
req.Header.Set(common.ContentType, expectedContentType)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)

assert.Equal(t, http.StatusOK, res.Code)
assert.Equal(t, expectedCorrelationId, res.Header().Get(common.CorrelationHeader))
assert.Equal(t, expectedContentType, res.Header().Get(common.ContentType))
}

func TestLoggingMiddleware(t *testing.T) {
e := echo.New()
e.GET("/", handler)
lcMock := &mocks.LoggingClient{}
lcMock.On("Trace", "Begin request", common.CorrelationHeader, expectedCorrelationId, "path", "/")
lcMock.On("Trace", "Response complete", common.CorrelationHeader, expectedCorrelationId, "duration", mock.Anything)
lcMock.On("LogLevel").Return("TRACE")
e.Use(LoggingMiddleware(lcMock))

req := httptest.NewRequest(http.MethodGet, "/", nil)
// lint:ignore SA1029 legacy
// nolint:staticcheck // See golangci-lint #741
ctx := context.WithValue(req.Context(), common.CorrelationHeader, expectedCorrelationId)
req = req.WithContext(ctx)
res := httptest.NewRecorder()
e.ServeHTTP(res, req)

lcMock.AssertCalled(t, "Trace", "Begin request", common.CorrelationHeader, expectedCorrelationId, "path", "/")
lcMock.AssertCalled(t, "Trace", "Response complete", common.CorrelationHeader, expectedCorrelationId, "duration", mock.Anything)
assert.Equal(t, http.StatusOK, res.Code)
}

func TestUrlDecodeMiddleware(t *testing.T) {
e := echo.New()

req := httptest.NewRequest(http.MethodGet, "/", nil)
res := httptest.NewRecorder()
c := e.NewContext(req, res)
c.SetParamNames("foo")
c.SetParamValues("abc%2F123%25") // the decoded value is abc/123%

lc = logger.NewMockClient()
m := UrlDecodeMiddleware(lc)
err := m(handler)(c)

assert.NoError(t, err)
assert.Equal(t, "abc/123%", c.Param("foo"))
}
5 changes: 5 additions & 0 deletions bootstrap/handlers/httpserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ func (b *HttpServer) BootstrapHandler(
return false
}

// Use the common middlewares
b.router.Use(ManageHeader)
b.router.Use(LoggingMiddleware(lc))
cloudxxx8 marked this conversation as resolved.
Show resolved Hide resolved
b.router.Use(UrlDecodeMiddleware(lc))

timeout, err := time.ParseDuration(bootstrapConfig.Service.RequestTimeout)
if err != nil {
lc.Errorf("unable to parse RequestTimeout value of %s to a duration: %v", bootstrapConfig.Service.RequestTimeout, err)
Expand Down