diff --git a/rest/engine.go b/rest/engine.go index 1fb0f2f40..f7a53987e 100644 --- a/rest/engine.go +++ b/rest/engine.go @@ -28,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature") type engine struct { conf RestConf routes []featuredRoutes - // timeout is the max timeout of all routes + // timeout is the max timeout of all routes, + // and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout. + // this network timeout is used to avoid DoS attacks by sending data slowly + // or receiving data slowly with many connections to exhaust server resources. timeout time.Duration unauthorizedCallback handler.UnauthorizedCallback unsignedCallback handler.UnsignedCallback @@ -60,11 +63,7 @@ func (ng *engine) addRoutes(r featuredRoutes) { } ng.routes = append(ng.routes, r) - // need to guarantee the timeout is the max of all routes - // otherwise impossible to set http.Server.ReadTimeout & WriteTimeout - if r.timeout > ng.timeout { - ng.timeout = r.timeout - } + ng.mightUpdateTimeout(r) } func buildSSERoutes(routes []Route) []Route { @@ -192,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 { return ng.conf.MaxBytes } -func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration { - if timeout > 0 { - return timeout +func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration { + if timeout != nil { + return *timeout } + // if timeout not set in featured routes, use global timeout return time.Duration(ng.conf.Timeout) * time.Millisecond } @@ -232,6 +232,28 @@ func (ng *engine) hasTimeout() bool { return ng.conf.Middlewares.Timeout && ng.timeout > 0 } +// mightUpdateTimeout checks if the route timeout is greater than the current, +// and updates the engine's timeout accordingly. +func (ng *engine) mightUpdateTimeout(r featuredRoutes) { + // if global timeout is set to 0, it means no need to set read/write timeout + // if route timeout is nil, no need to update ng.timeout + if ng.timeout == 0 || r.timeout == nil { + return + } + + // if route timeout is 0 (means no timeout), cannot set read/write timeout + if *r.timeout == 0 { + ng.timeout = 0 + return + } + + // need to guarantee the timeout is the max of all routes + // otherwise impossible to set http.Server.ReadTimeout & WriteTimeout + if *r.timeout > ng.timeout { + ng.timeout = *r.timeout + } +} + // notFoundHandler returns a middleware that handles 404 not found requests. func (ng *engine) notFoundHandler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -333,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error { } // make sure user defined options overwrite default options - opts = append([]StartOption{ng.withTimeout()}, opts...) + opts = append([]StartOption{ng.withNetworkTimeout()}, opts...) if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 { return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...) @@ -356,7 +378,7 @@ func (ng *engine) use(middleware Middleware) { ng.middlewares = append(ng.middlewares, middleware) } -func (ng *engine) withTimeout() internal.StartOption { +func (ng *engine) withNetworkTimeout() internal.StartOption { return func(svr *http.Server) { if !ng.hasTimeout() { return diff --git a/rest/engine_test.go b/rest/engine_test.go index e7a08eb43..6f461294b 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -73,7 +73,17 @@ Verbose: true Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, - timeout: time.Minute, + timeout: ptrOfDuration(time.Minute), + }, + { + jwt: jwtSetting{}, + signature: signatureSetting{}, + routes: []Route{{ + Method: http.MethodGet, + Path: "/", + Handler: func(w http.ResponseWriter, r *http.Request) {}, + }}, + timeout: ptrOfDuration(0), }, { priority: true, @@ -84,7 +94,7 @@ Verbose: true Path: "/", Handler: func(w http.ResponseWriter, r *http.Request) {}, }}, - timeout: time.Second, + timeout: ptrOfDuration(time.Second), }, { priority: true, @@ -227,8 +237,12 @@ Verbose: true })) timeout := time.Second * 3 - if route.timeout > timeout { - timeout = route.timeout + if route.timeout != nil { + if *route.timeout == 0 { + timeout = 0 + } else if *route.timeout > timeout { + timeout = *route.timeout + } } assert.Equal(t, timeout, ng.timeout) }) @@ -236,10 +250,69 @@ Verbose: true } } +func TestNewEngine_unsignedCallback(t *testing.T) { + priKeyfile, err := fs.TempFilenameWithText(priKey) + assert.Nil(t, err) + defer os.Remove(priKeyfile) + + yaml := `Name: foo +Host: localhost +Port: 0 +Middlewares: + Log: false +` + route := featuredRoutes{ + priority: true, + jwt: jwtSetting{ + enabled: true, + }, + signature: signatureSetting{ + enabled: true, + SignatureConf: SignatureConf{ + Strict: true, + PrivateKeys: []PrivateKeyConf{ + { + Fingerprint: "a", + KeyFile: priKeyfile, + }, + }, + }, + }, + routes: []Route{{ + Method: http.MethodGet, + Path: "/", + Handler: func(w http.ResponseWriter, r *http.Request) {}, + }}, + } + + var index int32 + t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) { + var cnf RestConf + assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf)) + ng := newEngine(cnf) + if atomic.AddInt32(&index, 1)%2 == 0 { + ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request, + next http.Handler, strict bool, code int) { + }) + } + ng.addRoutes(route) + ng.use(func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + } + }) + + assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) { + })) + + assert.Equal(t, time.Duration(time.Second*3), ng.timeout) + }) +} + func TestEngine_checkedTimeout(t *testing.T) { tests := []struct { name string - timeout time.Duration + timeout *time.Duration expect time.Duration }{ { @@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) { }, { name: "less", - timeout: time.Millisecond * 500, + timeout: ptrOfDuration(time.Millisecond * 500), expect: time.Millisecond * 500, }, { name: "equal", - timeout: time.Second, + timeout: ptrOfDuration(time.Second), expect: time.Second, }, { name: "more", - timeout: time.Millisecond * 1500, + timeout: ptrOfDuration(time.Millisecond * 1500), expect: time.Millisecond * 1500, }, } @@ -401,7 +474,7 @@ func TestEngine_withTimeout(t *testing.T) { }, }) svr := &http.Server{} - ng.withTimeout()(svr) + ng.withNetworkTimeout()(svr) assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout) assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout) @@ -451,7 +524,7 @@ func TestEngine_ReadWriteTimeout(t *testing.T) { }, }) svr := &http.Server{} - ng.withTimeout()(svr) + ng.withNetworkTimeout()(svr) assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout) assert.Equal(t, time.Duration(0), svr.IdleTimeout) diff --git a/rest/server.go b/rest/server.go index b97ec5aa5..df025ebfc 100644 --- a/rest/server.go +++ b/rest/server.go @@ -283,14 +283,14 @@ func WithSignature(signature SignatureConf) RouteOption { func WithSSE() RouteOption { return func(r *featuredRoutes) { r.sse = true - r.timeout = 0 + r.timeout = ptrOfDuration(0) } } // WithTimeout returns a RouteOption to set timeout with given value. func WithTimeout(timeout time.Duration) RouteOption { return func(r *featuredRoutes) { - r.timeout = timeout + r.timeout = &timeout } } @@ -325,6 +325,10 @@ func handleError(err error) { panic(err) } +func ptrOfDuration(d time.Duration) *time.Duration { + return &d +} + func validateSecret(secret string) { if len(secret) < 8 { panic("secret's length can't be less than 8") diff --git a/rest/server_test.go b/rest/server_test.go index 9a513ddc3..9676b3f87 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -345,7 +345,7 @@ func TestWithPriority(t *testing.T) { func TestWithTimeout(t *testing.T) { var fr featuredRoutes WithTimeout(time.Hour)(&fr) - assert.Equal(t, time.Hour, fr.timeout) + assert.Equal(t, time.Hour, *fr.timeout) } func TestWithTLSConfig(t *testing.T) { diff --git a/rest/types.go b/rest/types.go index 6d0bcab54..c69e750a9 100644 --- a/rest/types.go +++ b/rest/types.go @@ -31,7 +31,7 @@ type ( } featuredRoutes struct { - timeout time.Duration + timeout *time.Duration priority bool jwt jwtSetting signature signatureSetting