mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-07 15:10:01 +08:00
fix: timeout 0s not working (#4932)
Signed-off-by: kevin <wanjunfeng@gmail.com>
This commit is contained in:
@@ -28,7 +28,10 @@ var ErrSignatureConfig = errors.New("bad config for Signature")
|
||||
type engine struct {
|
||||
conf RestConf
|
||||
routes []featuredRoutes
|
||||
// timeout is the max timeout of all routes
|
||||
// timeout is the max timeout of all routes,
|
||||
// and is used to set http.Server.ReadTimeout and http.Server.WriteTimeout.
|
||||
// this network timeout is used to avoid DoS attacks by sending data slowly
|
||||
// or receiving data slowly with many connections to exhaust server resources.
|
||||
timeout time.Duration
|
||||
unauthorizedCallback handler.UnauthorizedCallback
|
||||
unsignedCallback handler.UnsignedCallback
|
||||
@@ -60,11 +63,7 @@ func (ng *engine) addRoutes(r featuredRoutes) {
|
||||
}
|
||||
ng.routes = append(ng.routes, r)
|
||||
|
||||
// need to guarantee the timeout is the max of all routes
|
||||
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||
if r.timeout > ng.timeout {
|
||||
ng.timeout = r.timeout
|
||||
}
|
||||
ng.mightUpdateTimeout(r)
|
||||
}
|
||||
|
||||
func buildSSERoutes(routes []Route) []Route {
|
||||
@@ -192,11 +191,12 @@ func (ng *engine) checkedMaxBytes(bytes int64) int64 {
|
||||
return ng.conf.MaxBytes
|
||||
}
|
||||
|
||||
func (ng *engine) checkedTimeout(timeout time.Duration) time.Duration {
|
||||
if timeout > 0 {
|
||||
return timeout
|
||||
func (ng *engine) checkedTimeout(timeout *time.Duration) time.Duration {
|
||||
if timeout != nil {
|
||||
return *timeout
|
||||
}
|
||||
|
||||
// if timeout not set in featured routes, use global timeout
|
||||
return time.Duration(ng.conf.Timeout) * time.Millisecond
|
||||
}
|
||||
|
||||
@@ -232,6 +232,28 @@ func (ng *engine) hasTimeout() bool {
|
||||
return ng.conf.Middlewares.Timeout && ng.timeout > 0
|
||||
}
|
||||
|
||||
// mightUpdateTimeout checks if the route timeout is greater than the current,
|
||||
// and updates the engine's timeout accordingly.
|
||||
func (ng *engine) mightUpdateTimeout(r featuredRoutes) {
|
||||
// if global timeout is set to 0, it means no need to set read/write timeout
|
||||
// if route timeout is nil, no need to update ng.timeout
|
||||
if ng.timeout == 0 || r.timeout == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// if route timeout is 0 (means no timeout), cannot set read/write timeout
|
||||
if *r.timeout == 0 {
|
||||
ng.timeout = 0
|
||||
return
|
||||
}
|
||||
|
||||
// need to guarantee the timeout is the max of all routes
|
||||
// otherwise impossible to set http.Server.ReadTimeout & WriteTimeout
|
||||
if *r.timeout > ng.timeout {
|
||||
ng.timeout = *r.timeout
|
||||
}
|
||||
}
|
||||
|
||||
// notFoundHandler returns a middleware that handles 404 not found requests.
|
||||
func (ng *engine) notFoundHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -333,7 +355,7 @@ func (ng *engine) start(router httpx.Router, opts ...StartOption) error {
|
||||
}
|
||||
|
||||
// make sure user defined options overwrite default options
|
||||
opts = append([]StartOption{ng.withTimeout()}, opts...)
|
||||
opts = append([]StartOption{ng.withNetworkTimeout()}, opts...)
|
||||
|
||||
if len(ng.conf.CertFile) == 0 && len(ng.conf.KeyFile) == 0 {
|
||||
return internal.StartHttp(ng.conf.Host, ng.conf.Port, router, opts...)
|
||||
@@ -356,7 +378,7 @@ func (ng *engine) use(middleware Middleware) {
|
||||
ng.middlewares = append(ng.middlewares, middleware)
|
||||
}
|
||||
|
||||
func (ng *engine) withTimeout() internal.StartOption {
|
||||
func (ng *engine) withNetworkTimeout() internal.StartOption {
|
||||
return func(svr *http.Server) {
|
||||
if !ng.hasTimeout() {
|
||||
return
|
||||
|
||||
@@ -73,7 +73,17 @@ Verbose: true
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
timeout: time.Minute,
|
||||
timeout: ptrOfDuration(time.Minute),
|
||||
},
|
||||
{
|
||||
jwt: jwtSetting{},
|
||||
signature: signatureSetting{},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
timeout: ptrOfDuration(0),
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
@@ -84,7 +94,7 @@ Verbose: true
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
timeout: time.Second,
|
||||
timeout: ptrOfDuration(time.Second),
|
||||
},
|
||||
{
|
||||
priority: true,
|
||||
@@ -227,8 +237,12 @@ Verbose: true
|
||||
}))
|
||||
|
||||
timeout := time.Second * 3
|
||||
if route.timeout > timeout {
|
||||
timeout = route.timeout
|
||||
if route.timeout != nil {
|
||||
if *route.timeout == 0 {
|
||||
timeout = 0
|
||||
} else if *route.timeout > timeout {
|
||||
timeout = *route.timeout
|
||||
}
|
||||
}
|
||||
assert.Equal(t, timeout, ng.timeout)
|
||||
})
|
||||
@@ -236,10 +250,69 @@ Verbose: true
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewEngine_unsignedCallback(t *testing.T) {
|
||||
priKeyfile, err := fs.TempFilenameWithText(priKey)
|
||||
assert.Nil(t, err)
|
||||
defer os.Remove(priKeyfile)
|
||||
|
||||
yaml := `Name: foo
|
||||
Host: localhost
|
||||
Port: 0
|
||||
Middlewares:
|
||||
Log: false
|
||||
`
|
||||
route := featuredRoutes{
|
||||
priority: true,
|
||||
jwt: jwtSetting{
|
||||
enabled: true,
|
||||
},
|
||||
signature: signatureSetting{
|
||||
enabled: true,
|
||||
SignatureConf: SignatureConf{
|
||||
Strict: true,
|
||||
PrivateKeys: []PrivateKeyConf{
|
||||
{
|
||||
Fingerprint: "a",
|
||||
KeyFile: priKeyfile,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
routes: []Route{{
|
||||
Method: http.MethodGet,
|
||||
Path: "/",
|
||||
Handler: func(w http.ResponseWriter, r *http.Request) {},
|
||||
}},
|
||||
}
|
||||
|
||||
var index int32
|
||||
t.Run(fmt.Sprintf("%s-%v", yaml, route.routes), func(t *testing.T) {
|
||||
var cnf RestConf
|
||||
assert.Nil(t, conf.LoadFromYamlBytes([]byte(yaml), &cnf))
|
||||
ng := newEngine(cnf)
|
||||
if atomic.AddInt32(&index, 1)%2 == 0 {
|
||||
ng.setUnsignedCallback(func(w http.ResponseWriter, r *http.Request,
|
||||
next http.Handler, strict bool, code int) {
|
||||
})
|
||||
}
|
||||
ng.addRoutes(route)
|
||||
ng.use(func(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
assert.NotNil(t, ng.start(mockedRouter{}, func(svr *http.Server) {
|
||||
}))
|
||||
|
||||
assert.Equal(t, time.Duration(time.Second*3), ng.timeout)
|
||||
})
|
||||
}
|
||||
|
||||
func TestEngine_checkedTimeout(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
timeout *time.Duration
|
||||
expect time.Duration
|
||||
}{
|
||||
{
|
||||
@@ -248,17 +321,17 @@ func TestEngine_checkedTimeout(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "less",
|
||||
timeout: time.Millisecond * 500,
|
||||
timeout: ptrOfDuration(time.Millisecond * 500),
|
||||
expect: time.Millisecond * 500,
|
||||
},
|
||||
{
|
||||
name: "equal",
|
||||
timeout: time.Second,
|
||||
timeout: ptrOfDuration(time.Second),
|
||||
expect: time.Second,
|
||||
},
|
||||
{
|
||||
name: "more",
|
||||
timeout: time.Millisecond * 1500,
|
||||
timeout: ptrOfDuration(time.Millisecond * 1500),
|
||||
expect: time.Millisecond * 1500,
|
||||
},
|
||||
}
|
||||
@@ -401,7 +474,7 @@ func TestEngine_withTimeout(t *testing.T) {
|
||||
},
|
||||
})
|
||||
svr := &http.Server{}
|
||||
ng.withTimeout()(svr)
|
||||
ng.withNetworkTimeout()(svr)
|
||||
|
||||
assert.Equal(t, time.Duration(test.timeout)*time.Millisecond*4/5, svr.ReadTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||
@@ -451,7 +524,7 @@ func TestEngine_ReadWriteTimeout(t *testing.T) {
|
||||
},
|
||||
})
|
||||
svr := &http.Server{}
|
||||
ng.withTimeout()(svr)
|
||||
ng.withNetworkTimeout()(svr)
|
||||
|
||||
assert.Equal(t, time.Duration(0), svr.ReadHeaderTimeout)
|
||||
assert.Equal(t, time.Duration(0), svr.IdleTimeout)
|
||||
|
||||
@@ -283,14 +283,14 @@ func WithSignature(signature SignatureConf) RouteOption {
|
||||
func WithSSE() RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.sse = true
|
||||
r.timeout = 0
|
||||
r.timeout = ptrOfDuration(0)
|
||||
}
|
||||
}
|
||||
|
||||
// WithTimeout returns a RouteOption to set timeout with given value.
|
||||
func WithTimeout(timeout time.Duration) RouteOption {
|
||||
return func(r *featuredRoutes) {
|
||||
r.timeout = timeout
|
||||
r.timeout = &timeout
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,6 +325,10 @@ func handleError(err error) {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
func ptrOfDuration(d time.Duration) *time.Duration {
|
||||
return &d
|
||||
}
|
||||
|
||||
func validateSecret(secret string) {
|
||||
if len(secret) < 8 {
|
||||
panic("secret's length can't be less than 8")
|
||||
|
||||
@@ -345,7 +345,7 @@ func TestWithPriority(t *testing.T) {
|
||||
func TestWithTimeout(t *testing.T) {
|
||||
var fr featuredRoutes
|
||||
WithTimeout(time.Hour)(&fr)
|
||||
assert.Equal(t, time.Hour, fr.timeout)
|
||||
assert.Equal(t, time.Hour, *fr.timeout)
|
||||
}
|
||||
|
||||
func TestWithTLSConfig(t *testing.T) {
|
||||
|
||||
@@ -31,7 +31,7 @@ type (
|
||||
}
|
||||
|
||||
featuredRoutes struct {
|
||||
timeout time.Duration
|
||||
timeout *time.Duration
|
||||
priority bool
|
||||
jwt jwtSetting
|
||||
signature signatureSetting
|
||||
|
||||
Reference in New Issue
Block a user