Skip to content

Commit

Permalink
Merge pull request #110 from zeripath/sanitize-to-writer
Browse files Browse the repository at this point in the history
Add function to sanitize to writer directly
  • Loading branch information
David Kitchen authored Jun 17, 2021
2 parents 7f2aa2d + 6a218ba commit 9eef462
Showing 1 changed file with 94 additions and 34 deletions.
128 changes: 94 additions & 34 deletions sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ package bluemonday

import (
"bytes"
"fmt"
"io"
"net/url"
"regexp"
Expand Down Expand Up @@ -61,7 +62,7 @@ func (p *Policy) Sanitize(s string) string {
return s
}

return p.sanitize(strings.NewReader(s)).String()
return p.sanitizeWithBuff(strings.NewReader(s)).String()
}

// SanitizeBytes takes a []byte that contains a HTML fragment or document and applies
Expand All @@ -75,7 +76,7 @@ func (p *Policy) SanitizeBytes(b []byte) []byte {
return b
}

return p.sanitize(bytes.NewReader(b)).Bytes()
return p.sanitizeWithBuff(bytes.NewReader(b)).Bytes()
}

// SanitizeReader takes an io.Reader that contains a HTML fragment or document
Expand All @@ -84,17 +85,23 @@ func (p *Policy) SanitizeBytes(b []byte) []byte {
// It returns a bytes.Buffer containing the HTML that has been sanitized by the
// policy. Errors during sanitization will merely return an empty result.
func (p *Policy) SanitizeReader(r io.Reader) *bytes.Buffer {
return p.sanitize(r)
return p.sanitizeWithBuff(r)
}

// SanitizeReaderWriter takes an io.Reader that contains a HTML fragment or document
// and applies the given policy whitelist and writes to the provided writer returning
// an error if there is one.
func (p *Policy) SanitizeReaderToWriter(r io.Reader, w io.Writer) error {
return p.sanitize(r, w)
}

const escapedURLChars = "'<>\"\r"

func escapeUrlComponent(val string) string {
w := bytes.NewBufferString("")
func escapeUrlComponent(w stringWriterWriter, val string) error {
i := strings.IndexAny(val, escapedURLChars)
for i != -1 {
if _, err := w.WriteString(val[:i]); err != nil {
return w.String()
return err
}
var esc string
switch val[i] {
Expand All @@ -115,12 +122,12 @@ func escapeUrlComponent(val string) string {
}
val = val[i+1:]
if _, err := w.WriteString(esc); err != nil {
return w.String()
return err
}
i = strings.IndexAny(val, escapedURLChars)
}
w.WriteString(val)
return w.String()
_, err := w.WriteString(val)
return err
}

// Query represents a query
Expand Down Expand Up @@ -206,15 +213,16 @@ func sanitizedURL(val string) (string, error) {
return u.String(), nil
}

func (p *Policy) writeLinkableBuf(buff *bytes.Buffer, token *html.Token) {
func (p *Policy) writeLinkableBuf(buff stringWriterWriter, token *html.Token) (int, error) {
// do not escape multiple query parameters
tokenBuff := bytes.NewBufferString("")
tokenBuff.WriteString("<")
tokenBuff := bytes.NewBuffer(make([]byte, 0, 1024)) // This should stay on the stack unless it gets too big

tokenBuff.WriteByte('<')
tokenBuff.WriteString(token.Data)
for _, attr := range token.Attr {
tokenBuff.WriteByte(' ')
tokenBuff.WriteString(attr.Key)
tokenBuff.WriteString(`="`)
tokenBuff.Write([]byte{'=', '"'})
switch attr.Key {
case "href", "src":
u, ok := p.validURL(attr.Val)
Expand All @@ -239,12 +247,32 @@ func (p *Policy) writeLinkableBuf(buff *bytes.Buffer, token *html.Token) {
tokenBuff.WriteString("/")
}
tokenBuff.WriteString(">")
buff.WriteString(tokenBuff.String())
return buff.Write(tokenBuff.Bytes())
}

// Performs the actual sanitization process.
func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
func (p *Policy) sanitizeWithBuff(r io.Reader) *bytes.Buffer {
var buff bytes.Buffer
if err := p.sanitize(r, &buff); err != nil {
return &bytes.Buffer{}
}
return &buff
}

type stringWriterWriter interface {
io.Writer
io.StringWriter
}

type asStringWriter struct {
io.Writer
}

func (a *asStringWriter) WriteString(s string) (int, error) {
return a.Write([]byte(s))
}

func (p *Policy) sanitize(r io.Reader, w io.Writer) error {
// It is possible that the developer has created the policy via:
// p := bluemonday.Policy{}
// rather than:
Expand All @@ -253,8 +281,12 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
// would initiliaze the maps, then we need to do that.
p.init()

buff, ok := w.(stringWriterWriter)
if !ok {
buff = &asStringWriter{w}
}

var (
buff bytes.Buffer
skipElementContent bool
skippingElementsCount int64
skipClosingTag bool
Expand All @@ -268,11 +300,11 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
err := tokenizer.Err()
if err == io.EOF {
// End of input means end of processing
return &buff
return nil
}

// Raw tokenizer error
return &bytes.Buffer{}
return err
}

token := tokenizer.Token()
Expand Down Expand Up @@ -308,7 +340,9 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
skippingElementsCount++
}
if p.addSpaces {
buff.WriteString(" ")
if _, err := buff.WriteString(" "); err != nil {
return err
}
}
break
}
Expand All @@ -323,7 +357,9 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
skipClosingTag = true
closingTagToSkipStack = append(closingTagToSkipStack, token.Data)
if p.addSpaces {
buff.WriteString(" ")
if _, err := buff.WriteString(" "); err != nil {
return err
}
}
break
}
Expand All @@ -332,9 +368,13 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
if !skipElementContent {
// do not escape multiple query parameters
if linkable(token.Data) {
p.writeLinkableBuf(&buff, &token)
if _, err := p.writeLinkableBuf(buff, &token); err != nil {
return err
}
} else {
buff.WriteString(token.String())
if _, err := buff.WriteString(token.String()); err != nil {
return err
}
}
}

Expand All @@ -350,7 +390,9 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
skipClosingTag = false
}
if p.addSpaces {
buff.WriteString(" ")
if _, err := buff.WriteString(" "); err != nil {
return err
}
}
break
}
Expand All @@ -371,14 +413,18 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
}
if !match {
if p.addSpaces {
buff.WriteString(" ")
if _, err := buff.WriteString(" "); err != nil {
return err
}
}
break
}
}

if !skipElementContent {
buff.WriteString(token.String())
if _, err := buff.WriteString(token.String()); err != nil {
return err
}
}

case html.SelfClosingTagToken:
Expand All @@ -388,7 +434,9 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
aa, matched := p.matchRegex(token.Data)
if !matched {
if p.addSpaces && !matched {
buff.WriteString(" ")
if _, err := buff.WriteString(" "); err != nil {
return err
}
}
break
}
Expand All @@ -401,16 +449,22 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {

if len(token.Attr) == 0 && !p.allowNoAttrs(token.Data) {
if p.addSpaces {
buff.WriteString(" ")
if _, err := buff.WriteString(" "); err != nil {
return err
}
break
}
}
if !skipElementContent {
// do not escape multiple query parameters
if linkable(token.Data) {
p.writeLinkableBuf(&buff, &token)
if _, err := p.writeLinkableBuf(buff, &token); err != nil {
return err
}
} else {
buff.WriteString(token.String())
if _, err := buff.WriteString(token.String()); err != nil {
return err
}
}
}

Expand All @@ -421,20 +475,26 @@ func (p *Policy) sanitize(r io.Reader) *bytes.Buffer {
case `script`:
// not encouraged, but if a policy allows JavaScript we
// should not HTML escape it as that would break the output
buff.WriteString(token.Data)
case `style`:
if _, err := buff.WriteString(token.Data); err != nil {
return err
}
case "style":
// not encouraged, but if a policy allows CSS styles we
// should not HTML escape it as that would break the output
buff.WriteString(token.Data)
if _, err := buff.WriteString(token.Data); err != nil {
return err
}
default:
// HTML escape the text
buff.WriteString(token.String())
if _, err := buff.WriteString(token.String()); err != nil {
return err
}
}
}

default:
// A token that didn't exist in the html package when we wrote this
return &bytes.Buffer{}
return fmt.Errorf("unknown token: %v", token)
}
}
}
Expand Down

0 comments on commit 9eef462

Please sign in to comment.