feat: support file server in rest (#4244)

This commit is contained in:
Kevin Wan
2024-07-13 19:58:35 +08:00
committed by GitHub
parent e776b5d8ab
commit ec86f22cd6
8 changed files with 216 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
package fileserver
import (
"net/http"
"strings"
)
func Middleware(path, dir string) func(http.HandlerFunc) http.HandlerFunc {
fileServer := http.FileServer(http.Dir(dir))
pathWithTrailSlash := ensureTrailingSlash(path)
pathWithoutTrailSlash := ensureNoTrailingSlash(path)
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && strings.HasPrefix(r.URL.Path, pathWithTrailSlash) {
r.URL.Path = strings.TrimPrefix(r.URL.Path, pathWithoutTrailSlash)
fileServer.ServeHTTP(w, r)
} else {
next(w, r)
}
}
}
}
func ensureTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path
}
return path + "/"
}
func ensureNoTrailingSlash(path string) string {
if strings.HasSuffix(path, "/") {
return path[:len(path)-1]
}
return path
}

View File

@@ -0,0 +1,99 @@
package fileserver
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMiddleware(t *testing.T) {
tests := []struct {
name string
path string
dir string
requestPath string
expectedStatus int
expectedContent string
}{
{
name: "Serve static file",
path: "/static/",
dir: "./testdata",
requestPath: "/static/example.txt",
expectedStatus: http.StatusOK,
expectedContent: "1",
},
{
name: "Pass through non-matching path",
path: "/static/",
dir: "./testdata",
requestPath: "/other/path",
expectedStatus: http.StatusNotFound,
},
{
name: "Directory with trailing slash",
path: "/assets",
dir: "testdata",
requestPath: "/assets/sample.txt",
expectedStatus: http.StatusOK,
expectedContent: "2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := Middleware(tt.path, tt.dir)
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
handlerToTest := middleware(nextHandler)
req := httptest.NewRequest("GET", tt.requestPath, nil)
rr := httptest.NewRecorder()
handlerToTest.ServeHTTP(rr, req)
assert.Equal(t, tt.expectedStatus, rr.Code)
if len(tt.expectedContent) > 0 {
assert.Equal(t, tt.expectedContent, rr.Body.String())
}
})
}
}
func TestEnsureTrailingSlash(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"path", "path/"},
{"path/", "path/"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := ensureTrailingSlash(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func TestEnsureNoTrailingSlash(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"path", "path"},
{"path/", "path"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := ensureNoTrailingSlash(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -0,0 +1 @@
1

View File

@@ -0,0 +1 @@
2