Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
c37121ff87 Add DestroyAll() method to syncx.Pool with comprehensive tests
Co-authored-by: kevwan <1918356+kevwan@users.noreply.github.com>
2025-09-26 14:21:07 +00:00
copilot-swe-agent[bot]
f08d9329a8 Initial plan 2025-09-26 14:15:49 +00:00
16 changed files with 192 additions and 551 deletions

View File

@@ -100,6 +100,34 @@ func (p *Pool) Put(x any) {
p.cond.Signal()
}
// DestroyAll destroys all resources in the pool.
// It calls the destroy function on each resource and resets the pool state.
// This is useful when you need to forcefully clean up all resources, for example:
// - When removing an obsolete pool
// - When refreshing all resources after configuration changes
// - When avoiding resource leaks in dynamic pool scenarios
func (p *Pool) DestroyAll() {
p.lock.Lock()
defer p.lock.Unlock()
// Iterate through the linked list and destroy all resources
current := p.head
for current != nil {
next := current.next
if p.destroy != nil {
p.destroy(current.item)
}
current = next
}
// Reset pool state
p.head = nil
p.created = 0
// Wake up all waiting goroutines since the pool is now empty
p.cond.Broadcast()
}
// WithMaxAge returns a function to customize a Pool with given max age.
func WithMaxAge(duration time.Duration) PoolOption {
return func(pool *Pool) {

View File

@@ -107,6 +107,155 @@ func TestNewPoolPanics(t *testing.T) {
})
}
func TestPoolDestroyAll(t *testing.T) {
var destroyed []int
var destroyCount int32
destroyFunc := func(item any) {
destroyed = append(destroyed, item.(int))
atomic.AddInt32(&destroyCount, 1)
}
pool := NewPool(limit, create, destroyFunc)
// Put some resources into the pool
pool.Put(10)
pool.Put(20)
pool.Put(30)
// Destroy all resources
pool.DestroyAll()
// Verify all resources were destroyed
assert.Equal(t, int32(3), atomic.LoadInt32(&destroyCount))
assert.Contains(t, destroyed, 10)
assert.Contains(t, destroyed, 20)
assert.Contains(t, destroyed, 30)
// Verify pool is empty - next Get should create new resource
val := pool.Get()
assert.Equal(t, 1, val) // create() returns 1
}
func TestPoolDestroyAllEmpty(t *testing.T) {
var destroyCount int32
destroyFunc := func(_ any) {
atomic.AddInt32(&destroyCount, 1)
}
pool := NewPool(limit, create, destroyFunc)
// DestroyAll on empty pool should not panic
pool.DestroyAll()
// No resources should have been destroyed
assert.Equal(t, int32(0), atomic.LoadInt32(&destroyCount))
// Pool should still work normally
val := pool.Get()
assert.Equal(t, 1, val)
}
func TestPoolDestroyAllWithNilDestroy(t *testing.T) {
pool := NewPool(limit, create, nil)
// Put some resources into the pool
pool.Put(10)
pool.Put(20)
// DestroyAll with nil destroy function should not panic
pool.DestroyAll()
// Pool should be empty and work normally
val := pool.Get()
assert.Equal(t, 1, val)
}
func TestPoolDestroyAllConcurrency(t *testing.T) {
var destroyCount int32
var createCount int32
createFunc := func() any {
return atomic.AddInt32(&createCount, 1)
}
destroyFunc := func(_ any) {
atomic.AddInt32(&destroyCount, 1)
}
pool := NewPool(limit, createFunc, destroyFunc)
// Add some initial resources
for i := 0; i < 5; i++ {
pool.Put(i + 100)
}
var wg sync.WaitGroup
const goroutines = 10
// Concurrently perform various operations
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
switch id % 4 {
case 0:
// DestroyAll
pool.DestroyAll()
case 1:
// Get resources
val := pool.Get()
pool.Put(val)
case 2:
// Put resources
pool.Put(id + 1000)
case 3:
// Get and don't put back
pool.Get()
}
}(i)
}
wg.Wait()
// Final DestroyAll to clean up
pool.DestroyAll()
// Pool should work after concurrent operations
val := pool.Get()
assert.NotNil(t, val)
}
func TestPoolDestroyAllWakesWaitingGoroutines(t *testing.T) {
pool := NewPool(1, create, destroy) // Small pool size
// Fill the pool
resource := pool.Get()
assert.Equal(t, 1, resource)
var wg sync.WaitGroup
var gotResource bool
// Start a goroutine that will wait for a resource
wg.Add(1)
go func() {
defer wg.Done()
val := pool.Get() // This will block since pool is full
gotResource = true
assert.Equal(t, 1, val) // Should get a newly created resource after DestroyAll
}()
// Give the goroutine time to start waiting
time.Sleep(10 * time.Millisecond)
// DestroyAll should wake up the waiting goroutine
pool.DestroyAll()
wg.Wait()
assert.True(t, gotResource)
}
func create() any {
return 1
}

2
go.mod
View File

@@ -16,7 +16,7 @@ require (
github.com/jhump/protoreflect v1.17.0
github.com/pelletier/go-toml/v2 v2.2.2
github.com/prometheus/client_golang v1.21.1
github.com/redis/go-redis/v9 v9.15.0
github.com/redis/go-redis/v9 v9.14.0
github.com/spaolacci/murmur3 v1.1.0
github.com/stretchr/testify v1.11.1
go.etcd.io/etcd/api/v3 v3.5.15

4
go.sum
View File

@@ -154,8 +154,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/redis/go-redis/v9 v9.15.0 h1:2jdes0xJxer4h3NUZrZ4OGSntGlXp4WbXju2nOTRXto=
github.com/redis/go-redis/v9 v9.15.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/redis/go-redis/v9 v9.14.0 h1:u4tNCjXOyzfgeLN+vAZaW1xUooqWDqVEsZN0U01jfAE=
github.com/redis/go-redis/v9 v9.14.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=

View File

@@ -389,9 +389,7 @@ func buildSSERoutes(routes []Route) []Route {
// because SSE requires the connection to be kept alive indefinitely.
rc := http.NewResponseController(w)
if err := rc.SetWriteDeadline(time.Time{}); err != nil {
// Some ResponseWriter implementations (like timeoutWriter) don't support SetWriteDeadline.
// This is expected behavior and doesn't affect SSE functionality.
logc.Debugf(r.Context(), "unable to clear write deadline for SSE connection: %v", err)
logc.Errorf(r.Context(), "set conn write deadline failed: %v", err)
}
w.Header().Set(header.ContentType, header.ContentTypeEventStream)

View File

@@ -24,16 +24,12 @@ import (
)
const (
limitBodyBytes = 1024
limitDetailedBodyBytes = 4096
defaultSlowThreshold = time.Millisecond * 500
defaultSSESlowThreshold = time.Minute * 3
limitBodyBytes = 1024
limitDetailedBodyBytes = 4096
defaultSlowThreshold = time.Millisecond * 500
)
var (
slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
sseSlowThreshold = syncx.ForAtomicDuration(defaultSSESlowThreshold)
)
var slowThreshold = syncx.ForAtomicDuration(defaultSlowThreshold)
// LogHandler returns a middleware that logs http request and response.
func LogHandler(next http.Handler) http.Handler {
@@ -113,11 +109,6 @@ func SetSlowThreshold(threshold time.Duration) {
slowThreshold.Set(threshold)
}
// SetSSESlowThreshold sets the slow threshold for SSE requests.
func SetSSESlowThreshold(threshold time.Duration) {
sseSlowThreshold.Set(threshold)
}
func dumpRequest(r *http.Request) string {
reqContent, err := httputil.DumpRequest(r, true)
if err != nil {
@@ -127,14 +118,6 @@ func dumpRequest(r *http.Request) string {
return string(reqContent)
}
func getSlowThreshold(r *http.Request) time.Duration {
if r.Header.Get(headerAccept) == valueSSE {
return sseSlowThreshold.Load()
} else {
return slowThreshold.Load()
}
}
func isOkResponse(code int) bool {
// not server error
return code < http.StatusInternalServerError
@@ -146,8 +129,7 @@ func logBrief(r *http.Request, code int, timer *utils.ElapsedTimer, logs *intern
logger := logx.WithContext(r.Context()).WithDuration(duration)
buf.WriteString(fmt.Sprintf("[HTTP] %s - %s %s - %s - %s",
wrapStatusCode(code), wrapMethod(r.Method), r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent()))
if duration > getSlowThreshold(r) {
if duration > slowThreshold.Load() {
logger.Slowf("[HTTP] %s - %s %s - %s - %s - slowcall(%s)",
wrapStatusCode(code), wrapMethod(r.Method), r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent(),
timex.ReprOfDuration(duration))
@@ -178,8 +160,7 @@ func logDetails(r *http.Request, response *detailLoggedResponseWriter, timer *ut
logger := logx.WithContext(r.Context())
buf.WriteString(fmt.Sprintf("[HTTP] %s - %d - %s - %s\n=> %s\n",
r.Method, code, r.RemoteAddr, timex.ReprOfDuration(duration), dumpRequest(r)))
if duration > getSlowThreshold(r) {
if duration > slowThreshold.Load() {
logger.Slowf("[HTTP] %s - %d - %s - slowcall(%s)\n=> %s\n", r.Method, code, r.RemoteAddr,
timex.ReprOfDuration(duration), dumpRequest(r))
}

View File

@@ -88,96 +88,6 @@ func TestLogHandlerSlow(t *testing.T) {
}
}
func TestLogHandlerSSE(t *testing.T) {
handlers := []func(handler http.Handler) http.Handler{
LogHandler,
DetailedLogHandler,
}
for _, logHandler := range handlers {
t.Run("SSE request with normal duration", func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
req.Header.Set(headerAccept, valueSSE)
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(defaultSlowThreshold + time.Second)
w.WriteHeader(http.StatusOK)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
})
t.Run("SSE request exceeding SSE threshold", func(t *testing.T) {
originalThreshold := sseSlowThreshold.Load()
SetSSESlowThreshold(time.Millisecond * 100)
defer SetSSESlowThreshold(originalThreshold)
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
req.Header.Set(headerAccept, valueSSE)
handler := logHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 150)
w.WriteHeader(http.StatusOK)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
})
}
}
func TestLogHandlerThresholdSelection(t *testing.T) {
tests := []struct {
name string
acceptHeader string
expectedIsSSE bool
}{
{
name: "Regular HTTP request",
acceptHeader: "text/html",
expectedIsSSE: false,
},
{
name: "SSE request",
acceptHeader: valueSSE,
expectedIsSSE: true,
},
{
name: "No Accept header",
acceptHeader: "",
expectedIsSSE: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "http://localhost", http.NoBody)
if tt.acceptHeader != "" {
req.Header.Set(headerAccept, tt.acceptHeader)
}
SetSlowThreshold(time.Millisecond * 100)
SetSSESlowThreshold(time.Millisecond * 200)
defer func() {
SetSlowThreshold(defaultSlowThreshold)
SetSSESlowThreshold(defaultSSESlowThreshold)
}()
handler := LogHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 150)
w.WriteHeader(http.StatusOK)
}))
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
})
}
}
func TestDetailedLogHandler_LargeBody(t *testing.T) {
lbuf := logtest.NewCollector(t)
@@ -229,12 +139,6 @@ func TestSetSlowThreshold(t *testing.T) {
assert.Equal(t, time.Second, slowThreshold.Load())
}
func TestSetSSESlowThreshold(t *testing.T) {
assert.Equal(t, defaultSSESlowThreshold, sseSlowThreshold.Load())
SetSSESlowThreshold(time.Minute * 10)
assert.Equal(t, time.Minute*10, sseSlowThreshold.Load())
}
func TestWrapMethodWithColor(t *testing.T) {
// no tty
assert.Equal(t, http.MethodGet, wrapMethod(http.MethodGet))

View File

@@ -118,8 +118,6 @@ func DoGenProjectWithModule(apiFile, dir, moduleName, style string, withTest boo
if withTest {
logx.Must(genHandlersTest(dir, rootPkg, projectPkg, cfg, api))
logx.Must(genLogicTest(dir, rootPkg, projectPkg, cfg, api))
logx.Must(genServiceContextTest(dir, rootPkg, projectPkg, cfg, api))
logx.Must(genIntegrationTest(dir, rootPkg, projectPkg, cfg, api))
}
if err := backupAndSweep(apiFile); err != nil {

View File

@@ -1,42 +0,0 @@
package gogen
import (
_ "embed"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
"github.com/zeromicro/go-zero/tools/goctl/util/format"
)
//go:embed integration_test.tpl
var integrationTestTemplate string
func genIntegrationTest(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
serviceName := api.Service.Name
if len(serviceName) == 0 {
serviceName = "server"
}
filename, err := format.FileNamingFormat(cfg.NamingFormat, serviceName)
if err != nil {
return err
}
return genFile(fileGenConfig{
dir: dir,
subdir: "",
filename: filename + "_test.go",
templateName: "integrationTestTemplate",
category: category,
templateFile: integrationTestTemplateFile,
builtinTemplate: integrationTestTemplate,
data: map[string]any{
"projectPkg": projectPkg,
"serviceName": serviceName,
"version": version.BuildVersion,
"hasRoutes": len(api.Service.Routes()) > 0,
"routes": api.Service.Routes(),
},
})
}

View File

@@ -1,34 +0,0 @@
package gogen
import (
_ "embed"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/zeromicro/go-zero/tools/goctl/config"
"github.com/zeromicro/go-zero/tools/goctl/internal/version"
"github.com/zeromicro/go-zero/tools/goctl/util/format"
)
//go:embed svc_test.tpl
var svcTestTemplate string
func genServiceContextTest(dir, rootPkg, projectPkg string, cfg *config.Config, api *spec.ApiSpec) error {
filename, err := format.FileNamingFormat(cfg.NamingFormat, contextFilename)
if err != nil {
return err
}
return genFile(fileGenConfig{
dir: dir,
subdir: contextDir,
filename: filename + "_test.go",
templateName: "svcTestTemplate",
category: category,
templateFile: svcTestTemplateFile,
builtinTemplate: svcTestTemplate,
data: map[string]any{
"projectPkg": projectPkg,
"version": version.BuildVersion,
},
})
}

View File

@@ -1,120 +0,0 @@
// Code scaffolded by goctl. Safe to edit.
// goctl {{.version}}
package main
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"{{.projectPkg}}/internal/config"
"{{.projectPkg}}/internal/handler"
"{{.projectPkg}}/internal/svc"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zeromicro/go-zero/rest"
)
func TestMain(m *testing.M) {
// TODO: Add setup/teardown logic here if needed
m.Run()
}
func TestServerIntegration(t *testing.T) {
// Create test server
c := config.Config{
RestConf: rest.RestConf{
Host: "127.0.0.1",
Port: 0, // Use random available port
},
}
server := rest.MustNewServer(c.RestConf)
defer server.Stop()
ctx := svc.NewServiceContext(c)
handler.RegisterHandlers(server, ctx)
// Start server in background
go func() {
server.Start()
}()
// Wait for server to start
time.Sleep(100 * time.Millisecond)
tests := []struct {
name string
method string
path string
body string
expectedStatus int
setup func()
}{
{
name: "health check",
method: "GET",
path: "/health",
expectedStatus: http.StatusNotFound, // Adjust based on actual routes
setup: func() {},
},
{{if .hasRoutes}}{{range .routes}}{
name: "{{.Method}} {{.Path}}",
method: "{{.Method}}",
path: "{{.Path}}",
expectedStatus: http.StatusOK, // TODO: Adjust expected status
setup: func() {
// TODO: Add setup logic for this endpoint
},
},
{{end}}{{end}}{
name: "not found route",
method: "GET",
path: "/nonexistent",
expectedStatus: http.StatusNotFound,
setup: func() {},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
req, err := http.NewRequest(tt.method, tt.path, nil)
require.NoError(t, err)
rr := httptest.NewRecorder()
server.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatus, rr.Code)
// TODO: Add response body assertions
t.Logf("Response: %s", rr.Body.String())
})
}
}
func TestServerLifecycle(t *testing.T) {
c := config.Config{
RestConf: rest.RestConf{
Host: "127.0.0.1",
Port: 0,
},
}
server := rest.MustNewServer(c.RestConf)
// Test server can start and stop without errors
ctx := svc.NewServiceContext(c)
handler.RegisterHandlers(server, ctx)
// In a real integration test, you might start the server in a goroutine
// and test actual HTTP requests, but for scaffolding we keep it simple
server.Stop()
// TODO: Add more lifecycle tests as needed
assert.True(t, true, "Server lifecycle test passed")
}

View File

@@ -1,17 +0,0 @@
type Request {
Name string `path:"name,options=you|me"`
}
type Response {
Message string `json:"message"`
}
@server(
jwt: Auth
jwtTransition: Trans
middleware: TokenValidate
)
service A-api {
@handler GreetHandler
get /greet/from/:name(Request) returns (Response)
}

View File

@@ -1,60 +0,0 @@
// Code scaffolded by goctl. Safe to edit.
// goctl {{.version}}
package svc
import (
"testing"
"{{.projectPkg}}/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewServiceContext(t *testing.T) {
tests := []struct {
name string
config config.Config
setup func() config.Config
}{
{
name: "default config",
setup: func() config.Config {
return config.Config{}
},
},
{
name: "valid config",
setup: func() config.Config {
return config.Config{
// TODO: Add valid config values here
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := tt.setup()
svcCtx := NewServiceContext(c)
// Basic assertions
require.NotNil(t, svcCtx)
assert.Equal(t, c, svcCtx.Config)
// TODO: Add additional assertions for middleware and dependencies
})
}
}
func TestServiceContext_Initialization(t *testing.T) {
c := config.Config{}
svcCtx := NewServiceContext(c)
// Verify service context is properly initialized
assert.NotNil(t, svcCtx)
assert.Equal(t, c, svcCtx.Config)
// TODO: Add tests for middleware initialization if any
// TODO: Add tests for external dependencies if any
}

View File

@@ -22,8 +22,6 @@ const (
routesTemplateFile = "routes.tpl"
routesAdditionTemplateFile = "route-addition.tpl"
typesTemplateFile = "types.tpl"
svcTestTemplateFile = "svc_test.tpl"
integrationTestTemplateFile = "integration_test.tpl"
)
var templates = map[string]string{
@@ -41,8 +39,6 @@ var templates = map[string]string{
routesTemplateFile: routesTemplate,
routesAdditionTemplateFile: routesAdditionTemplate,
typesTemplateFile: typesTemplate,
svcTestTemplateFile: svcTestTemplate,
integrationTestTemplateFile: integrationTestTemplate,
}
// Category returns the category of the api files.

View File

@@ -70,40 +70,15 @@ func propertiesFromType(ctx Context, tp apiSpec.Type) (spec.SchemaProperties, []
switch sampleTypeFromGoType(ctx, member.Type) {
case swaggerTypeArray:
schema.Items = itemsFromGoType(ctx, member.Type)
// Special handling for arrays with useDefinitions
if ctx.UseDefinitions {
// For arrays, check if the array element (not the array itself) contains a struct
if arrayType, ok := member.Type.(apiSpec.ArrayType); ok {
if structName, containsStruct := containsStruct(arrayType.Value); containsStruct {
// Set the $ref inside the items, not at the schema level
schema.Items = &spec.SchemaOrArray{
Schema: &spec.Schema{
SchemaProps: spec.SchemaProps{
Ref: spec.MustCreateRef(getRefName(structName)),
},
},
}
}
}
}
case swaggerTypeObject:
p, r := propertiesFromType(ctx, member.Type)
schema.Properties = p
schema.Required = r
// For objects with useDefinitions, set $ref at schema level
if ctx.UseDefinitions {
structName, containsStruct := containsStruct(member.Type)
if containsStruct {
schema.SchemaProps.Ref = spec.MustCreateRef(getRefName(structName))
}
}
default:
// For non-array, non-object types, apply useDefinitions logic
if ctx.UseDefinitions {
structName, containsStruct := containsStruct(member.Type)
if containsStruct {
schema.SchemaProps.Ref = spec.MustCreateRef(getRefName(structName))
}
}
if ctx.UseDefinitions {
structName, containsStruct := containsStruct(member.Type)
if containsStruct {
schema.SchemaProps.Ref = spec.MustCreateRef(getRefName(structName))
}
}

View File

@@ -3,7 +3,6 @@ package swagger
import (
"testing"
"github.com/zeromicro/go-zero/tools/goctl/api/spec"
"github.com/stretchr/testify/assert"
)
@@ -24,117 +23,3 @@ func Test_pathVariable2SwaggerVariable(t *testing.T) {
assert.Equal(t, tc.expected, result)
}
}
func TestArrayDefinitionsBug(t *testing.T) {
// Test case for the bug where array of structs with useDefinitions
// generates incorrect swagger JSON structure
// Context with useDefinitions enabled
ctx := Context{
UseDefinitions: true,
}
// Create a test struct containing an array of structs
testStruct := spec.DefineStruct{
RawName: "TestStruct",
Members: []spec.Member{
{
Name: "ArrayField",
Type: spec.ArrayType{
Value: spec.DefineStruct{
RawName: "ItemStruct",
Members: []spec.Member{
{
Name: "ItemName",
Type: spec.PrimitiveType{RawName: "string"},
Tag: `json:"itemName"`,
},
},
},
},
Tag: `json:"arrayField"`,
},
},
}
// Get properties from the struct
properties, _ := propertiesFromType(ctx, testStruct)
// Check that we have the array field
assert.Contains(t, properties, "arrayField")
arrayField := properties["arrayField"]
// Verify the array field has correct structure
assert.Equal(t, "array", arrayField.Type[0])
// Check that we have items
assert.NotNil(t, arrayField.Items, "Array should have items defined")
assert.NotNil(t, arrayField.Items.Schema, "Array items should have schema")
// The FIX: $ref should be inside items, not at schema level
hasRef := arrayField.Ref.String() != ""
assert.False(t, hasRef, "Schema level should NOT have $ref")
// The $ref should be in the items
hasItemsRef := arrayField.Items.Schema.Ref.String() != ""
assert.True(t, hasItemsRef, "Items should have $ref")
assert.Equal(t, "#/definitions/ItemStruct", arrayField.Items.Schema.Ref.String())
// Verify there are no other properties in the items when using $ref
assert.Nil(t, arrayField.Items.Schema.Properties, "Items with $ref should not have properties")
assert.Empty(t, arrayField.Items.Schema.Required, "Items with $ref should not have required")
assert.Empty(t, arrayField.Items.Schema.Type, "Items with $ref should not have type")
}
func TestArrayWithoutDefinitions(t *testing.T) {
// Test that arrays work correctly when useDefinitions is false
ctx := Context{
UseDefinitions: false, // This is the default
}
// Create the same test struct
testStruct := spec.DefineStruct{
RawName: "TestStruct",
Members: []spec.Member{
{
Name: "ArrayField",
Type: spec.ArrayType{
Value: spec.DefineStruct{
RawName: "ItemStruct",
Members: []spec.Member{
{
Name: "ItemName",
Type: spec.PrimitiveType{RawName: "string"},
Tag: `json:"itemName"`,
},
},
},
},
Tag: `json:"arrayField"`,
},
},
}
properties, _ := propertiesFromType(ctx, testStruct)
assert.Contains(t, properties, "arrayField")
arrayField := properties["arrayField"]
// Should be array type
assert.Equal(t, "array", arrayField.Type[0])
// Should have items with full schema, no $ref
assert.NotNil(t, arrayField.Items)
assert.NotNil(t, arrayField.Items.Schema)
// Should NOT have $ref at schema level
assert.Empty(t, arrayField.Ref.String(), "Schema should not have $ref when useDefinitions is false")
// Should NOT have $ref in items either
assert.Empty(t, arrayField.Items.Schema.Ref.String(), "Items should not have $ref when useDefinitions is false")
// Should have full schema properties in items
assert.Equal(t, "object", arrayField.Items.Schema.Type[0])
assert.Contains(t, arrayField.Items.Schema.Properties, "itemName")
assert.Equal(t, []string{"itemName"}, arrayField.Items.Schema.Required)
}