mirror of
https://github.com/zeromicro/go-zero.git
synced 2026-05-11 16:59:59 +08:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
888551627c | ||
|
|
bd623aaac3 | ||
|
|
9e6c2ba2c0 | ||
|
|
c0db8d017d | ||
|
|
52b4f8ca91 | ||
|
|
4884a7b3c6 | ||
|
|
3c6951577d | ||
|
|
fcd15c9b17 | ||
|
|
155e6061cb | ||
|
|
dda7666097 | ||
|
|
c954568b61 | ||
|
|
c2acc43a52 | ||
|
|
1a1a6f5239 | ||
|
|
60c7edf8f8 | ||
|
|
7ad86a52f3 | ||
|
|
1e4e5a02b2 | ||
|
|
39540e21d2 | ||
|
|
b321622c95 | ||
|
|
a25cba5380 |
@@ -26,7 +26,8 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
|
||||
ctx, cancel := contextx.ShrinkDeadline(parentCtx, timeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error)
|
||||
// create channel with buffer size 1 to avoid goroutine leak
|
||||
done := make(chan error, 1)
|
||||
panicChan := make(chan interface{}, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
@@ -35,7 +36,6 @@ func DoWithTimeout(fn func() error, timeout time.Duration, opts ...DoOption) err
|
||||
}
|
||||
}()
|
||||
done <- fn()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
@@ -4,6 +4,7 @@ package logx
|
||||
type LogConf struct {
|
||||
ServiceName string `json:",optional"`
|
||||
Mode string `json:",default=console,options=console|file|volume"`
|
||||
TimeFormat string `json:",optional"`
|
||||
Path string `json:",default=logs"`
|
||||
Level string `json:",default=info,options=info|error|severe"`
|
||||
Compress bool `json:",optional"`
|
||||
|
||||
@@ -32,8 +32,6 @@ const (
|
||||
)
|
||||
|
||||
const (
|
||||
timeFormat = "2006-01-02T15:04:05.000Z07"
|
||||
|
||||
accessFilename = "access.log"
|
||||
errorFilename = "error.log"
|
||||
severeFilename = "severe.log"
|
||||
@@ -64,6 +62,7 @@ var (
|
||||
// ErrLogServiceNameNotSet is an error that indicates that the service name is not set.
|
||||
ErrLogServiceNameNotSet = errors.New("log service name must be set")
|
||||
|
||||
timeFormat = "2006-01-02T15:04:05.000Z07"
|
||||
writeConsole bool
|
||||
logLevel uint32
|
||||
infoLog io.WriteCloser
|
||||
@@ -117,6 +116,10 @@ func MustSetup(c LogConf) {
|
||||
// we need to allow different service frameworks to initialize logx respectively.
|
||||
// the same logic for SetUp
|
||||
func SetUp(c LogConf) error {
|
||||
if len(c.TimeFormat) > 0 {
|
||||
timeFormat = c.TimeFormat
|
||||
}
|
||||
|
||||
switch c.Mode {
|
||||
case consoleMode:
|
||||
setupWithConsole(c)
|
||||
|
||||
@@ -43,11 +43,11 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func newCollection(collection *mgo.Collection) Collection {
|
||||
func newCollection(collection *mgo.Collection, brk breaker.Breaker) Collection {
|
||||
return &decoratedCollection{
|
||||
name: collection.FullName,
|
||||
collection: collection,
|
||||
brk: breaker.NewBreaker(),
|
||||
brk: brk,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ func TestNewCollection(t *testing.T) {
|
||||
Database: nil,
|
||||
Name: "foo",
|
||||
FullName: "bar",
|
||||
})
|
||||
}, breaker.GetBreaker("localhost"))
|
||||
assert.Equal(t, "bar", col.(*decoratedCollection).name)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/globalsign/mgo"
|
||||
"github.com/tal-tech/go-zero/core/breaker"
|
||||
)
|
||||
|
||||
type (
|
||||
@@ -20,6 +21,7 @@ type (
|
||||
session *concurrentSession
|
||||
db *mgo.Database
|
||||
collection string
|
||||
brk breaker.Breaker
|
||||
opts []Option
|
||||
}
|
||||
)
|
||||
@@ -46,6 +48,7 @@ func NewModel(url, collection string, opts ...Option) (*Model, error) {
|
||||
// If name is empty, the database name provided in the dialed URL is used instead
|
||||
db: session.DB(""),
|
||||
collection: collection,
|
||||
brk: breaker.GetBreaker(url),
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
@@ -66,7 +69,7 @@ func (mm *Model) FindId(id interface{}) (Query, error) {
|
||||
|
||||
// GetCollection returns a Collection with given session.
|
||||
func (mm *Model) GetCollection(session *mgo.Session) Collection {
|
||||
return newCollection(mm.db.C(mm.collection).With(session))
|
||||
return newCollection(mm.db.C(mm.collection).With(session), mm.brk)
|
||||
}
|
||||
|
||||
// Insert inserts docs into mm.
|
||||
|
||||
@@ -250,6 +250,21 @@ func (s *Redis) Eval(script string, keys []string, args ...interface{}) (val int
|
||||
return
|
||||
}
|
||||
|
||||
// EvalSha is the implementation of redis evalsha command.
|
||||
func (s *Redis) EvalSha(sha string, keys []string, args ...interface{}) (val interface{}, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val, err = conn.EvalSha(sha, keys, args...).Result()
|
||||
return err
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Exists is the implementation of redis exists command.
|
||||
func (s *Redis) Exists(key string) (val bool, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
@@ -449,14 +464,14 @@ func (s *Redis) GetBit(key string, offset int64) (val int, err error) {
|
||||
}
|
||||
|
||||
// Hdel is the implementation of redis hdel command.
|
||||
func (s *Redis) Hdel(key, field string) (val bool, err error) {
|
||||
func (s *Redis) Hdel(key string, fields ...string) (val bool, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.HDel(key, field).Result()
|
||||
v, err := conn.HDel(key, fields...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1032,6 +1047,16 @@ func (s *Redis) Scard(key string) (val int64, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// ScriptLoad is the implementation of redis script load command.
|
||||
func (s *Redis) ScriptLoad(script string) (string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return conn.ScriptLoad(script).Result()
|
||||
}
|
||||
|
||||
// Set is the implementation of redis set command.
|
||||
func (s *Redis) Set(key string, value string) error {
|
||||
return s.brk.DoWithAcceptable(func() error {
|
||||
@@ -1101,26 +1126,6 @@ func (s *Redis) Sismember(key string, value interface{}) (val bool, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Srem is the implementation of redis srem command.
|
||||
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.SRem(key, values...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = int(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Smembers is the implementation of redis smembers command.
|
||||
func (s *Redis) Smembers(key string) (val []string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
@@ -1166,6 +1171,31 @@ func (s *Redis) Srandmember(key string, count int) (val []string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Srem is the implementation of redis srem command.
|
||||
func (s *Redis) Srem(key string, values ...interface{}) (val int, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
v, err := conn.SRem(key, values...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
val = int(v)
|
||||
return nil
|
||||
}, acceptable)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// String returns the string representation of s.
|
||||
func (s *Redis) String() string {
|
||||
return s.Addr
|
||||
}
|
||||
|
||||
// Sunion is the implementation of redis sunion command.
|
||||
func (s *Redis) Sunion(keys ...string) (val []string, err error) {
|
||||
err = s.brk.DoWithAcceptable(func() error {
|
||||
@@ -1667,20 +1697,6 @@ func (s *Redis) Zunionstore(dest string, store ZStore, keys ...string) (val int6
|
||||
return
|
||||
}
|
||||
|
||||
// String returns the string representation of s.
|
||||
func (s *Redis) String() string {
|
||||
return s.Addr
|
||||
}
|
||||
|
||||
func (s *Redis) scriptLoad(script string) (string, error) {
|
||||
conn, err := getRedis(s)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return conn.ScriptLoad(script).Result()
|
||||
}
|
||||
|
||||
func acceptable(err error) bool {
|
||||
return err == nil || err == red.Nil
|
||||
}
|
||||
|
||||
@@ -947,13 +947,24 @@ func TestRedisString(t *testing.T) {
|
||||
func TestRedisScriptLoad(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
client.Ping()
|
||||
_, err := NewRedis(client.Addr, "").scriptLoad("foo")
|
||||
_, err := NewRedis(client.Addr, "").ScriptLoad("foo")
|
||||
assert.NotNil(t, err)
|
||||
_, err = client.scriptLoad("foo")
|
||||
_, err = client.ScriptLoad("foo")
|
||||
assert.NotNil(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisEvalSha(t *testing.T) {
|
||||
runOnRedis(t, func(client *Redis) {
|
||||
client.Ping()
|
||||
scriptHash, err := client.ScriptLoad(`return redis.call("EXISTS", KEYS[1])`)
|
||||
assert.Nil(t, err)
|
||||
result, err := client.EvalSha(scriptHash, []string{"key1"})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(0), result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisToPairs(t *testing.T) {
|
||||
pairs := toPairs([]red.Z{
|
||||
{
|
||||
|
||||
@@ -24,6 +24,7 @@ type (
|
||||
ResultHandler func(sql.Result, error)
|
||||
|
||||
// A BulkInserter is used to batch insert records.
|
||||
// Postgresql is not supported yet, because of the sql is formated with symbol `$`.
|
||||
BulkInserter struct {
|
||||
executor *executors.PeriodicalExecutor
|
||||
inserter *dbInserter
|
||||
|
||||
@@ -12,14 +12,10 @@ import (
|
||||
const slowThreshold = time.Millisecond * 500
|
||||
|
||||
func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := formatForPrint(q, args)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] exec: slowcall - %s", stmt)
|
||||
} else {
|
||||
@@ -33,10 +29,10 @@ func exec(conn sessionConn, q string, args ...interface{}) (sql.Result, error) {
|
||||
}
|
||||
|
||||
func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
stmt := fmt.Sprint(args...)
|
||||
startTime := timex.Now()
|
||||
result, err := conn.Exec(args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := fmt.Sprint(args...)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] execStmt: slowcall - %s", stmt)
|
||||
} else {
|
||||
@@ -50,14 +46,10 @@ func execStmt(conn stmtConn, args ...interface{}) (sql.Result, error) {
|
||||
}
|
||||
|
||||
func query(conn sessionConn, scanner func(*sql.Rows) error, q string, args ...interface{}) error {
|
||||
stmt, err := format(q, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
startTime := timex.Now()
|
||||
rows, err := conn.Query(q, args...)
|
||||
duration := timex.Since(startTime)
|
||||
stmt := fmt.Sprint(args...)
|
||||
if duration > slowThreshold {
|
||||
logx.WithDuration(duration).Slowf("[SQL] query: slowcall - %s", stmt)
|
||||
} else {
|
||||
|
||||
@@ -16,7 +16,6 @@ func TestStmt_exec(t *testing.T) {
|
||||
name string
|
||||
args []interface{}
|
||||
delay bool
|
||||
formatError bool
|
||||
hasError bool
|
||||
err error
|
||||
lastInsertId int64
|
||||
@@ -28,12 +27,6 @@ func TestStmt_exec(t *testing.T) {
|
||||
lastInsertId: 1,
|
||||
rowsAffected: 2,
|
||||
},
|
||||
{
|
||||
name: "wrong format",
|
||||
args: []interface{}{1, 2},
|
||||
formatError: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "exec error",
|
||||
args: []interface{}{1},
|
||||
@@ -70,18 +63,13 @@ func TestStmt_exec(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for i, fn := range fns {
|
||||
i := i
|
||||
for _, fn := range fns {
|
||||
fn := fn
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
res, err := fn(test.args...)
|
||||
if i == 0 && test.formatError {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
if !test.formatError && test.hasError {
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
@@ -100,23 +88,16 @@ func TestStmt_exec(t *testing.T) {
|
||||
|
||||
func TestStmt_query(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []interface{}
|
||||
delay bool
|
||||
formatError bool
|
||||
hasError bool
|
||||
err error
|
||||
name string
|
||||
args []interface{}
|
||||
delay bool
|
||||
hasError bool
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
args: []interface{}{1},
|
||||
},
|
||||
{
|
||||
name: "wrong format",
|
||||
args: []interface{}{1, 2},
|
||||
formatError: true,
|
||||
hasError: true,
|
||||
},
|
||||
{
|
||||
name: "query error",
|
||||
args: []interface{}{1},
|
||||
@@ -151,18 +132,13 @@ func TestStmt_query(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for i, fn := range fns {
|
||||
i := i
|
||||
for _, fn := range fns {
|
||||
fn := fn
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
err := fn(test.args...)
|
||||
if i == 0 && test.formatError {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
if !test.formatError && test.hasError {
|
||||
if test.hasError {
|
||||
assert.NotNil(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -45,6 +45,24 @@ func escape(input string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func formatForPrint(query string, args ...interface{}) string {
|
||||
if len(args) == 0 {
|
||||
return query
|
||||
}
|
||||
|
||||
var vals []string
|
||||
for _, arg := range args {
|
||||
vals = append(vals, fmt.Sprintf("%q", mapping.Repr(arg)))
|
||||
}
|
||||
|
||||
var b strings.Builder
|
||||
b.WriteByte('[')
|
||||
b.WriteString(strings.Join(vals, ", "))
|
||||
b.WriteByte(']')
|
||||
|
||||
return strings.Join([]string{query, b.String()}, " ")
|
||||
}
|
||||
|
||||
func format(query string, args ...interface{}) (string, error) {
|
||||
numArgs := len(args)
|
||||
if numArgs == 0 {
|
||||
|
||||
@@ -28,3 +28,31 @@ func TestDesensitize_WithoutAccount(t *testing.T) {
|
||||
datasource = desensitize(datasource)
|
||||
assert.True(t, strings.Contains(datasource, "tcp(111.222.333.44:3306)"))
|
||||
}
|
||||
|
||||
func TestFormatForPrint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
args []interface{}
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "no args",
|
||||
query: "select user, name from table where id=?",
|
||||
expect: `select user, name from table where id=?`,
|
||||
},
|
||||
{
|
||||
name: "one arg",
|
||||
query: "select user, name from table where id=?",
|
||||
args: []interface{}{"kevin"},
|
||||
expect: `select user, name from table where id=? ["kevin"]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
actual := formatForPrint(test.query, test.args...)
|
||||
assert.Equal(t, test.expect, actual)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ GO111MODULE=on GOPROXY=https://goproxy.cn/,direct go get -u github.com/tal-tech/
|
||||
|
||||
* API 文档
|
||||
|
||||
[https://www.yuque.com/tal-tech/go-zero](https://www.yuque.com/tal-tech/go-zero)
|
||||
[https://zeromicro.github.io/go-zero](https://zeromicro.github.io/go-zero)
|
||||
|
||||
* awesome 系列(更多文章见『微服务实践』公众号)
|
||||
* [快速构建高并发微服务](https://github.com/tal-tech/zero-doc/blob/main/doc/shorturl.md)
|
||||
|
||||
@@ -210,6 +210,12 @@ go get -u github.com/tal-tech/go-zero
|
||||
* [Rapid development of microservice systems - multiple RPCs](https://github.com/tal-tech/zero-doc/blob/main/docs/zero/bookstore-en.md)
|
||||
* [Examples](https://github.com/zeromicro/zero-examples)
|
||||
|
||||
## 9. Chat group
|
||||
## 9. Important notes
|
||||
|
||||
* Use grpc 1.29.1, because etcd lib doesn’t support latter versions.
|
||||
|
||||
`google.golang.org/grpc v1.29.1`
|
||||
|
||||
## 10. Chat group
|
||||
|
||||
Join the chat via https://join.slack.com/t/go-zeroworkspace/shared_invite/zt-m39xssxc-kgIqERa7aVsujKNj~XuPKg
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
|
||||
@@ -138,6 +140,16 @@ func (grw *guardedResponseWriter) Header() http.Header {
|
||||
return grw.writer.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (grw *guardedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := grw.writer.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
func (grw *guardedResponseWriter) Write(body []byte) (int, error) {
|
||||
return grw.writer.Write(body)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -87,6 +89,26 @@ func TestAuthHandler_NilError(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthHandler_Flush(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
handler := newGuardedResponseWriter(resp)
|
||||
handler.Flush()
|
||||
assert.True(t, resp.Flushed)
|
||||
}
|
||||
|
||||
func TestAuthHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := newGuardedResponseWriter(resp)
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = newGuardedResponseWriter(mockedHijackable{resp})
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
func buildToken(secretKey string, payloads map[string]interface{}, seconds int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
claims := make(jwt.MapClaims)
|
||||
@@ -101,3 +123,11 @@ func buildToken(secretKey string, payloads map[string]interface{}, seconds int64
|
||||
|
||||
return token.SignedString([]byte(secretKey))
|
||||
}
|
||||
|
||||
type mockedHijackable struct {
|
||||
*httptest.ResponseRecorder
|
||||
}
|
||||
|
||||
func (m mockedHijackable) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/codec"
|
||||
@@ -94,6 +96,16 @@ func (w *cryptionResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (w *cryptionResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
func (w *cryptionResponseWriter) Write(p []byte) (int, error) {
|
||||
return w.buf.Write(p)
|
||||
}
|
||||
|
||||
@@ -103,3 +103,16 @@ func TestCryptionHandlerFlush(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, base64.StdEncoding.EncodeToString(expect), recorder.Body.String())
|
||||
}
|
||||
|
||||
func TestCryptionHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := newCryptionResponseWriter(resp)
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = newCryptionResponseWriter(mockedHijackable{resp})
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"time"
|
||||
@@ -25,10 +28,26 @@ type loggedResponseWriter struct {
|
||||
code int
|
||||
}
|
||||
|
||||
func (w *loggedResponseWriter) Flush() {
|
||||
if flusher, ok := w.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *loggedResponseWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (w *loggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := w.w.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
func (w *loggedResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.w.Write(bytes)
|
||||
}
|
||||
@@ -38,12 +57,6 @@ func (w *loggedResponseWriter) WriteHeader(code int) {
|
||||
w.code = code
|
||||
}
|
||||
|
||||
func (w *loggedResponseWriter) Flush() {
|
||||
if flusher, ok := w.w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// LogHandler returns a middleware that logs http request and response.
|
||||
func LogHandler(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -83,6 +96,16 @@ func (w *detailLoggedResponseWriter) Header() http.Header {
|
||||
return w.writer.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (w *detailLoggedResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacked, ok := w.writer.w.(http.Hijacker); ok {
|
||||
return hijacked.Hijack()
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("server doesn't support hijacking")
|
||||
}
|
||||
|
||||
func (w *detailLoggedResponseWriter) Write(bs []byte) (int, error) {
|
||||
w.buf.Write(bs)
|
||||
return w.writer.Write(bs)
|
||||
|
||||
@@ -62,6 +62,44 @@ func TestLogHandlerSlow(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := &loggedResponseWriter{
|
||||
w: resp,
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = &loggedResponseWriter{
|
||||
w: mockedHijackable{resp},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
func TestDetailedLogHandler_Hijack(t *testing.T) {
|
||||
resp := httptest.NewRecorder()
|
||||
writer := &detailLoggedResponseWriter{
|
||||
writer: &loggedResponseWriter{
|
||||
w: resp,
|
||||
},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
|
||||
writer = &detailLoggedResponseWriter{
|
||||
writer: &loggedResponseWriter{
|
||||
w: mockedHijackable{resp},
|
||||
},
|
||||
}
|
||||
assert.NotPanics(t, func() {
|
||||
writer.Hijack()
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkLogHandler(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package security
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// A WithCodeResponseWriter is a helper to delay sealing a http.ResponseWriter on writing code.
|
||||
type WithCodeResponseWriter struct {
|
||||
@@ -20,6 +24,12 @@ func (w *WithCodeResponseWriter) Header() http.Header {
|
||||
return w.Writer.Header()
|
||||
}
|
||||
|
||||
// Hijack implements the http.Hijacker interface.
|
||||
// This expands the Response to fulfill http.Hijacker if the underlying http.ResponseWriter supports it.
|
||||
func (w *WithCodeResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
return w.Writer.(http.Hijacker).Hijack()
|
||||
}
|
||||
|
||||
// Write writes bytes into w.
|
||||
func (w *WithCodeResponseWriter) Write(bytes []byte) (int, error) {
|
||||
return w.Writer.Write(bytes)
|
||||
|
||||
@@ -77,6 +77,8 @@ func (e *Server) AddRoute(r Route, opts ...RouteOption) {
|
||||
}
|
||||
|
||||
// Start starts the Server.
|
||||
// Graceful shutdown is enabled by default.
|
||||
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
||||
func (e *Server) Start() {
|
||||
handleError(e.opts.start(e.ngin))
|
||||
}
|
||||
|
||||
@@ -103,17 +103,9 @@ func genComponents(dir, packetName string, api *spec.ApiSpec) error {
|
||||
}
|
||||
|
||||
func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type) error {
|
||||
defineStruct, ok := ty.(spec.DefineStruct)
|
||||
if !ok {
|
||||
return errors.New("unsupported type %s" + ty.Name())
|
||||
}
|
||||
|
||||
for _, item := range c.requestTypes {
|
||||
if item.Name() == defineStruct.Name() {
|
||||
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
defineStruct, done, err := c.checkStruct(ty)
|
||||
if done {
|
||||
return err
|
||||
}
|
||||
|
||||
modelFile := util.Title(ty.Name()) + ".java"
|
||||
@@ -181,6 +173,22 @@ func (c *componentsContext) createComponent(dir, packetName string, ty spec.Type
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *componentsContext) checkStruct(ty spec.Type) (spec.DefineStruct, bool, error) {
|
||||
defineStruct, ok := ty.(spec.DefineStruct)
|
||||
if !ok {
|
||||
return spec.DefineStruct{}, true, errors.New("unsupported type %s" + ty.Name())
|
||||
}
|
||||
|
||||
for _, item := range c.requestTypes {
|
||||
if item.Name() == defineStruct.Name() {
|
||||
if len(defineStruct.GetFormMembers())+len(defineStruct.GetBodyMembers()) == 0 {
|
||||
return spec.DefineStruct{}, true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return defineStruct, false, nil
|
||||
}
|
||||
|
||||
func (c *componentsContext) buildProperties(defineStruct spec.DefineStruct) (string, error) {
|
||||
var builder strings.Builder
|
||||
if err := c.writeType(&builder, defineStruct); err != nil {
|
||||
|
||||
@@ -95,17 +95,9 @@ func specTypeToJava(tp spec.Type) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
switch valueType {
|
||||
case "int":
|
||||
return "Integer[]", nil
|
||||
case "long":
|
||||
return "Long[]", nil
|
||||
case "float":
|
||||
return "Float[]", nil
|
||||
case "double":
|
||||
return "Double[]", nil
|
||||
case "boolean":
|
||||
return "Boolean[]", nil
|
||||
s := getBaseType(valueType)
|
||||
if len(s) == 0 {
|
||||
return s, errors.New("unsupported primitive type " + tp.Name())
|
||||
}
|
||||
|
||||
return fmt.Sprintf("java.util.ArrayList<%s>", util.Title(valueType)), nil
|
||||
@@ -118,6 +110,23 @@ func specTypeToJava(tp spec.Type) (string, error) {
|
||||
return "", errors.New("unsupported primitive type " + tp.Name())
|
||||
}
|
||||
|
||||
func getBaseType(valueType string) string {
|
||||
switch valueType {
|
||||
case "int":
|
||||
return "Integer[]"
|
||||
case "long":
|
||||
return "Long[]"
|
||||
case "float":
|
||||
return "Float[]"
|
||||
case "double":
|
||||
return "Double[]"
|
||||
case "boolean":
|
||||
return "Boolean[]"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func primitiveType(tp string) (string, bool) {
|
||||
switch tp {
|
||||
case "string":
|
||||
|
||||
@@ -33,7 +33,7 @@ func (v *ApiVisitor) VisitApi(ctx *api.ApiContext) interface{} {
|
||||
for _, each := range ctx.AllSpec() {
|
||||
root := each.Accept(v).(*Api)
|
||||
v.acceptSyntax(root, &final)
|
||||
v.accetpImport(root, &final)
|
||||
v.acceptImport(root, &final)
|
||||
v.acceptInfo(root, &final)
|
||||
v.acceptType(root, &final)
|
||||
v.acceptService(root, &final)
|
||||
@@ -133,7 +133,7 @@ func (v *ApiVisitor) acceptInfo(root *Api, final *Api) {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *ApiVisitor) accetpImport(root *Api, final *Api) {
|
||||
func (v *ApiVisitor) acceptImport(root *Api, final *Api) {
|
||||
for _, imp := range root.Import {
|
||||
if _, ok := final.importM[imp.Value.Text()]; ok {
|
||||
v.panic(imp.Import, fmt.Sprintf("duplicate import '%s'", imp.Value.Text()))
|
||||
|
||||
@@ -50,7 +50,7 @@ type AtDoc struct {
|
||||
Kv []*KvExpr
|
||||
}
|
||||
|
||||
// AtHandler describes service hander ast for api syntax
|
||||
// AtHandler describes service handler ast for api syntax
|
||||
type AtHandler struct {
|
||||
AtHandlerToken Expr
|
||||
Name Expr
|
||||
@@ -630,7 +630,7 @@ func (s *Service) Equal(v interface{}) bool {
|
||||
return s.ServiceApi.Equal(service.ServiceApi)
|
||||
}
|
||||
|
||||
// Get returns the tergate KV by specified key
|
||||
// Get returns the target KV by specified key
|
||||
func (kv KV) Get(key string) Expr {
|
||||
for _, each := range kv {
|
||||
if each.Key.Text() == key {
|
||||
|
||||
@@ -17,7 +17,7 @@ type (
|
||||
NameExpr() Expr
|
||||
}
|
||||
|
||||
// TypeAlias describes alias ast for api syatax
|
||||
// TypeAlias describes alias ast for api syntax
|
||||
TypeAlias struct {
|
||||
Name Expr
|
||||
Assign Expr
|
||||
@@ -26,7 +26,7 @@ type (
|
||||
CommentExpr Expr
|
||||
}
|
||||
|
||||
// TypeStruct describes structure ast for api syatax
|
||||
// TypeStruct describes structure ast for api syntax
|
||||
TypeStruct struct {
|
||||
Name Expr
|
||||
Struct Expr
|
||||
@@ -225,7 +225,7 @@ func (v *ApiVisitor) VisitTypeBlockAlias(ctx *api.TypeBlockAliasContext) interfa
|
||||
alias.DocExpr = v.getDoc(ctx)
|
||||
alias.CommentExpr = v.getComment(ctx)
|
||||
// todo: reopen if necessary
|
||||
v.panic(alias.Name, "unsupport alias")
|
||||
v.panic(alias.Name, "unsupported alias")
|
||||
return &alias
|
||||
}
|
||||
|
||||
@@ -238,7 +238,7 @@ func (v *ApiVisitor) VisitTypeAlias(ctx *api.TypeAliasContext) interface{} {
|
||||
alias.DocExpr = v.getDoc(ctx)
|
||||
alias.CommentExpr = v.getComment(ctx)
|
||||
// todo: reopen if necessary
|
||||
v.panic(alias.Name, "unsupport alias")
|
||||
v.panic(alias.Name, "unsupported alias")
|
||||
return &alias
|
||||
}
|
||||
|
||||
@@ -319,7 +319,7 @@ func (v *ApiVisitor) VisitDataType(ctx *api.DataTypeContext) interface{} {
|
||||
if ctx.GetTime() != nil {
|
||||
// todo: reopen if it is necessary
|
||||
timeExpr := v.newExprWithToken(ctx.GetTime())
|
||||
v.panic(timeExpr, "unsupport time.Time")
|
||||
v.panic(timeExpr, "unsupported time.Time")
|
||||
return &Time{Literal: timeExpr}
|
||||
}
|
||||
if ctx.PointerType() != nil {
|
||||
|
||||
@@ -219,9 +219,9 @@ func (p parser) fillService() error {
|
||||
|
||||
for _, astRoute := range item.ServiceApi.ServiceRoute {
|
||||
route := spec.Route{
|
||||
Annotation: spec.Annotation{},
|
||||
Method: astRoute.Route.Method.Text(),
|
||||
Path: astRoute.Route.Path.Text(),
|
||||
AtServerAnnotation: spec.Annotation{},
|
||||
Method: astRoute.Route.Method.Text(),
|
||||
Path: astRoute.Route.Path.Text(),
|
||||
}
|
||||
if astRoute.AtHandler != nil {
|
||||
route.Handler = astRoute.AtHandler.Name.Text()
|
||||
@@ -275,7 +275,7 @@ func (p parser) fillRouteAtServer(astRoute *ast.ServiceRoute, route *spec.Route)
|
||||
for _, kv := range astRoute.AtServer.Kv {
|
||||
properties[kv.Key.Text()] = kv.Value.Text()
|
||||
}
|
||||
route.Annotation.Properties = properties
|
||||
route.AtServerAnnotation.Properties = properties
|
||||
if len(route.Handler) == 0 {
|
||||
route.Handler = properties["handler"]
|
||||
}
|
||||
|
||||
@@ -11,10 +11,11 @@ import (
|
||||
const (
|
||||
bodyTagKey = "json"
|
||||
formTagKey = "form"
|
||||
pathTagKey = "path"
|
||||
defaultSummaryKey = "summary"
|
||||
)
|
||||
|
||||
var definedKeys = []string{bodyTagKey, formTagKey, "path"}
|
||||
var definedKeys = []string{bodyTagKey, formTagKey, pathTagKey}
|
||||
|
||||
// Routes returns all routes in api service
|
||||
func (s Service) Routes() []Route {
|
||||
@@ -25,7 +26,7 @@ func (s Service) Routes() []Route {
|
||||
return result
|
||||
}
|
||||
|
||||
// Tags retuens all tags in Member
|
||||
// Tags returns all tags in Member
|
||||
func (m Member) Tags() []*Tag {
|
||||
tags, err := Parse(m.Tag)
|
||||
if err != nil {
|
||||
@@ -141,7 +142,7 @@ func (t DefineStruct) GetFormMembers() []Member {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetNonBodyMembers retruns all have no tag fields
|
||||
// GetNonBodyMembers returns all have no tag fields
|
||||
func (t DefineStruct) GetNonBodyMembers() []Member {
|
||||
var result []Member
|
||||
for _, member := range t.Members {
|
||||
@@ -162,16 +163,16 @@ func (r Route) JoinedDoc() string {
|
||||
return strings.TrimSpace(doc)
|
||||
}
|
||||
|
||||
// GetAnnotation returns the value by specified key
|
||||
// GetAnnotation returns the value by specified key from @server
|
||||
func (r Route) GetAnnotation(key string) string {
|
||||
if r.Annotation.Properties == nil {
|
||||
if r.AtServerAnnotation.Properties == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return r.Annotation.Properties[key]
|
||||
return r.AtServerAnnotation.Properties[key]
|
||||
}
|
||||
|
||||
// GetAnnotation returns the value by specified key
|
||||
// GetAnnotation returns the value by specified key from @server
|
||||
func (g Group) GetAnnotation(key string) string {
|
||||
if g.Annotation.Properties == nil {
|
||||
return ""
|
||||
|
||||
@@ -63,14 +63,14 @@ type (
|
||||
|
||||
// Route describes api route
|
||||
Route struct {
|
||||
Annotation Annotation
|
||||
Method string
|
||||
Path string
|
||||
RequestType Type
|
||||
ResponseType Type
|
||||
Docs Doc
|
||||
Handler string
|
||||
AtDoc AtDoc
|
||||
AtServerAnnotation Annotation
|
||||
Method string
|
||||
Path string
|
||||
RequestType Type
|
||||
ResponseType Type
|
||||
Docs Doc
|
||||
Handler string
|
||||
AtDoc AtDoc
|
||||
}
|
||||
|
||||
// Service describes api service
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
// TsCommand provides the entry to generting typescript codes
|
||||
// TsCommand provides the entry to generate typescript codes
|
||||
func TsCommand(c *cli.Context) error {
|
||||
apiFile := c.String("api")
|
||||
dir := c.String("dir")
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/tal-tech/go-zero/tools/goctl/configgen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/docker"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/kube"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/mongo"
|
||||
model "github.com/tal-tech/go-zero/tools/goctl/model/sql/command"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/plugin"
|
||||
rpc "github.com/tal-tech/go-zero/tools/goctl/rpc/cli"
|
||||
@@ -28,7 +29,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
buildVersion = "1.1.5"
|
||||
buildVersion = "1.1.6"
|
||||
commands = []cli.Command{
|
||||
{
|
||||
Name: "upgrade",
|
||||
@@ -447,6 +448,29 @@ var (
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "mongo",
|
||||
Usage: `generate mongo model`,
|
||||
Flags: []cli.Flag{
|
||||
cli.StringSliceFlag{
|
||||
Name: "type, t",
|
||||
Usage: "specified model type name",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "cache, c",
|
||||
Usage: "generate code with cache [optional]",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "dir, d",
|
||||
Usage: "the target dir",
|
||||
},
|
||||
cli.StringFlag{
|
||||
Name: "style",
|
||||
Usage: "the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]",
|
||||
},
|
||||
},
|
||||
Action: mongo.Action,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
69
tools/goctl/model/mongo/generate/generate.go
Normal file
69
tools/goctl/model/mongo/generate/generate.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package generate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/format"
|
||||
)
|
||||
|
||||
// Context defines the model generation data what they needs
|
||||
type Context struct {
|
||||
Types []string
|
||||
Cache bool
|
||||
Output string
|
||||
Cfg *config.Config
|
||||
}
|
||||
|
||||
// Do executes model template and output the result into the specified file path
|
||||
func Do(ctx *Context) error {
|
||||
if ctx.Cfg == nil {
|
||||
return errors.New("missing config")
|
||||
}
|
||||
|
||||
err := generateModel(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return generateError(ctx)
|
||||
}
|
||||
|
||||
func generateModel(ctx *Context) error {
|
||||
for _, t := range ctx.Types {
|
||||
fn, err := format.FileNamingFormat(ctx.Cfg.NamingFormat, t+"_model")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
text, err := util.LoadTemplate(category, modelTemplateFile, template.Text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output := filepath.Join(ctx.Output, fn+".go")
|
||||
err = util.With("model").Parse(text).GoFmt(true).SaveTo(map[string]interface{}{
|
||||
"Type": t,
|
||||
"Cache": ctx.Cache,
|
||||
}, output, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateError(ctx *Context) error {
|
||||
text, err := util.LoadTemplate(category, errTemplateFile, template.Error)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
output := filepath.Join(ctx.Output, "error.go")
|
||||
|
||||
return util.With("error").Parse(text).GoFmt(true).SaveTo(ctx, output, false)
|
||||
}
|
||||
34
tools/goctl/model/mongo/generate/generate_test.go
Normal file
34
tools/goctl/model/mongo/generate/generate_test.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package generate
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||
)
|
||||
|
||||
var testTypes = `
|
||||
type User struct{}
|
||||
type Class struct{}
|
||||
`
|
||||
|
||||
func TestDo(t *testing.T) {
|
||||
cfg, err := config.NewConfig(config.DefaultFormat)
|
||||
assert.Nil(t, err)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
typesfile := filepath.Join(tempDir, "types.go")
|
||||
err = ioutil.WriteFile(typesfile, []byte(testTypes), 0666)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = Do(&Context{
|
||||
Types: []string{"User", "Class"},
|
||||
Cache: false,
|
||||
Output: tempDir,
|
||||
Cfg: cfg,
|
||||
})
|
||||
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
55
tools/goctl/model/mongo/generate/template.go
Normal file
55
tools/goctl/model/mongo/generate/template.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package generate
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
const (
|
||||
category = "mongo"
|
||||
modelTemplateFile = "model.tpl"
|
||||
errTemplateFile = "err.tpl"
|
||||
)
|
||||
|
||||
var templates = map[string]string{
|
||||
modelTemplateFile: template.Text,
|
||||
errTemplateFile: template.Error,
|
||||
}
|
||||
|
||||
// Category returns the mongo category.
|
||||
func Category() string {
|
||||
return category
|
||||
}
|
||||
|
||||
// Clean cleans the mongo templates.
|
||||
func Clean() error {
|
||||
return util.Clean(category)
|
||||
}
|
||||
|
||||
// Templates initializes the mongo templates.
|
||||
func Templates(_ *cli.Context) error {
|
||||
return util.InitTemplates(category, templates)
|
||||
}
|
||||
|
||||
// RevertTemplate reverts the given template.
|
||||
func RevertTemplate(name string) error {
|
||||
content, ok := templates[name]
|
||||
if !ok {
|
||||
return fmt.Errorf("%s: no such file name", name)
|
||||
}
|
||||
|
||||
return util.CreateTemplate(category, name, content)
|
||||
}
|
||||
|
||||
// Update cleans and updates the templates.
|
||||
func Update() error {
|
||||
err := Clean()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return util.InitTemplates(category, templates)
|
||||
}
|
||||
39
tools/goctl/model/mongo/mongo.go
Normal file
39
tools/goctl/model/mongo/mongo.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/mongo/generate"
|
||||
"github.com/urfave/cli"
|
||||
)
|
||||
|
||||
// Action provides the entry for goctl mongo code generation.
|
||||
func Action(ctx *cli.Context) error {
|
||||
tp := ctx.StringSlice("type")
|
||||
c := ctx.Bool("cache")
|
||||
o := strings.TrimSpace(ctx.String("dir"))
|
||||
s := ctx.String("style")
|
||||
if len(tp) == 0 {
|
||||
return errors.New("missing type")
|
||||
}
|
||||
|
||||
cfg, err := config.NewConfig(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
a, err := filepath.Abs(o)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return generate.Do(&generate.Context{
|
||||
Types: tp,
|
||||
Cache: c,
|
||||
Output: a,
|
||||
Cfg: cfg,
|
||||
})
|
||||
}
|
||||
210
tools/goctl/model/mongo/readme.md
Normal file
210
tools/goctl/model/mongo/readme.md
Normal file
@@ -0,0 +1,210 @@
|
||||
# mongo生成model
|
||||
|
||||
## 背景
|
||||
|
||||
在业务务开发中,model(dao)数据访问层是一个服务必不可缺的一层,因此数据库访问的CURD也是必须要对外提供的访问方法, 而CURD在go-zero中就仅存在两种情况
|
||||
|
||||
* 带缓存model
|
||||
* 不带缓存model
|
||||
|
||||
从代码结构上来看,C-U-R-D四个方法就是固定的结构,因此我们可以将其交给goctl工具去完成,帮助我们提升开发效率。
|
||||
|
||||
## 方案设计
|
||||
|
||||
mongo的生成不同于mysql,mysql可以从scheme_information库中读取到一张表的信息(字段名称,数据类型,索引等),
|
||||
而mongo是文档型数据库,我们暂时无法从db中读取某一条记录来实现字段信息获取,就算有也不一定是完整信息(某些字段可能是omitempty修饰,可有可无), 这里采用type自己编写+代码生成方式实现
|
||||
|
||||
## 使用示例
|
||||
|
||||
假设我们需要生成一个usermodel.go的代码文件,其包含用户信息字段有
|
||||
|
||||
|字段名称|字段类型|
|
||||
|---|---|
|
||||
|_id|bson.ObejctId|
|
||||
|name|string|
|
||||
|
||||
### 编写types.go
|
||||
|
||||
```shell
|
||||
$ vim types.go
|
||||
```
|
||||
|
||||
```golang
|
||||
package model
|
||||
|
||||
//go:generate goctl model mongo -t User
|
||||
import "github.com/globalsign/mgo/bson"
|
||||
|
||||
type User struct {
|
||||
ID bson.ObjectId `bson:"_id"`
|
||||
Name string `bson:"name"`
|
||||
}
|
||||
```
|
||||
|
||||
### 生成代码
|
||||
|
||||
生成代码的方式有两种
|
||||
|
||||
* 命令行生成 在types.go所在文件夹执行命令
|
||||
```shell
|
||||
$ goctl model mongo -t User -style gozero
|
||||
```
|
||||
* 在types.go中添加`//go:generate`,然后点击执行按钮即可生成,内容示例如下:
|
||||
```golang
|
||||
//go:generate goctl model mongo -t User
|
||||
```
|
||||
|
||||
### 生成示例代码
|
||||
|
||||
* usermodel.go
|
||||
|
||||
```golang
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/globalsign/mgo/bson"
|
||||
cachec "github.com/tal-tech/go-zero/core/stores/cache"
|
||||
"github.com/tal-tech/go-zero/core/stores/mongoc"
|
||||
)
|
||||
|
||||
type UserModel interface {
|
||||
Insert(data *User, ctx context.Context) error
|
||||
FindOne(id string, ctx context.Context) (*User, error)
|
||||
Update(data *User, ctx context.Context) error
|
||||
Delete(id string, ctx context.Context) error
|
||||
}
|
||||
|
||||
type defaultUserModel struct {
|
||||
*mongoc.Model
|
||||
}
|
||||
|
||||
func NewUserModel(url, collection string, c cachec.CacheConf) UserModel {
|
||||
return &defaultUserModel{
|
||||
Model: mongoc.MustNewModel(url, collection, c),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) Insert(data *User, ctx context.Context) error {
|
||||
if !data.ID.Valid() {
|
||||
data.ID = bson.NewObjectId()
|
||||
}
|
||||
|
||||
session, err := m.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer m.PutSession(session)
|
||||
return m.GetCollection(session).Insert(data)
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) FindOne(id string, ctx context.Context) (*User, error) {
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return nil, ErrInvalidObjectId
|
||||
}
|
||||
|
||||
session, err := m.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer m.PutSession(session)
|
||||
var data User
|
||||
|
||||
err = m.GetCollection(session).FindOneIdNoCache(&data, bson.ObjectIdHex(id))
|
||||
switch err {
|
||||
case nil:
|
||||
return &data, nil
|
||||
case mongoc.ErrNotFound:
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) Update(data *User, ctx context.Context) error {
|
||||
session, err := m.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer m.PutSession(session)
|
||||
|
||||
return m.GetCollection(session).UpdateIdNoCache(data.ID, data)
|
||||
}
|
||||
|
||||
func (m *defaultUserModel) Delete(id string, ctx context.Context) error {
|
||||
session, err := m.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer m.PutSession(session)
|
||||
|
||||
return m.GetCollection(session).RemoveIdNoCache(bson.ObjectIdHex(id))
|
||||
}
|
||||
```
|
||||
|
||||
* error.go
|
||||
|
||||
```golang
|
||||
package model
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
var ErrInvalidObjectId = errors.New("invalid objectId")
|
||||
```
|
||||
|
||||
### 文件目录预览
|
||||
|
||||
```text
|
||||
.
|
||||
├── error.go
|
||||
├── types.go
|
||||
└── usermodel.go
|
||||
|
||||
```
|
||||
|
||||
## 命令预览
|
||||
|
||||
```text
|
||||
NAME:
|
||||
goctl model - generate model code
|
||||
|
||||
USAGE:
|
||||
goctl model command [command options] [arguments...]
|
||||
|
||||
COMMANDS:
|
||||
mysql generate mysql model
|
||||
mongo generate mongo model
|
||||
|
||||
OPTIONS:
|
||||
--help, -h show help
|
||||
```
|
||||
|
||||
```text
|
||||
NAME:
|
||||
goctl model mongo - generate mongo model
|
||||
|
||||
USAGE:
|
||||
goctl model mongo [command options] [arguments...]
|
||||
|
||||
OPTIONS:
|
||||
--type value, -t value specified model type name
|
||||
--cache, -c generate code with cache [optional]
|
||||
--dir value, -d value the target dir
|
||||
--style value the file naming format, see [https://github.com/tal-tech/go-zero/tree/master/tools/goctl/config/readme.md]
|
||||
|
||||
```
|
||||
|
||||
> 温馨提示
|
||||
>
|
||||
> `--type` 支持slice传值,示例 `goctl model mongo -t=User -t=Class`
|
||||
## 注意事项
|
||||
|
||||
types.go本质上与xxxmodel.go无关,只是将type定义部分交给开发人员自己编写了,在xxxmodel.go中,mongo文档的存储结构必须包含
|
||||
`_id`字段,对应到types中的field为`ID`,model中的findOne,update均以data.ID来进行操作的,当然,如果不符合你的命名风格,你也 可以修改模板,只要保证`id`
|
||||
在types中的field名称和模板中一致就行。
|
||||
112
tools/goctl/model/mongo/template/template.go
Normal file
112
tools/goctl/model/mongo/template/template.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package template
|
||||
|
||||
// Text provides the default template for model to generate
|
||||
var Text = `package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/globalsign/mgo/bson"
|
||||
cachec "github.com/tal-tech/go-zero/core/stores/cache"
|
||||
"github.com/tal-tech/go-zero/core/stores/mongoc"
|
||||
)
|
||||
|
||||
{{if .Cache}}var prefix{{.Type}}CacheKey = "cache#{{.Type}}#"{{end}}
|
||||
|
||||
type {{.Type}}Model interface{
|
||||
Insert(ctx context.Context,data *{{.Type}}) error
|
||||
FindOne(ctx context.Context,id string) (*{{.Type}}, error)
|
||||
Update(ctx context.Context,data *{{.Type}}) error
|
||||
Delete(ctx context.Context,id string) error
|
||||
}
|
||||
|
||||
type default{{.Type}}Model struct {
|
||||
*mongoc.Model
|
||||
}
|
||||
|
||||
func New{{.Type}}Model(url, collection string, c cachec.CacheConf) {{.Type}}Model {
|
||||
return &default{{.Type}}Model{
|
||||
Model: mongoc.MustNewModel(url, collection, c),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func (m *default{{.Type}}Model) Insert(ctx context.Context, data *{{.Type}}) error {
|
||||
if !data.ID.Valid() {
|
||||
data.ID = bson.NewObjectId()
|
||||
}
|
||||
|
||||
session, err := m.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer m.PutSession(session)
|
||||
return m.GetCollection(session).Insert(data)
|
||||
}
|
||||
|
||||
func (m *default{{.Type}}Model) FindOne(ctx context.Context, id string) (*{{.Type}}, error) {
|
||||
if !bson.IsObjectIdHex(id) {
|
||||
return nil, ErrInvalidObjectId
|
||||
}
|
||||
|
||||
session, err := m.TakeSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer m.PutSession(session)
|
||||
var data {{.Type}}
|
||||
{{if .Cache}}key := prefix{{.Type}}CacheKey + id
|
||||
err = m.GetCollection(session).FindOneId(&data, key, bson.ObjectIdHex(id))
|
||||
{{- else}}
|
||||
err = m.GetCollection(session).FindOneIdNoCache(&data, bson.ObjectIdHex(id))
|
||||
{{- end}}
|
||||
switch err {
|
||||
case nil:
|
||||
return &data,nil
|
||||
case mongoc.ErrNotFound:
|
||||
return nil,ErrNotFound
|
||||
default:
|
||||
return nil,err
|
||||
}
|
||||
}
|
||||
|
||||
func (m *default{{.Type}}Model) Update(ctx context.Context, data *{{.Type}}) error {
|
||||
session, err := m.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer m.PutSession(session)
|
||||
{{if .Cache}}key := prefix{{.Type}}CacheKey + data.ID.Hex()
|
||||
return m.GetCollection(session).UpdateId(data.ID, data, key)
|
||||
{{- else}}
|
||||
return m.GetCollection(session).UpdateIdNoCache(data.ID, data)
|
||||
{{- end}}
|
||||
}
|
||||
|
||||
func (m *default{{.Type}}Model) Delete(ctx context.Context, id string) error {
|
||||
session, err := m.TakeSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer m.PutSession(session)
|
||||
{{if .Cache}}key := prefix{{.Type}}CacheKey + id
|
||||
return m.GetCollection(session).RemoveId(bson.ObjectIdHex(id), key)
|
||||
{{- else}}
|
||||
return m.GetCollection(session).RemoveIdNoCache(bson.ObjectIdHex(id))
|
||||
{{- end}}
|
||||
}
|
||||
`
|
||||
|
||||
// Error provides the default template for error definition in mongo code generation.
|
||||
var Error = `
|
||||
package model
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrNotFound = errors.New("not found")
|
||||
var ErrInvalidObjectId = errors.New("invalid objectId")
|
||||
`
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/config"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
@@ -19,7 +21,10 @@ var (
|
||||
)
|
||||
|
||||
func TestFromDDl(t *testing.T) {
|
||||
err := fromDDl("./user.sql", t.TempDir(), cfg, true, false)
|
||||
err := gen.Clean()
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = fromDDl("./user.sql", t.TempDir(), cfg, true, false)
|
||||
assert.Equal(t, errNotMatched, err)
|
||||
|
||||
// case dir is not exists
|
||||
|
||||
@@ -25,27 +25,7 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
||||
var list []string
|
||||
camelTableName := table.Name.ToCamel()
|
||||
for _, key := range table.UniqueCacheKey {
|
||||
var inJoin, paramJoin, argJoin Join
|
||||
for _, f := range key.Fields {
|
||||
param := stringx.From(f.Name.ToCamel()).Untitle()
|
||||
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
||||
paramJoin = append(paramJoin, param)
|
||||
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
||||
}
|
||||
var in string
|
||||
if len(inJoin) > 0 {
|
||||
in = inJoin.With(", ").Source()
|
||||
}
|
||||
|
||||
var paramJoinString string
|
||||
if len(paramJoin) > 0 {
|
||||
paramJoinString = paramJoin.With(",").Source()
|
||||
}
|
||||
|
||||
var originalFieldString string
|
||||
if len(argJoin) > 0 {
|
||||
originalFieldString = argJoin.With(" and ").Source()
|
||||
}
|
||||
in, paramJoinString, originalFieldString := convertJoin(key)
|
||||
|
||||
output, err := t.Execute(map[string]interface{}{
|
||||
"upperStartCamelObject": camelTableName,
|
||||
@@ -125,3 +105,25 @@ func genFindOneByField(table Table, withCache bool) (*findOneCode, error) {
|
||||
findOneInterfaceMethod: strings.Join(listMethod, util.NL),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertJoin(key Key) (in, paramJoinString, originalFieldString string) {
|
||||
var inJoin, paramJoin, argJoin Join
|
||||
for _, f := range key.Fields {
|
||||
param := stringx.From(f.Name.ToCamel()).Untitle()
|
||||
inJoin = append(inJoin, fmt.Sprintf("%s %s", param, f.DataType))
|
||||
paramJoin = append(paramJoin, param)
|
||||
argJoin = append(argJoin, fmt.Sprintf("%s = ?", wrapWithRawString(f.Name.Source())))
|
||||
}
|
||||
if len(inJoin) > 0 {
|
||||
in = inJoin.With(", ").Source()
|
||||
}
|
||||
|
||||
if len(paramJoin) > 0 {
|
||||
paramJoinString = paramJoin.With(",").Source()
|
||||
}
|
||||
|
||||
if len(argJoin) > 0 {
|
||||
originalFieldString = argJoin.With(" and ").Source()
|
||||
}
|
||||
return in, paramJoinString, originalFieldString
|
||||
}
|
||||
|
||||
@@ -11,15 +11,15 @@ import (
|
||||
|
||||
// Key describes cache key
|
||||
type Key struct {
|
||||
// VarLeft describes the varible of cache key expression which likes cacheUserIdPrefix
|
||||
// VarLeft describes the variable of cache key expression which likes cacheUserIdPrefix
|
||||
VarLeft string
|
||||
// VarRight describes the value of cache key expression which likes "cache#user#id#"
|
||||
VarRight string
|
||||
// VarExpression describes the cache key expression which likes cacheUserIdPrefix = "cache#user#id#"
|
||||
VarExpression string
|
||||
// KeyLeft describes the varible of key definiation expression which likes userKey
|
||||
// KeyLeft describes the variable of key definition expression which likes userKey
|
||||
KeyLeft string
|
||||
// KeyRight describes the value of key definiation expression which likes fmt.Sprintf("%s%v", cacheUserPrefix, user)
|
||||
// KeyRight describes the value of key definition expression which likes fmt.Sprintf("%s%v", cacheUserPrefix, user)
|
||||
KeyRight string
|
||||
// DataKeyRight describes data key likes fmt.Sprintf("%s%v", cacheUserPrefix, data.User)
|
||||
DataKeyRight string
|
||||
|
||||
@@ -3,6 +3,7 @@ package gen
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/collection"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/model/sql/template"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util/stringx"
|
||||
@@ -23,6 +24,15 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
|
||||
expressionValues = append(expressionValues, "data."+camel)
|
||||
}
|
||||
|
||||
keySet := collection.NewSet()
|
||||
keyVariableSet := collection.NewSet()
|
||||
keySet.AddStr(table.PrimaryCacheKey.DataKeyExpression)
|
||||
keyVariableSet.AddStr(table.PrimaryCacheKey.KeyLeft)
|
||||
for _, key := range table.UniqueCacheKey {
|
||||
keySet.AddStr(key.DataKeyExpression)
|
||||
keyVariableSet.AddStr(key.KeyLeft)
|
||||
}
|
||||
|
||||
expressionValues = append(expressionValues, "data."+table.PrimaryKey.Name.ToCamel())
|
||||
camelTableName := table.Name.ToCamel()
|
||||
text, err := util.LoadTemplate(category, updateTemplateFile, template.Update)
|
||||
@@ -35,6 +45,8 @@ func genUpdate(table Table, withCache bool) (string, string, error) {
|
||||
Execute(map[string]interface{}{
|
||||
"withCache": withCache,
|
||||
"upperStartCamelObject": camelTableName,
|
||||
"keys": strings.Join(keySet.KeysStr(), "\n"),
|
||||
"keyValues": strings.Join(keyVariableSet.KeysStr(), ", "),
|
||||
"primaryCacheKey": table.PrimaryCacheKey.DataKeyExpression,
|
||||
"primaryKeyVariable": table.PrimaryCacheKey.KeyLeft,
|
||||
"lowerStartCamelObject": stringx.From(camelTableName).Untitle(),
|
||||
|
||||
@@ -102,6 +102,17 @@ func Parse(ddl string) (*Table, error) {
|
||||
}
|
||||
}
|
||||
|
||||
checkDuplicateUniqueIndex(uniqueIndex, tableName, normalIndex)
|
||||
return &Table{
|
||||
Name: stringx.From(tableName),
|
||||
PrimaryKey: primaryKey,
|
||||
UniqueIndex: uniqueIndex,
|
||||
NormalIndex: normalIndex,
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func checkDuplicateUniqueIndex(uniqueIndex map[string][]*Field, tableName string, normalIndex map[string][]*Field) {
|
||||
log := console.NewColorConsole()
|
||||
uniqueSet := collection.NewSet()
|
||||
for k, i := range uniqueIndex {
|
||||
@@ -136,14 +147,6 @@ func Parse(ddl string) (*Table, error) {
|
||||
|
||||
normalIndexSet.Add(joinRet)
|
||||
}
|
||||
|
||||
return &Table{
|
||||
Name: stringx.From(tableName),
|
||||
PrimaryKey: primaryKey,
|
||||
UniqueIndex: uniqueIndex,
|
||||
NormalIndex: normalIndex,
|
||||
Fields: fields,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func convertColumns(columns []*sqlparser.ColumnDefinition, primaryColumn string) (Primary, map[string]*Field, error) {
|
||||
@@ -289,27 +292,9 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
||||
AutoIncrement: strings.Contains(table.PrimaryKey.Extra, "auto_increment"),
|
||||
}
|
||||
|
||||
fieldM := make(map[string]*Field)
|
||||
for _, each := range table.Columns {
|
||||
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
||||
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
columnSeqInIndex := 0
|
||||
if each.Index != nil {
|
||||
columnSeqInIndex = each.Index.SeqInIndex
|
||||
}
|
||||
|
||||
field := &Field{
|
||||
Name: stringx.From(each.Name),
|
||||
DataBaseType: each.DataType,
|
||||
DataType: dt,
|
||||
Comment: each.Comment,
|
||||
SeqInIndex: columnSeqInIndex,
|
||||
OrdinalPosition: each.OrdinalPosition,
|
||||
}
|
||||
fieldM[each.Name] = field
|
||||
fieldM, err := getTableFields(table)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, each := range fieldM {
|
||||
@@ -379,3 +364,29 @@ func ConvertDataType(table *model.Table) (*Table, error) {
|
||||
|
||||
return &reply, nil
|
||||
}
|
||||
|
||||
func getTableFields(table *model.Table) (map[string]*Field, error) {
|
||||
fieldM := make(map[string]*Field)
|
||||
for _, each := range table.Columns {
|
||||
isDefaultNull := each.ColumnDefault == nil && each.IsNullAble == "YES"
|
||||
dt, err := converter.ConvertDataType(each.DataType, isDefaultNull)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
columnSeqInIndex := 0
|
||||
if each.Index != nil {
|
||||
columnSeqInIndex = each.Index.SeqInIndex
|
||||
}
|
||||
|
||||
field := &Field{
|
||||
Name: stringx.From(each.Name),
|
||||
DataBaseType: each.DataType,
|
||||
DataType: dt,
|
||||
Comment: each.Comment,
|
||||
SeqInIndex: columnSeqInIndex,
|
||||
OrdinalPosition: each.OrdinalPosition,
|
||||
}
|
||||
fieldM[each.Name] = field
|
||||
}
|
||||
return fieldM, nil
|
||||
}
|
||||
|
||||
@@ -3,11 +3,11 @@ package template
|
||||
// Update defines a template for generating update codes
|
||||
var Update = `
|
||||
func (m *default{{.upperStartCamelObject}}Model) Update(data {{.upperStartCamelObject}}) error {
|
||||
{{if .withCache}}{{.primaryCacheKey}}
|
||||
{{if .withCache}}{{.keys}}
|
||||
_, err := m.Exec(func(conn sqlx.SqlConn) (result sql.Result, err error) {
|
||||
query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
|
||||
return conn.Exec(query, {{.expressionValues}})
|
||||
}, {{.primaryKeyVariable}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
|
||||
}, {{.keyValues}}){{else}}query := fmt.Sprintf("update %s set %s where {{.originalPrimaryKey}} = ?", m.table, {{.lowerStartCamelObject}}RowsWithPlaceHolder)
|
||||
_,err:=m.conn.Exec(query, {{.expressionValues}}){{end}}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/tal-tech/go-zero/tools/goctl/api/gogen"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/docker"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/kube"
|
||||
mongogen "github.com/tal-tech/go-zero/tools/goctl/model/mongo/generate"
|
||||
modelgen "github.com/tal-tech/go-zero/tools/goctl/model/sql/gen"
|
||||
rpcgen "github.com/tal-tech/go-zero/tools/goctl/rpc/generator"
|
||||
"github.com/tal-tech/go-zero/tools/goctl/util"
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
|
||||
const templateParentPath = "/"
|
||||
|
||||
// GenTemplates wtites the latest template text into file which is not exists
|
||||
// GenTemplates writes the latest template text into file which is not exists
|
||||
func GenTemplates(ctx *cli.Context) error {
|
||||
if err := errorx.Chain(
|
||||
func() error {
|
||||
@@ -34,6 +35,9 @@ func GenTemplates(ctx *cli.Context) error {
|
||||
func() error {
|
||||
return kube.GenTemplates(ctx)
|
||||
},
|
||||
func() error {
|
||||
return mongogen.Templates(ctx)
|
||||
},
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -61,6 +65,15 @@ func CleanTemplates(_ *cli.Context) error {
|
||||
func() error {
|
||||
return rpcgen.Clean()
|
||||
},
|
||||
func() error {
|
||||
return docker.Clean()
|
||||
},
|
||||
func() error {
|
||||
return kube.Clean()
|
||||
},
|
||||
func() error {
|
||||
return mongogen.Clean()
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -70,8 +83,8 @@ func CleanTemplates(_ *cli.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTemplates wtites the latest template text into file,
|
||||
// it will delete the oldler templates if there are exists
|
||||
// UpdateTemplates writes the latest template text into file,
|
||||
// it will delete the older templates if there are exists
|
||||
func UpdateTemplates(ctx *cli.Context) (err error) {
|
||||
category := ctx.String("category")
|
||||
defer func() {
|
||||
@@ -90,6 +103,8 @@ func UpdateTemplates(ctx *cli.Context) (err error) {
|
||||
return rpcgen.Update()
|
||||
case modelgen.Category():
|
||||
return modelgen.Update()
|
||||
case mongogen.Category():
|
||||
return mongogen.Update()
|
||||
default:
|
||||
err = fmt.Errorf("unexpected category: %s", category)
|
||||
return
|
||||
@@ -116,6 +131,8 @@ func RevertTemplates(ctx *cli.Context) (err error) {
|
||||
return rpcgen.RevertTemplate(filename)
|
||||
case modelgen.Category():
|
||||
return modelgen.RevertTemplate(filename)
|
||||
case mongogen.Category():
|
||||
return mongogen.RevertTemplate(filename)
|
||||
default:
|
||||
err = fmt.Errorf("unexpected category: %s", category)
|
||||
return
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
type (
|
||||
// Console wraps from the fmt.Sprintf,
|
||||
// by default, it implemented the colorConsole to provide the colorful output to the consle
|
||||
// by default, it implemented the colorConsole to provide the colorful output to the console
|
||||
// and the ideaConsole to output with prefix for the plugin of intellij
|
||||
Console interface {
|
||||
Success(format string, a ...interface{})
|
||||
@@ -81,7 +81,7 @@ func (c *colorConsole) Must(err error) {
|
||||
}
|
||||
}
|
||||
|
||||
// NewIdeaConsole returns a instace of ideaConsole
|
||||
// NewIdeaConsole returns a instance of ideaConsole
|
||||
func NewIdeaConsole() Console {
|
||||
return &ideaConsole{}
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func CreateIfNotExist(file string) (*os.File, error) {
|
||||
return os.Create(file)
|
||||
}
|
||||
|
||||
// RemoveIfExist deletes the specficed file if it is exists
|
||||
// RemoveIfExist deletes the specified file if it is exists
|
||||
func RemoveIfExist(filename string) error {
|
||||
if !FileExists(filename) {
|
||||
return nil
|
||||
@@ -36,7 +36,7 @@ func RemoveIfExist(filename string) error {
|
||||
return os.Remove(filename)
|
||||
}
|
||||
|
||||
// RemoveOrQuit deletes the specficed file if read a permit command from stdin
|
||||
// RemoveOrQuit deletes the specified file if read a permit command from stdin
|
||||
func RemoveOrQuit(filename string) error {
|
||||
if !FileExists(filename) {
|
||||
return nil
|
||||
@@ -49,7 +49,7 @@ func RemoveOrQuit(filename string) error {
|
||||
return os.Remove(filename)
|
||||
}
|
||||
|
||||
// FileExists returns true if the specficed file is exists
|
||||
// FileExists returns true if the specified file is exists
|
||||
func FileExists(file string) bool {
|
||||
_, err := os.Stat(file)
|
||||
return err == nil
|
||||
|
||||
@@ -18,7 +18,7 @@ const (
|
||||
upper
|
||||
)
|
||||
|
||||
// ErrNamingFormat defines an error for unknown fomat
|
||||
// ErrNamingFormat defines an error for unknown format
|
||||
var ErrNamingFormat = errors.New("unsupported format")
|
||||
|
||||
type (
|
||||
|
||||
@@ -33,7 +33,7 @@ func MkdirIfNotExist(dir string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PathFromGoSrc returns the path whihout slash where has been trim the prefix $GOPATH
|
||||
// PathFromGoSrc returns the path without slash where has been trim the prefix $GOPATH
|
||||
func PathFromGoSrc() (string, error) {
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// String provides for coverting the source text into other spell case,like lower,snake,camel
|
||||
// String provides for converting the source text into other spell case,like lower,snake,camel
|
||||
type String struct {
|
||||
source string
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ type DefaultTemplate struct {
|
||||
savePath string
|
||||
}
|
||||
|
||||
// With returns a instace of DefaultTemplate
|
||||
// With returns a instance of DefaultTemplate
|
||||
func With(name string) *DefaultTemplate {
|
||||
return &DefaultTemplate{
|
||||
name: name,
|
||||
@@ -30,7 +30,7 @@ func (t *DefaultTemplate) Parse(text string) *DefaultTemplate {
|
||||
return t
|
||||
}
|
||||
|
||||
// GoFmt sets the value to goFmt and marks the generated codes will be formated or not
|
||||
// GoFmt sets the value to goFmt and marks the generated codes will be formatted or not
|
||||
func (t *DefaultTemplate) GoFmt(format bool) *DefaultTemplate {
|
||||
t.goFmt = format
|
||||
return t
|
||||
|
||||
@@ -3,7 +3,7 @@ package vars
|
||||
const (
|
||||
// ProjectName the const value of zero
|
||||
ProjectName = "zero"
|
||||
// ProjectOpenSourceURL the githb url of go-zero
|
||||
// ProjectOpenSourceURL the github url of go-zero
|
||||
ProjectOpenSourceURL = "github.com/tal-tech/go-zero"
|
||||
// OsWindows windows os
|
||||
OsWindows = "windows"
|
||||
|
||||
@@ -18,6 +18,27 @@ func TimeoutInterceptor(timeout time.Duration) grpc.UnaryClientInterceptor {
|
||||
|
||||
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
|
||||
defer cancel()
|
||||
return invoker(ctx, method, req, reply, cc, opts...)
|
||||
|
||||
// create channel with buffer size 1 to avoid goroutine leak
|
||||
done := make(chan error, 1)
|
||||
panicChan := make(chan interface{}, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicChan <- p
|
||||
}
|
||||
}()
|
||||
|
||||
done <- invoker(ctx, method, req, reply, cc, opts...)
|
||||
}()
|
||||
|
||||
select {
|
||||
case p := <-panicChan:
|
||||
panic(p)
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,3 +48,40 @@ func TestTimeoutInterceptor_timeout(t *testing.T) {
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestTimeoutInterceptor_timeoutExpire(t *testing.T) {
|
||||
const timeout = time.Millisecond * 10
|
||||
interceptor := TimeoutInterceptor(timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
cc := new(grpc.ClientConn)
|
||||
err := interceptor(ctx, "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
defer wg.Done()
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
return nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
}
|
||||
|
||||
func TestTimeoutInterceptor_panic(t *testing.T) {
|
||||
timeouts := []time.Duration{0, time.Millisecond * 10}
|
||||
for _, timeout := range timeouts {
|
||||
t.Run(strconv.FormatInt(int64(timeout), 10), func(t *testing.T) {
|
||||
interceptor := TimeoutInterceptor(timeout)
|
||||
cc := new(grpc.ClientConn)
|
||||
assert.Panics(t, func() {
|
||||
_ = interceptor(context.Background(), "/foo", nil, nil, cc,
|
||||
func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn,
|
||||
opts ...grpc.CallOption) error {
|
||||
panic("any")
|
||||
},
|
||||
)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package serverinterceptors
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tal-tech/go-zero/core/contextx"
|
||||
@@ -11,9 +12,38 @@ import (
|
||||
// UnaryTimeoutInterceptor returns a func that sets timeout to incoming unary requests.
|
||||
func UnaryTimeoutInterceptor(timeout time.Duration) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
handler grpc.UnaryHandler) (interface{}, error) {
|
||||
ctx, cancel := contextx.ShrinkDeadline(ctx, timeout)
|
||||
defer cancel()
|
||||
return handler(ctx, req)
|
||||
|
||||
var resp interface{}
|
||||
var err error
|
||||
var lock sync.Mutex
|
||||
done := make(chan struct{})
|
||||
// create channel with buffer size 1 to avoid goroutine leak
|
||||
panicChan := make(chan interface{}, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
panicChan <- p
|
||||
}
|
||||
}()
|
||||
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
resp, err = handler(ctx, req)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case p := <-panicChan:
|
||||
panic(p)
|
||||
case <-done:
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
return resp, err
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,6 +20,17 @@ func TestUnaryTimeoutInterceptor(t *testing.T) {
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestUnaryTimeoutInterceptor_panic(t *testing.T) {
|
||||
interceptor := UnaryTimeoutInterceptor(time.Millisecond * 10)
|
||||
assert.Panics(t, func() {
|
||||
_, _ = interceptor(context.Background(), nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
panic("any")
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
|
||||
const timeout = time.Millisecond * 10
|
||||
interceptor := UnaryTimeoutInterceptor(timeout)
|
||||
@@ -39,3 +50,21 @@ func TestUnaryTimeoutInterceptor_timeout(t *testing.T) {
|
||||
wg.Wait()
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestUnaryTimeoutInterceptor_timeoutExpire(t *testing.T) {
|
||||
const timeout = time.Millisecond * 10
|
||||
interceptor := UnaryTimeoutInterceptor(timeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
|
||||
defer cancel()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
_, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{
|
||||
FullMethod: "/",
|
||||
}, func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
defer wg.Done()
|
||||
time.Sleep(time.Millisecond * 50)
|
||||
return nil, nil
|
||||
})
|
||||
wg.Wait()
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
}
|
||||
|
||||
@@ -79,6 +79,8 @@ func (rs *RpcServer) AddUnaryInterceptors(interceptors ...grpc.UnaryServerInterc
|
||||
}
|
||||
|
||||
// Start starts the RpcServer.
|
||||
// Graceful shutdown is enabled by default.
|
||||
// Use proc.SetTimeToForceQuit to customize the graceful shutdown period.
|
||||
func (rs *RpcServer) Start() {
|
||||
if err := rs.server.Start(rs.register); err != nil {
|
||||
logx.Error(err)
|
||||
|
||||
Reference in New Issue
Block a user