Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 bug: Fix square bracket notation in Multipart FormData #3235

Merged
merged 13 commits into from
Dec 31, 2024
103 changes: 100 additions & 3 deletions bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"mime/multipart"
"net/http/httptest"
"reflect"
"testing"
Expand Down Expand Up @@ -886,7 +887,8 @@ func Test_Bind_Body(t *testing.T) {
reqBody := []byte(`{"name":"john"}`)

type Demo struct {
Name string `json:"name" xml:"name" form:"name" query:"name"`
Name string `json:"name" xml:"name" form:"name" query:"name"`
Names []string `json:"names" xml:"names" form:"names" query:"names"`
}

// Helper function to test compressed bodies
Expand Down Expand Up @@ -996,6 +998,48 @@ func Test_Bind_Body(t *testing.T) {
Data []Demo `query:"data"`
}

t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data.0.name", "john"))
require.NoError(t, writer.WriteField("data.1.name", "doe"))
require.NoError(t, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})

t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data[0][name]", "john"))
require.NoError(t, writer.WriteField("data[1][name]", "doe"))
require.NoError(t, writer.Close())

c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))

cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})

t.Run("CollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()
Expand Down Expand Up @@ -1192,9 +1236,57 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
Name string `form:"name"`
}

body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.Close())
body := buf.Bytes()

c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)

b.ReportAllocs()
b.ResetTimer()

for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}

require.NoError(b, err)
require.Equal(b, "john", d.Name)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4
func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) {
var err error

app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})

type Person struct {
Name string `form:"name"`
Age int `form:"age"`
}

type Demo struct {
Name string `form:"name"`
Persons []Person `form:"persons"`
}

buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.WriteField("persons.0.name", "john"))
require.NoError(b, writer.WriteField("persons[0][age]", "10"))
require.NoError(b, writer.WriteField("persons[1][name]", "doe"))
require.NoError(b, writer.WriteField("persons.1.age", "20"))
require.NoError(b, writer.Close())
body := buf.Bytes()

c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)

Expand All @@ -1204,8 +1296,13 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}

require.NoError(b, err)
require.Equal(b, "john", d.Name)
require.Equal(b, "john", d.Persons[0].Name)
require.Equal(b, 10, d.Persons[0].Age)
require.Equal(b, "doe", d.Persons[1].Name)
require.Equal(b, 20, d.Persons[1].Age)
}

// go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4
Expand Down
13 changes: 1 addition & 12 deletions binder/cookie.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -30,15 +27,7 @@ func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error {

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
})

if err != nil {
Expand Down
36 changes: 22 additions & 14 deletions binder/form.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,7 @@

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
})

if err != nil {
Expand All @@ -66,7 +54,27 @@
return err
}

return parse(b.Name(), out, data.Value)
temp := make(map[string][]string)
for key, values := range data.Value {
ReneWerner87 marked this conversation as resolved.
Show resolved Hide resolved
if strings.Contains(key, "[") {
k, err := parseParamSquareBrackets(key)
if err != nil {
return err
}

Check warning on line 63 in binder/form.go

View check run for this annotation

Codecov / codecov/patch

binder/form.go#L62-L63

Added lines #L62 - L63 were not covered by tests

key = k // We have to update key in case bracket notation and slice type are used at the same time
}

for _, v := range values {
if strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, key) {
temp[key] = strings.Split(v, ",")
} else {
temp[key] = append(temp[key], v)
}
}
}

return parse(b.Name(), out, temp)
}

// Reset resets the FormBinding binder.
Expand Down
16 changes: 15 additions & 1 deletion binder/form_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,14 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
}
require.Equal(t, "form", b.Name())

type Post struct {
Title string `form:"title"`
}

type User struct {
Name string `form:"name"`
Names []string `form:"names"`
Posts []Post `form:"posts"`
Age int `form:"age"`
}
var user User
Expand All @@ -106,9 +111,13 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
mw := multipart.NewWriter(buf)

require.NoError(t, mw.WriteField("name", "john"))
require.NoError(t, mw.WriteField("names", "john"))
require.NoError(t, mw.WriteField("names", "john,eric"))
require.NoError(t, mw.WriteField("names", "doe"))
require.NoError(t, mw.WriteField("age", "42"))
require.NoError(t, mw.WriteField("posts[0][title]", "post1"))
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))

require.NoError(t, mw.Close())

req.Header.SetContentType(mw.FormDataContentType())
Expand All @@ -125,6 +134,11 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
require.Equal(t, 42, user.Age)
require.Contains(t, user.Names, "john")
require.Contains(t, user.Names, "doe")
require.Contains(t, user.Names, "eric")
require.Len(t, user.Posts, 3)
require.Equal(t, "post1", user.Posts[0].Title)
require.Equal(t, "post2", user.Posts[1].Title)
require.Equal(t, "post3", user.Posts[2].Title)
}

func Benchmark_FormBinder_BindMultipart(b *testing.B) {
Expand Down
22 changes: 10 additions & 12 deletions binder/header.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -21,20 +18,21 @@
// Bind parses the request header and returns the result.
func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error {
data := make(map[string][]string)
var err error
req.Header.VisitAll(func(key, val []byte) {
if err != nil {
return
}

Check warning on line 25 in binder/header.go

View check run for this annotation

Codecov / codecov/patch

binder/header.go#L24-L25

Added lines #L24 - L25 were not covered by tests

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
})

if err != nil {
return err
}

Check warning on line 34 in binder/header.go

View check run for this annotation

Codecov / codecov/patch

binder/header.go#L33-L34

Added lines #L33 - L34 were not covered by tests

return parse(b.Name(), out, data)
}

Expand Down
18 changes: 18 additions & 0 deletions binder/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,21 @@ func FilterFlags(content string) string {
}
return content
}

func formatBindData(out any, data map[string][]string, key, value string, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
var err error
if supportBracketNotation && strings.Contains(key, "[") {
key, err = parseParamSquareBrackets(key)
ReneWerner87 marked this conversation as resolved.
Show resolved Hide resolved
}

if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key) {
values := strings.Split(value, ",")
for i := 0; i < len(values); i++ {
data[key] = append(data[key], values[i])
}
} else {
data[key] = append(data[key], value)
}

return err
}
17 changes: 1 addition & 16 deletions binder/query.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -30,19 +27,7 @@ func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error {

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
})

if err != nil {
Expand Down
23 changes: 11 additions & 12 deletions binder/resp_header.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package binder

import (
"reflect"
"strings"

"github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp"
)
Expand All @@ -21,20 +18,22 @@
// Bind parses the response header and returns the result.
func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error {
data := make(map[string][]string)
var err error

resp.Header.VisitAll(func(key, val []byte) {
if err != nil {
return
}

Check warning on line 26 in binder/resp_header.go

View check run for this annotation

Codecov / codecov/patch

binder/resp_header.go#L25-L26

Added lines #L25 - L26 were not covered by tests

k := utils.UnsafeString(key)
v := utils.UnsafeString(val)

if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
})

if err != nil {
return err
}

Check warning on line 35 in binder/resp_header.go

View check run for this annotation

Codecov / codecov/patch

binder/resp_header.go#L34-L35

Added lines #L34 - L35 were not covered by tests

return parse(b.Name(), out, data)
}

Expand Down
Loading