mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-13 18:00:00 +08:00
fix: timeout 0s not working (#4932)
Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
@@ -28,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
|
|||||||
type engine struct {
|
type engine struct {
|
||||||
conf RestConf
|
conf RestConf
|
||||||
routes []featuredRoutes
|
routes []featuredRoutes
|
||||||
// timeout is the max timeout of all routes
|
// timeout is the max timeout of all routes,
|
||||||
|
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
|
||||||
|
// this network timeout is used to avoid DoS attacks by sending data slowly
|
||||||
|
// or receiving data slowly with many connections to exhaust server resources.
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
unauthorizedCallback handler.UnauthorizedCallback
|
unauthorizedCallback handler.UnauthorizedCallback
|
||||||
unsignedCallback handler.UnsignedCallback
|
unsignedCallback handler.UnsignedCallback
|
||||||
@@ -60,11 +63,7 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
|||||||
}
|
}
|
||||||
ng.routes = append(ng.routes, r)
|
ng.routes = append(ng.routes, r)
|
||||||
|
|
||||||
// need to guarantee the timeout is the max of all routes
|
ng.mightUpdateTimeout(r)
|
||||||
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
|
||||||
if r.timeout > ng.timeout {
|
|
||||||
ng.timeout = r.timeout
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildSSERoutes(routes []Route) []Route {
|
func buildSSERoutes(routes []Route) []Route {
|
||||||
@@ -192,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
|
|||||||
return ng.conf.MaxBytes
|
return ng.conf.MaxBytes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
|
||||||
if timeout > 0 {
|
if timeout != nil {
|
||||||
return timeout
|
return *timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if timeout not set in featured routes, use global timeout
|
||||||
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,6 +232,28 @@ func (ng *engine) hasTimeout() bool {
|
|||||||
return ng.conf.Middlewares.Timeout && ng.timeout > 0
|
return ng.conf.Middlewares.Timeout && ng.timeout > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mightUpdateTimeout checks if the route timeout is greater than the current,
|
||||||
|
// and updates the engine's timeout accordingly.
|
||||||
|
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
|
||||||
|
// if global timeout is set to 0, it means no need to set read/write timeout
|
||||||
|
// if route timeout is nil, no need to update ng.timeout
|
||||||
|
if ng.timeout == 0 || r.timeout == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if route timeout is 0 (means no timeout), cannot set read/write timeout
|
||||||
|
if *r.timeout == 0 {
|
||||||
|
ng.timeout = 0
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// need to guarantee the timeout is the max of all routes
|
||||||
|
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||||
|
if *r.timeout > ng.timeout {
|
||||||
|
ng.timeout = *r.timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -333,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// make sure user defined options overwrite default options
|
// make sure user defined options overwrite default options
|
||||||
opts = append([]StartOption{ng.withTimeout()}, opts...)
|
opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
|
||||||
|
|
||||||
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
||||||
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
|
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
|
||||||
@@ -356,7 +378,7 @@ func (ng *engine) use(middleware Middleware) {
|
|||||||
ng.middlewares = append(ng.middlewares, middleware)
|
ng.middlewares = append(ng.middlewares, middleware)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ng *engine) withTimeout() internal.StartOption {
|
func (ng *engine) withNetworkTimeout() internal.StartOption {
|
||||||
return func(svr *http.Server) {
|
return func(svr *http.Server) {
|
||||||
if !ng.hasTimeout() {
|
if !ng.hasTimeout() {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -73,7 +73,17 @@ Verbose: true
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
}},
|
}},
|
||||||
timeout: time.Minute,
|
timeout: ptrOfDuration(time.Minute),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
jwt: jwtSetting{},
|
||||||
|
signature: signatureSetting{},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
timeout: ptrOfDuration(0),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
priority: true,
|
priority: true,
|
||||||
@@ -84,7 +94,7 @@ Verbose: true
|
|||||||
Path: "/",
|
Path: "/",
|
||||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
}},
|
}},
|
||||||
timeout: time.Second,
|
timeout: ptrOfDuration(time.Second),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
priority: true,
|
priority: true,
|
||||||
@@ -227,8 +237,12 @@ Verbose: true
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
timeout := time.Second * 3
|
timeout := time.Second * 3
|
||||||
if route.timeout > timeout {
|
if route.timeout != nil {
|
||||||
timeout = route.timeout
|
if *route.timeout == 0 {
|
||||||
|
timeout = 0
|
||||||
|
} else if *route.timeout > timeout {
|
||||||
|
timeout = *route.timeout
|
||||||
|
}
|
||||||
}
|
}
|
||||||
assert.Equal(t, timeout, ng.timeout)
|
assert.Equal(t, timeout, ng.timeout)
|
||||||
})
|
})
|
||||||
@@ -236,10 +250,69 @@ Verbose: true
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewEngine_unsignedCallback(t *testing.T) {
|
||||||
|
priKeyfile, err := fs.TempFilenameWithText(priKey)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
defer os.Remove(priKeyfile)
|
||||||
|
|
||||||
|
yaml := `Name: foo
|
||||||
|
Host: localhost
|
||||||
|
Port: 0
|
||||||
|
Middlewares:
|
||||||
|
Log: false
|
||||||
|
`
|
||||||
|
route := featuredRoutes{
|
||||||
|
priority: true,
|
||||||
|
jwt: jwtSetting{
|
||||||
|
enabled: true,
|
||||||
|
},
|
||||||
|
signature: signatureSetting{
|
||||||
|
enabled: true,
|
||||||
|
SignatureConf: SignatureConf{
|
||||||
|
Strict: true,
|
||||||
|
PrivateKeys: []PrivateKeyConf{
|
||||||
|
{
|
||||||
|
Fingerprint: "a",
|
||||||
|
KeyFile: priKeyfile,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
routes: []Route{{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||||
|
}},
|
||||||
|
}
|
||||||
|
|
||||||
|
var index int32
|
||||||
|
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
|
||||||
|
var cnf RestConf
|
||||||
|
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
|
||||||
|
ng := newEngine(cnf)
|
||||||
|
if atomic.AddInt32(&index, 1)%2 == 0 {
|
||||||
|
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
|
||||||
|
next http.Handler, strict bool, code int) {
|
||||||
|
})
|
||||||
|
}
|
||||||
|
ng.addRoutes(route)
|
||||||
|
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
|
||||||
|
}))
|
||||||
|
|
||||||
|
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestEngine_checkedTimeout(t *testing.T) {
|
func TestEngine_checkedTimeout(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
timeout time.Duration
|
timeout *time.Duration
|
||||||
expect time.Duration
|
expect time.Duration
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "less",
|
name: "less",
|
||||||
timeout: time.Millisecond * 500,
|
timeout: ptrOfDuration(time.Millisecond * 500),
|
||||||
expect: time.Millisecond * 500,
|
expect: time.Millisecond * 500,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "equal",
|
name: "equal",
|
||||||
timeout: time.Second,
|
timeout: ptrOfDuration(time.Second),
|
||||||
expect: time.Second,
|
expect: time.Second,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "more",
|
name: "more",
|
||||||
timeout: time.Millisecond * 1500,
|
timeout: ptrOfDuration(time.Millisecond * 1500),
|
||||||
expect: time.Millisecond * 1500,
|
expect: time.Millisecond * 1500,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -401,7 +474,7 @@ func TestEngine_withTimeout(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
svr := &http.Server{}
|
svr := &http.Server{}
|
||||||
ng.withTimeout()(svr)
|
ng.withNetworkTimeout()(svr)
|
||||||
|
|
||||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||||
@@ -451,7 +524,7 @@ func TestEngine_ReadWriteTimeout(t *testing.T) {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
svr := &http.Server{}
|
svr := &http.Server{}
|
||||||
ng.withTimeout()(svr)
|
ng.withNetworkTimeout()(svr)
|
||||||
|
|
||||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||||
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
|
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
|
||||||
|
|||||||
@@ -283,14 +283,14 @@ func WithSignature(signature SignatureConf) RouteOption {
|
|||||||
func WithSSE() RouteOption {
|
func WithSSE() RouteOption {
|
||||||
return func(r *featuredRoutes) {
|
return func(r *featuredRoutes) {
|
||||||
r.sse = true
|
r.sse = true
|
||||||
r.timeout = 0
|
r.timeout = ptrOfDuration(0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithTimeout returns a RouteOption to set timeout with given value.
|
// WithTimeout returns a RouteOption to set timeout with given value.
|
||||||
func WithTimeout(timeout time.Duration) RouteOption {
|
func WithTimeout(timeout time.Duration) RouteOption {
|
||||||
return func(r *featuredRoutes) {
|
return func(r *featuredRoutes) {
|
||||||
r.timeout = timeout
|
r.timeout = &timeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -325,6 +325,10 @@ func handleError(err error) {
|
|||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ptrOfDuration(d time.Duration) *time.Duration {
|
||||||
|
return &d
|
||||||
|
}
|
||||||
|
|
||||||
func validateSecret(secret string) {
|
func validateSecret(secret string) {
|
||||||
if len(secret) < 8 {
|
if len(secret) < 8 {
|
||||||
panic("secret's length can't be less than 8")
|
panic("secret's length can't be less than 8")
|
||||||
|
|||||||
@@ -345,7 +345,7 @@ func TestWithPriority(t *testing.T) {
|
|||||||
func TestWithTimeout(t *testing.T) {
|
func TestWithTimeout(t *testing.T) {
|
||||||
var fr featuredRoutes
|
var fr featuredRoutes
|
||||||
WithTimeout(time.Hour)(&fr)
|
WithTimeout(time.Hour)(&fr)
|
||||||
assert.Equal(t, time.Hour, fr.timeout)
|
assert.Equal(t, time.Hour, *fr.timeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWithTLSConfig(t *testing.T) {
|
func TestWithTLSConfig(t *testing.T) {
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ type (
|
|||||||
}
|
}
|
||||||
|
|
||||||
featuredRoutes struct {
|
featuredRoutes struct {
|
||||||
timeout time.Duration
|
timeout *time.Duration
|
||||||
priority bool
|
priority bool
|
||||||
jwt jwtSetting
|
jwt jwtSetting
|
||||||
signature signatureSetting
|
signature signatureSetting
|
||||||
|
|||||||
Reference in New Issue
Block a user