feat: support embed file system to serve files in rest (#4253)

This commit is contained in:
Kevin Wan
2024-07-17 16:21:08 +08:00
committed by GitHub
parent a00c956776
commit 5dd6f2a43a
4 changed files with 33 additions and 10 deletions

View File

@@ -5,8 +5,9 @@ import (
"strings" "strings"
) )
func Middleware(path, dir string) func(http.HandlerFunc) http.HandlerFunc { // Middleware returns a middleware that serves files from the given file system.
fileServer := http.FileServer(http.Dir(dir)) func Middleware(path string, fs http.FileSystem) func(http.HandlerFunc) http.HandlerFunc {
fileServer := http.FileServer(fs)
pathWithTrailSlash := ensureTrailingSlash(path) pathWithTrailSlash := ensureTrailingSlash(path)
pathWithoutTrailSlash := ensureNoTrailingSlash(path) pathWithoutTrailSlash := ensureNoTrailingSlash(path)

View File

@@ -44,7 +44,7 @@ func TestMiddleware(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
middleware := Middleware(tt.path, tt.dir) middleware := Middleware(tt.path, http.Dir(tt.dir))
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)
}) })

View File

@@ -172,9 +172,9 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt
} }
// WithFileServer returns a RunOption to serve files from given dir with given path. // WithFileServer returns a RunOption to serve files from given dir with given path.
func WithFileServer(path, dir string) RunOption { func WithFileServer(path string, fs http.FileSystem) RunOption {
return func(server *Server) { return func(server *Server) {
server.router = newFileServingRouter(server.router, path, dir) server.router = newFileServingRouter(server.router, path, fs)
} }
} }
@@ -351,10 +351,10 @@ type fileServingRouter struct {
middleware Middleware middleware Middleware
} }
func newFileServingRouter(router httpx.Router, path, dir string) httpx.Router { func newFileServingRouter(router httpx.Router, path string, fs http.FileSystem) httpx.Router {
return &fileServingRouter{ return &fileServingRouter{
Router: router, Router: router,
middleware: fileserver.Middleware(path, dir), middleware: fileserver.Middleware(path, fs),
} }
} }

View File

@@ -2,8 +2,10 @@ package rest
import ( import (
"crypto/tls" "crypto/tls"
"embed"
"fmt" "fmt"
"io" "io"
"io/fs"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -21,6 +23,11 @@ import (
"github.com/zeromicro/go-zero/rest/router" "github.com/zeromicro/go-zero/rest/router"
) )
const (
exampleContent = "example content"
sampleContent = "sample content"
)
func TestNewServer(t *testing.T) { func TestNewServer(t *testing.T) {
logtest.Discard(t) logtest.Discard(t)
@@ -199,7 +206,7 @@ func TestWithFileServerMiddleware(t *testing.T) {
dir: "./testdata", dir: "./testdata",
requestPath: "/assets/example.txt", requestPath: "/assets/example.txt",
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedContent: "example content", expectedContent: exampleContent,
}, },
{ {
name: "Pass through non-matching path", name: "Pass through non-matching path",
@@ -214,13 +221,13 @@ func TestWithFileServerMiddleware(t *testing.T) {
dir: "testdata", dir: "testdata",
requestPath: "/static/sample.txt", requestPath: "/static/sample.txt",
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedContent: "sample content", expectedContent: sampleContent,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
server := MustNewServer(RestConf{}, WithFileServer(tt.path, tt.dir)) server := MustNewServer(RestConf{}, WithFileServer(tt.path, http.Dir(tt.dir)))
req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil) req := httptest.NewRequest(http.MethodGet, tt.requestPath, nil)
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
@@ -688,3 +695,18 @@ Port: 54321
}) })
} }
} }
//go:embed testdata
var content embed.FS
func TestServerEmbedFileSystem(t *testing.T) {
filesys, err := fs.Sub(content, "testdata")
assert.NoError(t, err)
server := MustNewServer(RestConf{}, WithFileServer("/assets", http.FS(filesys)))
req, err := http.NewRequest(http.MethodGet, "/assets/sample.txt", http.NoBody)
assert.Nil(t, err)
rr := httptest.NewRecorder()
server.ServeHTTP(rr, req)
assert.Equal(t, sampleContent, rr.Body.String())
}