Compare commits

...

15 Commits

Author SHA1 Message Date
kesonan
aeceb3cfbe Update goctl version to 1.8.3-beta (#4812) 2025-04-27 15:53:49 +00:00
kesonan
15ea07aad1 goctl: support custom swagger authentication (#4811) 2025-04-27 15:43:37 +00:00
shaouai
98bebbc74f feat(swagger): allow users to specify the generated swagger file name (#4809) 2025-04-27 15:34:52 +00:00
kesonan
eafd11d949 goctl: supported api types group for EXPERIMENTAL(实验性功能:支持 api type 结构体按照分组名称拆分文件) (#4810) 2025-04-27 15:18:33 +00:00
Kevin Wan
b251ce346e feat: mcp server sdk (#4794)
Signed-off-by: kevin <wanjunfeng@gmail.com>
2025-04-27 23:06:37 +08:00
kesonan
812140ba36 fix: goctl swagger missing security definition and submit json body data error (#4808) 2025-04-25 14:58:45 +00:00
kesonan
44735e949c fix array schmea generation incorrect (#4801) 2025-04-23 23:59:01 +00:00
kesonan
bf313c3c56 fix: swagger separator incorrect in Windows OS (#4799) 2025-04-23 13:51:02 +00:00
kesonan
94e7753262 fix: the parameter "required" in the Swagger document generated for repair is incorrect (#4791) 2025-04-21 04:18:02 +00:00
kesonan
9c478626d2 feature/goctl-api-swagger (#4780) 2025-04-17 14:38:55 +00:00
Kevin Wan
801c283478 Delete issue-translator.yml 2025-04-10 12:01:21 +08:00
Kevin Wan
2a54faf997 chore: coding style (#4771) 2025-04-10 09:28:42 +08:00
Hanggang Z
ecd98f3653 chore: add more orm_test (#4766) 2025-04-09 13:49:00 +00:00
soasurs
61641581eb fix: form fields of request optional (#4755)
Signed-off-by: soasurs <soasurs@gmail.com>
2025-04-08 13:05:21 +00:00
Kevin Wan
6f2730d5ae chore: update goctl version (#4754) 2025-04-06 19:09:02 +08:00
55 changed files with 8619 additions and 79 deletions

View File

@@ -1,18 +0,0 @@
name: 'issue-translator'
on:
issue_comment:
types: [created]
issues:
types: [opened]
jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: usthe/issues-translate-action@v2.7
with:
IS_MODIFY_TITLE: true
# not require, default false, . Decide whether to modify the issue title
# if true, the robot account @Issues-translate-bot must have modification permissions, invite @Issues-translate-bot to your project or use your custom bot.
CUSTOM_BOT_NOTE: Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿
# not require. Customize the translation robot prefix message.

View File

@@ -267,6 +267,20 @@ func TestUnmarshalRowStruct(t *testing.T) {
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
Name string
age int
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotMatchDestination)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
rs := sqlmock.NewRows([]string{"value"}).FromCSVString("8")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
@@ -310,6 +324,20 @@ func TestUnmarshalRowStructWithTags(t *testing.T) {
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
value := new(struct {
age int `db:"age"`
Name string `db:"name"`
})
rs := sqlmock.NewRows([]string{"name", "age"}).FromCSVString("liao,5")
mock.ExpectQuery("select (.+) from users where user=?").WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRow(value, rows, true)
}, "select name, age from users where user=?", "anyone"), ErrNotReadableValue)
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
var value struct {
Age *int `db:"age"`
@@ -1307,25 +1335,25 @@ func TestAnonymousStructPr(t *testing.T) {
}
func TestAnonymousStructPrError(t *testing.T) {
type Score struct {
Discipline string `db:"discipline"`
score uint `db:"score"`
}
type ClassType struct {
Grade sql.NullString `db:"grade"`
ClassName *string `db:"class_name"`
}
type Class struct {
*ClassType
Score
}
var value []*struct {
Age int64 `db:"age"`
Class
Name string `db:"name"`
}
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
type Score struct {
Discipline string `db:"discipline"`
score uint `db:"score"`
}
type ClassType struct {
Grade sql.NullString `db:"grade"`
ClassName *string `db:"class_name"`
}
type Class struct {
*ClassType
Score
}
var value []*struct {
Age int64 `db:"age"`
Class
Name string `db:"name"`
}
rs := sqlmock.NewRows([]string{
"name",
"age",
@@ -1338,10 +1366,50 @@ func TestAnonymousStructPrError(t *testing.T) {
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
mock.ExpectQuery("select (.+) from users where user=?").
WithArgs("anyone").WillReturnRows(rs)
assert.Error(t, query(context.Background(), db, func(rows *sql.Rows) error {
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, grade, discipline, class_name, score from users where user=?",
"anyone"))
"anyone"), ErrNotReadableValue)
if len(value) > 0 {
assert.Equal(t, value[0].score, 0)
}
})
dbtest.RunTest(t, func(db *sql.DB, mock sqlmock.Sqlmock) {
type Score struct {
Discipline string
score uint
}
type ClassType struct {
Grade sql.NullString
ClassName *string
}
type Class struct {
*ClassType
Score
}
var value []*struct {
Age int64
Class
Name string
}
rs := sqlmock.NewRows([]string{
"name",
"age",
"grade",
"discipline",
"class_name",
"score",
}).
AddRow("first", 2, nil, "math", "experimental class", 100).
AddRow("second", 3, "grade one", "chinese", "class three grade two", 99)
mock.ExpectQuery("select (.+) from users where user=?").
WithArgs("anyone").WillReturnRows(rs)
assert.ErrorIs(t, query(context.Background(), db, func(rows *sql.Rows) error {
return unmarshalRows(&value, rows, true)
}, "select name, age, grade, discipline, class_name, score from users where user=?",
"anyone"), ErrNotMatchDestination)
if len(value) > 0 {
assert.Equal(t, value[0].score, 0)
}

40
mcp/config.go Normal file
View File

@@ -0,0 +1,40 @@
package mcp
import (
"time"
"github.com/zeromicro/go-zero/rest"
)
// McpConf defines the configuration for an MCP server.
// It embeds rest.RestConf for HTTP server settings
// and adds MCP-specific configuration options.
type McpConf struct {
rest.RestConf
Mcp struct {
// Name is the server name reported in initialize responses
Name string `json:",optional"`
// Version is the server version reported in initialize responses
Version string `json:",default=1.0.0"`
// ProtocolVersion is the MCP protocol version implemented
ProtocolVersion string `json:",default=2024-11-05"`
// BaseUrl is the base URL for the server, used in SSE endpoint messages
// If not set, defaults to http://localhost:{Port}
BaseUrl string `json:",optional"`
// SseEndpoint is the path for Server-Sent Events connections
SseEndpoint string `json:",default=/sse"`
// MessageEndpoint is the path for JSON-RPC requests
MessageEndpoint string `json:",default=/message"`
// Cors contains allowed CORS origins
Cors []string `json:",optional"`
// ToolTimeout is the maximum time allowed for tool execution
ToolTimeout time.Duration `json:",default=30s"`
}
}

63
mcp/config_test.go Normal file
View File

@@ -0,0 +1,63 @@
package mcp
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/conf"
)
func TestMcpConfDefaults(t *testing.T) {
// Test default values are set correctly when unmarshalled from JSON
jsonConfig := `name: test-service
port: 8080
mcp:
name: test-mcp-server
version: 1.0.0
`
var c McpConf
err := conf.LoadFromYamlBytes([]byte(jsonConfig), &c)
assert.NoError(t, err)
// Check default values
assert.Equal(t, "test-mcp-server", c.Mcp.Name)
assert.Equal(t, "1.0.0", c.Mcp.Version, "Default version should be 1.0.0")
assert.Equal(t, "2024-11-05", c.Mcp.ProtocolVersion, "Default protocol version should be 2024-11-05")
assert.Equal(t, "/sse", c.Mcp.SseEndpoint, "Default SSE endpoint should be /sse")
assert.Equal(t, "/message", c.Mcp.MessageEndpoint, "Default message endpoint should be /message")
assert.Equal(t, 30*time.Second, c.Mcp.ToolTimeout, "Default tool timeout should be 30s")
}
func TestMcpConfCustomValues(t *testing.T) {
// Test custom values can be set
jsonConfig := `{
"Name": "test-service",
"Port": 8080,
"Mcp": {
"Name": "test-mcp-server",
"Version": "2.0.0",
"ProtocolVersion": "2025-01-01",
"BaseUrl": "http://example.com",
"SseEndpoint": "/custom-sse",
"MessageEndpoint": "/custom-message",
"Cors": ["http://localhost:3000", "http://example.com"],
"ToolTimeout": "60s"
}
}`
var c McpConf
err := conf.LoadFromJsonBytes([]byte(jsonConfig), &c)
assert.NoError(t, err)
// Check custom values
assert.Equal(t, "test-mcp-server", c.Mcp.Name, "Name should be inherited from RestConf")
assert.Equal(t, "2.0.0", c.Mcp.Version, "Version should be customizable")
assert.Equal(t, "2025-01-01", c.Mcp.ProtocolVersion, "Protocol version should be customizable")
assert.Equal(t, "http://example.com", c.Mcp.BaseUrl, "BaseUrl should be customizable")
assert.Equal(t, "/custom-sse", c.Mcp.SseEndpoint, "SSE endpoint should be customizable")
assert.Equal(t, "/custom-message", c.Mcp.MessageEndpoint, "Message endpoint should be customizable")
assert.Equal(t, []string{"http://localhost:3000", "http://example.com"}, c.Mcp.Cors, "CORS settings should be customizable")
assert.Equal(t, 60*time.Second, c.Mcp.ToolTimeout, "Tool timeout should be customizable")
}

443
mcp/integration_test.go Normal file
View File

@@ -0,0 +1,443 @@
package mcp
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// syncResponseRecorder is a thread-safe wrapper around httptest.ResponseRecorder
type syncResponseRecorder struct {
*httptest.ResponseRecorder
mu sync.Mutex
}
// Create a new synchronized response recorder
func newSyncResponseRecorder() *syncResponseRecorder {
return &syncResponseRecorder{
ResponseRecorder: httptest.NewRecorder(),
}
}
// Override Write method to synchronize access
func (srr *syncResponseRecorder) Write(p []byte) (int, error) {
srr.mu.Lock()
defer srr.mu.Unlock()
return srr.ResponseRecorder.Write(p)
}
// Override WriteHeader method to synchronize access
func (srr *syncResponseRecorder) WriteHeader(statusCode int) {
srr.mu.Lock()
defer srr.mu.Unlock()
srr.ResponseRecorder.WriteHeader(statusCode)
}
// Override Result method to synchronize access
func (srr *syncResponseRecorder) Result() *http.Response {
srr.mu.Lock()
defer srr.mu.Unlock()
return srr.ResponseRecorder.Result()
}
// TestHTTPHandlerIntegration tests the HTTP handlers with a real server instance
func TestHTTPHandlerIntegration(t *testing.T) {
// Skip in short test mode
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// Create a test configuration
conf := McpConf{}
conf.Mcp.Name = "test-integration"
conf.Mcp.Version = "1.0.0-test"
conf.Mcp.ToolTimeout = 1 * time.Second
// Create a mock server directly
server := &sseMcpServer{
conf: conf,
clients: make(map[string]*mcpClient),
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// Register a test tool
err := server.RegisterTool(Tool{
Name: "echo",
Description: "Echo tool for testing",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{
"message": map[string]any{
"type": "string",
"description": "Message to echo",
},
},
},
Handler: func(params map[string]any) (any, error) {
if msg, ok := params["message"].(string); ok {
return fmt.Sprintf("Echo: %s", msg), nil
}
return "Echo: no message provided", nil
},
})
require.NoError(t, err)
// Create a test HTTP request to the SSE endpoint
req := httptest.NewRequest("GET", "/sse", nil)
w := newSyncResponseRecorder()
// Create a done channel to signal completion of test
done := make(chan bool)
// Start the SSE handler in a goroutine
go func() {
// lock.Lock()
server.handleSSE(w, req)
// lock.Unlock()
done <- true
}()
// Allow time for the handler to process
select {
case <-time.After(100 * time.Millisecond):
// Expected - handler would normally block indefinitely
case <-done:
// This shouldn't happen immediately - the handler should block
t.Error("SSE handler returned unexpectedly")
}
// Check the initial headers
resp := w.Result()
assert.Equal(t, "chunked", resp.Header.Get("Transfer-Encoding"))
resp.Body.Close()
// The handler creates a client and sends the endpoint message
var sessionId string
// Give the handler time to set up the client
time.Sleep(50 * time.Millisecond)
// Check that a client was created
server.clientsLock.Lock()
assert.Equal(t, 1, len(server.clients))
for id := range server.clients {
sessionId = id
}
server.clientsLock.Unlock()
require.NotEmpty(t, sessionId, "Expected a session ID to be created")
// Now that we have a session ID, we can test the message endpoint
messageBody, _ := json.Marshal(Request{
JsonRpc: "2.0",
ID: 1,
Method: methodInitialize,
Params: json.RawMessage(`{}`),
})
// Create a message request
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, sessionId)
msgReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(messageBody))
msgW := newSyncResponseRecorder()
// Process the message
server.handleRequest(msgW, msgReq)
// Check the response
msgResp := msgW.Result()
assert.Equal(t, http.StatusAccepted, msgResp.StatusCode)
msgResp.Body.Close() // Ensure response body is closed
}
// TestHandlerResponseFlow tests the flow of a full request/response cycle
func TestHandlerResponseFlow(t *testing.T) {
// Create a mock server for testing
server := &sseMcpServer{
conf: McpConf{},
clients: map[string]*mcpClient{
"test-session": {
id: "test-session",
channel: make(chan string, 10),
initialized: true,
},
},
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// Register test resources
server.RegisterTool(Tool{
Name: "test.tool",
Description: "Test tool",
InputSchema: InputSchema{Type: "object"},
Handler: func(params map[string]any) (any, error) {
return "tool result", nil
},
})
server.RegisterPrompt(Prompt{
Name: "test.prompt",
Description: "Test prompt",
})
server.RegisterResource(Resource{
Name: "test.resource",
URI: "http://example.com",
Description: "Test resource",
})
// Create a request with session ID parameter
reqURL := fmt.Sprintf("/message?%s=%s", sessionIdKey, "test-session")
// Test tools/list request
toolsListBody, _ := json.Marshal(Request{
JsonRpc: "2.0",
ID: 1,
Method: methodToolsList,
Params: json.RawMessage(`{}`),
})
toolsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(toolsListBody))
toolsW := newSyncResponseRecorder()
// Process the request
server.handleRequest(toolsW, toolsReq)
// Check the response code
toolsResp := toolsW.Result()
assert.Equal(t, http.StatusAccepted, toolsResp.StatusCode)
toolsResp.Body.Close()
// Check the channel message
client := server.clients["test-session"]
select {
case message := <-client.channel:
assert.Contains(t, message, `"tools":[{"name":"test.tool"`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for tools/list response")
}
// Test prompts/list request
promptsListBody, _ := json.Marshal(Request{
JsonRpc: "2.0",
ID: 2,
Method: methodPromptsList,
Params: json.RawMessage(`{}`),
})
promptsReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(promptsListBody))
promptsW := newSyncResponseRecorder()
// Process the request
server.handleRequest(promptsW, promptsReq)
// Check the response code
promptsResp := promptsW.Result()
assert.Equal(t, http.StatusAccepted, promptsResp.StatusCode)
promptsResp.Body.Close()
// Check the channel message
select {
case message := <-client.channel:
assert.Contains(t, message, `"prompts":[{"name":"test.prompt"`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for prompts/list response")
}
// Test resources/list request
resourcesListBody, _ := json.Marshal(Request{
JsonRpc: "2.0",
ID: 3,
Method: methodResourcesList,
Params: json.RawMessage(`{}`),
})
resourcesReq := httptest.NewRequest("POST", reqURL, bytes.NewReader(resourcesListBody))
resourcesW := newSyncResponseRecorder()
// Process the request
server.handleRequest(resourcesW, resourcesReq)
// Check the response code
resourcesResp := resourcesW.Result()
assert.Equal(t, http.StatusAccepted, resourcesResp.StatusCode)
resourcesResp.Body.Close()
// Check the channel message
select {
case message := <-client.channel:
assert.Contains(t, message, `"name":"test.resource"`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for resources/list response")
}
}
// TestProcessListMethods tests the list processing methods with pagination
func TestProcessListMethods(t *testing.T) {
server := &sseMcpServer{
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// Add some test data
for i := 1; i <= 5; i++ {
tool := Tool{
Name: fmt.Sprintf("tool%d", i),
Description: fmt.Sprintf("Tool %d", i),
InputSchema: InputSchema{Type: "object"},
}
server.tools[tool.Name] = tool
prompt := Prompt{
Name: fmt.Sprintf("prompt%d", i),
Description: fmt.Sprintf("Prompt %d", i),
}
server.prompts[prompt.Name] = prompt
resource := Resource{
Name: fmt.Sprintf("resource%d", i),
URI: fmt.Sprintf("http://example.com/%d", i),
Description: fmt.Sprintf("Resource %d", i),
}
server.resources[resource.Name] = resource
}
// Create a test client
client := &mcpClient{
id: "test-client",
channel: make(chan string, 10),
initialized: true,
}
// Test processListTools
req := Request{
JsonRpc: "2.0",
ID: 1,
Method: methodToolsList,
Params: json.RawMessage(`{"cursor": "", "_meta": {"progressToken": "token1"}}`),
}
server.processListTools(client, req)
// Read response
select {
case response := <-client.channel:
assert.Contains(t, response, `"tools":`)
assert.Contains(t, response, `"progressToken":"token1"`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for tools/list response")
}
// Test processListPrompts
req.ID = 2
req.Method = methodPromptsList
req.Params = json.RawMessage(`{"cursor": "next"}`)
server.processListPrompts(client, req)
// Read response
select {
case response := <-client.channel:
assert.Contains(t, response, `"prompts":`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for prompts/list response")
}
// Test processListResources
req.ID = 3
req.Method = methodResourcesList
req.Params = json.RawMessage(`{"cursor": "next"}`)
server.processListResources(client, req)
// Read response
select {
case response := <-client.channel:
assert.Contains(t, response, `"resources":`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for resources/list response")
}
}
// TestErrorResponseHandling tests error handling in the server
func TestErrorResponseHandling(t *testing.T) {
server := &sseMcpServer{
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// Create a test client
client := &mcpClient{
id: "test-client",
channel: make(chan string, 10),
initialized: true,
}
// Test invalid method
req := Request{
JsonRpc: "2.0",
ID: 1,
Method: "invalid_method",
Params: json.RawMessage(`{}`),
}
// Mock handleRequest by directly calling error handler
server.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
// Check response
select {
case response := <-client.channel:
assert.Contains(t, response, `"error":{"code":-32601,"message":"Method not found"}`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for error response")
}
// Test invalid tool
toolReq := Request{
JsonRpc: "2.0",
ID: 2,
Method: methodToolsCall,
Params: json.RawMessage(`{"name":"non_existent_tool"}`),
}
// Call process method directly
server.processToolCall(client, toolReq)
// Check response
select {
case response := <-client.channel:
assert.Contains(t, response, `"error":{"code":-32602,"message":"Tool 'non_existent_tool' not found"}`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for error response")
}
// Test invalid prompt
promptReq := Request{
JsonRpc: "2.0",
ID: 3,
Method: methodPromptsGet,
Params: json.RawMessage(`{"name":"non_existent_prompt"}`),
}
// Call process method directly
server.processGetPrompt(client, promptReq)
// Check response
select {
case response := <-client.channel:
assert.Contains(t, response, `"error":{"code":-32602,"message":"Prompt 'non_existent_prompt' not found"}`)
case <-time.After(100 * time.Millisecond):
t.Fatal("Timed out waiting for error response")
}
}

62
mcp/readme.md Normal file
View File

@@ -0,0 +1,62 @@
# Model Context Protocol (MCP) SDK Implementation
## Overview
This package implements a Model Context Protocol (MCP) server in Go that facilitates real-time communication between AI models and clients using Server-Sent Events (SSE). The implementation provides a framework for building AI-assisted applications with bidirectional communication capabilities.
## Core Components
### Server-Sent Events (SSE) Communication
- **Real-time Communication**: Robust SSE-based communication system that maintains persistent connections with clients
- **Connection Management**: Client registration, message broadcasting, and client cleanup mechanisms
- **Event Handling**: Event types for tools, prompts, and resources changes
### JSON-RPC Implementation
- **Request Processing**: Complete JSON-RPC request processor for handling MCP protocol methods
- **Response Formatting**: Proper response formatting according to JSON-RPC specifications
- **Error Handling**: Comprehensive error handling with appropriate error codes
### Tool Management
- **Tool Registration**: System to register custom tools with handlers
- **Tool Execution**: Mechanism to execute tool functions with proper timeout handling
- **Result Handling**: Flexible result handling supporting various return types (string, JSON, images)
### Prompt System
- **Prompt Registration**: System for registering both static and dynamic prompts
- **Argument Validation**: Validation for required arguments and default values for optional ones
- **Message Generation**: Handlers that generate properly formatted conversation messages
### Resource Management
- **Resource Registration**: System for managing and accessing external resources
- **Content Delivery**: Handlers for delivering resource content to clients on demand
- **Resource Subscription**: Mechanisms for clients to subscribe to resource updates
### Protocol Features
- **Initialization Sequence**: Proper handshaking with capability negotiation
- **Notification Handling**: Support for both standard and client-specific notifications
- **Message Routing**: Intelligent routing of requests to appropriate handlers
## Technical Highlights
### Configuration System
- **Flexible Configuration**: Configuration system with sensible defaults and customization options
- **CORS Support**: Configurable CORS settings for cross-origin requests
- **Server Information**: Proper server identification and versioning
### Client Session Management
- **Session Tracking**: Client session tracking with unique identifiers
- **Connection Health**: Ping/pong mechanism to maintain connection health
- **Initialization State**: Client initialization state tracking
### Content Handling
- **Multi-format Content**: Support for text, code, and binary content
- **MIME Type Support**: Proper MIME type identification for various content types
- **Audience Annotations**: Content audience annotations for user/assistant targeting
## Usage
To create and use an MCP server, see the examples directory for practical implementation examples including:
- Tool registration and execution
- Static and dynamic prompt creation
- Resource handling with proper URI identification
- Embedded resources in prompt responses
- Client connection management

974
mcp/server.go Normal file
View File

@@ -0,0 +1,974 @@
package mcp
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/rest"
)
func NewMcpServer(c McpConf) McpServer {
var server *rest.Server
if len(c.Mcp.Cors) == 0 {
server = rest.MustNewServer(c.RestConf)
} else {
server = rest.MustNewServer(c.RestConf, rest.WithCors(c.Mcp.Cors...))
}
if len(c.Mcp.Name) == 0 {
c.Mcp.Name = c.Name
}
if len(c.Mcp.BaseUrl) == 0 {
c.Mcp.BaseUrl = fmt.Sprintf("http://localhost:%d", c.Port)
}
s := &sseMcpServer{
conf: c,
server: server,
clients: make(map[string]*mcpClient),
tools: make(map[string]Tool),
prompts: make(map[string]Prompt),
resources: make(map[string]Resource),
}
// SSE endpoint for real-time updates
s.server.AddRoute(rest.Route{
Method: http.MethodGet,
Path: s.conf.Mcp.SseEndpoint,
Handler: s.handleSSE,
}, rest.WithSSE())
// JSON-RPC message endpoint for regular requests
s.server.AddRoute(rest.Route{
Method: http.MethodPost,
Path: s.conf.Mcp.MessageEndpoint,
Handler: s.handleRequest,
})
return s
}
// RegisterPrompt registers a new prompt with the server
func (s *sseMcpServer) RegisterPrompt(prompt Prompt) {
s.promptsLock.Lock()
s.prompts[prompt.Name] = prompt
s.promptsLock.Unlock()
// Notify clients about the new prompt
s.broadcast(eventPromptsListChanged, map[string][]Prompt{keyPrompts: {prompt}})
}
// RegisterResource registers a new resource with the server
func (s *sseMcpServer) RegisterResource(resource Resource) {
s.resourcesLock.Lock()
s.resources[resource.URI] = resource
s.resourcesLock.Unlock()
// Notify clients about the new resource
s.broadcast(eventResourcesListChanged, map[string][]Resource{keyResources: {resource}})
}
// RegisterTool registers a new tool with the server
func (s *sseMcpServer) RegisterTool(tool Tool) error {
if tool.Handler == nil {
return fmt.Errorf("tool '%s' has no handler function", tool.Name)
}
s.toolsLock.Lock()
s.tools[tool.Name] = tool
s.toolsLock.Unlock()
// Notify clients about the new tool
s.broadcast(eventToolsListChanged, map[string][]Tool{keyTools: {tool}})
return nil
}
// Start implements McpServer.
func (s *sseMcpServer) Start() {
s.server.Start()
}
func (s *sseMcpServer) Stop() {
s.server.Stop()
}
// broadcast sends a message to all connected clients
// It uses Server-Sent Events (SSE) format for real-time communication
func (s *sseMcpServer) broadcast(event string, data any) {
jsonData, err := json.Marshal(data)
if err != nil {
logx.Errorf("Failed to marshal broadcast data: %v", err)
return
}
// Lock only while reading the clients map
s.clientsLock.Lock()
clients := make([]*mcpClient, 0, len(s.clients))
for _, client := range s.clients {
clients = append(clients, client)
}
s.clientsLock.Unlock()
clientCount := len(clients)
if clientCount == 0 {
return
}
logx.Infof("Broadcasting event '%s' to %d clients", event, clientCount)
// Use CRLF line endings as per SSE specification
message := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(jsonData))
// Send messages without holding the lock
for _, client := range clients {
select {
case client.channel <- message:
// Message sent successfully
default:
// Channel buffer is full, log warning and continue
logx.Errorf("Client channel buffer full, dropping message for client %s", client.id)
}
}
}
// cleanupClient removes a client from the active clients map
func (s *sseMcpServer) cleanupClient(sessionId string) {
s.clientsLock.Lock()
defer s.clientsLock.Unlock()
if client, exists := s.clients[sessionId]; exists {
// Close the channel to signal any goroutines waiting on it
close(client.channel)
// Remove from active clients
delete(s.clients, sessionId)
logx.Infof("Cleaned up client %s (remaining clients: %d)", sessionId, len(s.clients))
}
}
// handleRequest handles MCP JSON-RPC requests
func (s *sseMcpServer) handleRequest(w http.ResponseWriter, r *http.Request) {
// Extract sessionId from query parameters
sessionId := r.URL.Query().Get(sessionIdKey)
if len(sessionId) == 0 {
http.Error(w, fmt.Sprintf("Missing %s", sessionIdKey), http.StatusBadRequest)
return
}
// Check if the client with this sessionId exists
s.clientsLock.Lock()
client, exists := s.clients[sessionId]
s.clientsLock.Unlock()
if !exists {
http.Error(w, fmt.Sprintf("Invalid or expired %s", sessionIdKey), http.StatusBadRequest)
return
}
var req Request
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Invalid request", http.StatusBadRequest)
return
}
w.WriteHeader(http.StatusAccepted)
// For notification methods (no ID), we don't send a response
isNotification := req.ID == 0
// Special handling for initialization sequence
// Always allow initialize and notifications/initialized regardless of client state
if req.Method == methodInitialize {
logx.Infof("Processing initialize request with ID: %d", req.ID)
s.processInitialize(client, req)
logx.Infof("Sent initialize response for ID: %d, waiting for notifications/initialized", req.ID)
return
} else if req.Method == methodNotificationsInitialized {
// Handle initialized notification
logx.Info("Received notifications/initialized notification")
if !isNotification {
s.sendErrorResponse(client, req.ID, "Method should be used as a notification", errCodeInvalidRequest)
return
}
s.processNotificationInitialized(client)
return
} else if !client.initialized && req.Method != methodNotificationsCancelled {
// Block most requests until client is initialized (except for cancellations)
s.sendErrorResponse(client, req.ID, "Client not fully initialized, waiting for notifications/initialized",
errCodeClientNotInitialized)
return
}
// Process normal requests only after initialization
switch req.Method {
case methodToolsCall:
logx.Infof("Received tools call request with ID: %d", req.ID)
s.processToolCall(client, req)
logx.Infof("Sent tools call response for ID: %d", req.ID)
case methodToolsList:
logx.Infof("Processing tools/list request with ID: %d", req.ID)
s.processListTools(client, req)
logx.Infof("Sent tools/list response for ID: %d", req.ID)
case methodPromptsList:
logx.Infof("Processing prompts/list request with ID: %d", req.ID)
s.processListPrompts(client, req)
logx.Infof("Sent prompts/list response for ID: %d", req.ID)
case methodPromptsGet:
logx.Infof("Processing prompts/get request with ID: %d", req.ID)
s.processGetPrompt(client, req)
logx.Infof("Sent prompts/get response for ID: %d", req.ID)
case methodResourcesList:
logx.Infof("Processing resources/list request with ID: %d", req.ID)
s.processListResources(client, req)
logx.Infof("Sent resources/list response for ID: %d", req.ID)
case methodResourcesRead:
logx.Infof("Processing resources/read request with ID: %d", req.ID)
s.processResourcesRead(client, req)
logx.Infof("Sent resources/read response for ID: %d", req.ID)
case methodResourcesSubscribe:
logx.Infof("Processing resources/subscribe request with ID: %d", req.ID)
s.processResourceSubscribe(client, req)
logx.Infof("Sent resources/subscribe response for ID: %d", req.ID)
case methodPing:
logx.Infof("Processing ping request with ID: %d", req.ID)
s.processPing(client, req)
case methodNotificationsCancelled:
logx.Infof("Received notifications/cancelled notification: %v", req.Params)
s.processNotificationCancelled(client, req)
default:
logx.Infof("Unknown method: %s", req.Method)
s.sendErrorResponse(client, req.ID, "Method not found", errCodeMethodNotFound)
}
}
// handleSSE handles Server-Sent Events connections
func (s *sseMcpServer) handleSSE(w http.ResponseWriter, r *http.Request) {
// Generate a unique session ID for this client
sessionId := uuid.New().String()
// Create new client with buffered channel to prevent blocking
client := &mcpClient{
id: sessionId,
channel: make(chan string, eventChanSize),
}
// Add client to active clients map
s.clientsLock.Lock()
s.clients[sessionId] = client
activeClients := len(s.clients)
s.clientsLock.Unlock()
logx.Infof("New SSE connection established for client %s (active clients: %d)",
sessionId, activeClients)
// Set proper SSE headers
w.Header().Set("Transfer-Encoding", "chunked")
// Enable streaming
flusher, ok := w.(http.Flusher)
if !ok {
logx.Error("Streaming not supported by the underlying http.ResponseWriter")
http.Error(w, "Streaming not supported", http.StatusInternalServerError)
return
}
// Send the message endpoint URL to the client
endpoint := fmt.Sprintf("%s%s?%s=%s",
s.conf.Mcp.BaseUrl, s.conf.Mcp.MessageEndpoint, sessionIdKey, sessionId)
// Format and send the endpoint message
endpointMsg := formatSSEMessage(eventEndpoint, []byte(endpoint))
if _, err := fmt.Fprint(w, endpointMsg); err != nil {
logx.Errorf("Failed to send endpoint message to client %s: %v", sessionId, err)
s.cleanupClient(sessionId)
return
}
flusher.Flush()
// Set up keep-alive ping and client cleanup
ticker := time.NewTicker(pingInterval.Load())
defer func() {
ticker.Stop()
s.cleanupClient(sessionId)
logx.Infof("SSE connection closed for client %s", sessionId)
}()
// Message processing loop
for {
select {
case message, ok := <-client.channel:
if !ok {
// Channel was closed, end connection
logx.Infof("Client channel was closed for %s", sessionId)
return
}
// Write message to the response
if _, err := fmt.Fprint(w, message); err != nil {
logx.Infof("Failed to write message to client %s: %v", sessionId, err)
return
}
flusher.Flush()
case <-ticker.C:
// Send keep-alive ping to maintain connection
ping := fmt.Sprintf(`{"type":"ping","timestamp":"%s"}`, time.Now().String())
pingMsg := formatSSEMessage("ping", []byte(ping))
if _, err := fmt.Fprint(w, pingMsg); err != nil {
logx.Errorf("Failed to send ping to client %s, closing connection: %v", sessionId, err)
return
}
flusher.Flush()
case <-r.Context().Done():
// Client disconnected or request was canceled
logx.Infof("Client %s disconnected: context done", sessionId)
return
}
}
}
// processInitialize processes the initialize request
func (s *sseMcpServer) processInitialize(client *mcpClient, req Request) {
// Create a proper JSON-RPC response that preserves the client's request ID
result := initializationResponse{
ProtocolVersion: s.conf.Mcp.ProtocolVersion,
Capabilities: capabilities{
Prompts: struct {
ListChanged bool `json:"listChanged"`
}{
ListChanged: true,
},
Resources: struct {
Subscribe bool `json:"subscribe"`
ListChanged bool `json:"listChanged"`
}{
Subscribe: true,
ListChanged: true,
},
Tools: struct {
ListChanged bool `json:"listChanged"`
}{
ListChanged: true,
},
},
ServerInfo: serverInfo{
Name: s.conf.Mcp.Name,
Version: s.conf.Mcp.Version,
},
}
// Mark client as initialized
client.initialized = true
// Send response with client's original request ID
s.sendResponse(client, req.ID, result)
}
// processListTools processes the tools/list request
func (s *sseMcpServer) processListTools(client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
var progressToken any
// Extract meta data including progress token
if req.Params != nil {
var metaParams struct {
Cursor string `json:"cursor"`
Meta struct {
ProgressToken any `json:"progressToken"`
} `json:"_meta"`
}
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
if len(metaParams.Cursor) > 0 {
nextCursor = metaParams.Cursor
}
progressToken = metaParams.Meta.ProgressToken
}
}
var toolsList []Tool
s.toolsLock.Lock()
for _, tool := range s.tools {
toolsList = append(toolsList, tool)
}
s.toolsLock.Unlock()
result := ListToolsResult{
PaginatedResult: PaginatedResult{
Result: Result{},
NextCursor: Cursor(nextCursor),
},
Tools: toolsList,
}
// Add meta information if progress token was provided
if progressToken != nil {
result.Result.Meta = map[string]any{
"progressToken": progressToken,
}
}
s.sendResponse(client, req.ID, result)
}
// processListPrompts processes the prompts/list request
func (s *sseMcpServer) processListPrompts(client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
if req.Params != nil {
var cursorParams struct {
Cursor string `json:"cursor"`
}
if err := json.Unmarshal(req.Params, &cursorParams); err == nil && cursorParams.Cursor != "" {
// If we have a valid cursor, we could use it for pagination
// For now, we're not actually implementing pagination, so this is just
// to show how it would be extracted from the request
_ = cursorParams.Cursor
}
}
// Prepare prompt list
var promptsList []Prompt
s.promptsLock.Lock()
for _, prompt := range s.prompts {
promptsList = append(promptsList, prompt)
}
s.promptsLock.Unlock()
// In a real implementation, you'd handle pagination here
// For now, we'll return all prompts at once
result := struct {
Prompts []Prompt `json:"prompts"`
NextCursor string `json:"nextCursor,omitempty"`
Meta *struct{} `json:"_meta,omitempty"`
}{
Prompts: promptsList,
NextCursor: nextCursor,
}
s.sendResponse(client, req.ID, result)
}
// processListResources processes the resources/list request
func (s *sseMcpServer) processListResources(client *mcpClient, req Request) {
// Extract pagination params if any
var nextCursor string
var progressToken any
// Extract meta information including progress token if available
if req.Params != nil {
var metaParams PaginatedParams
if err := json.Unmarshal(req.Params, &metaParams); err == nil {
if len(metaParams.Cursor) > 0 {
nextCursor = metaParams.Cursor
}
progressToken = metaParams.Meta.ProgressToken
}
}
var resourcesList []Resource
s.resourcesLock.Lock()
for _, resource := range s.resources {
// Create a copy without the handler function which shouldn't be sent to clients
resourceCopy := Resource{
URI: resource.URI,
Name: resource.Name,
Description: resource.Description,
MimeType: resource.MimeType,
}
resourcesList = append(resourcesList, resourceCopy)
}
s.resourcesLock.Unlock()
// Create proper ResourcesListResult according to MCP specification
result := ResourcesListResult{
PaginatedResult: PaginatedResult{
Result: Result{},
NextCursor: Cursor(nextCursor),
},
Resources: resourcesList,
}
// Add meta information if progress token was provided
if progressToken != nil {
result.Result.Meta = map[string]any{
"progressToken": progressToken,
}
}
s.sendResponse(client, req.ID, result)
}
// processGetPrompt processes the prompts/get request
func (s *sseMcpServer) processGetPrompt(client *mcpClient, req Request) {
type GetPromptParams struct {
Name string `json:"name"`
Arguments map[string]string `json:"arguments,omitempty"`
}
var params GetPromptParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
// Check if prompt exists
s.promptsLock.Lock()
prompt, exists := s.prompts[params.Name]
s.promptsLock.Unlock()
if !exists {
message := fmt.Sprintf("Prompt '%s' not found", params.Name)
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
return
}
logx.Infof("Processing prompt request: %s with %d arguments", prompt.Name, len(params.Arguments))
// Validate required arguments
missingArgs := validatePromptArguments(prompt, params.Arguments)
if len(missingArgs) > 0 {
message := fmt.Sprintf("Missing required arguments: %s", strings.Join(missingArgs, ", "))
s.sendErrorResponse(client, req.ID, message, errCodeInvalidParams)
return
}
// Apply default values for missing optional arguments
args := applyDefaultArguments(prompt, params.Arguments)
// Generate messages using handler or static content
var messages []PromptMessage
var err error
if prompt.Handler != nil {
// Use dynamic handler to generate messages
logx.Info("Using prompt handler to generate content")
messages, err = prompt.Handler(args)
if err != nil {
logx.Errorf("Error from prompt handler: %v", err)
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error generating prompt content: %v", err), errCodeInternalError)
return
}
} else {
// No handler, generate messages from static content
var messageText string
if prompt.Content != "" {
messageText = prompt.Content
// Apply argument substitutions to static content
for key, value := range args {
placeholder := fmt.Sprintf("{{%s}}", key)
messageText = strings.Replace(messageText, placeholder, value, -1)
}
} else {
// No content, use a default fallback
topic := "this topic"
if t, ok := args["topic"]; ok && t != "" {
topic = t
}
messageText = fmt.Sprintf("Tell me about %s", topic)
}
// Create a single user message with the content
messages = []PromptMessage{
{
Role: roleUser,
Content: TextContent{
Type: contentTypeText,
Text: messageText,
},
},
}
}
// Construct the response according to MCP spec
result := struct {
Description string `json:"description,omitempty"`
Messages []PromptMessage `json:"messages"`
}{
Description: prompt.Description,
Messages: messages,
}
s.sendResponse(client, req.ID, result)
}
// validatePromptArguments checks if all required arguments are provided
// Returns a list of missing required arguments
func validatePromptArguments(prompt Prompt, providedArgs map[string]string) []string {
var missingArgs []string
for _, arg := range prompt.Arguments {
if arg.Required {
if value, exists := providedArgs[arg.Name]; !exists || len(value) == 0 {
missingArgs = append(missingArgs, arg.Name)
}
}
}
return missingArgs
}
// applyDefaultArguments adds default values for missing optional arguments
func applyDefaultArguments(prompt Prompt, providedArgs map[string]string) map[string]string {
result := make(map[string]string)
// Copy all provided arguments
for k, v := range providedArgs {
result[k] = v
}
// Add defaults for missing arguments
for _, arg := range prompt.Arguments {
if _, exists := result[arg.Name]; !exists && arg.Default != "" {
result[arg.Name] = arg.Default
}
}
return result
}
// processToolCall processes the tools/call request
func (s *sseMcpServer) processToolCall(client *mcpClient, req Request) {
var toolCallParams struct {
Name string `json:"name"`
Arguments map[string]any `json:"arguments,omitempty"`
Meta struct {
ProgressToken any `json:"progressToken"`
} `json:"_meta,omitempty"`
}
// Handle different types of req.Params
// If it's a RawMessage (JSON), unmarshal it
if err := json.Unmarshal(req.Params, &toolCallParams); err != nil {
logx.Errorf("Failed to unmarshal tool call params: %v", err)
s.sendErrorResponse(client, req.ID, "Invalid tool call parameters", errCodeInvalidParams)
return
}
// Extract progress token if available
progressToken := toolCallParams.Meta.ProgressToken
// Find the requested tool
s.toolsLock.Lock()
tool, exists := s.tools[toolCallParams.Name]
s.toolsLock.Unlock()
if !exists {
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Tool '%s' not found",
toolCallParams.Name), errCodeInvalidParams)
return
}
// Create a context with the configured timeout
ctx, cancel := context.WithTimeout(context.Background(), s.conf.Mcp.ToolTimeout)
defer cancel()
// Log parameters before execution
logx.Infof("Executing tool '%s' with arguments: %#v", toolCallParams.Name, toolCallParams.Arguments)
// Execute the tool handler with timeout handling
var result any
var err error
// Create a channel to receive the result
resultCh := make(chan struct {
result any
err error
}, 1)
// Execute the tool handler in a goroutine
go func() {
toolResult, toolErr := tool.Handler(toolCallParams.Arguments)
resultCh <- struct {
result any
err error
}{
result: toolResult,
err: toolErr,
}
}()
// Wait for either the result or a timeout
select {
case res := <-resultCh:
result = res.result
err = res.err
case <-ctx.Done():
// Handle timeout
logx.Errorf("Tool execution timed out after %v: %s", s.conf.Mcp.ToolTimeout, toolCallParams.Name)
s.sendErrorResponse(client, req.ID, "Tool execution timed out", errCodeTimeout)
return
}
// Create the base result structure with metadata
callToolResult := CallToolResult{
Result: Result{},
Content: []any{},
IsError: false,
}
// Add meta information if progress token was provided
if progressToken != nil {
callToolResult.Result.Meta = map[string]any{
"progressToken": progressToken,
}
}
// Check if there was an error during tool execution
if err != nil {
// According to the spec, for tool-level errors (as opposed to protocol-level errors),
// we should report them inside the result with isError=true
logx.Errorf("Tool execution reported error: %v", err)
callToolResult.Content = []any{
TextContent{
Type: contentTypeText,
Text: fmt.Sprintf("Error: %v", err),
},
}
callToolResult.IsError = true
s.sendResponse(client, req.ID, callToolResult)
return
}
// Format the response according to the CallToolResult schema
switch v := result.(type) {
case string:
// Simple string becomes text content
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: v,
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
},
})
case map[string]any:
// JSON-like object becomes formatted JSON text
jsonStr, err := json.Marshal(v)
if err != nil {
jsonStr = []byte(err.Error())
}
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: string(jsonStr),
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
},
})
case TextContent:
// Direct TextContent object
callToolResult.Content = append(callToolResult.Content, v)
case ImageContent:
// Direct ImageContent object
callToolResult.Content = append(callToolResult.Content, v)
case []any:
// Array of content items
callToolResult.Content = v
case ToolResult:
// Handle legacy ToolResult type
switch v.Type {
case contentTypeText:
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: fmt.Sprintf("%v", v.Content),
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
},
})
case contentTypeImage:
if imgData, ok := v.Content.(map[string]any); ok {
callToolResult.Content = append(callToolResult.Content, ImageContent{
Type: contentTypeImage,
Data: fmt.Sprintf("%v", imgData["data"]),
MimeType: fmt.Sprintf("%v", imgData["mimeType"]),
})
}
default:
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: fmt.Sprintf("%v", v.Content),
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
},
})
}
default:
// For any other type, convert to string
callToolResult.Content = append(callToolResult.Content, TextContent{
Type: contentTypeText,
Text: fmt.Sprintf("%v", v),
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
},
})
}
logx.Infof("Tool call result: %#v", callToolResult)
s.sendResponse(client, req.ID, callToolResult)
}
// processResourcesRead processes the resources/read request
func (s *sseMcpServer) processResourcesRead(client *mcpClient, req Request) {
var params ResourceReadParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
// Find resource that matches the URI
s.resourcesLock.Lock()
resource, exists := s.resources[params.URI]
s.resourcesLock.Unlock()
if !exists {
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
params.URI), errCodeResourceNotFound)
return
}
// If no handler is provided, return an empty content array
if resource.Handler == nil {
result := ResourceReadResult{
Contents: []ResourceContent{
{
URI: params.URI,
MimeType: resource.MimeType,
Text: "",
},
},
}
s.sendResponse(client, req.ID, result)
return
}
// Execute the resource handler
content, err := resource.Handler()
if err != nil {
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Error reading resource: %v", err),
errCodeInternalError)
return
}
// Ensure the URI is set if not already provided by the handler
if len(content.URI) == 0 {
content.URI = params.URI
}
// Ensure MimeType is set if available from the resource definition
if len(content.MimeType) == 0 && resource.MimeType != "" {
content.MimeType = resource.MimeType
}
// Create response with contents from the handler
// The MCP specification requires a contents array
result := ResourceReadResult{
Contents: []ResourceContent{content},
}
s.sendResponse(client, req.ID, result)
}
// processResourceSubscribe processes the resources/subscribe request
func (s *sseMcpServer) processResourceSubscribe(client *mcpClient, req Request) {
var params ResourceSubscribeParams
if err := json.Unmarshal(req.Params, &params); err != nil {
s.sendErrorResponse(client, req.ID, "Invalid parameters", errCodeInvalidParams)
return
}
// Check if the resource exists
s.resourcesLock.Lock()
_, exists := s.resources[params.URI]
s.resourcesLock.Unlock()
if !exists {
s.sendErrorResponse(client, req.ID, fmt.Sprintf("Resource with URI '%s' not found",
params.URI), errCodeResourceNotFound)
return
}
// Send success response for the subscription
s.sendResponse(client, req.ID, struct{}{})
logx.Infof("Client %s subscribed to resource '%s'", client.id, params.URI)
}
// processNotificationCancelled processes the notifications/cancelled notification
func (s *sseMcpServer) processNotificationCancelled(client *mcpClient, req Request) {
// Extract the requestId that was canceled
type CancelParams struct {
RequestId int64 `json:"requestId"`
Reason string `json:"reason"`
}
var params CancelParams
if err := json.Unmarshal(req.Params, &params); err != nil {
logx.Errorf("Failed to parse cancellation params: %v", err)
return
}
logx.Infof("Request %d was cancelled by client. Reason: %s", params.RequestId, params.Reason)
}
// processNotificationInitialized processes the notifications/initialized notification
func (s *sseMcpServer) processNotificationInitialized(client *mcpClient) {
// Mark the client as properly initialized
client.initialized = true
logx.Infof("Client %s is now fully initialized and ready for normal operations", client.id)
}
// processPing processes the ping request and responds immediately
func (s *sseMcpServer) processPing(client *mcpClient, req Request) {
// A ping request should simply respond with an empty result to confirm the server is alive
logx.Infof("Received ping request with ID: %d", req.ID)
// Send an empty response with client's original request ID
s.sendResponse(client, req.ID, struct{}{})
logx.Infof("Sent ping response for ID: %d", req.ID)
}
// sendErrorResponse sends an error response via the SSE channel
func (s *sseMcpServer) sendErrorResponse(client *mcpClient, id int64, message string, code int) {
errorResponse := struct {
JsonRpc string `json:"jsonrpc"`
ID int64 `json:"id"`
Error errorMessage `json:"error"`
}{
JsonRpc: jsonRpcVersion,
ID: id,
Error: errorMessage{
Code: code,
Message: message,
},
}
// all fields are primitive types, impossible to fail
jsonData, _ := json.Marshal(errorResponse)
// Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending error for ID %d: %s", id, sseMessage)
client.channel <- sseMessage
}
// sendResponse sends a success response via the SSE channel
func (s *sseMcpServer) sendResponse(client *mcpClient, id int64, result any) {
response := Response{
JsonRpc: jsonRpcVersion,
ID: id,
Result: result,
}
jsonData, err := json.Marshal(response)
if err != nil {
s.sendErrorResponse(client, id, "Failed to marshal response", errCodeInternalError)
return
}
// Use CRLF line endings as requested
sseMessage := fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", eventMessage, string(jsonData))
logx.Infof("Sending response for ID %d: %s", id, sseMessage)
client.channel <- sseMessage
}

3418
mcp/server_test.go Normal file

File diff suppressed because it is too large Load Diff

294
mcp/types.go Normal file
View File

@@ -0,0 +1,294 @@
package mcp
import (
"encoding/json"
"sync"
"github.com/zeromicro/go-zero/rest"
)
// Cursor is an opaque token used for pagination
type Cursor string
// Request represents a generic MCP request following JSON-RPC 2.0 specification
type Request struct {
SessionId string `form:"session_id"` // Session identifier for client tracking
JsonRpc string `json:"jsonrpc"` // Must be "2.0" per JSON-RPC spec
ID int64 `json:"id"` // Request identifier for matching responses
Method string `json:"method"` // Method name to invoke
Params json.RawMessage `json:"params"` // Parameters for the method
}
type PaginatedParams struct {
Cursor string `json:"cursor"`
Meta struct {
ProgressToken any `json:"progressToken"`
} `json:"_meta"`
}
// Result is the base interface for all results
type Result struct {
Meta map[string]any `json:"_meta,omitempty"` // Optional metadata
}
// PaginatedResult is a base for results that support pagination
type PaginatedResult struct {
Result
NextCursor Cursor `json:"nextCursor,omitempty"` // Opaque token for fetching next page
}
// ListToolsResult represents the response to a tools/list request
type ListToolsResult struct {
PaginatedResult
Tools []Tool `json:"tools"` // List of available tools
}
// Message Content Types
// roleType represents the sender or recipient of messages in a conversation
type roleType string
// PromptArgument defines a single argument that can be passed to a prompt
type PromptArgument struct {
Name string `json:"name"` // Argument name
Description string `json:"description,omitempty"` // Human-readable description
Required bool `json:"required,omitempty"` // Whether this argument is required
Default string `json:"default,omitempty"` // Default value if not provided
}
// PromptHandler is a function that dynamically generates prompt content
type PromptHandler func(args map[string]string) ([]PromptMessage, error)
// Prompt represents an MCP Prompt definition
type Prompt struct {
Name string `json:"name"` // Unique identifier for the prompt
Description string `json:"description,omitempty"` // Human-readable description
Arguments []PromptArgument `json:"arguments,omitempty"` // Arguments for customization
Content string `json:"-"` // Static content (internal use only)
Handler PromptHandler `json:"-"` // Handler for dynamic content generation
}
// PromptMessage represents a message in a conversation
type PromptMessage struct {
Role roleType `json:"role"` // Message sender role
Content any `json:"content"` // Message content (TextContent, ImageContent, etc.)
}
// TextContent represents text content in a message
type TextContent struct {
Type string `json:"type"` // Always "text"
Text string `json:"text"` // The text content
Annotations *Annotations `json:"annotations,omitempty"` // Optional annotations
}
// ImageContent represents image data in a message
type ImageContent struct {
Type string `json:"type"` // Always "image"
Data string `json:"data"` // Base64-encoded image data
MimeType string `json:"mimeType"` // MIME type (e.g., "image/png")
}
// AudioContent represents audio data in a message
type AudioContent struct {
Type string `json:"type"` // Always "audio"
Data string `json:"data"` // Base64-encoded audio data
MimeType string `json:"mimeType"` // MIME type (e.g., "audio/mp3")
}
// FileContent represents file content
type FileContent struct {
URI string `json:"uri"` // URI identifying the file
MimeType string `json:"mimeType"` // MIME type of the file
Text string `json:"text"` // File content as text
}
// EmbeddedResource represents a resource embedded in a message
type EmbeddedResource struct {
Type string `json:"type"` // Always "resource"
Resource struct {
URI string `json:"uri"` // Resource URI
MimeType string `json:"mimeType"` // MIME type of the resource
Text string `json:"text,omitempty"` // Text content (if available)
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
} `json:"resource"` // The resource data
}
// Annotations provides additional metadata for content
type Annotations struct {
Audience []roleType `json:"audience,omitempty"` // Who should see this content
Priority *float64 `json:"priority,omitempty"` // Optional priority (0-1)
}
// Tool-related Types
// Tool Definition Types
// ToolHandler is a function that handles tool calls
type ToolHandler func(params map[string]any) (any, error)
// Tool represents a Model Context Protocol Tool definition
type Tool struct {
Name string `json:"name"` // Unique identifier for the tool
Description string `json:"description"` // Human-readable description
InputSchema InputSchema `json:"inputSchema"` // JSON Schema for parameters
Handler ToolHandler `json:"-"` // Not sent to clients
}
// InputSchema represents tool's input schema in JSON Schema format
type InputSchema struct {
Type string `json:"type"` // Always "object" for tool inputs
Properties map[string]any `json:"properties"` // Property definitions
Required []string `json:"required,omitempty"` // List of required properties
}
// CallToolResult represents a tool call result that conforms to the MCP schema
type CallToolResult struct {
Result
Content []interface{} `json:"content"` // Content items (text, images, etc.)
IsError bool `json:"isError,omitempty"` // True if tool execution failed
}
// Resource represents a Model Context Protocol Resource definition
type Resource struct {
URI string `json:"uri"` // Unique resource identifier (RFC3986)
Name string `json:"name"` // Human-readable name
Description string `json:"description,omitempty"` // Optional description
MimeType string `json:"mimeType,omitempty"` // Optional MIME type
Handler ResourceHandler `json:"-"` // Internal handler not sent to clients
}
// ResourceHandler is a function that handles resource read requests
type ResourceHandler func() (ResourceContent, error)
// ResourceContent represents the content of a resource
type ResourceContent struct {
URI string `json:"uri"` // Resource URI (required)
MimeType string `json:"mimeType,omitempty"` // MIME type of the resource
Text string `json:"text,omitempty"` // Text content (if available)
Blob string `json:"blob,omitempty"` // Base64 encoded blob data (if available)
}
// ResourcesListResult represents the response to a resources/list request
type ResourcesListResult struct {
PaginatedResult
Resources []Resource `json:"resources"` // List of available resources
}
// ResourceReadParams contains parameters for a resources/read request
type ResourceReadParams struct {
URI string `json:"uri"` // URI of the resource to read
}
// ResourceReadResult contains the result of a resources/read request
type ResourceReadResult struct {
Result
Contents []ResourceContent `json:"contents"` // Array of resource content
}
// ResourceSubscribeParams contains parameters for a resources/subscribe request
type ResourceSubscribeParams struct {
URI string `json:"uri"` // URI of the resource to subscribe to
}
// ResourceUpdateNotification represents a notification about a resource update
type ResourceUpdateNotification struct {
URI string `json:"uri"` // URI of the updated resource
Content ResourceContent `json:"content"` // New resource content
}
// Client and Server Types
// mcpClient represents an SSE client connection
type mcpClient struct {
id string // Unique client identifier
channel chan string // Channel for sending SSE messages
initialized bool // Tracks if client has sent notifications/initialized
}
// McpServer defines the interface for Model Context Protocol servers
type McpServer interface {
Start()
Stop()
RegisterTool(tool Tool) error
RegisterPrompt(prompt Prompt)
RegisterResource(resource Resource)
}
// sseMcpServer implements the McpServer interface using SSE
type sseMcpServer struct {
conf McpConf
server *rest.Server
clients map[string]*mcpClient
clientsLock sync.Mutex
tools map[string]Tool
toolsLock sync.Mutex
prompts map[string]Prompt
promptsLock sync.Mutex
resources map[string]Resource
resourcesLock sync.Mutex
}
// Response Types
// errorObj represents a JSON-RPC error object
type errorObj struct {
Code int `json:"code"` // Error code
Message string `json:"message"` // Error message
}
// Response represents a JSON-RPC response
type Response struct {
JsonRpc string `json:"jsonrpc"` // Always "2.0"
ID int64 `json:"id"` // Same as request ID
Result any `json:"result"` // Result object (null if error)
Error *errorObj `json:"error,omitempty"` // Error object (null if success)
}
// Server Information Types
// serverInfo provides information about the server
type serverInfo struct {
Name string `json:"name"` // Server name
Version string `json:"version"` // Server version
}
// capabilities describes the server's capabilities
type capabilities struct {
Logging struct{} `json:"logging"`
Prompts struct {
ListChanged bool `json:"listChanged"` // Server will notify on prompt changes
} `json:"prompts"`
Resources struct {
Subscribe bool `json:"subscribe"` // Server supports resource subscriptions
ListChanged bool `json:"listChanged"` // Server will notify on resource changes
} `json:"resources"`
Tools struct {
ListChanged bool `json:"listChanged"` // Server will notify on tool changes
} `json:"tools"`
}
// initializationResponse is sent in response to an initialize request
type initializationResponse struct {
ProtocolVersion string `json:"protocolVersion"` // Protocol version
Capabilities capabilities `json:"capabilities"` // Server capabilities
ServerInfo serverInfo `json:"serverInfo"` // Server information
}
// ToolCallParams contains the parameters for a tool call
type ToolCallParams struct {
Name string `json:"name"` // Tool name
Parameters map[string]any `json:"parameters"` // Tool parameters
}
// ToolResult contains the result of a tool execution
type ToolResult struct {
Type string `json:"type"` // Content type (text, image, etc.)
Content any `json:"content"` // Result content
}
// errorMessage represents a detailed error message
type errorMessage struct {
Code int `json:"code"` // Error code
Message string `json:"message"` // Error message
Data any `json:",omitempty"` // Additional error data
}

212
mcp/types_test.go Normal file
View File

@@ -0,0 +1,212 @@
package mcp
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
)
func TestResponseMarshaling(t *testing.T) {
// Test that the Response struct marshals correctly
resp := Response{
JsonRpc: "2.0",
ID: 123,
Result: map[string]string{
"key": "value",
},
}
data, err := json.Marshal(resp)
assert.NoError(t, err)
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
assert.Contains(t, string(data), `"id":123`)
assert.Contains(t, string(data), `"result":{"key":"value"}`)
// Test response with error
respWithError := Response{
JsonRpc: "2.0",
ID: 456,
Error: &errorObj{
Code: errCodeInvalidRequest,
Message: "Invalid Request",
},
}
data, err = json.Marshal(respWithError)
assert.NoError(t, err)
assert.Contains(t, string(data), `"jsonrpc":"2.0"`)
assert.Contains(t, string(data), `"id":456`)
assert.Contains(t, string(data), `"error":{"code":-32600,"message":"Invalid Request"}`)
}
func TestRequestUnmarshaling(t *testing.T) {
// Test that the Request struct unmarshals correctly
jsonStr := `{
"jsonrpc": "2.0",
"id": 789,
"method": "test_method",
"params": {"key": "value"}
}`
var req Request
err := json.Unmarshal([]byte(jsonStr), &req)
assert.NoError(t, err)
assert.Equal(t, "2.0", req.JsonRpc)
assert.Equal(t, int64(789), req.ID)
assert.Equal(t, "test_method", req.Method)
// Check params unmarshaled correctly
var params map[string]string
err = json.Unmarshal(req.Params, &params)
assert.NoError(t, err)
assert.Equal(t, "value", params["key"])
}
func TestToolStructs(t *testing.T) {
// Test Tool struct
tool := Tool{
Name: "test.tool",
Description: "A test tool",
InputSchema: InputSchema{
Type: "object",
Properties: map[string]any{
"input": map[string]any{
"type": "string",
"description": "Input parameter",
},
},
Required: []string{"input"},
},
Handler: func(params map[string]any) (any, error) {
return "result", nil
},
}
// Verify fields are correct
assert.Equal(t, "test.tool", tool.Name)
assert.Equal(t, "A test tool", tool.Description)
assert.Equal(t, "object", tool.InputSchema.Type)
assert.Contains(t, tool.InputSchema.Properties, "input")
propMap, ok := tool.InputSchema.Properties["input"].(map[string]any)
assert.True(t, ok, "Property should be a map")
assert.Equal(t, "string", propMap["type"])
assert.NotNil(t, tool.Handler)
// Verify JSON marshalling (which should exclude Handler function)
data, err := json.Marshal(tool)
assert.NoError(t, err)
assert.Contains(t, string(data), `"name":"test.tool"`)
assert.Contains(t, string(data), `"description":"A test tool"`)
assert.Contains(t, string(data), `"inputSchema":`)
assert.NotContains(t, string(data), `"Handler":`)
}
func TestPromptStructs(t *testing.T) {
// Test Prompt struct
prompt := Prompt{
Name: "test.prompt",
Description: "A test prompt description",
}
// Verify fields are correct
assert.Equal(t, "test.prompt", prompt.Name)
assert.Equal(t, "A test prompt description", prompt.Description)
// Verify JSON marshalling
data, err := json.Marshal(prompt)
assert.NoError(t, err)
assert.Contains(t, string(data), `"name":"test.prompt"`)
assert.Contains(t, string(data), `"description":"A test prompt description"`)
}
func TestResourceStructs(t *testing.T) {
// Test Resource struct
resource := Resource{
Name: "test.resource",
URI: "http://example.com/resource",
Description: "A test resource",
}
// Verify fields are correct
assert.Equal(t, "test.resource", resource.Name)
assert.Equal(t, "http://example.com/resource", resource.URI)
assert.Equal(t, "A test resource", resource.Description)
// Verify JSON marshalling
data, err := json.Marshal(resource)
assert.NoError(t, err)
assert.Contains(t, string(data), `"name":"test.resource"`)
assert.Contains(t, string(data), `"uri":"http://example.com/resource"`)
assert.Contains(t, string(data), `"description":"A test resource"`)
}
func TestContentTypes(t *testing.T) {
// Test TextContent
textContent := TextContent{
Type: "text",
Text: "Sample text",
Annotations: &Annotations{
Audience: []roleType{roleUser, roleAssistant},
Priority: ptr(1.0),
},
}
data, err := json.Marshal(textContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"type":"text"`)
assert.Contains(t, string(data), `"text":"Sample text"`)
assert.Contains(t, string(data), `"audience":["user","assistant"]`)
assert.Contains(t, string(data), `"priority":1`)
// Test ImageContent
imageContent := ImageContent{
Type: "image",
Data: "base64data",
MimeType: "image/png",
}
data, err = json.Marshal(imageContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"type":"image"`)
assert.Contains(t, string(data), `"data":"base64data"`)
assert.Contains(t, string(data), `"mimeType":"image/png"`)
// Test AudioContent
audioContent := AudioContent{
Type: "audio",
Data: "base64audio",
MimeType: "audio/mp3",
}
data, err = json.Marshal(audioContent)
assert.NoError(t, err)
assert.Contains(t, string(data), `"type":"audio"`)
assert.Contains(t, string(data), `"data":"base64audio"`)
assert.Contains(t, string(data), `"mimeType":"audio/mp3"`)
}
func TestCallToolResult(t *testing.T) {
// Test CallToolResult
result := CallToolResult{
Result: Result{
Meta: map[string]any{
"progressToken": "token123",
},
},
Content: []interface{}{
TextContent{
Type: "text",
Text: "Sample result",
},
},
IsError: false,
}
data, err := json.Marshal(result)
assert.NoError(t, err)
assert.Contains(t, string(data), `"_meta":{"progressToken":"token123"}`)
assert.Contains(t, string(data), `"content":[{"type":"text","text":"Sample result"}]`)
assert.NotContains(t, string(data), `"isError":`)
}

15
mcp/util.go Normal file
View File

@@ -0,0 +1,15 @@
package mcp
import (
"fmt"
)
// ptr is a helper function to get a pointer to a value
func ptr[T any](v T) *T {
return &v
}
// formatSSEMessage formats a Server-Sent Event message with proper CRLF line endings
func formatSSEMessage(event string, data []byte) string {
return fmt.Sprintf("event: %s\r\ndata: %s\r\n\r\n", event, string(data))
}

63
mcp/util_test.go Normal file
View File

@@ -0,0 +1,63 @@
package mcp
import (
"bufio"
"encoding/json"
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
func TestPtr(t *testing.T) {
tests := []struct {
name string
v interface{}
}{
{"string", "test"},
{"int", 42},
{"bool", true},
{"float", 3.14},
{"struct", struct{ Name string }{"test"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ptr(tt.v)
assert.NotNil(t, got, "ptr() should not return nil")
assert.Equal(t, tt.v, *got, "dereferenced pointer should equal input value")
})
}
}
type Event struct {
Type string
Data map[string]any
}
func parseEvent(input string) (*Event, error) {
var evt Event
var dataStr string
scanner := bufio.NewScanner(strings.NewReader(input))
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "event:") {
evt.Type = strings.TrimSpace(strings.TrimPrefix(line, "event:"))
} else if strings.HasPrefix(line, "data:") {
dataStr = strings.TrimSpace(strings.TrimPrefix(line, "data:"))
}
}
if err := scanner.Err(); err != nil {
return nil, err
}
if len(dataStr) > 0 {
if err := json.Unmarshal([]byte(dataStr), &evt.Data); err != nil {
return nil, fmt.Errorf("failed to parse data: %w", err)
}
}
return &evt, nil
}

137
mcp/vars.go Normal file
View File

@@ -0,0 +1,137 @@
package mcp
import (
"time"
"github.com/zeromicro/go-zero/core/syncx"
)
// Protocol constants
const (
// JSON-RPC version as defined in the specification
jsonRpcVersion = "2.0"
// Session identifier key used in request URLs
sessionIdKey = "session_id"
)
// Server-Sent Events (SSE) event types
const (
// Standard message event for JSON-RPC responses
eventMessage = "message"
// Endpoint event for sending endpoint URL to clients
eventEndpoint = "endpoint"
)
// Content type identifiers
const (
// Text content type
contentTypeText = "text"
// Image content type
contentTypeImage = "image"
)
// Collection keys for broadcast events
const (
// Key for prompts collection
keyPrompts = "prompts"
// Key for resources collection
keyResources = "resources"
// Key for tools collection
keyTools = "tools"
)
// JSON-RPC error codes
// Standard error codes from JSON-RPC 2.0 spec
const (
// Invalid JSON was received by the server
errCodeInvalidRequest = -32600
// The method does not exist / is not available
errCodeMethodNotFound = -32601
// Invalid method parameter(s)
errCodeInvalidParams = -32602
// Internal JSON-RPC error
errCodeInternalError = -32603
// Tool execution timed out
errCodeTimeout = -32001
// Resource not found error
errCodeResourceNotFound = -32002
// Client hasn't completed initialization
errCodeClientNotInitialized = -32800
)
// User and assistant role definitions
const (
// The "user" role - the entity asking questions
roleUser roleType = "user"
// The "assistant" role - the entity providing responses
roleAssistant roleType = "assistant"
)
// Method names as defined in the MCP specification
const (
// Initialize the connection between client and server
methodInitialize = "initialize"
// List available tools
methodToolsList = "tools/list"
// Call a specific tool
methodToolsCall = "tools/call"
// List available prompts
methodPromptsList = "prompts/list"
// Get a specific prompt
methodPromptsGet = "prompts/get"
// List available resources
methodResourcesList = "resources/list"
// Read a specific resource
methodResourcesRead = "resources/read"
// Subscribe to resource updates
methodResourcesSubscribe = "resources/subscribe"
// Simple ping to check server availability
methodPing = "ping"
// Notification that client is fully initialized
methodNotificationsInitialized = "notifications/initialized"
// Notification that a request was canceled
methodNotificationsCancelled = "notifications/cancelled"
)
// Event names for Server-Sent Events (SSE)
const (
// Notification of tool list changes
eventToolsListChanged = "tools/list_changed"
// Notification of prompt list changes
eventPromptsListChanged = "prompts/list_changed"
// Notification of resource list changes
eventResourcesListChanged = "resources/list_changed"
)
var (
// Default channel size for events
eventChanSize = 10
// Default ping interval for checking connection availability
// use syncx.ForAtomicDuration to ensure atomicity in test race
pingInterval = syncx.ForAtomicDuration(30 * time.Second)
)

214
mcp/vars_test.go Normal file
View File

@@ -0,0 +1,214 @@
// filepath: /Users/kevin/Develop/go/opensource/go-zero/mcp/vars_test.go
package mcp
import (
"encoding/json"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
// TestErrorCodes ensures error codes are applied correctly in error responses
func TestErrorCodes(t *testing.T) {
testCases := []struct {
name string
code int
message string
expected string
}{
{
name: "invalid request error",
code: errCodeInvalidRequest,
message: "Invalid request",
expected: `"code":-32600`,
},
{
name: "method not found error",
code: errCodeMethodNotFound,
message: "Method not found",
expected: `"code":-32601`,
},
{
name: "invalid params error",
code: errCodeInvalidParams,
message: "Invalid parameters",
expected: `"code":-32602`,
},
{
name: "internal error",
code: errCodeInternalError,
message: "Internal server error",
expected: `"code":-32603`,
},
{
name: "timeout error",
code: errCodeTimeout,
message: "Operation timed out",
expected: `"code":-32001`,
},
{
name: "resource not found error",
code: errCodeResourceNotFound,
message: "Resource not found",
expected: `"code":-32002`,
},
{
name: "client not initialized error",
code: errCodeClientNotInitialized,
message: "Client not initialized",
expected: `"code":-32800`,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
resp := Response{
JsonRpc: jsonRpcVersion,
ID: int64(1),
Error: &errorObj{
Code: tc.code,
Message: tc.message,
},
}
data, err := json.Marshal(resp)
assert.NoError(t, err)
assert.Contains(t, string(data), tc.expected, "Error code should match expected value")
assert.Contains(t, string(data), tc.message, "Error message should be included")
assert.Contains(t, string(data), jsonRpcVersion, "JSON-RPC version should be included")
})
}
}
// TestJsonRpcVersion ensures the correct JSON-RPC version is used
func TestJsonRpcVersion(t *testing.T) {
assert.Equal(t, "2.0", jsonRpcVersion, "JSON-RPC version should be 2.0")
// Test that it's used in responses
resp := Response{
JsonRpc: jsonRpcVersion,
ID: int64(1),
Result: "test",
}
data, err := json.Marshal(resp)
assert.NoError(t, err)
assert.Contains(t, string(data), `"jsonrpc":"2.0"`, "Response should use correct JSON-RPC version")
// Test that it's expected in requests
reqStr := `{"jsonrpc":"2.0","id":1,"method":"test"}`
var req Request
err = json.Unmarshal([]byte(reqStr), &req)
assert.NoError(t, err)
assert.Equal(t, jsonRpcVersion, req.JsonRpc, "Request should parse correct JSON-RPC version")
}
// TestSessionIdKey ensures session ID extraction works correctly
func TestSessionIdKey(t *testing.T) {
// Create a mock server implementation
mock := newMockMcpServer(t)
defer mock.shutdown()
// Verify the key constant
assert.Equal(t, "session_id", sessionIdKey, "Session ID key should be 'session_id'")
// Test that session ID is extracted correctly
mockR := httptest.NewRequest("GET", "/?"+sessionIdKey+"=test-session", nil)
// Since the mock server is using the same session key logic,
// we can test this by accessing the request query parameters directly
sessionID := mockR.URL.Query().Get(sessionIdKey)
assert.Equal(t, "test-session", sessionID, "Session ID should be extracted correctly")
}
// TestEventTypes ensures event types are set correctly in SSE responses
func TestEventTypes(t *testing.T) {
// Test message event
assert.Equal(t, "message", eventMessage, "Message event should be 'message'")
// Test endpoint event
assert.Equal(t, "endpoint", eventEndpoint, "Endpoint event should be 'endpoint'")
// Verify them in an actual SSE format string
messageEvent := "event: " + eventMessage + "\ndata: test\n\n"
assert.Contains(t, messageEvent, "event: message", "Message event should format correctly")
endpointEvent := "event: " + eventEndpoint + "\ndata: test\n\n"
assert.Contains(t, endpointEvent, "event: endpoint", "Endpoint event should format correctly")
}
// TestCollectionKeys checks that collection keys are used correctly
func TestCollectionKeys(t *testing.T) {
// Verify collection key constants
assert.Equal(t, "prompts", keyPrompts, "Prompts key should be 'prompts'")
assert.Equal(t, "resources", keyResources, "Resources key should be 'resources'")
assert.Equal(t, "tools", keyTools, "Tools key should be 'tools'")
}
// TestRoleTypes checks that role types are used correctly
func TestRoleTypes(t *testing.T) {
// Verify role type constants
assert.Equal(t, "user", string(roleUser), "User role should be 'user'")
assert.Equal(t, "assistant", string(roleAssistant), "Assistant role should be 'assistant'")
// Test in annotations
annotations := Annotations{
Audience: []roleType{roleUser, roleAssistant},
}
data, err := json.Marshal(annotations)
assert.NoError(t, err)
assert.Contains(t, string(data), `"audience":["user","assistant"]`, "Role types should marshal correctly")
}
// TestMethodNames checks that method names are used correctly
func TestMethodNames(t *testing.T) {
// Verify method name constants
methods := map[string]string{
"initialize": methodInitialize,
"tools/list": methodToolsList,
"tools/call": methodToolsCall,
"prompts/list": methodPromptsList,
"prompts/get": methodPromptsGet,
"resources/list": methodResourcesList,
"resources/read": methodResourcesRead,
"resources/subscribe": methodResourcesSubscribe,
"ping": methodPing,
"notifications/initialized": methodNotificationsInitialized,
"notifications/cancelled": methodNotificationsCancelled,
}
for expected, actual := range methods {
assert.Equal(t, expected, actual, "Method name should be "+expected)
}
// Test in a request
for methodName := range methods {
req := Request{
JsonRpc: jsonRpcVersion,
ID: int64(1),
Method: methodName,
}
data, err := json.Marshal(req)
assert.NoError(t, err)
assert.Contains(t, string(data), `"method":"`+methodName+`"`, "Method name should be used in requests")
}
}
// TestEventNames checks that event names are used correctly
func TestEventNames(t *testing.T) {
// Verify event name constants
events := map[string]string{
"tools/list_changed": eventToolsListChanged,
"prompts/list_changed": eventPromptsListChanged,
"resources/list_changed": eventResourcesListChanged,
}
for expected, actual := range events {
assert.Equal(t, expected, actual, "Event name should be "+expected)
}
// Test event names in SSE format
for _, eventName := range events {
sseEvent := "event: " + eventName + "\ndata: test\n\n"
assert.Contains(t, sseEvent, "event: "+eventName, "Event name should format correctly in SSE")
}
}

View File

@@ -2,6 +2,7 @@ package api
import (
"github.com/spf13/cobra"
"github.com/zeromicro/go-zero/tools/goctl/api/apigen"
"github.com/zeromicro/go-zero/tools/goctl/api/dartgen"
"github.com/zeromicro/go-zero/tools/goctl/api/docgen"
@@ -10,10 +11,12 @@ import (
"github.com/zeromicro/go-zero/tools/goctl/api/javagen"
"github.com/zeromicro/go-zero/tools/goctl/api/ktgen"
"github.com/zeromicro/go-zero/tools/goctl/api/new"
"github.com/zeromicro/go-zero/tools/goctl/api/swagger"
"github.com/zeromicro/go-zero/tools/goctl/api/tsgen"
"github.com/zeromicro/go-zero/tools/goctl/api/validate"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/internal/cobrax"
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
"github.com/zeromicro/go-zero/tools/goctl/plugin"
)
@@ -31,6 +34,7 @@ var (
ktCmd = cobrax.NewCommand("kt", cobrax.WithRunE(ktgen.KtCommand))
pluginCmd = cobrax.NewCommand("plugin", cobrax.WithRunE(plugin.PluginCommand))
tsCmd = cobrax.NewCommand("ts", cobrax.WithRunE(tsgen.TsCommand))
swaggerCmd = cobrax.NewCommand("swagger", cobrax.WithRunE(swagger.Command))
)
func init() {
@@ -46,6 +50,7 @@ func init() {
pluginCmdFlags = pluginCmd.Flags()
tsCmdFlags = tsCmd.Flags()
validateCmdFlags = validateCmd.Flags()
swaggerCmdFlags = swaggerCmd.Flags()
)
apiCmdFlags.StringVar(&apigen.VarStringOutput, "o")
@@ -97,8 +102,16 @@ func init() {
tsCmdFlags.StringVar(&tsgen.VarStringCaller, "caller")
tsCmdFlags.BoolVar(&tsgen.VarBoolUnWrap, "unwrap")
swaggerCmdFlags.StringVar(&swagger.VarStringAPI, "api")
swaggerCmdFlags.StringVar(&swagger.VarStringDir, "dir")
swaggerCmdFlags.StringVar(&swagger.VarStringFilename, "filename")
swaggerCmdFlags.BoolVar(&swagger.VarBoolYaml, "yaml")
validateCmdFlags.StringVar(&validate.VarStringAPI, "api")
// Add sub-commands
Cmd.AddCommand(dartCmd, docCmd, formatCmd, goCmd, javaCmd, ktCmd, newCmd, pluginCmd, tsCmd, validateCmd)
if env.UseExperimental() {
Cmd.AddCommand(swaggerCmd)
}
}

View File

@@ -6,12 +6,14 @@ import (
"io"
"os"
"path"
"sort"
"strings"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
apiutil "github.com/zeromicro/go-zero/tools/goctl/api/util"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
"github.com/zeromicro/go-zero/tools/goctl/pkg/env"
"github.com/zeromicro/go-zero/tools/goctl/util"
"github.com/zeromicro/go-zero/tools/goctl/util/format"
)
@@ -39,20 +41,89 @@ func BuildTypes(types []spec.Type) (string, error) {
return builder.String(), nil
}
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
val, err := BuildTypes(api.Types)
func removeTypeFromDefault(tp spec.Type, group string, groupTypes map[string]map[string]spec.Type) map[string]map[string]spec.Type {
switch val := tp.(type) {
case spec.DefineStruct:
typeName := util.Title(tp.Name())
defaultGroups, ok := groupTypes[groupTypeDefault]
if ok {
delete(defaultGroups, typeName)
types, ok := groupTypes[group]
if !ok {
types = make(map[string]spec.Type)
}
types[typeName] = tp
groupTypes[group] = types
}
groupTypes[groupTypeDefault] = defaultGroups
case spec.PointerType:
groupTypes = removeTypeFromDefault(val.Type, group, groupTypes)
case spec.ArrayType:
groupTypes = removeTypeFromDefault(val.Value, group, groupTypes)
}
return groupTypes
}
func genTypesWithGroup(dir string, cfg *config.Config, api *spec.ApiSpec) error {
groupTypes := make(map[string]map[string]spec.Type)
for _, v := range api.Types {
types, ok := groupTypes[groupTypeDefault]
if !ok {
types = make(map[string]spec.Type)
}
types[util.Title(v.Name())] = v
groupTypes[groupTypeDefault] = types
}
for _, v := range api.Service.Groups {
group := v.GetAnnotation(groupProperty)
if len(group) == 0 {
continue
}
for _, v := range v.Routes {
if v.RequestType != nil {
groupTypes = removeTypeFromDefault(v.RequestType, group, groupTypes)
}
if v.ResponseType != nil {
groupTypes = removeTypeFromDefault(v.ResponseType, group, groupTypes)
}
}
}
for group, typeGroup := range groupTypes {
var types []spec.Type
for _, v := range typeGroup {
types = append(types, v)
}
sort.Slice(types, func(i, j int) bool {
return types[i].Name() < types[j].Name()
})
if err := writeTypes(dir, group, cfg, types); err != nil {
return err
}
}
return nil
}
func writeTypes(dir, baseFilename string, cfg *config.Config, types []spec.Type) error {
if len(types) == 0 {
return nil
}
val, err := BuildTypes(types)
if err != nil {
return err
}
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, typesFile)
typeFilename, err := format.FileNamingFormat(cfg.NamingFormat, baseFilename)
if err != nil {
return err
}
typeFilename = typeFilename + ".go"
filename := path.Join(dir, typesDir, typeFilename)
os.Remove(filename)
_ = os.Remove(filename)
return genFile(fileGenConfig{
dir: dir,
@@ -70,6 +141,13 @@ func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
})
}
func genTypes(dir string, cfg *config.Config, api *spec.ApiSpec) error {
if env.UseExperimental() {
return genTypesWithGroup(dir, cfg, api)
}
return writeTypes(dir, typesFile, cfg, api.Types)
}
func writeType(writer io.Writer, tp spec.Type) error {
structType, ok := tp.(spec.DefineStruct)
if !ok {

View File

@@ -10,4 +10,6 @@ const (
middlewareDir = internal + "middleware"
typesDir = internal + typesPacket
groupProperty = "group"
groupTypeDefault="types"
)

View File

@@ -57,13 +57,13 @@ func (m Member) Tags() []*Tag {
// IsOptional returns true if tag is optional
func (m Member) IsOptional() bool {
if !m.IsBodyMember() {
if !m.IsBodyMember() && !m.IsFormMember() {
return false
}
tag := m.Tags()
for _, item := range tag {
if item.Key == bodyTagKey {
if item.Key == bodyTagKey || item.Key == formTagKey {
if stringx.Contains(item.Options, "optional") {
return true
}

View File

@@ -21,7 +21,7 @@ type (
// ApiSpec describes an api file
ApiSpec struct {
Info Info // Deprecated: useless expression
Info Info
Syntax ApiSyntax // Deprecated: useless expression
Imports []Import // Deprecated: useless expression
Types []Type

View File

@@ -0,0 +1,71 @@
package swagger
import (
"strconv"
"github.com/zeromicro/go-zero/tools/goctl/util"
"google.golang.org/grpc/metadata"
)
func hasKey(properties map[string]string, key string) bool {
if len(properties) == 0 {
return false
}
md := metadata.New(properties)
_, ok := md[key]
return ok
}
func getBoolFromKVOrDefault(properties map[string]string, key string, def bool) bool {
if len(properties) == 0 {
return def
}
md := metadata.New(properties)
val := md.Get(key)
if len(val) == 0 {
return def
}
str := util.Unquote(val[0])
if len(str) == 0 {
return def
}
res, _ := strconv.ParseBool(str)
return res
}
func getStringFromKVOrDefault(properties map[string]string, key string, def string) string {
if len(properties) == 0 {
return def
}
md := metadata.New(properties)
val := md.Get(key)
if len(val) == 0 {
return def
}
str := util.Unquote(val[0])
if len(str) == 0 {
return def
}
return str
}
func getListFromInfoOrDefault(properties map[string]string, key string, def []string) []string {
if len(properties) == 0 {
return def
}
md := metadata.New(properties)
val := md.Get(key)
if len(val) == 0 {
return def
}
str := util.Unquote(val[0])
if len(str) == 0 {
return def
}
resp := util.FieldsAndTrimSpace(str, commaRune)
if len(resp) == 0 {
return def
}
return resp
}

View File

@@ -0,0 +1,53 @@
package swagger
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_getBoolFromKVOrDefault(t *testing.T) {
properties := map[string]string{
"enabled": `"true"`,
"disabled": `"false"`,
"invalid": `"notabool"`,
"empty_value": `""`,
}
assert.True(t, getBoolFromKVOrDefault(properties, "enabled", false))
assert.False(t, getBoolFromKVOrDefault(properties, "disabled", true))
assert.False(t, getBoolFromKVOrDefault(properties, "invalid", false))
assert.True(t, getBoolFromKVOrDefault(properties, "missing", true))
assert.False(t, getBoolFromKVOrDefault(properties, "empty_value", false))
assert.False(t, getBoolFromKVOrDefault(nil, "nil", false))
assert.False(t, getBoolFromKVOrDefault(map[string]string{}, "empty", false))
}
func Test_getStringFromKVOrDefault(t *testing.T) {
properties := map[string]string{
"name": `"example"`,
"empty": `""`,
}
assert.Equal(t, "example", getStringFromKVOrDefault(properties, "name", "default"))
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "empty", "default"))
assert.Equal(t, "default", getStringFromKVOrDefault(properties, "missing", "default"))
assert.Equal(t, "default", getStringFromKVOrDefault(nil, "nil", "default"))
assert.Equal(t, "default", getStringFromKVOrDefault(map[string]string{}, "empty", "default"))
}
func Test_getListFromInfoOrDefault(t *testing.T) {
properties := map[string]string{
"list": `"a, b, c"`,
"empty": `""`,
}
assert.Equal(t, []string{"a", "b", "c"}, getListFromInfoOrDefault(properties, "list", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "empty", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(properties, "missing", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(nil, "nil", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{}, "empty", []string{"default"}))
assert.Equal(t, []string{"default"}, getListFromInfoOrDefault(map[string]string{
"foo": ",,",
}, "foo", []string{"default"}))
}

View File

@@ -0,0 +1,138 @@
package swagger
import "github.com/zeromicro/go-zero/tools/goctl/api/spec"
func fillAllStructs(api *spec.ApiSpec) {
var (
tps []spec.Type
structTypes = make(map[string]spec.DefineStruct)
groups []spec.Group
)
for _, tp := range api.Types {
structTypes[tp.Name()] = tp.(spec.DefineStruct)
}
for _, tp := range api.Types {
filledTP := fillStruct("", tp, structTypes)
tps = append(tps, filledTP)
structTypes[filledTP.Name()] = filledTP.(spec.DefineStruct)
}
for _, group := range api.Service.Groups {
var routes []spec.Route
for _, route := range group.Routes {
route.RequestType = fillStruct("", route.RequestType, structTypes)
route.ResponseType = fillStruct("", route.ResponseType, structTypes)
routes = append(routes, route)
}
group.Routes = routes
groups = append(groups, group)
}
api.Service.Groups = groups
api.Types = tps
}
func fillStruct(parent string, tp spec.Type, allTypes map[string]spec.DefineStruct) spec.Type {
switch val := tp.(type) {
case spec.DefineStruct:
var members []spec.Member
for _, member := range val.Members {
switch memberType := member.Type.(type) {
case spec.PointerType:
member.Type = spec.PointerType{
RawName: memberType.RawName,
Type: fillStruct(val.Name(), memberType.Type, allTypes),
}
case spec.ArrayType:
member.Type = spec.ArrayType{
RawName: memberType.RawName,
Value: fillStruct(val.Name(), memberType.Value, allTypes),
}
case spec.MapType:
member.Type = spec.MapType{
RawName: memberType.RawName,
Key: memberType.Key,
Value: fillStruct(val.Name(), memberType.Value, allTypes),
}
case spec.DefineStruct:
if parent != memberType.Name() { // avoid recursive struct
if st, ok := allTypes[memberType.Name()]; ok {
member.Type = fillStruct("", st, allTypes)
}
}
case spec.NestedStruct:
member.Type = fillStruct("", member.Type, allTypes)
}
members = append(members, member)
}
if len(members) == 0 {
st, ok := allTypes[val.RawName]
if ok {
members = st.Members
}
}
val.Members = members
return val
case spec.NestedStruct:
var members []spec.Member
for _, member := range val.Members {
switch memberType := member.Type.(type) {
case spec.PointerType:
member.Type = spec.PointerType{
RawName: memberType.RawName,
Type: fillStruct(val.Name(), memberType.Type, allTypes),
}
case spec.ArrayType:
member.Type = spec.ArrayType{
RawName: memberType.RawName,
Value: fillStruct(val.Name(), memberType.Value, allTypes),
}
case spec.MapType:
member.Type = spec.MapType{
RawName: memberType.RawName,
Key: memberType.Key,
Value: fillStruct(val.Name(), memberType.Value, allTypes),
}
case spec.DefineStruct:
if parent != memberType.Name() { // avoid recursive struct
if st, ok := allTypes[memberType.Name()]; ok {
member.Type = fillStruct("", st, allTypes)
}
}
case spec.NestedStruct:
if parent != memberType.Name() {
if st, ok := allTypes[memberType.Name()]; ok {
member.Type = fillStruct("", st, allTypes)
}
}
}
members = append(members, member)
}
if len(members) == 0 {
st, ok := allTypes[val.RawName]
if ok {
members = st.Members
}
}
val.Members = members
return val
case spec.PointerType:
return spec.PointerType{
RawName: val.RawName,
Type: fillStruct(parent, val.Type, allTypes),
}
case spec.ArrayType:
return spec.ArrayType{
RawName: val.RawName,
Value: fillStruct(parent, val.Value, allTypes),
}
case spec.MapType:
return spec.MapType{
RawName: val.RawName,
Key: val.Key,
Value: fillStruct(parent, val.Value, allTypes),
}
default:
return tp
}
}

View File

@@ -0,0 +1,88 @@
package swagger
import (
"encoding/json"
"errors"
"os"
"path/filepath"
"strings"
"github.com/spf13/cobra"
"gopkg.in/yaml.v2"
"github.com/zeromicro/go-zero/tools/goctl/pkg/parser/api/parser"
"github.com/zeromicro/go-zero/tools/goctl/util/pathx"
)
var (
// VarStringAPI specifies the API filename.
VarStringAPI string
// VarStringDir specifies the directory to generate swagger file.
VarStringDir string
// VarStringFilename specifies the generated swagger file name without the extension.
VarStringFilename string
// VarBoolYaml specifies whether to generate a YAML file.
VarBoolYaml bool
)
func Command(_ *cobra.Command, _ []string) error {
if len(VarStringAPI) == 0 {
return errors.New("missing -api")
}
if len(VarStringDir) == 0 {
return errors.New("missing -dir")
}
api, err := parser.Parse(VarStringAPI, "")
if err != nil {
return err
}
fillAllStructs(api)
if err := api.Validate(); err != nil {
return err
}
swagger, err := spec2Swagger(api)
if err != nil {
return err
}
data, err := json.MarshalIndent(swagger, "", " ")
if err != nil {
return err
}
err = pathx.MkdirIfNotExist(VarStringDir)
if err != nil {
return err
}
filename := VarStringFilename
if filename == "" {
base := filepath.Base(VarStringAPI)
filename = strings.TrimSuffix(base, filepath.Ext(base))
}
if VarBoolYaml {
filePath := filepath.Join(VarStringDir, filename+".yaml")
var jsonObj interface{}
if err := yaml.Unmarshal(data, &jsonObj); err != nil {
return err
}
data, err := yaml.Marshal(jsonObj)
if err != nil {
return err
}
return os.WriteFile(filePath, data, 0644)
}
// generate json swagger file
filePath := filepath.Join(VarStringDir, filename+".json")
return os.WriteFile(filePath, data, 0644)
}

View File

@@ -0,0 +1,32 @@
package swagger
const (
tagHeader = "header"
tagPath = "path"
tagForm = "form"
tagJson = "json"
defFlag = "default="
enumFlag = "options="
rangeFlag = "range="
exampleFlag = "example="
paramsInHeader = "header"
paramsInPath = "path"
paramsInQuery = "query"
paramsInBody = "body"
paramsInForm = "formData"
swaggerTypeInteger = "integer"
swaggerTypeNumber = "number"
swaggerTypeString = "string"
swaggerTypeBoolean = "boolean"
swaggerTypeArray = "array"
swaggerTypeObject = "object"
swaggerVersion = "2.0"
applicationJson = "application/json"
applicationForm = "application/x-www-form-urlencoded"
schemeHttps = "https"
defaultHost = "127.0.0.1"
defaultBasePath = "/"
)

View File

@@ -0,0 +1,25 @@
package swagger
import (
"net/http"
"strings"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func consumesFromTypeOrDef(method string, tp spec.Type) []string {
if strings.EqualFold(method, http.MethodGet) {
return []string{}
}
if tp == nil {
return []string{}
}
structType, ok := tp.(spec.DefineStruct)
if !ok {
return []string{}
}
if typeContainsTag(structType, tagJson) {
return []string{applicationJson}
}
return []string{applicationForm}
}

View File

@@ -0,0 +1,68 @@
package swagger
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func TestConsumesFromTypeOrDef(t *testing.T) {
tests := []struct {
name string
method string
tp spec.Type
expected []string
}{
{
name: "GET method with nil type",
method: http.MethodGet,
tp: nil,
expected: []string{},
},
{
name: "post nil",
method: http.MethodPost,
tp: nil,
expected: []string{},
},
{
name: "json tag",
method: http.MethodPost,
tp: spec.DefineStruct{
Members: []spec.Member{
{
Tag: `json:"example"`,
},
},
},
expected: []string{applicationJson},
},
{
name: "form tag",
method: http.MethodPost,
tp: spec.DefineStruct{
Members: []spec.Member{
{
Tag: `form:"example"`,
},
},
},
expected: []string{applicationForm},
},
{
name: "Non struct type",
method: http.MethodPost,
tp: spec.ArrayType{},
expected: []string{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := consumesFromTypeOrDef(tt.method, tt.tp)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -0,0 +1,4 @@
*.json
*.yaml
bin
output

View File

@@ -0,0 +1,240 @@
syntax = "v1"
info (
title: "Demo API" // title corresponding to Swagger
description: "Generating Swagger files using the API demo." // description corresponding to Swagger
version: "v1" // version corresponding to Swagger
termsOfService: "https://github.com/zeromicro/go-zero" // termsOfService corresponding to Swagger
contactName: "keson.an" // contactName corresponding to Swagger
contactURL: "https://github.com/zeromicro/go-zero" // contactURL corresponding to Swagger
contactEmail: "example@gmail.com" // contactEmail corresponding to Swagger
licenseName: "MIT" // licenseName corresponding to Swagger
licenseURL: "https://github.com/zeromicro/go-zero" // licenseURL corresponding to Swagger
consumes: "application/json" // consumes corresponding to Swagger,default value is `application/json`
produces: "application/json" // produces corresponding to Swagger,default value is `application/json`
schemes: "https" // schemes corresponding to Swagger,default value is `https``
host: "example.com" // host corresponding to Swagger,default value is `127.0.0.1`
basePath: "/v1" // basePath corresponding to Swagger,default value is `/`
wrapCodeMsg: "true" // to wrap in the universal code-msg structure, like {"code":0,"msg":"OK","data":$data}
bizCodeEnumDescription: "1001-User not login<br>1002-User permission denied" // enums of business error codes, in JSON format, with the key being the business error code and the value being the description of that error code. This only takes effect when wrapCodeMsg is set to true.
// securityDefinitionsFromJson is a custom authentication configuration, and the JSON content will be directly inserted into the securityDefinitions of Swagger.
// Format reference: https://swagger.io/specification/v2/#security-definitions-object
// You can declare authType in the @server of the API to specify the authentication type used for its routes.
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey type description","type":"apiKey","name":"x-api-key","in":"header"}}`
)
type (
QueryReq {
Id int `form:"id,range=[1:10000],example=10"`
Name string `form:"name,example=keson.an"`
Avatar string `form:"avatar,optional,example=https://example.com/avatar.png"`
}
QueryResp {
Id int `json:"id,example=10"`
Name string `json:"name,example=keson.an"`
}
PathQueryReq {
Id int `path:"id,range=[1:10000],example=10"`
Name string `form:"name,example=keson.an"`
}
PathQueryResp {
Id int `json:"id,example=10"`
Name string `json:"name,example=keson.an"`
}
)
@server (
tags: "query" // tags corresponding to Swagger
summary: "query API set" // summary corresponding to Swagger
prefix: v1
authType: apiKey // Specifies the authentication type used for this route, which is the name defined in securityDefinitionsFromJson.
)
service Swagger {
@doc (
description: "query demo"
)
@handler query
get /query (QueryReq) returns (QueryResp)
@doc (
description: "show path query demo"
)
@handler queryPath
get /query/:id (PathQueryReq) returns (PathQueryResp)
}
type (
FormReq {
Id int `form:"id,range=[1:10000],example=10"`
Name string `form:"name,example=keson.an"`
}
FormResp {
Id int `json:"id,example=10"`
Name string `json:"name,example=keson.an"`
}
)
@server (
tags: "form" // tags corresponding to Swagger
summary: "form API set" // summary corresponding to Swagger
)
service Swagger {
@doc (
description: "form demo"
)
@handler form
post /form (FormReq) returns (FormResp)
}
type (
JsonReq {
Id int `json:"id,range=[1:10000],example=10"`
Name string `json:"name,example=keson.an"`
Avatar string `json:"avatar,optional"`
Language string `json:"language,options=golang|java|python|typescript|rust"`
Gender string `json:"gender,default=male,options=male|female,example=male"`
}
JsonResp {
Id int `json:"id"`
Name string `json:"name"`
Avatar string `json:"avatar"`
Language string `json:"language"`
Gender string `json:"gender"`
}
ComplexJsonLevel2 {
// basic
Integer int `json:"integer,example=1"`
Number float64 `json:"number,example=1.1"`
Boolean bool `json:"boolean,options=true|false,example=true"`
String string `json:"string,example=some text"`
}
ComplexJsonLevel1 {
// basic
Integer int `json:"integer,example=1"`
Number float64 `json:"number,example=1.1"`
Boolean bool `json:"boolean,options=true|false,example=true"`
String string `json:"string,example=some text"`
// Object
Object ComplexJsonLevel2 `json:"object"`
PointerObject *ComplexJsonLevel2 `json:"pointerObject"`
}
ComplexJsonReq {
// basic
Integer int `json:"integer,example=1"`
Number float64 `json:"number,example=1.1"`
Boolean bool `json:"boolean,options=true|false,example=true"`
String string `json:"string,example=some text"`
// basic array
ArrayInteger []int `json:"arrayInteger"`
ArrayNumber []float64 `json:"arrayNumber"`
ArrayBoolean []bool `json:"arrayBoolean"`
ArrayString []string `json:"arrayString"`
// basic array array
ArrayArrayInteger [][]int `json:"arrayArrayInteger"`
ArrayArrayNumber [][]float64 `json:"arrayArrayNumber"`
ArrayArrayBoolean [][]bool `json:"arrayArrayBoolean"`
ArrayArrayString [][]string `json:"arrayArrayString"`
// basic map
MapInteger map[string]int `json:"mapInteger"`
MapNumber map[string]float64 `json:"mapNumber"`
MapBoolean map[string]bool `json:"mapBoolean"`
MapString map[string]string `json:"mapString"`
// basic map array
MapArrayInteger map[string][]int `json:"mapArrayInteger"`
MapArrayNumber map[string][]float64 `json:"mapArrayNumber"`
MapArrayBoolean map[string][]bool `json:"mapArrayBoolean"`
MapArrayString map[string][]string `json:"mapArrayString"`
// basic map map
MapMapInteger map[string]map[string]int `json:"mapMapInteger"`
MapMapNumber map[string]map[string]float64 `json:"mapMapNumber"`
MapMapBoolean map[string]map[string]bool `json:"mapMapBoolean"`
MapMapString map[string]map[string]string `json:"mapMapString"`
// Object
Object ComplexJsonLevel1 `json:"object"`
PointerObject *ComplexJsonLevel1 `json:"pointerObject"`
// Object array
ArrayObject []ComplexJsonLevel1 `json:"arrayObject"`
ArrayPointerObject []*ComplexJsonLevel1 `json:"arrayPointerObject"`
// Object map
MapObject map[string]ComplexJsonLevel1 `json:"mapObject"`
MapPointerObject map[string]*ComplexJsonLevel1 `json:"mapPointerObject"`
// Object array array
ArrayArrayObject [][]ComplexJsonLevel1 `json:"arrayArrayObject"`
ArrayArrayPointerObject [][]*ComplexJsonLevel1 `json:"arrayArrayPointerObject"`
// Object array map
ArrayMapObject []map[string]ComplexJsonLevel1 `json:"arrayMapObject"`
ArrayMapPointerObject []map[string]*ComplexJsonLevel1 `json:"arrayMapPointerObject"`
// Object map array
MapArrayObject map[string][]ComplexJsonLevel1 `json:"mapArrayObject"`
MapArrayPointerObject map[string][]*ComplexJsonLevel1 `json:"mapArrayPointerObject"`
}
ComplexJsonResp {
// basic
Integer int `json:"integer,example=1"`
Number float64 `json:"number,example=1.1"`
Boolean bool `json:"boolean,options=true|false,example=true"`
String string `json:"string,example=some text"`
// basic array
ArrayInteger []int `json:"arrayInteger"`
ArrayNumber []float64 `json:"arrayNumber"`
ArrayBoolean []bool `json:"arrayBoolean"`
ArrayString []string `json:"arrayString"`
// basic array array
ArrayArrayInteger [][]int `json:"arrayArrayInteger"`
ArrayArrayNumber [][]float64 `json:"arrayArrayNumber"`
ArrayArrayBoolean [][]bool `json:"arrayArrayBoolean"`
ArrayArrayString [][]string `json:"arrayArrayString"`
// basic map
MapInteger map[string]int `json:"mapInteger"`
MapNumber map[string]float64 `json:"mapNumber"`
MapBoolean map[string]bool `json:"mapBoolean"`
MapString map[string]string `json:"mapString"`
// basic map array
MapArrayInteger map[string][]int `json:"mapArrayInteger"`
MapArrayNumber map[string][]float64 `json:"mapArrayNumber"`
MapArrayBoolean map[string][]bool `json:"mapArrayBoolean"`
MapArrayString map[string][]string `json:"mapArrayString"`
// basic map map
MapMapInteger map[string]map[string]int `json:"mapMapInteger"`
MapMapNumber map[string]map[string]float64 `json:"mapMapNumber"`
MapMapBoolean map[string]map[string]bool `json:"mapMapBoolean"`
MapMapString map[string]map[string]string `json:"mapMapString"`
// Object
Object ComplexJsonLevel1 `json:"object"`
PointerObject *ComplexJsonLevel1 `json:"pointerObject"`
// Object array
ArrayObject []ComplexJsonLevel1 `json:"arrayObject"`
ArrayPointerObject []*ComplexJsonLevel1 `json:"arrayPointerObject"`
// Object map
MapObject map[string]ComplexJsonLevel1 `json:"mapObject"`
MapPointerObject map[string]*ComplexJsonLevel1 `json:"mapPointerObject"`
// Object array array
ArrayArrayObject [][]ComplexJsonLevel1 `json:"arrayArrayObject"`
ArrayArrayPointerObject [][]*ComplexJsonLevel1 `json:"arrayArrayPointerObject"`
// Object array map
ArrayMapObject []map[string]ComplexJsonLevel1 `json:"arrayMapObject"`
ArrayMapPointerObject []map[string]*ComplexJsonLevel1 `json:"arrayMapPointerObject"`
// Object map array
MapArrayObject map[string][]ComplexJsonLevel1 `json:"mapArrayObject"`
MapArrayPointerObject map[string][]*ComplexJsonLevel1 `json:"mapArrayPointerObject"`
}
)
@server (
tags: "postJson" // tags corresponding to Swagger
summary: "json API set" // summary corresponding to Swagger
)
service Swagger {
@doc (
description: "simple json request body API"
)
@handler jsonSimple
post /json/simple (JsonReq) returns (JsonResp)
@doc (
description: "complex json request body API"
)
@handler jsonComplex
post /json/complex (ComplexJsonReq) returns (ComplexJsonResp)
}

View File

@@ -0,0 +1,244 @@
syntax = "v1"
info (
title: "演示 API" // 对应 swagger 的 title
description: "演示 api 生成 swagger 文件的 api 完整写法" // 对应 swagger 的 description
version: "v1" // 对应 swagger 的 version
termsOfService: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 termsOfService
contactName: "keson.an" // 对应 swagger 的 contactName
contactURL: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 contactURL
contactEmail: "example@gmail.com" // 对应 swagger 的 contactEmail
licenseName: "MIT" // 对应 swagger 的 licenseName
licenseURL: "https://github.com/zeromicro/go-zero" // 对应 swagger 的 licenseURL
consumes: "application/json" // 对应 swagger 的 consumes,不填默认为 application/json
produces: "application/json" // 对应 swagger 的 produces,不填默认为 application/json
schemes: "https" // 对应 swagger 的 schemes,不填默认为 https
host: "example.com" // 对应 swagger 的 host,不填默认为 127.0.0.1
basePath: "/v1" // 对应 swagger 的 basePath,不填默认为 /
wrapCodeMsg: "true" // 是否用 code-msg 通用响应体,如果开启,则以格式 {"code":0,"msg":"OK","data":$data} 包括响应体
bizCodeEnumDescription: "1001-未登录<br>1002-无权限操作" // 业务错误码枚举描述json 格式,key 为业务错误码value 为该错误码的描述,仅当 wrapCodeMsg 为 true 时生效
// securityDefinitionsFromJson 为自定义鉴权配置json 内容将直接放入 swagger 的 securityDefinitions 中,
// 格式参考 https://swagger.io/specification/v2/#security-definitions-object
// 在 api 的 @server 中可声明 authType 来指定其路由使用的鉴权类型
securityDefinitionsFromJson: `{"apiKey":{"description":"apiKey 类型鉴权自定义","type":"apiKey","name":"x-api-key","in":"header"}}`
)
type (
QueryReq {
Id int `form:"id,range=[1:10000],example=10"`
Name string `form:"name,example=keson.an"`
Avatar string `form:"avatar,optional,example=https://example.com/avatar.png"`
}
QueryResp {
Id int `json:"id,example=10"`
Name string `json:"name,example=keson.an"`
}
PathQueryReq {
Id int `path:"id,range=[1:10000],example=10"`
Name string `form:"name,example=keson.an"`
}
PathQueryResp {
Id int `json:"id,example=10"`
Name string `json:"name,example=keson.an"`
}
)
@server (
tags: "query 演示" // 对应 swagger 的 tags,可以对 swagger 中的 api 进行分组
summary: "query 类型接口集合" // 对应 swagger 的 summary
prefix: v1
authType: apiKey // 指定该路由使用的鉴权类型,值为 securityDefinitionsFromJson 中定义的名称
)
service Swagger {
@doc (
description: "query 接口"
)
@handler query
get /query (QueryReq) returns (QueryResp)
@doc (
description: "query path 中包含 id 字段接口"
)
@handler queryPath
get /query/:id (PathQueryReq) returns (PathQueryResp)
}
type (
FormReq {
Id int `form:"id,range=[1:10000],example=10"`
Name string `form:"name,example=keson.an"`
}
FormResp {
Id int `json:"id,example=10"`
Name string `json:"name,example=keson.an"`
}
)
@server (
tags: "form 表单 api 演示" // 对应 swagger 的 tags,可以对 swagger 中的 api 进行分组
summary: "form 表单类型接口集合" // 对应 swagger 的 summary
)
service Swagger {
@doc (
description: "form 接口"
)
@handler form
post /form (FormReq) returns (FormResp)
}
type (
JsonReq {
Id int `json:"id,range=[1:10000],example=10"`
Name string `json:"name,example=keson.an"`
Avatar string `json:"avatar,optional"`
Language string `json:"language,options=golang|java|python|typescript|rust"`
Gender string `json:"gender,default=male,options=male|female,example=male"`
}
JsonResp {
Id int `json:"id"`
Name string `json:"name"`
Avatar string `json:"avatar"`
Language string `json:"language"`
Gender string `json:"gender"`
}
ComplexJsonLevel2 {
// basic
Integer int `json:"integer,example=1"`
Number float64 `json:"number,example=1.1"`
Boolean bool `json:"boolean,options=true|false,example=true"`
String string `json:"string,example=some text"`
}
ComplexJsonLevel1 {
// basic
Integer int `json:"integer,example=1"`
Number float64 `json:"number,example=1.1"`
Boolean bool `json:"boolean,options=true|false,example=true"`
String string `json:"string,example=some text"`
// Object
Object ComplexJsonLevel2 `json:"object"`
PointerObject *ComplexJsonLevel2 `json:"pointerObject"`
}
ComplexJsonReq {
// basic
Integer int `json:"integer,example=1"`
Number float64 `json:"number,example=1.1"`
Boolean bool `json:"boolean,options=true|false,example=true"`
String string `json:"string,example=some text"`
// basic array
ArrayInteger []int `json:"arrayInteger"`
ArrayNumber []float64 `json:"arrayNumber"`
ArrayBoolean []bool `json:"arrayBoolean"`
ArrayString []string `json:"arrayString"`
// basic array array
ArrayArrayInteger [][]int `json:"arrayArrayInteger"`
ArrayArrayNumber [][]float64 `json:"arrayArrayNumber"`
ArrayArrayBoolean [][]bool `json:"arrayArrayBoolean"`
ArrayArrayString [][]string `json:"arrayArrayString"`
// basic map
MapInteger map[string]int `json:"mapInteger"`
MapNumber map[string]float64 `json:"mapNumber"`
MapBoolean map[string]bool `json:"mapBoolean"`
MapString map[string]string `json:"mapString"`
// basic map array
MapArrayInteger map[string][]int `json:"mapArrayInteger"`
MapArrayNumber map[string][]float64 `json:"mapArrayNumber"`
MapArrayBoolean map[string][]bool `json:"mapArrayBoolean"`
MapArrayString map[string][]string `json:"mapArrayString"`
// basic map map
MapMapInteger map[string]map[string]int `json:"mapMapInteger"`
MapMapNumber map[string]map[string]float64 `json:"mapMapNumber"`
MapMapBoolean map[string]map[string]bool `json:"mapMapBoolean"`
MapMapString map[string]map[string]string `json:"mapMapString"`
MapMapObject map[string]map[string]ComplexJsonLevel1 `json:"mapMapObject"`
MapMapPointerObject map[string]map[string]*ComplexJsonLevel1 `json:"mapMapPointerObject"`
// Object
Object ComplexJsonLevel1 `json:"object"`
PointerObject *ComplexJsonLevel1 `json:"pointerObject"`
// Object array
ArrayObject []ComplexJsonLevel1 `json:"arrayObject"`
ArrayPointerObject []*ComplexJsonLevel1 `json:"arrayPointerObject"`
// Object map
MapObject map[string]ComplexJsonLevel1 `json:"mapObject"`
MapPointerObject map[string]*ComplexJsonLevel1 `json:"mapPointerObject"`
// Object array array
ArrayArrayObject [][]ComplexJsonLevel1 `json:"arrayArrayObject"`
ArrayArrayPointerObject [][]*ComplexJsonLevel1 `json:"arrayArrayPointerObject"`
// Object array map
ArrayMapObject []map[string]ComplexJsonLevel1 `json:"arrayMapObject"`
ArrayMapPointerObject []map[string]*ComplexJsonLevel1 `json:"arrayMapPointerObject"`
// Object map array
MapArrayObject map[string][]ComplexJsonLevel1 `json:"mapArrayObject"`
MapArrayPointerObject map[string][]*ComplexJsonLevel1 `json:"mapArrayPointerObject"`
}
ComplexJsonResp {
// basic
Integer int `json:"integer,example=1"`
Number float64 `json:"number,example=1.1"`
Boolean bool `json:"boolean,options=true|false,example=true"`
String string `json:"string,example=some text"`
// basic array
ArrayInteger []int `json:"arrayInteger"`
ArrayNumber []float64 `json:"arrayNumber"`
ArrayBoolean []bool `json:"arrayBoolean"`
ArrayString []string `json:"arrayString"`
// basic array array
ArrayArrayInteger [][]int `json:"arrayArrayInteger"`
ArrayArrayNumber [][]float64 `json:"arrayArrayNumber"`
ArrayArrayBoolean [][]bool `json:"arrayArrayBoolean"`
ArrayArrayString [][]string `json:"arrayArrayString"`
// basic map
MapInteger map[string]int `json:"mapInteger"`
MapNumber map[string]float64 `json:"mapNumber"`
MapBoolean map[string]bool `json:"mapBoolean"`
MapString map[string]string `json:"mapString"`
// basic map array
MapArrayInteger map[string][]int `json:"mapArrayInteger"`
MapArrayNumber map[string][]float64 `json:"mapArrayNumber"`
MapArrayBoolean map[string][]bool `json:"mapArrayBoolean"`
MapArrayString map[string][]string `json:"mapArrayString"`
// basic map map
MapMapInteger map[string]map[string]int `json:"mapMapInteger"`
MapMapNumber map[string]map[string]float64 `json:"mapMapNumber"`
MapMapBoolean map[string]map[string]bool `json:"mapMapBoolean"`
MapMapString map[string]map[string]string `json:"mapMapString"`
MapMapObject map[string]map[string]ComplexJsonLevel1 `json:"mapMapObject"`
MapMapPointerObject map[string]map[string]*ComplexJsonLevel1 `json:"mapMapPointerObject"`
// Object
Object ComplexJsonLevel1 `json:"object"`
PointerObject *ComplexJsonLevel1 `json:"pointerObject"`
// Object array
ArrayObject []ComplexJsonLevel1 `json:"arrayObject"`
ArrayPointerObject []*ComplexJsonLevel1 `json:"arrayPointerObject"`
// Object map
MapObject map[string]ComplexJsonLevel1 `json:"mapObject"`
MapPointerObject map[string]*ComplexJsonLevel1 `json:"mapPointerObject"`
// Object array array
ArrayArrayObject [][]ComplexJsonLevel1 `json:"arrayArrayObject"`
ArrayArrayPointerObject [][]*ComplexJsonLevel1 `json:"arrayArrayPointerObject"`
// Object array map
ArrayMapObject []map[string]ComplexJsonLevel1 `json:"arrayMapObject"`
ArrayMapPointerObject []map[string]*ComplexJsonLevel1 `json:"arrayMapPointerObject"`
// Object map array
MapArrayObject map[string][]ComplexJsonLevel1 `json:"mapArrayObject"`
MapArrayPointerObject map[string][]*ComplexJsonLevel1 `json:"mapArrayPointerObject"`
}
)
@server (
tags: "post json api 演示" // 对应 swagger 的 tags,可以对 swagger 中的 api 进行分组
summary: "json 请求类型接口集合" // 对应 swagger 的 summary
)
service Swagger {
@doc (
description: "简单的 json 请求体接口"
)
@handler jsonSimple
post /json/simple (JsonReq) returns (JsonResp)
@doc (
description: "复杂的 json 请求体接口"
)
@handler jsonComplex
post /json/complex (ComplexJsonReq) returns (ComplexJsonResp)
}

View File

@@ -0,0 +1,39 @@
#!/bin/bash
# 1. 检查并安装 swagger
if ! command -v swagger &> /dev/null; then
echo "swagger 未安装,正在从 GitHub 安装..."
# 这里使用 go-swagger 的安装方式
go install github.com/go-swagger/go-swagger/cmd/swagger@latest
if [ $? -ne 0 ]; then
echo "安装 swagger 失败"
exit 1
fi
echo "swagger 安装成功"
else
echo "swagger 已安装"
fi
mkdir bin output
export GOBIN=$(pwd)/bin
# 2. 安装最新版 goctl
go install ../../..
if [ $? -ne 0 ]; then
echo "安装 goctl 失败"
exit 1
fi
echo "goctl 安装成功"
# 3. 生成 swagger 文件
echo "正在生成 swagger 文件..."
./bin/goctl api swagger --api example_cn.api --dir output
if [ $? -ne 0 ]; then
echo "生成 swagger 文件失败"
exit 1
fi
# 4. 启动 swagger 服务
echo "启动 swagger 服务..."
swagger serve ./output/example_cn.json

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.6 MiB

View File

@@ -0,0 +1,39 @@
#!/bin/bash
# 1. Check and install swagger if not exists
if ! command -v swagger &> /dev/null; then
echo "swagger not found, installing from GitHub..."
# Using go-swagger installation method
go install github.com/go-swagger/go-swagger/cmd/swagger@latest
if [ $? -ne 0 ]; then
echo "Failed to install swagger"
exit 1
fi
echo "swagger installed successfully"
else
echo "swagger already installed"
fi
mkdir bin output
export GOBIN=$(pwd)/bin
# 2. Install latest goctl version
go install ../../..
if [ $? -ne 0 ]; then
echo "Failed to install goctl"
exit 1
fi
echo "goctl installed successfully"
# 3. Generate swagger files
echo "Generating swagger files..."
./bin/goctl api swagger --api example.api --dir output
if [ $? -ne 0 ]; then
echo "Failed to generate swagger files"
exit 1
fi
# 4. Start swagger server
echo "Starting swagger server..."
swagger serve ./output/example.json

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 MiB

View File

@@ -0,0 +1,91 @@
#!/bin/bash
# 检查Docker是否运行的函数
is_docker_running() {
if ! docker info >/dev/null 2>&1; then
return 1 # Docker未运行
else
return 0 # Docker正在运行
fi
}
mkdir bin output
export GOBIN=$(pwd)/bin
# 1. 检查并安装Docker如果不存在
if ! command -v docker &> /dev/null; then
echo "未检测到Docker正在尝试安装..."
# 使用官方脚本安装Docker
curl -fsSL https://get.docker.com -o get-docker.sh
sudo sh get-docker.sh
rm get-docker.sh
# 验证安装
if ! command -v docker &> /dev/null; then
echo "Docker安装失败"
exit 1
fi
# 将当前用户加入docker组可能需要重新登录
sudo usermod -aG docker $USER
echo "Docker安装成功。您可能需要注销并重新登录使更改生效。"
else
echo "Docker已安装"
fi
# 2. 安装最新版goctl
go install ../../..
if [ $? -ne 0 ]; then
echo "goctl安装失败"
exit 1
fi
echo "goctl 安装成功"
# 3. 生成swagger文件
echo "正在生成swagger文件..."
./bin/goctl api swagger --api example_cn.api --dir output
if [ $? -ne 0 ]; then
echo "swagger文件生成失败"
exit 1
fi
# 检查Docker是否运行
if ! is_docker_running; then
echo "Docker未运行请先启动Docker服务"
exit 1
fi
# 4. 清理现有的swagger-ui容器
echo "正在清理现有的swagger-ui容器..."
docker rm -f swagger-ui 2>/dev/null && echo "已移除现有的swagger-ui容器"
# 5. 在Docker中运行swagger-ui
echo "正在启动swagger-ui容器..."
docker run -d --name swagger-ui -p 8080:8080 \
-e SWAGGER_JSON=/tmp/example.json \
-v $(pwd)/output/example_cn.json:/tmp/example.json \
swaggerapi/swagger-ui
if [ $? -ne 0 ]; then
echo "swagger-ui容器启动失败"
exit 1
fi
# 等待1秒确保服务就绪
echo "等待swagger-ui初始化..."
sleep 1
# 显示访问信息并尝试打开浏览器
SWAGGER_URL="http://localhost:8080"
echo -e "\nSwagger UI 已准备就绪,访问地址: \033[1;34m${SWAGGER_URL}\033[0m"
echo "正在尝试在默认浏览器中打开..."
# 跨平台打开浏览器
case "$(uname -s)" in
Linux*) xdg-open "$SWAGGER_URL";;
Darwin*) open "$SWAGGER_URL";;
CYGWIN*|MINGW*|MSYS*) start "$SWAGGER_URL";;
*) echo "无法在当前操作系统自动打开浏览器";;
esac

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.4 MiB

View File

@@ -0,0 +1,80 @@
#!/bin/bash
is_docker_running() {
if ! docker info >/dev/null 2>&1; then
return 1 # Docker is not running
else
return 0 # Docker is running
fi
}
mkdir bin output
export GOBIN=$(pwd)/bin
# 1. Check and install Docker if not exists
if ! command -v docker &> /dev/null; then
echo "Docker not found, attempting to install..."
# Install Docker using official installation script
curl -fsSL https://get.docker.com -o get-docker.sh
sudo sh get-docker.sh
rm get-docker.sh
# Verify installation
if ! command -v docker &> /dev/null; then
echo "Failed to install Docker"
exit 1
fi
# Add current user to docker group (may require logout/login)
sudo usermod -aG docker $USER
echo "Docker installed successfully. You may need to logout and login again for changes to take effect."
else
echo "Docker already installed"
fi
# 2. Install latest goctl version
go install ../../..
if [ $? -ne 0 ]; then
echo "Failed to install goctl"
exit 1
fi
echo "goctl installed successfully"
# 3. Generate swagger files
echo "Generating swagger files..."
./bin/goctl api swagger --api example.api --dir output
if [ $? -ne 0 ]; then
echo "Failed to generate swagger files"
exit 1
fi
if ! is_docker_running; then
echo "Docker is not running, Pls start Docker first"
fi
# 4. Clean up any existing swagger-ui container
echo "Cleaning up existing swagger-ui containers..."
docker rm -f swagger-ui 2>/dev/null && echo "Removed existing swagger-ui container"
# 5. Run swagger-ui in Docker
echo "Starting swagger-ui in Docker..."
docker run -d --name swagger-ui -p 8080:8080 -e SWAGGER_JSON=/tmp/example.json -v $(pwd)/output/example.json:/tmp/example.json swaggerapi/swagger-ui
if [ $? -ne 0 ]; then
echo "Failed to start swagger-ui container"
exit 1
fi
echo "Waiting for swagger-ui to initialize..."
sleep 1
SWAGGER_URL="http://localhost:8080"
echo -e "\nSwagger UI is ready at: \033[1;34m${SWAGGER_URL}\033[0m"
echo "Opening in default browser..."
case "$(uname -s)" in
Linux*) xdg-open "$SWAGGER_URL";;
Darwin*) open "$SWAGGER_URL";;
CYGWIN*|MINGW*|MSYS*) start "$SWAGGER_URL";;
*) echo "System not supported";;
esac

View File

@@ -0,0 +1,123 @@
package swagger
import (
"strconv"
"strings"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/util"
)
func rangeValueFromOptions(options []string) (minimum *float64, maximum *float64, exclusiveMinimum bool, exclusiveMaximum bool) {
if len(options) == 0 {
return nil, nil, false, false
}
for _, option := range options {
if strings.HasPrefix(option, rangeFlag) {
val := option[6:]
start, end := val[0], val[len(val)-1]
if start != '[' && start != '(' {
return nil, nil, false, false
}
if end != ']' && end != ')' {
return nil, nil, false, false
}
exclusiveMinimum = start == '('
exclusiveMaximum = end == ')'
content := val[1 : len(val)-1]
idxColon := strings.Index(content, ":")
if idxColon < 0 {
return nil, nil, false, false
}
var (
minStr, maxStr string
minVal, maxVal *float64
)
minStr = util.TrimWhiteSpace(content[:idxColon])
if len(val) >= idxColon+1 {
maxStr = util.TrimWhiteSpace(content[idxColon+1:])
}
if len(minStr) > 0 {
min, err := strconv.ParseFloat(minStr, 64)
if err != nil {
return nil, nil, false, false
}
minVal = &min
}
if len(maxStr) > 0 {
max, err := strconv.ParseFloat(maxStr, 64)
if err != nil {
return nil, nil, false, false
}
maxVal = &max
}
return minVal, maxVal, exclusiveMinimum, exclusiveMaximum
}
}
return nil, nil, false, false
}
func enumsValueFromOptions(options []string) []any {
if len(options) == 0 {
return []any{}
}
for _, option := range options {
if strings.HasPrefix(option, enumFlag) {
var resp = make([]any, 0)
val := option[8:]
fields := util.FieldsAndTrimSpace(val, func(r rune) bool {
return r == '|'
})
for _, field := range fields {
resp = append(resp, field)
}
return resp
}
}
return []any{}
}
func defValueFromOptions(options []string, apiType spec.Type) any {
tp := sampleTypeFromGoType(apiType)
return valueFromOptions(options, defFlag, tp)
}
func exampleValueFromOptions(options []string, apiType spec.Type) any {
tp := sampleTypeFromGoType(apiType)
val := valueFromOptions(options, exampleFlag, tp)
if val != nil {
return val
}
return defValueFromOptions(options, apiType)
}
func valueFromOptions(options []string, key string, tp string) any {
if len(options) == 0 {
return nil
}
for _, option := range options {
if strings.HasPrefix(option, key) {
s := option[len(key):]
switch tp {
case "integer":
val, _ := strconv.ParseInt(s, 10, 64)
return val
case "boolean":
val, _ := strconv.ParseBool(s)
return val
case "number":
val, _ := strconv.ParseFloat(s, 64)
return val
case "string":
return s
default:
return nil
}
}
}
return nil
}

View File

@@ -0,0 +1,258 @@
package swagger
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func TestRangeValueFromOptions(t *testing.T) {
tests := []struct {
name string
options []string
expectedMin *float64
expectedMax *float64
expectedExclMin bool
expectedExclMax bool
}{
{
name: "Valid range with inclusive bounds",
options: []string{"range=[1.0:10.0]"},
expectedMin: floatPtr(1.0),
expectedMax: floatPtr(10.0),
expectedExclMin: false,
expectedExclMax: false,
},
{
name: "Valid range with exclusive bounds",
options: []string{"range=(1.0:10.0)"},
expectedMin: floatPtr(1.0),
expectedMax: floatPtr(10.0),
expectedExclMin: true,
expectedExclMax: true,
},
{
name: "Invalid range format",
options: []string{"range=1.0:10.0"},
expectedMin: nil,
expectedMax: nil,
expectedExclMin: false,
expectedExclMax: false,
},
{
name: "Invalid range start",
options: []string{"range=[a:1.0)"},
expectedMin: nil,
expectedMax: nil,
expectedExclMin: false,
expectedExclMax: false,
},
{
name: "Missing range end",
options: []string{"range=[1.0:)"},
expectedMin: floatPtr(1.0),
expectedMax: nil,
expectedExclMin: false,
expectedExclMax: true,
},
{
name: "Missing range start and end",
options: []string{"range=[:)"},
expectedMin: nil,
expectedMax: nil,
expectedExclMin: false,
expectedExclMax: true,
},
{
name: "Missing range start",
options: []string{"range=[:1.0)"},
expectedMin: nil,
expectedMax: floatPtr(1.0),
expectedExclMin: false,
expectedExclMax: true,
},
{
name: "Invalid range end",
options: []string{"range=[1.0:b)"},
expectedMin: nil,
expectedMax: nil,
expectedExclMin: false,
expectedExclMax: false,
},
{
name: "Empty options",
options: []string{},
expectedMin: nil,
expectedMax: nil,
expectedExclMin: false,
expectedExclMax: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
min, max, exclMin, exclMax := rangeValueFromOptions(tt.options)
assert.Equal(t, tt.expectedMin, min)
assert.Equal(t, tt.expectedMax, max)
assert.Equal(t, tt.expectedExclMin, exclMin)
assert.Equal(t, tt.expectedExclMax, exclMax)
})
}
}
func TestEnumsValueFromOptions(t *testing.T) {
tests := []struct {
name string
options []string
expected []any
}{
{
name: "Valid enums",
options: []string{"options=a|b|c"},
expected: []any{"a", "b", "c"},
},
{
name: "Empty enums",
options: []string{"options="},
expected: []any{},
},
{
name: "No enum option",
options: []string{},
expected: []any{},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := enumsValueFromOptions(tt.options)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDefValueFromOptions(t *testing.T) {
tests := []struct {
name string
options []string
apiType spec.Type
expected any
}{
{
name: "Default integer value",
options: []string{"default=42"},
apiType: spec.PrimitiveType{RawName: "int"},
expected: int64(42),
},
{
name: "Default string value",
options: []string{"default=hello"},
apiType: spec.PrimitiveType{RawName: "string"},
expected: "hello",
},
{
name: "No default value",
options: []string{},
apiType: spec.PrimitiveType{RawName: "string"},
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := defValueFromOptions(tt.options, tt.apiType)
assert.Equal(t, tt.expected, result)
})
}
}
func TestExampleValueFromOptions(t *testing.T) {
tests := []struct {
name string
options []string
apiType spec.Type
expected any
}{
{
name: "Example value present",
options: []string{"example=3.14"},
apiType: spec.PrimitiveType{RawName: "float"},
expected: 3.14,
},
{
name: "Fallback to default value",
options: []string{"default=42"},
apiType: spec.PrimitiveType{RawName: "int"},
expected: int64(42),
},
{
name: "Fallback to default value",
options: []string{"default="},
apiType: spec.PrimitiveType{RawName: "int"},
expected: int64(0),
},
{
name: "No example or default value",
options: []string{},
apiType: spec.PrimitiveType{RawName: "string"},
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
exampleValueFromOptions(tt.options, tt.apiType)
})
}
}
func TestValueFromOptions(t *testing.T) {
tests := []struct {
name string
options []string
key string
tp string
expected any
}{
{
name: "Integer value",
options: []string{"default=42"},
key: "default=",
tp: "integer",
expected: int64(42),
},
{
name: "Boolean value",
options: []string{"default=true"},
key: "default=",
tp: "boolean",
expected: true,
},
{
name: "Number value",
options: []string{"default=1.1"},
key: "default=",
tp: "number",
expected: 1.1,
},
{
name: "No matching key",
options: []string{"example=42"},
key: "default=",
tp: "integer",
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := valueFromOptions(tt.options, tt.key, tt.tp)
assert.Equal(t, tt.expected, result)
})
}
}
func floatPtr(f float64) *float64 {
return &f
}

View File

@@ -0,0 +1,184 @@
package swagger
import (
"net/http"
"strings"
"github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func parametersFromType(method string, tp apiSpec.Type) []spec.Parameter {
if tp == nil {
return []spec.Parameter{}
}
structType, ok := tp.(apiSpec.DefineStruct)
if !ok {
return []spec.Parameter{}
}
var (
resp []spec.Parameter
properties = map[string]spec.Schema{}
requiredFields []string
)
rangeMemberAndDo(structType, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
headerTag, _ := tag.Get(tagHeader)
hasHeader := headerTag != nil
pathParameterTag, _ := tag.Get(tagPath)
hasPathParameter := pathParameterTag != nil
formTag, _ := tag.Get(tagForm)
hasForm := formTag != nil
jsonTag, _ := tag.Get(tagJson)
hasJson := jsonTag != nil
if hasHeader {
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(headerTag.Options)
resp = append(resp, spec.Parameter{
CommonValidations: spec.CommonValidations{
Maximum: maximum,
ExclusiveMaximum: exclusiveMaximum,
Minimum: minimum,
ExclusiveMinimum: exclusiveMinimum,
Enum: enumsValueFromOptions(headerTag.Options),
},
SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(member.Type),
Default: defValueFromOptions(headerTag.Options, member.Type),
Example: exampleValueFromOptions(headerTag.Options, member.Type),
Items: sampleItemsFromGoType(member.Type),
},
ParamProps: spec.ParamProps{
In: paramsInHeader,
Name: headerTag.Name,
Description: formatComment(member.Comment),
Required: required,
},
})
}
if hasPathParameter {
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(pathParameterTag.Options)
resp = append(resp, spec.Parameter{
CommonValidations: spec.CommonValidations{
Maximum: maximum,
ExclusiveMaximum: exclusiveMaximum,
Minimum: minimum,
ExclusiveMinimum: exclusiveMinimum,
Enum: enumsValueFromOptions(pathParameterTag.Options),
},
SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(member.Type),
Default: defValueFromOptions(pathParameterTag.Options, member.Type),
Example: exampleValueFromOptions(pathParameterTag.Options, member.Type),
Items: sampleItemsFromGoType(member.Type),
},
ParamProps: spec.ParamProps{
In: paramsInPath,
Name: pathParameterTag.Name,
Description: formatComment(member.Comment),
Required: required,
},
})
}
if hasForm {
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(formTag.Options)
if strings.EqualFold(method, http.MethodGet) {
resp = append(resp, spec.Parameter{
CommonValidations: spec.CommonValidations{
Maximum: maximum,
ExclusiveMaximum: exclusiveMaximum,
Minimum: minimum,
ExclusiveMinimum: exclusiveMinimum,
Enum: enumsValueFromOptions(formTag.Options),
},
SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(member.Type),
Default: defValueFromOptions(formTag.Options, member.Type),
Example: exampleValueFromOptions(formTag.Options, member.Type),
Items: sampleItemsFromGoType(member.Type),
},
ParamProps: spec.ParamProps{
In: paramsInQuery,
Name: formTag.Name,
Description: formatComment(member.Comment),
Required: required,
AllowEmptyValue: !required,
},
})
} else {
resp = append(resp, spec.Parameter{
CommonValidations: spec.CommonValidations{
Maximum: maximum,
ExclusiveMaximum: exclusiveMaximum,
Minimum: minimum,
ExclusiveMinimum: exclusiveMinimum,
Enum: enumsValueFromOptions(formTag.Options),
},
SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(member.Type),
Default: defValueFromOptions(formTag.Options, member.Type),
Example: exampleValueFromOptions(formTag.Options, member.Type),
Items: sampleItemsFromGoType(member.Type),
},
ParamProps: spec.ParamProps{
In: paramsInForm,
Name: formTag.Name,
Description: formatComment(member.Comment),
Required: required,
AllowEmptyValue: !required,
},
})
}
}
if hasJson {
minimum, maximum, exclusiveMinimum, exclusiveMaximum := rangeValueFromOptions(jsonTag.Options)
if required {
requiredFields = append(requiredFields, jsonTag.Name)
}
var schema = spec.Schema{
SwaggerSchemaProps: spec.SwaggerSchemaProps{
Example: exampleValueFromOptions(jsonTag.Options, member.Type),
},
SchemaProps: spec.SchemaProps{
Description: formatComment(member.Comment),
Type: typeFromGoType(member.Type),
Default: defValueFromOptions(jsonTag.Options, member.Type),
Maximum: maximum,
ExclusiveMaximum: exclusiveMaximum,
Minimum: minimum,
ExclusiveMinimum: exclusiveMinimum,
Enum: enumsValueFromOptions(jsonTag.Options),
AdditionalProperties: mapFromGoType(member.Type),
},
}
switch sampleTypeFromGoType(member.Type) {
case swaggerTypeArray:
schema.Items = itemsFromGoType(member.Type)
case swaggerTypeObject:
p, r := propertiesFromType(member.Type)
schema.Properties = p
schema.Required = r
}
properties[jsonTag.Name] = schema
}
})
if len(properties) > 0 {
resp = append(resp, spec.Parameter{
ParamProps: spec.ParamProps{
In: paramsInBody,
Name: paramsInBody,
Required: true,
Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{
Type: typeFromGoType(structType),
Properties: properties,
Required: requiredFields,
},
},
},
})
}
return resp
}

View File

@@ -0,0 +1,115 @@
package swagger
import (
"net/http"
"path"
"strings"
"github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func spec2Paths(info apiSpec.Info, srv apiSpec.Service) *spec.Paths {
paths := &spec.Paths{
Paths: make(map[string]spec.PathItem),
}
for _, group := range srv.Groups {
prefix := path.Clean(strings.TrimPrefix(group.GetAnnotation("prefix"), "/"))
for _, route := range group.Routes {
routPath := pathVariable2SwaggerVariable(route.Path)
if len(prefix) > 0 && prefix != "." {
routPath = "/" + path.Clean(prefix) + routPath
}
pathItem := spec2Path(info, group, route)
existPathItem, ok := paths.Paths[routPath]
if !ok {
paths.Paths[routPath] = pathItem
} else {
paths.Paths[routPath] = mergePathItem(existPathItem, pathItem)
}
}
}
return paths
}
func mergePathItem(old, new spec.PathItem) spec.PathItem {
if new.Get != nil {
old.Get = new.Get
}
if new.Put != nil {
old.Put = new.Put
}
if new.Post != nil {
old.Post = new.Post
}
if new.Delete != nil {
old.Delete = new.Delete
}
if new.Options != nil {
old.Options = new.Options
}
if new.Head != nil {
old.Head = new.Head
}
if new.Patch != nil {
old.Patch = new.Patch
}
if new.Parameters != nil {
old.Parameters = new.Parameters
}
return old
}
func spec2Path(info apiSpec.Info, group apiSpec.Group, route apiSpec.Route) spec.PathItem {
authType := getStringFromKVOrDefault(group.Annotation.Properties, "authType", "")
var security []map[string][]string
if len(authType) > 0 {
security = []map[string][]string{
{
authType: []string{},
},
}
}
op := &spec.Operation{
OperationProps: spec.OperationProps{
Description: getStringFromKVOrDefault(route.AtDoc.Properties, "description", ""),
Consumes: consumesFromTypeOrDef(route.Method, route.RequestType),
Produces: getListFromInfoOrDefault(route.AtDoc.Properties, "produces", []string{applicationJson}),
Schemes: getListFromInfoOrDefault(route.AtDoc.Properties, "schemes", []string{schemeHttps}),
Tags: getListFromInfoOrDefault(group.Annotation.Properties, "tags", []string{""}),
Summary: getStringFromKVOrDefault(route.AtDoc.Properties, "summary", ""),
Deprecated: getBoolFromKVOrDefault(route.AtDoc.Properties, "deprecated", false),
Parameters: parametersFromType(route.Method, route.RequestType),
Responses: jsonResponseFromType(info, route.ResponseType),
Security: security,
},
}
externalDocsDescription := getStringFromKVOrDefault(route.AtDoc.Properties, "externalDocsDescription", "")
externalDocsURL := getStringFromKVOrDefault(route.AtDoc.Properties, "externalDocsURL", "")
if len(externalDocsDescription) > 0 || len(externalDocsURL) > 0 {
op.ExternalDocs = &spec.ExternalDocumentation{
Description: externalDocsDescription,
URL: externalDocsURL,
}
}
item := spec.PathItem{}
switch strings.ToUpper(route.Method) {
case http.MethodGet:
item.Get = op
case http.MethodHead:
item.Head = op
case http.MethodPost:
item.Post = op
case http.MethodPut:
item.Put = op
case http.MethodPatch:
item.Patch = op
case http.MethodDelete:
item.Delete = op
case http.MethodOptions:
item.Options = op
default: // [http.MethodConnect,http.MethodTrace] not supported
}
return item
}

View File

@@ -0,0 +1,69 @@
package swagger
import (
"github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func propertiesFromType(tp apiSpec.Type) (spec.SchemaProperties, []string) {
var (
properties = map[string]spec.Schema{}
requiredFields []string
)
switch val := tp.(type) {
case apiSpec.PointerType:
return propertiesFromType(val.Type)
case apiSpec.ArrayType:
return propertiesFromType(val.Value)
case apiSpec.DefineStruct, apiSpec.NestedStruct:
rangeMemberAndDo(val, func(tag *apiSpec.Tags, required bool, member apiSpec.Member) {
var (
jsonTagString = member.Name
minimum, maximum *float64
exclusiveMinimum, exclusiveMaximum bool
example, defaultValue any
enum []any
)
jsonTag, _ := tag.Get(tagJson)
if jsonTag != nil {
jsonTagString = jsonTag.Name
minimum, maximum, exclusiveMinimum, exclusiveMaximum = rangeValueFromOptions(jsonTag.Options)
example = exampleValueFromOptions(jsonTag.Options, member.Type)
defaultValue = defValueFromOptions(jsonTag.Options, member.Type)
enum = enumsValueFromOptions(jsonTag.Options)
}
if required {
requiredFields = append(requiredFields, jsonTagString)
}
var schema = spec.Schema{
SwaggerSchemaProps: spec.SwaggerSchemaProps{
Example: example,
},
SchemaProps: spec.SchemaProps{
Description: formatComment(member.Comment),
Type: typeFromGoType(member.Type),
Default: defaultValue,
Maximum: maximum,
ExclusiveMaximum: exclusiveMaximum,
Minimum: minimum,
ExclusiveMinimum: exclusiveMinimum,
Enum: enum,
AdditionalProperties: mapFromGoType(member.Type),
},
}
switch sampleTypeFromGoType(member.Type) {
case swaggerTypeArray:
schema.Items = itemsFromGoType(member.Type)
case swaggerTypeObject:
p, r := propertiesFromType(member.Type)
schema.Properties = p
schema.Required = r
}
properties[jsonTagString] = schema
})
}
return properties, requiredFields
}

View File

@@ -0,0 +1,28 @@
package swagger
import (
"github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
)
func jsonResponseFromType(info apiSpec.Info, tp apiSpec.Type) *spec.Responses {
p, _ := propertiesFromType(tp)
props := spec.SchemaProps{
Type: typeFromGoType(tp),
Properties: p,
AdditionalProperties: mapFromGoType(tp),
Items: itemsFromGoType(tp),
}
return &spec.Responses{
ResponsesProps: spec.ResponsesProps{
Default: &spec.Response{
ResponseProps: spec.ResponseProps{
Schema: &spec.Schema{
SchemaProps: wrapCodeMsgProps(props, info),
},
},
},
},
}
}

View File

@@ -0,0 +1,321 @@
package swagger
import (
"encoding/json"
"strings"
"time"
"github.com/go-openapi/spec"
apiSpec "github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
"github.com/zeromicro/go-zero/tools/goctl/util"
)
func spec2Swagger(api *apiSpec.ApiSpec) (*spec.Swagger, error) {
extensions, info := specExtensions(api.Info)
var securityDefinitions spec.SecurityDefinitions
securityDefinitionsFromJson := getStringFromKVOrDefault(api.Info.Properties, "securityDefinitionsFromJson", `{}`)
_ = json.Unmarshal([]byte(securityDefinitionsFromJson), &securityDefinitions)
swagger := &spec.Swagger{
VendorExtensible: spec.VendorExtensible{
Extensions: extensions,
},
SwaggerProps: spec.SwaggerProps{
Consumes: getListFromInfoOrDefault(api.Info.Properties, "consumes", []string{applicationJson}),
Produces: getListFromInfoOrDefault(api.Info.Properties, "produces", []string{applicationJson}),
Schemes: getListFromInfoOrDefault(api.Info.Properties, "schemes", []string{schemeHttps}),
Swagger: swaggerVersion,
Info: info,
Host: getStringFromKVOrDefault(api.Info.Properties, "host", defaultHost),
BasePath: getStringFromKVOrDefault(api.Info.Properties, "basePath", defaultBasePath),
Paths: spec2Paths(api.Info, api.Service),
SecurityDefinitions: securityDefinitions,
},
}
return swagger, nil
}
func formatComment(comment string) string {
s := strings.TrimPrefix(comment, "//")
return strings.TrimSpace(s)
}
func sampleItemsFromGoType(tp apiSpec.Type) *spec.Items {
val, ok := tp.(apiSpec.ArrayType)
if !ok {
return nil
}
item := val.Value
switch item.(type) {
case apiSpec.PrimitiveType:
return &spec.Items{
SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(item),
},
}
case apiSpec.ArrayType:
return &spec.Items{
SimpleSchema: spec.SimpleSchema{
Type: sampleTypeFromGoType(item),
Items: sampleItemsFromGoType(item),
},
}
default: // unsupported type
}
return nil
}
// itemsFromGoType returns the schema or array of the type, just for non json body parameters.
func itemsFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
array, ok := tp.(apiSpec.ArrayType)
if !ok {
return nil
}
return itemFromGoType(array.Value)
}
func mapFromGoType(tp apiSpec.Type) *spec.SchemaOrBool {
mapType, ok := tp.(apiSpec.MapType)
if !ok {
return nil
}
var schema = &spec.Schema{
SchemaProps: spec.SchemaProps{
Type: typeFromGoType(mapType.Value),
AdditionalProperties: mapFromGoType(mapType.Value),
},
}
switch sampleTypeFromGoType(mapType.Value) {
case swaggerTypeArray:
schema.Items = itemsFromGoType(mapType.Value)
case swaggerTypeObject:
p, r := propertiesFromType(mapType.Value)
schema.Properties = p
schema.Required = r
}
return &spec.SchemaOrBool{
Allows: true,
Schema: schema,
}
}
// itemFromGoType returns the schema or array of the type, just for non json body parameters.
func itemFromGoType(tp apiSpec.Type) *spec.SchemaOrArray {
switch itemType := tp.(type) {
case apiSpec.PrimitiveType:
return &spec.SchemaOrArray{
Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{
Type: typeFromGoType(tp),
},
},
}
case apiSpec.DefineStruct, apiSpec.NestedStruct:
properties, requiredFields := propertiesFromType(itemType)
return &spec.SchemaOrArray{
Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{
Type: typeFromGoType(itemType),
Items: itemsFromGoType(itemType),
Properties: properties,
Required: requiredFields,
AdditionalProperties: mapFromGoType(itemType),
},
},
}
case apiSpec.PointerType:
return itemFromGoType(itemType.Type)
case apiSpec.ArrayType:
return &spec.SchemaOrArray{
Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{
Type: typeFromGoType(itemType),
Items: itemsFromGoType(itemType),
},
},
}
}
return nil
}
func typeFromGoType(tp apiSpec.Type) []string {
switch val := tp.(type) {
case apiSpec.PrimitiveType:
res, ok := tpMapper[val.RawName]
if ok {
return []string{res}
}
case apiSpec.ArrayType:
return []string{swaggerTypeArray}
case apiSpec.DefineStruct, apiSpec.MapType:
return []string{swaggerTypeObject}
case apiSpec.PointerType:
return typeFromGoType(val.Type)
}
return nil
}
func sampleTypeFromGoType(tp apiSpec.Type) string {
switch val := tp.(type) {
case apiSpec.PrimitiveType:
return tpMapper[val.RawName]
case apiSpec.ArrayType:
return swaggerTypeArray
case apiSpec.DefineStruct, apiSpec.MapType, apiSpec.NestedStruct:
return swaggerTypeObject
case apiSpec.PointerType:
return sampleTypeFromGoType(val.Type)
default:
return ""
}
}
func typeContainsTag(structType apiSpec.DefineStruct, tag string) bool {
for _, field := range structType.Members {
tags, _ := apiSpec.Parse(field.Tag)
for _, t := range tags.Tags() {
if t.Key == tag {
return true
}
}
}
return false
}
func expandMembers(tp apiSpec.Type) []apiSpec.Member {
var members []apiSpec.Member
switch val := tp.(type) {
case apiSpec.DefineStruct:
for _, v := range val.Members {
if v.IsInline {
members = append(members, expandMembers(v.Type)...)
continue
}
members = append(members, v)
}
case apiSpec.NestedStruct:
for _, v := range val.Members {
if v.IsInline {
members = append(members, expandMembers(v.Type)...)
continue
}
members = append(members, v)
}
}
return members
}
func rangeMemberAndDo(structType apiSpec.Type, do func(tag *apiSpec.Tags, required bool, member apiSpec.Member)) {
var members = expandMembers(structType)
for _, field := range members {
tags, _ := apiSpec.Parse(field.Tag)
required := isRequired(tags)
do(tags, required, field)
}
}
func isRequired(tags *apiSpec.Tags) bool {
tag, err := tags.Get(tagJson)
if err == nil {
return !isOptional(tag.Options)
}
tag, err = tags.Get(tagForm)
if err == nil {
return !isOptional(tag.Options)
}
tag, err = tags.Get(tagPath)
if err == nil {
return !isOptional(tag.Options)
}
return false
}
func isOptional(options []string) bool {
for _, option := range options {
if option == "optional" {
return true
}
}
return false
}
func pathVariable2SwaggerVariable(path string) string {
pathItems := strings.FieldsFunc(path, slashRune)
var resp []string
for _, v := range pathItems {
if strings.HasPrefix(v, ":") {
resp = append(resp, "{"+v[1:]+"}")
} else {
resp = append(resp, v)
}
}
return "/" + strings.Join(resp, "/")
}
func wrapCodeMsgProps(properties spec.SchemaProps, api apiSpec.Info) spec.SchemaProps {
wrapCodeMsg := getBoolFromKVOrDefault(api.Properties, "wrapCodeMsg", false)
if !wrapCodeMsg {
return properties
}
return spec.SchemaProps{
Type: []string{swaggerTypeObject},
Properties: spec.SchemaProperties{
"code": {
SwaggerSchemaProps: spec.SwaggerSchemaProps{
Example: 0,
},
SchemaProps: spec.SchemaProps{
Type: []string{swaggerTypeInteger},
Description: getStringFromKVOrDefault(api.Properties, "bizCodeEnumDescription", "business code"),
},
},
"msg": {
SwaggerSchemaProps: spec.SwaggerSchemaProps{
Example: "ok",
},
SchemaProps: spec.SchemaProps{
Type: []string{swaggerTypeString},
Description: "business message",
},
},
"data": {
SchemaProps: properties,
},
},
}
}
func specExtensions(api apiSpec.Info) (spec.Extensions, *spec.Info) {
ext := spec.Extensions{}
ext.Add("x-goctl-version", version.BuildVersion)
ext.Add("x-description", "This is a goctl generated swagger file.")
ext.Add("x-date", time.Now().Format("2006-01-02 15:04:05"))
ext.Add("x-github", "https://github.com/zeromicro/go-zero")
ext.Add("x-go-zero-doc", "https://go-zero.dev/")
info := &spec.Info{}
info.Description = util.Unquote(api.Properties["description"])
info.Title = util.Unquote(api.Properties["title"])
info.TermsOfService = util.Unquote(api.Properties["termsOfService"])
info.Version = util.Unquote(api.Properties["version"])
contactInfo := spec.ContactInfo{}
contactInfo.Name = util.Unquote(api.Properties["contactName"])
contactInfo.URL = util.Unquote(api.Properties["contactURL"])
contactInfo.Email = util.Unquote(api.Properties["contactEmail"])
if len(contactInfo.Name) > 0 || len(contactInfo.URL) > 0 || len(contactInfo.Email) > 0 {
info.Contact = &contactInfo
}
license := &spec.License{}
license.Name = util.Unquote(api.Properties["licenseName"])
license.URL = util.Unquote(api.Properties["licenseURL"])
if len(license.Name) > 0 || len(license.URL) > 0 {
info.License = license
}
return ext, info
}

View File

@@ -0,0 +1,25 @@
package swagger
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_pathVariable2SwaggerVariable(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{input: "/api/:id", expected: "/api/{id}"},
{input: "/api/:id/details", expected: "/api/{id}/details"},
{input: "/:version/api/:id", expected: "/{version}/api/{id}"},
{input: "/api/v1", expected: "/api/v1"},
{input: "/api/:id/:action", expected: "/api/{id}/{action}"},
}
for _, tc := range testCases {
result := pathVariable2SwaggerVariable(tc.input)
assert.Equal(t, tc.expected, result)
}
}

View File

@@ -0,0 +1,27 @@
package swagger
var (
tpMapper = map[string]string{
"uint8": swaggerTypeInteger,
"uint16": swaggerTypeInteger,
"uint32": swaggerTypeInteger,
"uint64": swaggerTypeInteger,
"int8": swaggerTypeInteger,
"int16": swaggerTypeInteger,
"int32": swaggerTypeInteger,
"int64": swaggerTypeInteger,
"int": swaggerTypeInteger,
"uint": swaggerTypeInteger,
"byte": swaggerTypeInteger,
"float32": swaggerTypeNumber,
"float64": swaggerTypeNumber,
"string": swaggerTypeString,
"bool": swaggerTypeBoolean,
}
commaRune = func(r rune) bool {
return r == ','
}
slashRune = func(r rune) bool {
return r == '/'
}
)

View File

@@ -6,6 +6,7 @@ require (
github.com/DATA-DOG/go-sqlmock v1.5.2
github.com/emicklei/proto v1.14.0
github.com/fatih/structtag v1.2.0
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e
github.com/go-sql-driver/mysql v1.9.0
github.com/gookit/color v1.5.4
github.com/iancoleman/strcase v0.3.0
@@ -15,7 +16,7 @@ require (
github.com/withfig/autocomplete-tools/integrations/cobra v1.2.1
github.com/zeromicro/antlr v0.0.1
github.com/zeromicro/ddl-parser v1.0.5
github.com/zeromicro/go-zero v1.8.1
github.com/zeromicro/go-zero v1.8.2
golang.org/x/text v0.22.0
google.golang.org/grpc v1.65.0
google.golang.org/protobuf v1.36.5
@@ -38,9 +39,9 @@ require (
github.com/fatih/color v1.18.0 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/swag v0.22.4 // indirect
github.com/go-openapi/jsonpointer v0.21.1 // indirect
github.com/go-openapi/jsonreference v0.21.0 // indirect
github.com/go-openapi/swag v0.23.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/mock v1.6.0 // indirect
github.com/golang/protobuf v1.5.4 // indirect
@@ -52,13 +53,13 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/pgx/v5 v5.7.2 // indirect
github.com/jackc/pgx/v5 v5.7.4 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/logrusorgru/aurora v2.0.3+incompatible // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mailru/easyjson v0.9.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
@@ -67,11 +68,11 @@ require (
github.com/openzipkin/zipkin-go v0.4.3 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.21.0 // indirect
github.com/prometheus/client_golang v1.21.1 // indirect
github.com/prometheus/client_model v0.6.1 // indirect
github.com/prometheus/common v0.62.0 // indirect
github.com/prometheus/procfs v0.15.1 // indirect
github.com/redis/go-redis/v9 v9.7.1 // indirect
github.com/redis/go-redis/v9 v9.7.3 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect
github.com/yuin/gopher-lua v1.1.1 // indirect

View File

@@ -25,7 +25,6 @@ github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03V
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -45,13 +44,14 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE=
github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs=
github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE=
github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k=
github.com/go-openapi/swag v0.22.3/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU=
github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14=
github.com/go-openapi/jsonpointer v0.21.1 h1:whnzv/pNXtK2FbX/W9yJfRmE2gsmkfahjMKB0fZvcic=
github.com/go-openapi/jsonpointer v0.21.1/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk=
github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ=
github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4=
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e h1:auobAirzhPsLHMso0NVMqK0QunuLDYCK83KnaVUM/RU=
github.com/go-openapi/spec v0.21.1-0.20250328170532-a3928469592e/go.mod h1:NAKTe9SplQBxIUlHlsuId1jk1I7bWTVV/2q/GtdRi6g=
github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU=
github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0=
github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo=
github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
@@ -89,8 +89,8 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.2 h1:mLoDLV6sonKlvjIEsV56SkWNCnuNv531l94GaIzO+XI=
github.com/jackc/pgx/v5 v5.7.2/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg=
github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
@@ -102,19 +102,16 @@ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+o
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczGlG91VSDkswnjF5A8=
github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
@@ -141,18 +138,18 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g=
github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U=
github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA=
github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk=
github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg=
github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E=
github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY=
github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io=
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.7.1 h1:4LhKRCIduqXqtvCUlaq9c8bdHOkICjDMrr1+Zb3osAc=
github.com/redis/go-redis/v9 v9.7.1/go.mod h1:f6zhXITC7JUJIlPEiBOTXxJgPLdZcA93GewI7inzyWw=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
@@ -169,7 +166,6 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
@@ -187,8 +183,8 @@ github.com/zeromicro/antlr v0.0.1 h1:CQpIn/dc0pUjgGQ81y98s/NGOm2Hfru2NNio2I9mQgk
github.com/zeromicro/antlr v0.0.1/go.mod h1:nfpjEwFR6Q4xGDJMcZnCL9tEfQRgszMwu3rDz2Z+p5M=
github.com/zeromicro/ddl-parser v1.0.5 h1:LaVqHdzMTjasua1yYpIYaksxKqRzFrEukj2Wi2EbWaQ=
github.com/zeromicro/ddl-parser v1.0.5/go.mod h1:ISU/8NuPyEpl9pa17Py9TBPetMjtsiHrb9f5XGiYbo8=
github.com/zeromicro/go-zero v1.8.1 h1:iUYQEMQzS9Pb8ebzJtV3FGtv/YTjZxAh/NvLW/316wo=
github.com/zeromicro/go-zero v1.8.1/go.mod h1:gc54Ad4qt7OJ0PbKajnYsSKsZBYN4JLRIXKlqDX2A2I=
github.com/zeromicro/go-zero v1.8.2 h1:AbJckBoojbr1lqCN1dkvURTIHOau7yvKReEd7ZmjuCk=
github.com/zeromicro/go-zero v1.8.2/go.mod h1:G5dF+jzCEuq0t1j8qdrtVAy30QMgctGcKSfqFIGsvSg=
go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk=
go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM=
go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA=

View File

@@ -71,6 +71,13 @@
"api": "{{.goctl.api.api}}",
"caller": "The web api caller",
"unwrap": "Unwrap the webapi caller for import"
},
"swagger": {
"short": "Generate swagger file from api",
"dir": "{{.goctl.api.dir}}",
"api": "{{.goctl.api.api}}",
"filename": "The generated swagger file name without the extension",
"yaml": "Generate swagger yaml file, default to json"
}
},
"bug": {

View File

@@ -6,7 +6,7 @@ import (
)
// BuildVersion is the version of goctl.
const BuildVersion = "1.8.1"
const BuildVersion = "1.8.3-beta"
var tag = map[string]int{"pre-alpha": 0, "alpha": 1, "pre-bata": 2, "beta": 3, "released": 4, "": 5}

View File

@@ -1342,7 +1342,7 @@ func (p *Parser) parseKVExpression() *ast.KVExpr {
expr.Colon = p.curTokenNode()
// token STRING
if !p.advanceIfPeekTokenIs(token.STRING) {
if !p.advanceIfPeekTokenIs(token.STRING, token.RAW_STRING) {
return nil
}

View File

@@ -121,3 +121,34 @@ func IsEmptyStringOrWhiteSpace(s string) bool {
v := TrimWhiteSpace(s)
return len(v) == 0
}
func FieldsAndTrimSpace(s string, f func(r rune) bool) []string {
fields := strings.FieldsFunc(s, f)
var resp []string
for _, v := range fields {
val := TrimWhiteSpace(v)
if len(val) > 0 {
resp = append(resp, v)
}
}
return resp
}
func Unquote(s string) string {
if len(s) == 0 {
return s
}
left := s[0]
if left == '`' || left == '"' {
s = s[1:len(s)]
}
if len(s) == 0 {
return s
}
right := s[len(s)-1]
if right == '`' || right == '"' {
s = s[0 : len(s)-1]
}
return s
}

View File

@@ -3,6 +3,7 @@ package util
import (
"strings"
"testing"
"unicode"
"github.com/stretchr/testify/assert"
)
@@ -72,3 +73,67 @@ func TestEscapeGoKeyword(t *testing.T) {
assert.False(t, isGolangKeyword(strings.Title(k)))
}
}
func TestFieldsAndTrimSpace(t *testing.T) {
testCases := []struct {
name string
input string
delimiter func(r rune) bool
expected []string
}{
{
name: "Comma-separated values",
input: "a, b, c",
delimiter: func(r rune) bool { return r == ',' },
expected: []string{"a", " b", " c"},
},
{
name: "Space-separated values",
input: "a b c",
delimiter: unicode.IsSpace,
expected: []string{"a", "b", "c"},
},
{
name: "Mixed whitespace",
input: "a\tb\nc",
delimiter: unicode.IsSpace,
expected: []string{"a", "b", "c"},
},
{
name: "Empty input",
input: "",
delimiter: unicode.IsSpace,
expected: []string(nil),
},
{
name: "Trailing and leading spaces",
input: " a , b , c ",
delimiter: func(r rune) bool { return r == ',' },
expected: []string{" a ", " b ", " c "},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
result := FieldsAndTrimSpace(tc.input, tc.delimiter)
assert.Equal(t, tc.expected, result)
})
}
}
func TestUnquote(t *testing.T) {
testCases := []struct {
input string
expected string
}{
{input: `"hello"`, expected: `hello`},
{input: "`world`", expected: `world`},
{input: `"foo'bar"`, expected: `foo'bar`},
{input: "", expected: ""},
}
for _, tc := range testCases {
result := Unquote(tc.input)
assert.Equal(t, tc.expected, result)
}
}