commit 8265df0dcd007712dc2495447c5aaa6002f96755 Author: zhilv Date: Thu Jan 8 23:32:29 2026 +0800 feat: 初始提交 - Code Server Bridge完整实现 - OAuth认证系统(Gitea + Lua扩展) - Git自动化操作(本地/SSH远程) - 实时进度WebSocket推送 - 现代化Tab界面UI - Cobra CLI命令行(init/version/serve) - 完整构建系统(Makefile + Taskfile) - UPX压缩支持(体积减少70%) diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ff3984c --- /dev/null +++ b/.gitignore @@ -0,0 +1,31 @@ +# 配置文件 +config.yaml + +# 数据目录 +data/ + +# 构建产物 +bin/ +releases/ +*.exe +code-server-bridge +code-server-bridge-* + +# 测试覆盖率 +coverage.out +coverage.html + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# 操作系统 +.DS_Store +Thumbs.db + +# 临时文件 +*.log +tmp/ \ No newline at end of file diff --git a/BUILD.md b/BUILD.md new file mode 100644 index 0000000..1197cd8 --- /dev/null +++ b/BUILD.md @@ -0,0 +1,68 @@ +# Code Server Bridge + +## 构建项目 + +项目提供两种构建方式: + +### 使用 Make + +```bash +# 查看所有可用命令 +make help + +# 构建当前平台 +make build + +# 构建所有平台 +make build-all + +# 打包发布版本 +make package + +# 开发模式运行 +make dev + +# 清理构建产物 +make clean +``` + +### 使用 Task + +安装Task: https://taskfile.dev/installation/ + +```bash +# 查看所有任务 +task --list + +# 构建当前平台 +task build + +# 构建所有平台 +task build:all + +# 打包发布版本 +task package + +# 开发模式运行 +task dev + +# 清理构建产物 +task clean + +# 运行测试 +task test + +# 代码格式化 +task fmt +``` + +## 构建产物 + +- `code-server-bridge` - 当前平台可执行文件 +- `bin/` - 多平台构建输出 +- `releases/` - 打包的发布文件 + +## 支持平台 + +- Linux: amd64, arm64 +- Windows: amd64 diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..8294c5e --- /dev/null +++ b/Makefile @@ -0,0 +1,118 @@ +.PHONY: all build build-linux build-windows clean dev help package upx + +APP_NAME = code-server-bridge +ENTRY = ./cmd/server +OUTPUT_DIR = bin +VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +BUILD_TIME = $(shell date -u '+%Y-%m-%d_%H:%M:%S') +GIT_COMMIT = $(shell git rev-parse --short HEAD 2>/dev/null || echo "unknown") + +# Go build flags with version injection +LDFLAGS = -ldflags "-s -w \ + -X 'main.Version=$(VERSION)' \ + -X 'main.BuildTime=$(BUILD_TIME)' \ + -X 'main.GitCommit=$(GIT_COMMIT)'" + +# UPX settings +UPX = upx +UPX_FLAGS = --best --lzma + +# Platforms +LINUX_ARCHS = amd64 arm64 +WINDOWS_ARCHS = amd64 + +all: clean build-all upx + +# Build for current platform +build: + @echo "Building $(APP_NAME) for current platform..." + go build $(LDFLAGS) -o $(APP_NAME) $(ENTRY) + @echo "Build complete: $(APP_NAME)" + +# Build for Linux platforms +build-linux: + @echo "Building for Linux platforms..." + @mkdir -p $(OUTPUT_DIR) + @for arch in $(LINUX_ARCHS); do \ + echo "Building linux/$$arch..."; \ + GOOS=linux GOARCH=$$arch go build $(LDFLAGS) -o $(OUTPUT_DIR)/$(APP_NAME)-linux-$$arch $(ENTRY); \ + done + @echo "Linux builds complete" + +# Build for Windows platforms +build-windows: + @echo "Building for Windows platforms..." + @mkdir -p $(OUTPUT_DIR) + @for arch in $(WINDOWS_ARCHS); do \ + echo "Building windows/$$arch..."; \ + GOOS=windows GOARCH=$$arch go build $(LDFLAGS) -o $(OUTPUT_DIR)/$(APP_NAME)-windows-$$arch.exe $(ENTRY); \ + done + @echo "Windows builds complete" + +# Build for all platforms +build-all: build-linux build-windows + @echo "All platform builds complete" + +# Package release +package: build-all upx + @echo "Packaging releases..." + @mkdir -p releases + @for arch in $(LINUX_ARCHS); do \ + tar -czf releases/$(APP_NAME)-$(VERSION)-linux-$$arch.tar.gz -C $(OUTPUT_DIR) $(APP_NAME)-linux-$$arch; \ + done + @for arch in $(WINDOWS_ARCHS); do \ + cd $(OUTPUT_DIR) && zip -q ../releases/$(APP_NAME)-$(VERSION)-windows-$$arch.zip $(APP_NAME)-windows-$$arch.exe && cd ..; \ + done + @echo "Packaging complete: releases/" + +# Development mode +dev: + @echo "Running in development mode..." + go run $(ENTRY) + +# Clean build artifacts +clean: + @echo "Cleaning build artifacts..." + @rm -rf $(OUTPUT_DIR) releases $(APP_NAME) $(APP_NAME).exe + @echo "Clean complete" + +# Run tests +test: + @echo "Running tests..." + go test -v ./... + +# Install dependencies +deps: + @echo "Installing dependencies..." + go mod download + go mod tidy + +# UPX Compression +upx: + @echo "Compressing binaries with UPX..." + @if command -v $(UPX) > /dev/null 2>&1; then \ + for file in $(OUTPUT_DIR)/*; do \ + if [ -f "$$file" ]; then \ + echo "Compressing $$file..."; \ + $(UPX) $(UPX_FLAGS) "$$file" 2>/dev/null || echo "⚠️ UPX压缩失败: $$file"; \ + fi; \ + done; \ + echo "✅ UPX compression complete"; \ + else \ + echo "⚠️ UPX not found, skipping compression"; \ + echo "Install UPX: https://upx.github.io/"; \ + fi + +# Display help +help: + @echo "Available targets:" + @echo " make build - Build for current platform" + @echo " make build-linux - Build for Linux (amd64, arm64)" + @echo " make build-windows - Build for Windows (amd64)" + @echo " make build-all - Build for all platforms" + @echo " make upx - Compress binaries with UPX" + @echo " make package - Build, compress and package releases" + @echo " make dev - Run in development mode" + @echo " make test - Run tests" + @echo " make deps - Install dependencies" + @echo " make clean - Clean build artifacts" diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..7d636f9 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,102 @@ +version: '3' + +vars: + APP_NAME: code-server-bridge + ENTRY: ./cmd/server + OUTPUT_DIR: bin + VERSION: + sh: git describe --tags --always --dirty 2>/dev/null || echo "dev" + BUILD_TIME: + sh: date -u '+%Y-%m-%d_%H:%M:%S' + GIT_COMMIT: + sh: git rev-parse --short HEAD 2>/dev/null || echo "unknown" + +tasks: + default: + desc: Build for current platform + cmds: + - task: build + + build: + desc: Build for current platform + vars: + LDFLAGS: '-s -w -X "main.Version={{.VERSION}}" -X "main.BuildTime={{.BUILD_TIME}}" -X "main.GitCommit={{.GIT_COMMIT}}"' + cmds: + - echo "Building {{.APP_NAME}} v{{.VERSION}}..." + - go build -ldflags "{{.LDFLAGS}}" -o {{.APP_NAME}} {{.ENTRY}} + - echo "Build complete - {{.APP_NAME}}" + + build-linux: + desc: Build for Linux (amd64 arm64) + vars: + LDFLAGS: '-s -w -X "main.Version={{.VERSION}}" -X "main.BuildTime={{.BUILD_TIME}}" -X "main.GitCommit={{.GIT_COMMIT}}"' + cmds: + - mkdir -p {{.OUTPUT_DIR}} + - echo "Building Linux amd64..." + - cmd: GOOS=linux GOARCH=amd64 go build -ldflags "{{.LDFLAGS}}" -o {{.OUTPUT_DIR}}/{{.APP_NAME}}-linux-amd64 {{.ENTRY}} + - echo "Building Linux arm64..." + - cmd: GOOS=linux GOARCH=arm64 go build -ldflags "{{.LDFLAGS}}" -o {{.OUTPUT_DIR}}/{{.APP_NAME}}-linux-arm64 {{.ENTRY}} + - echo "Linux builds complete" + + build-windows: + desc: Build for Windows (amd64) + vars: + LDFLAGS: '-s -w -X "main.Version={{.VERSION}}" -X "main.BuildTime={{.BUILD_TIME}}" -X "main.GitCommit={{.GIT_COMMIT}}"' + cmds: + - mkdir -p {{.OUTPUT_DIR}} + - echo "Building Windows amd64..." + - cmd: GOOS=windows GOARCH=amd64 go build -ldflags "{{.LDFLAGS}}" -o {{.OUTPUT_DIR}}/{{.APP_NAME}}-windows-amd64.exe {{.ENTRY}} + - echo "Windows builds complete" + + build-all: + desc: Build for all platforms + deps: + - build-linux + - build-windows + cmds: + - echo "All platform builds complete" + + upx: + desc: Compress binaries with UPX + cmds: + - | + if command -v upx > /dev/null 2>&1; then + for file in {{.OUTPUT_DIR}}/*; do + if [ -f "$file" ]; then + echo "Compressing $file..." + upx --best --lzma "$file" 2>/dev/null || echo "UPX compression failed: $file" + fi + done + echo "UPX compression complete" + else + echo "UPX not found, skipping compression" + echo "Install UPX: https://upx.github.io/" + fi + + package: + desc: Build, compress and package releases + deps: + - build-all + - upx + cmds: + - mkdir -p releases + - tar -czf releases/{{.APP_NAME}}-{{.VERSION}}-linux-amd64.tar.gz -C {{.OUTPUT_DIR}} {{.APP_NAME}}-linux-amd64 + - tar -czf releases/{{.APP_NAME}}-{{.VERSION}}-linux-arm64.tar.gz -C {{.OUTPUT_DIR}} {{.APP_NAME}}-linux-arm64 + - cd {{.OUTPUT_DIR}} && zip -q ../releases/{{.APP_NAME}}-{{.VERSION}}-windows-amd64.zip {{.APP_NAME}}-windows-amd64.exe + - echo "Packaging complete - releases/" + + clean: + desc: Clean build artifacts + cmds: + - rm -rf {{.OUTPUT_DIR}} releases {{.APP_NAME}} {{.APP_NAME}}.exe + - echo "Clean complete" + + dev: + desc: Run in development mode + cmds: + - go run {{.ENTRY}} + + test: + desc: Run tests + cmds: + - go test -v ./... diff --git a/cmd/server/cmd/init.go b/cmd/server/cmd/init.go new file mode 100644 index 0000000..2720ffc --- /dev/null +++ b/cmd/server/cmd/init.go @@ -0,0 +1,103 @@ +package cmd + +import ( + "cs-bridge/internal/config" + "fmt" + "os" + "path/filepath" + + "github.com/go-yaml/yaml" + "github.com/spf13/cobra" +) + +var initCmd = &cobra.Command{ + Use: "init", + Short: "初始化配置文件", + Long: `生成默认的config.yaml配置文件和scripts目录`, + Run: func(cmd *cobra.Command, args []string) { + configPath, _ := cmd.Flags().GetString("output") + force, _ := cmd.Flags().GetBool("force") + + // 检查文件是否存在 + if _, err := os.Stat(configPath); err == nil && !force { + fmt.Printf("❌ 配置文件已存在: %s\n", configPath) + fmt.Println("使用 --force 强制覆盖") + return + } + + // 创建默认配置文件 + cfg := config.DefaultConfig() + defaultConfig, err := yaml.Marshal(cfg) + if err != nil { + fmt.Printf("❌ 生成配置失败: %v\n", err) + return + } + + // 写入配置文件 + if err := os.WriteFile(configPath, defaultConfig, 0644); err != nil { + fmt.Printf("❌ 创建配置文件失败: %v\n", err) + return + } + fmt.Printf("✅ 配置文件已创建: %s\n", configPath) + + // 创建scripts目录和示例Lua脚本 + scriptsDir := "scripts/oauth" + if err := os.MkdirAll(scriptsDir, 0755); err != nil { + fmt.Printf("⚠️ 创建scripts目录失败: %v\n", err) + } else { + // 创建示例Lua脚本 + luaScript := `-- Gitea OAuth Provider +-- 此脚本定义了Gitea OAuth认证流程 + +function auth_url(cfg, state) + local params = { + "client_id=" .. util.url_encode(cfg.client_id), + "redirect_uri=" .. util.url_encode(cfg.redirect_uri), + "response_type=code", + "scope=read:user", + "state=" .. util.url_encode(state) + } + + return cfg.base_url .. "/login/oauth/authorize?" .. table.concat(params, "&") +end + +function exchange(cfg, code) + local body = { + client_id = cfg.client_id, + client_secret = cfg.client_secret, + code = code, + grant_type = "authorization_code", + redirect_uri = cfg.redirect_uri + } + + return util.http_post_json(cfg.base_url .. "/login/oauth/access_token", body) +end + +function user_info(cfg, token) + local headers = { + Authorization = "Bearer " .. token + } + + return util.http_get_json(cfg.base_url .. "/api/v1/user", headers) +end +` + luaPath := filepath.Join(scriptsDir, "gitea.lua") + if err := os.WriteFile(luaPath, []byte(luaScript), 0644); err != nil { + fmt.Printf("⚠️ 创建Lua脚本失败: %v\n", err) + } else { + fmt.Printf("✅ 示例Lua脚本已创建: %s\n", luaPath) + } + } + + fmt.Println("\n📝 下一步:") + fmt.Printf("1. 编辑配置文件: %s\n", configPath) + fmt.Println("2. 配置OAuth客户端信息") + fmt.Println("3. 运行服务: code-server-bridge") + }, +} + +func init() { + initCmd.Flags().StringP("output", "o", "config.yaml", "配置文件输出路径") + initCmd.Flags().BoolP("force", "f", false, "强制覆盖已存在的文件") + rootCmd.AddCommand(initCmd) +} diff --git a/cmd/server/cmd/root.go b/cmd/server/cmd/root.go new file mode 100644 index 0000000..4895a96 --- /dev/null +++ b/cmd/server/cmd/root.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +) + +var ( + // 版本信息,构建时注入 + Version = "dev" + BuildTime = "unknown" + GitCommit = "unknown" +) + +var rootCmd = &cobra.Command{ + Use: "code-server-bridge", + Short: "Code Server Bridge - OAuth认证和Git仓库准备工具", + Long: `Code Server Bridge 是一个为code-server提供OAuth认证和Git仓库自动准备的桥接服务。 + +功能特性: + - OAuth认证 (支持Gitea等) + - Git仓库自动克隆/更新 + - 实时进度显示 + - Workspace自动管理 + +使用 'code-server-bridge serve' 启动服务器 +使用 'code-server-bridge init' 初始化配置文件`, +} + +func Execute() { + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func init() { + // 全局flags + rootCmd.PersistentFlags().StringP("config", "c", "config.yaml", "配置文件路径") +} diff --git a/cmd/server/cmd/serve.go b/cmd/server/cmd/serve.go new file mode 100644 index 0000000..d9dbe02 --- /dev/null +++ b/cmd/server/cmd/serve.go @@ -0,0 +1,80 @@ +package cmd + +import ( + "cs-bridge/internal/config" + "cs-bridge/internal/db" + "cs-bridge/internal/http" + "cs-bridge/internal/oauth" + "cs-bridge/pkg/httpclient" + "cs-bridge/pkg/logger" + "fmt" + "os" + "strings" + + "github.com/spf13/cobra" +) + +var serverCmd = &cobra.Command{ + Use: "serve", + Short: "启动服务器", + Long: `启动Code Server Bridge服务器`, + Run: func(cmd *cobra.Command, args []string) { + configPath, _ := rootCmd.PersistentFlags().GetString("config") + + // 加载配置 + os.Setenv("CONFIG_PATH", configPath) + cfg := config.NewConfig() + + logger.Init(cfg.Log) + db.InitDB(cfg.Database, logger.GetLogger()) + + httpclient.Init() + + mgr := oauth.NewManager() + for _, p := range cfg.OAuth.Providers { + redirect := cfg.OAuth.BaseURL + "/oauth/" + p.Name + "/callback" + + switch strings.ToLower(p.Type) { + case "gitea": + // 优先尝试使用Lua Provider + scriptPath := fmt.Sprintf("scripts/oauth/%s.lua", p.Name) + if _, err := os.Stat(scriptPath); err == nil { + // Lua脚本存在,使用Lua Provider + logger.GetLogger().Info(fmt.Sprintf("使用Lua脚本加载OAuth Provider: %s (%s)", p.Name, scriptPath)) + + luaProvider, err := oauth.NewLuaProvider( + oauth.ProviderModel{ + Name: p.Name, + Type: p.Type, + BaseURL: p.BaseURL, + ClientID: p.ClientID, + ClientSecret: p.ClientSecret, + AuthorizeURL: p.AuthorizeURL, + TokenURL: p.TokenURL, + UserURL: p.UserURL, + }, + redirect, + scriptPath, + ) + if err != nil { + logger.GetLogger().Error(fmt.Sprintf("加载Lua Provider失败: %v,回退到原生实现", err)) + mgr.Register(oauth.NewGitea(p, redirect)) + } else { + mgr.Register(luaProvider) + } + } else { + // Lua脚本不存在,使用原生Go实现 + logger.GetLogger().Info(fmt.Sprintf("使用原生Go实现加载OAuth Provider: %s", p.Name)) + mgr.Register(oauth.NewGitea(p, redirect)) + } + } + } + r := http.NewRouter(cfg, mgr) + + r.ListenAndServe() + }, +} + +func init() { + rootCmd.AddCommand(serverCmd) +} diff --git a/cmd/server/cmd/version.go b/cmd/server/cmd/version.go new file mode 100644 index 0000000..c66c0e6 --- /dev/null +++ b/cmd/server/cmd/version.go @@ -0,0 +1,23 @@ +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +var versionCmd = &cobra.Command{ + Use: "version", + Short: "显示版本信息", + Long: `显示程序的版本号、构建时间和Git提交信息`, + Run: func(cmd *cobra.Command, args []string) { + fmt.Printf("Code Server Bridge\n\n") + fmt.Printf("Version: %s\n", Version) + fmt.Printf("Build Time: %s\n", BuildTime) + fmt.Printf("Git Commit: %s\n", GitCommit) + }, +} + +func init() { + rootCmd.AddCommand(versionCmd) +} diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..af48e93 --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "cs-bridge/cmd/server/cmd" + "fmt" +) + +// 版本信息,编译时通过ldflags注入 +var ( + Version = "dev" + BuildTime = "unknown" + GitCommit = "unknown" +) + +func main() { + // 设置版本信息到cmd包 + cmd.Version = Version + cmd.BuildTime = BuildTime + cmd.GitCommit = GitCommit + + // 启动信息 + fmt.Printf("🚀 启动 code-server-bridge v%s by zhilv666\n", Version) + + cmd.Execute() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..76a5465 --- /dev/null +++ b/go.mod @@ -0,0 +1,49 @@ +module cs-bridge + +go 1.24.6 + +require ( + github.com/glebarez/sqlite v1.11.0 + github.com/go-chi/chi/v5 v5.2.3 + github.com/go-yaml/yaml v2.1.0+incompatible + github.com/google/uuid v1.3.0 + github.com/gorilla/sessions v1.4.0 + github.com/gorilla/websocket v1.5.3 + github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 + github.com/yuin/gopher-lua v1.1.1 + go.uber.org/zap v1.27.1 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 + gorm.io/gorm v1.31.1 + resty.dev/v3 v3.0.0-beta.6 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/glebarez/go-sqlite v1.21.2 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect + github.com/mattn/go-isatty v0.0.17 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + modernc.org/libc v1.22.5 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/sqlite v1.23.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..31207bb --- /dev/null +++ b/go.sum @@ -0,0 +1,112 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= +github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= +github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= +github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= +github.com/go-chi/chi/v5 v5.2.3 h1:WQIt9uxdsAbgIYgid+BpYc+liqQZGMHRaUwp0JUcvdE= +github.com/go-chi/chi/v5 v5.2.3/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/go-yaml/yaml v2.1.0+incompatible h1:RYi2hDdss1u4YE7GwixGzWwVo47T8UQwnTLB6vQiq+o= +github.com/go-yaml/yaml v2.1.0+incompatible/go.mod h1:w2MrLa16VYP0jy6N7M5kHaCkaLENm+P+Tv+MfurjSw0= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= +github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-isatty v0.0.17 h1:BTarxUcIeDqL27Mc+vyvdWYSL28zpIhv3RoTdsLMPng= +github.com/mattn/go-isatty v0.0.17/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +modernc.org/libc v1.22.5 h1:91BNch/e5B0uPbJFgqbxXuOnxBQjlS//icfQEGmvyjE= +modernc.org/libc v1.22.5/go.mod h1:jj+Z7dTNX8fBScMVNRAYZ/jF91K8fdT2hYMThc3YjBY= +modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= +modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/sqlite v1.23.1 h1:nrSBg4aRQQwq59JpvGEQ15tNxoO5pX/kUjcRNwSAGQM= +modernc.org/sqlite v1.23.1/go.mod h1:OrDj17Mggn6MhE+iPbBNf7RGKODDE9NFT0f3EwDzJqk= +resty.dev/v3 v3.0.0-beta.6 h1:ghRdNpoE8/wBCv+kTKIOauW1aCrSIeTq7GxtfYgtevU= +resty.dev/v3 v3.0.0-beta.6/go.mod h1:NTOerrC/4T7/FE6tXIZGIysXXBdgNqwMZuKtxpea9NM= diff --git a/internal/auth/identify.go b/internal/auth/identify.go new file mode 100644 index 0000000..ef293a8 --- /dev/null +++ b/internal/auth/identify.go @@ -0,0 +1,8 @@ +package auth + +type Identify struct { + Provider string `lua:"provider"` + UserId string `lua:"uid"` + Username string `lua:"username"` + Avatar string `lua:"avatar"` +} diff --git a/internal/auth/token.go b/internal/auth/token.go new file mode 100644 index 0000000..36a1b0b --- /dev/null +++ b/internal/auth/token.go @@ -0,0 +1,87 @@ +package auth + +import ( + "crypto/rand" + "cs-bridge/internal/db" + "encoding/hex" + "errors" + "fmt" + "time" + + "gorm.io/gorm" +) + +var ( + ErrTokenNotFound = errors.New("token not found") + ErrTokenExpired = errors.New("token expired") + ErrTokenUsed = errors.New("token already used") +) + +// GenerateWorkspaceToken 生成workspace访问token +// workspacePath: workspace的路径 +// ttl: token有效期 +// 返回生成的token字符串和可能的错误 +func GenerateWorkspaceToken(workspacePath string, ttl time.Duration) (string, error) { + // 生成随机token (32字节 = 64个十六进制字符) + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return "", err + } + token := hex.EncodeToString(tokenBytes) + + // 计算过期时间 + expiresAt := time.Now().Add(time.Second * ttl) + + fmt.Println("GenerateWorkspaceToken: expiresAt: ", expiresAt) + + // 保存到数据库 + workspaceToken := db.WorkspaceToken{ + Token: token, + WorkspacePath: workspacePath, + ExpiresAt: expiresAt, + Used: false, + } + + database := db.GetDB() + if err := database.Create(&workspaceToken).Error; err != nil { + return "", err + } + + return token, nil +} + +// ValidateToken 验证token (可多次使用,直到过期) +// token: 要验证的token字符串 +// 返回workspace路径和可能的错误 +func ValidateAndConsumeToken(token string) (string, error) { + database := db.GetDB() + + var workspaceToken db.WorkspaceToken + + // 查找token + if err := database.Where("token = ?", token).First(&workspaceToken).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return "", ErrTokenNotFound + } + return "", err + } + + // 检查是否过期 + if time.Now().After(workspaceToken.ExpiresAt) { + return "", ErrTokenExpired + } + + // Token有效,可以多次使用直到过期 + return workspaceToken.WorkspacePath, nil +} + +// CleanExpiredTokens 清理过期的token +// 删除所有已过期的token记录 +func CleanExpiredTokens() error { + database := db.GetDB() + + // 删除所有过期的token + result := database.Where("expires_at < ?", time.Now()).Delete(&db.WorkspaceToken{}) + + return result.Error +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..1c36690 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,129 @@ +package config + +import ( + "cs-bridge/internal/consts" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "sync" + + "github.com/go-yaml/yaml" + "github.com/spf13/viper" +) + +var ( + globalViper *viper.Viper + globalViperOnce sync.Once + configMu sync.RWMutex + cachedConfig *Config +) + +// NewConfig 初始化配置 +func NewConfig() *Config { + globalViperOnce.Do(func() { + globalViper = viper.New() + }) + + configPath := loadConfigFile() + if configPath == "" { + log.Println("⚠️ 未找到配置文件,使用默认配置并导出为 config.yaml") + cachedConfig = DefaultConfig() + exportConfig(cachedConfig, "yaml") + return cachedConfig + } + + conf, err := parseConfig() + if err != nil { + log.Printf("⚠️ 解析配置失败,使用默认配置并导出为 config.yaml: %v", err) + conf = DefaultConfig() + exportConfig(conf, "yaml") + } + + cachedConfig = conf + return conf +} + +// loadConfigFile 加载配置文件 +func loadConfigFile() string { + files := []string{ + fmt.Sprintf("%s.yaml", consts.ConfigName), + fmt.Sprintf("%s.yml", consts.ConfigName), + } + + for _, file := range files { + if _, err := os.Stat(file); err == nil { + ext := filepath.Ext(file)[1:] + globalViper.SetConfigFile(file) + globalViper.SetConfigType(ext) + globalViper.AddConfigPath(".") + + if err := globalViper.ReadInConfig(); err != nil { + log.Printf("⚠️ 读取配置文件 %s 失败: %v", file, err) + continue + } + + log.Printf("✅ 成功加载配置文件: %s", file) + return file + } + } + return "" +} + +// parseConfig 将配置解析为结构体 +func parseConfig() (*Config, error) { + configMu.Lock() + defer configMu.Unlock() + + var conf Config + if err := globalViper.Unmarshal(&conf); err != nil { + return nil, err + } + return &conf, nil +} + +// exportConfig 导出配置文件(默认保存为 config.yaml) +func exportConfig(conf *Config, format string) { + var ( + data []byte + err error + ) + + switch format { + case "json": + data, err = json.MarshalIndent(conf, "", " ") + case "yaml", "yml", "": + format = "yaml" + data, err = yaml.Marshal(conf) + default: + log.Printf("⚠️ 不支持的导出格式: %s", format) + return + } + + if err != nil { + log.Printf("⚠️ 导出 %s 格式配置失败: %v", format, err) + return + } + + filename := fmt.Sprintf("%s.%s", consts.ConfigName, format) + if err := os.WriteFile(filename, data, 0644); err != nil { + log.Printf("⚠️ 写入配置文件 %s 失败: %v", filename, err) + } else { + log.Printf("✅ 已导出配置文件: %s", filename) + } + + log.Printf("📋 当前 %s 配置:\n%s", format, data) +} + +// GetConfig 获取当前配置(线程安全) +func GetConfig() *Config { + configMu.RLock() + defer configMu.RUnlock() + + if cachedConfig == nil { + return DefaultConfig() + } + conf := *cachedConfig + return &conf +} diff --git a/internal/config/type.go b/internal/config/type.go new file mode 100644 index 0000000..2be42f6 --- /dev/null +++ b/internal/config/type.go @@ -0,0 +1,116 @@ +package config + +import ( + "cs-bridge/internal/consts" + "path/filepath" + "time" +) + +type Server struct { + Debug bool `mapstructure:"debug"` + Port int `mapstructure:"port"` +} + +type Log struct { + Level string `mapstructure:"level" yaml:"level"` + Filepath string `mapstructure:"filepath" yaml:"filepath"` + MaxSizeMB int `mapstructure:"max_size_mb" yaml:"max_size_mb"` // 单个日志文件最大(MB) + MaxAgeDay int `mapstructure:"max_age_day" yaml:"max_age_day"` // 日志文件最大保存天数 + Backups int `mapstructure:"backups" yaml:"backups"` // 保留的旧文件个数 + Compress bool `mapstructure:"compress" yaml:"compress"` // 是否压缩 +} + +type Database struct { + SqliteDbPath string `mapstructure:"sqlite_db_path" yaml:"sqlite_db_path"` +} + +type Secret struct { + TokenTTL time.Duration `mapstructure:"token_ttl" yaml:"token_ttl"` + WSTTL time.Duration `mapstructure:"ws_ttl" yaml:"ws_ttl"` +} + +type Provider struct { + Name string `mapstructure:"name" yaml:"name"` + Type string `mapstructure:"type" yaml:"type"` + ClientID string `mapstructure:"client_id" yaml:"client_id"` + ClientSecret string `mapstructure:"client_secret" yaml:"client_secret"` + BaseURL string `mapstructure:"base_url" yaml:"base_url"` + AuthorizeURL string `mapstructure:"authorize_url" yaml:"authorize_url"` + TokenURL string `mapstructure:"token_url" yaml:"token_url"` + UserURL string `mapstructure:"user_url" yaml:"user_url"` +} + +type OAuth struct { + BaseURL string `mapstructure:"base_url" yaml:"base_url"` + Providers []Provider `mapstructure:"providers" yaml:"providers"` +} + +type CodeServer struct { + BaseURL string `mapstructure:"base_url" yaml:"base_url"` + WorkspaceRoot string `mapstructure:"workspace_root" yaml:"workspace_root"` // 容器内路径,传递给code-server + // SSH配置 - 用于远程执行git操作 + SSHHost string `mapstructure:"ssh_host" yaml:"ssh_host"` // SSH服务器地址 + SSHPort int `mapstructure:"ssh_port" yaml:"ssh_port"` // SSH端口,默认22 + SSHUser string `mapstructure:"ssh_user" yaml:"ssh_user"` // SSH用户名(专用账号) + SSHKeyPath string `mapstructure:"ssh_key_path" yaml:"ssh_key_path"` // SSH私钥路径 + SSHWorkspaceRoot string `mapstructure:"ssh_workspace_root" yaml:"ssh_workspace_root"` // SSH服务器上的workspace路径(容器映射到宿主机的路径) +} + +type Config struct { + Server Server `mapstructure:"server" yaml:"server"` + Log Log `mapstructure:"log" yaml:"log"` + Database Database `mapstructure:"database" yaml:"database"` + Secret Secret `mapstructure:"secret" yaml:"secret"` + OAuth OAuth `mapstructure:"oauth" yaml:"oauth"` + CodeServer CodeServer `mapstructure:"code_server" yaml:"code_server"` +} + +func DefaultConfig() *Config { + logPath := filepath.Join(consts.DataDir, "log.log") + dbPath := filepath.Join(consts.DataDir, "sqlite.db") + return &Config{ + Server: Server{ + Debug: true, + Port: 8080, + }, + Log: Log{ + Level: "debug", + Filepath: logPath, + MaxSizeMB: 10, + MaxAgeDay: 7, + Backups: 3, + Compress: true, + }, + Database: Database{ + SqliteDbPath: dbPath, + }, + Secret: Secret{ + TokenTTL: time.Millisecond * 600, + WSTTL: time.Millisecond * 3600, + }, + OAuth: OAuth{ + BaseURL: "http://localhost:8080", + Providers: []Provider{ + { + Name: "gitea", + Type: "gitea", + ClientID: "xxx", + ClientSecret: "xxx", + BaseURL: "https://xxx", + AuthorizeURL: "/login/oauth/authorize", + TokenURL: "/login/oauth/access_token", + UserURL: "/api/v1/user", + }, + }, + }, + CodeServer: CodeServer{ + BaseURL: "xxx", + WorkspaceRoot: "/config/workspace", // 容器内路径 + SSHHost: "", + SSHPort: 22, + SSHUser: "git", + SSHKeyPath: "", + SSHWorkspaceRoot: "/root/code-server/data/workspace", // SSH服务器实际路径 + }, + } +} diff --git a/internal/consts/const.go b/internal/consts/const.go new file mode 100644 index 0000000..83690fb --- /dev/null +++ b/internal/consts/const.go @@ -0,0 +1,41 @@ +package consts + +import ( + "fmt" +) + +// 通用系统级常量 +const ( + AppName = "cs-bridge" + AppAuthor = "zhilv666" + AppVersion = "0.0.1" +) + +var ( + BuildTime string + GitCommit string +) + +const DataDir = "data" + +// 默认服务端口 +const DefaultPort = 8080 + +// 通用时间格式 +const ( + TimeFormatDate = "2006-01-02" + TimeFormatDateTime = "2006-01-02 15:04:05" +) + +func init() { + fmt.Printf("🚀 启动 %s v%s by %s\n", + AppName, + AppVersion, + AppAuthor, + ) +} + +// 项目路径与文件定义 +var ( + ConfigName = "config" +) diff --git a/internal/db/connect.go b/internal/db/connect.go new file mode 100644 index 0000000..702b460 --- /dev/null +++ b/internal/db/connect.go @@ -0,0 +1,10 @@ +package db + +import ( + "github.com/glebarez/sqlite" + "gorm.io/gorm" +) + +func connectSQLite(filepath string) gorm.Dialector { + return sqlite.Open(filepath + "?cache=shared&_fk=1&_driver=modernc.org/sqlite") +} diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 0000000..4bdf8e7 --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,50 @@ +package db + +import ( + "cs-bridge/internal/config" + "fmt" + "os" + + "go.uber.org/zap" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +var db *gorm.DB + +func init() { + if _, err := os.Stat("./data"); os.IsNotExist(err) { + _ = os.Mkdir("./data", os.ModePerm) + } +} + +func GetDB() *gorm.DB { + if db == nil { + panic("数据库未初始化,请先调用 db.InitDB()") + } + return db +} + +func InitDB(cfg config.Database, log *zap.Logger) { + go func() { + if err := recover(); err != nil { + panic(fmt.Sprintf("Init DB Error: %v", err)) + } + }() + + fmt.Println(cfg.SqliteDbPath) + dialector := connectSQLite(cfg.SqliteDbPath) + + var err error + db, err = gorm.Open(dialector, &gorm.Config{ + Logger: NewZapGormLogger(log).LogMode(logger.Info), + }) + if err != nil { + panic(err) + } + + // 自动迁移数据库模型 + if err := db.AutoMigrate(&WorkspaceToken{}); err != nil { + panic(fmt.Sprintf("Failed to migrate WorkspaceToken model: %v", err)) + } +} diff --git a/internal/db/gorm_logger.go b/internal/db/gorm_logger.go new file mode 100644 index 0000000..365d9ce --- /dev/null +++ b/internal/db/gorm_logger.go @@ -0,0 +1,81 @@ +package db + +import ( + "context" + "time" + + "go.uber.org/zap" + "gorm.io/gorm/logger" + "gorm.io/gorm/utils" +) + +type ZapGormLogger struct { + ZapLogger *zap.Logger + LogLevel logger.LogLevel + SlowThreshold time.Duration +} + +func NewZapGormLogger(zapLogger *zap.Logger) *ZapGormLogger { + return &ZapGormLogger{ + ZapLogger: zapLogger, + LogLevel: logger.Info, + SlowThreshold: time.Second, // 1s 慢查询阈值 + } +} + +func (l *ZapGormLogger) LogMode(level logger.LogLevel) logger.Interface { + newlogger := *l + newlogger.LogLevel = level + return &newlogger +} + +func (l *ZapGormLogger) Info(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= logger.Info { + l.ZapLogger.Sugar().Infof(msg, data...) + } +} + +func (l *ZapGormLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= logger.Warn { + l.ZapLogger.Sugar().Warnf(msg, data...) + } +} + +func (l *ZapGormLogger) Error(ctx context.Context, msg string, data ...interface{}) { + if l.LogLevel >= logger.Error { + l.ZapLogger.Sugar().Errorf(msg, data...) + } +} + +func (l *ZapGormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + if l.LogLevel <= logger.Silent { + return + } + elapsed := time.Since(begin) + sql, rows := fc() + + switch { + case err != nil && l.LogLevel >= logger.Error: + l.ZapLogger.Error("SQL Error", + zap.Error(err), + zap.String("sql", sql), + zap.Int64("rows", rows), + zap.Duration("elapsed", elapsed), + zap.String("file", utils.FileWithLineNum()), + ) + case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= logger.Warn: + l.ZapLogger.Warn("Slow SQL", + zap.Duration("elapsed", elapsed), + zap.String("sql", sql), + zap.Int64("rows", rows), + zap.String("file", utils.FileWithLineNum()), + ) + case l.LogLevel >= logger.Info: + l.ZapLogger.Info("SQL", + zap.String("sql", sql), + zap.Int64("rows", rows), + zap.Duration("elapsed", elapsed), + zap.String("file", utils.FileWithLineNum()), + ) + } +} diff --git a/internal/db/workspace_token.go b/internal/db/workspace_token.go new file mode 100644 index 0000000..723b8da --- /dev/null +++ b/internal/db/workspace_token.go @@ -0,0 +1,16 @@ +package db + +import ( + "time" +) + +// WorkspaceToken 工作区访问令牌模型 +// 用于存储一次性/限时的workspace访问token +type WorkspaceToken struct { + ID uint `gorm:"primarykey"` + Token string `gorm:"uniqueIndex;not null"` // Token哈希值 + WorkspacePath string `gorm:"not null"` // 工作区路径 + CreatedAt time.Time // 创建时间 + ExpiresAt time.Time `gorm:"index"` // 过期时间 + Used bool `gorm:"default:false;index"` // 是否已使用 +} diff --git a/internal/http/handlers/auth_cs.go b/internal/http/handlers/auth_cs.go new file mode 100644 index 0000000..3a9b193 --- /dev/null +++ b/internal/http/handlers/auth_cs.go @@ -0,0 +1,85 @@ +package handlers + +import ( + "cs-bridge/internal/auth" + "cs-bridge/pkg/logger" + "encoding/json" + "fmt" + "net/http" + + "go.uber.org/zap" +) + +// ValidateTokenResponse Token验证响应结构 +type ValidateTokenResponse struct { + Success bool `json:"success"` + Workspace string `json:"workspace,omitempty"` + Error string `json:"error,omitempty"` +} + +// ValidateWorkspaceToken Token验证处理器(供nginx调用) +// nginx通过此接口验证token并获取workspace路径 +func ValidateWorkspaceToken() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + log := logger.GetLogger() + + // 获取token参数 - 优先从query获取,否则从header获取(nginx auth_request场景) + token := r.URL.Query().Get("token") + if token == "" { + token = r.Header.Get("X-Auth-Token") + } + + log.Info("[Auth] 收到验证请求", + zap.String("url", r.URL.String()), + zap.Bool("has_token", token != "")) + + if token == "" { + // 无token时返回200,允许请求通过(不设置X-Workspace) + // 这样静态资源等请求可以正常访问 + log.Debug("[Auth] 无token, 允许通过") + w.WriteHeader(http.StatusOK) + return + } + + // 验证token (可多次使用,直到过期) + log.Info(fmt.Sprintf("[Auth] 开始验证token: %s...", token[:min(16, len(token))])) + workspacePath, err := auth.ValidateAndConsumeToken(token) + if err != nil { + log.Error(fmt.Sprintf("[Auth] Token验证失败: %v", err)) + w.Header().Set("Content-Type", "application/json") + + // 根据错误类型返回不同的状态码 + switch err { + case auth.ErrTokenNotFound: + w.WriteHeader(http.StatusNotFound) + case auth.ErrTokenExpired: + w.WriteHeader(http.StatusGone) + default: + w.WriteHeader(http.StatusInternalServerError) + } + + json.NewEncoder(w).Encode(ValidateTokenResponse{ + Success: false, + Error: err.Error(), + }) + return + } + + log.Info(fmt.Sprintf("[Auth] Token验证成功, Workspace: %s", workspacePath)) + + // 返回成功响应 + w.Header().Set("X-Workspace", workspacePath) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ValidateTokenResponse{ + Success: true, + Workspace: workspacePath, + }) + } +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/internal/http/handlers/entry.go b/internal/http/handlers/entry.go new file mode 100644 index 0000000..63c08a5 --- /dev/null +++ b/internal/http/handlers/entry.go @@ -0,0 +1,25 @@ +package handlers + +import ( + "cs-bridge/internal/config" + "net/http" + "net/url" + "path" + + "github.com/google/uuid" +) + +func Entry(cfg *config.CodeServer) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + workspaceID := uuid.NewString() + workspacePath := path.Join(cfg.WorkspaceRoot, workspaceID) + + redirectURL, _ := url.Parse(cfg.BaseURL) + + q := redirectURL.Query() + q.Set("folder", workspacePath) + redirectURL.RawQuery = q.Encode() + + http.Redirect(w, r, redirectURL.String(), http.StatusFound) + } +} diff --git a/internal/http/handlers/oauth.go b/internal/http/handlers/oauth.go new file mode 100644 index 0000000..f5d674c --- /dev/null +++ b/internal/http/handlers/oauth.go @@ -0,0 +1,84 @@ +package handlers + +import ( + "cs-bridge/internal/http/middleware" + "cs-bridge/internal/oauth" + "net/http" + + "github.com/go-chi/chi/v5" +) + +func OauthLogin(mgr *oauth.Manager) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "provider") + p, err := mgr.Get(name) + if err != nil { + http.Error(w, err.Error(), 404) + return + } + + state := oauth.NewState() + session, err := middleware.GetSession(r) + session.Values["oauth_state"] = state + session.Values["oauth_provider"] = name + session.Save(r, w) + + redirectURL, _ := p.AuthURL(state) + + http.Redirect(w, r, redirectURL, http.StatusFound) + } +} + +func OauthCallBack(mgr *oauth.Manager) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "provider") + p, err := mgr.Get(name) + if err != nil { + http.Error(w, err.Error(), 404) + return + } + + session, _ := middleware.GetSession(r) + expectedState, ok := session.Values["oauth_state"].(string) + if !ok { + http.Error(w, "missing oauth state", 400) + return + } + + goState := r.URL.Query().Get("state") + if goState != expectedState { + http.Error(w, "invaild oauth state2", 400) + return + } + + delete(session.Values, "oauth_state") + + code := r.URL.Query().Get("code") + token, err := p.Exchange(code) + if err != nil { + http.Error(w, err.Error(), 404) + return + } + + userInfo, err := p.UserInfo(token) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + + // 只保存用户ID,避免session过大 + session.Values["uid"] = userInfo.UserId + session.Values["username"] = userInfo.Username + + // 获取登录前保存的URL + redirectURL := "/" + if savedURL, ok := session.Values["redirect_after_login"].(string); ok && savedURL != "" { + redirectURL = savedURL + delete(session.Values, "redirect_after_login") // 使用后删除 + } + + session.Save(r, w) + + http.Redirect(w, r, redirectURL, http.StatusFound) + } +} diff --git a/internal/http/handlers/prepare.go b/internal/http/handlers/prepare.go new file mode 100644 index 0000000..d35b008 --- /dev/null +++ b/internal/http/handlers/prepare.go @@ -0,0 +1,630 @@ +package handlers + +import ( + "crypto/sha1" + "cs-bridge/internal/config" + gitpkg "cs-bridge/pkg/git" + "cs-bridge/pkg/logger" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "path" + "strings" +) + +// GitProgressPage 显示git进度页面 +func GitProgressPage() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // 获取repo参数 + repo := r.URL.Query().Get("repo") + if repo == "" { + http.Error(w, "Missing repo parameter", http.StatusBadRequest) + return + } + + // 返回HTML页面 + w.Header().Set("Content-Type", "text/html; charset=utf-8") + logger.GetLogger().Info(fmt.Sprintf("[GitProgressPage] Repo参数: %s", repo)) + // 使用字符串替换避免fmt格式化问题 + html := strings.Replace(gitProgressHTML, "{{REPO_URL}}", repo, 1) + w.Write([]byte(html)) + } +} + +// ExecuteGitRequest 执行git操作的请求结构 +type ExecuteGitRequest struct { + RepoURL string `json:"repo_url"` +} + +// ExecuteGitResponse 执行git操作的响应结构 +type ExecuteGitResponse struct { + Success bool `json:"success"` + Message string `json:"message"` + RedirectURL string `json:"redirect_url,omitempty"` + Error string `json:"error,omitempty"` +} + +// ExecuteGitOperations 执行git操作(克隆或更新) +func ExecuteGitOperations(cfg *config.Config) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // 解析请求 + var req ExecuteGitRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + SendProgressJSON(w, "error", "Invalid request", 0) + return + } + + if req.RepoURL == "" { + SendProgressJSON(w, "error", "Missing repo_url", 0) + return + } + + log := logger.GetLogger() + log.Info(fmt.Sprintf("[ExecuteGit] 开始执行Git操作, RepoURL: %s, ClientIP: %s", req.RepoURL, r.RemoteAddr)) + + // 生成workspace路径 - 使用客户端IP和仓库URL的SHA1哈希作为workspaceID + repoName := gitpkg.GetRepoName(req.RepoURL) + + // 基于客户端IP和仓库URL生成稳定的workspaceID + // hashInput := r.RemoteAddr + ":" + req.RepoURL + hashInput := req.RepoURL + hasher := sha1.New() + hasher.Write([]byte(hashInput)) + workspaceID := hex.EncodeToString(hasher.Sum(nil))[:16] // 取前16位 + + log.Info(fmt.Sprintf("[ExecuteGit] 生成workspaceID: %s (基于: %s)", workspaceID, hashInput)) + + // 容器内路径(传递给code-server) - 使用path.Join确保Linux风格路径 + containerWorkspacePath := path.Join(cfg.CodeServer.WorkspaceRoot, workspaceID, repoName) + + // SSH服务器上的实际路径(用于git操作) - 使用path.Join确保Linux风格路径 + sshWorkspacePath := containerWorkspacePath + if cfg.CodeServer.SSHWorkspaceRoot != "" { + sshWorkspacePath = path.Join(cfg.CodeServer.SSHWorkspaceRoot, workspaceID, repoName) + } + + log.Info(fmt.Sprintf("[ExecuteGit] 路径配置 - Container: %s, SSH: %s", containerWorkspacePath, sshWorkspacePath)) + + // 进度回调函数 + progressCallback := func(message string, percent int) { + SendProgressToClients("processing", message, percent) + } + + // 检查是否使用SSH远程执行 + useSSH := cfg.CodeServer.SSHHost != "" + var sshCfg gitpkg.SSHConfig + if useSSH { + sshCfg = gitpkg.SSHConfig{ + Host: cfg.CodeServer.SSHHost, + Port: cfg.CodeServer.SSHPort, + User: cfg.CodeServer.SSHUser, + KeyPath: cfg.CodeServer.SSHKeyPath, + } + } + + // 检查仓库是否存在 + SendProgressToClients("checking", "正在检查仓库...", 5) + + var err error + var repoExists bool + + if useSSH { + repoExists = gitpkg.CheckRepoExistsRemote(sshCfg, sshWorkspacePath) + } else { + repoExists = gitpkg.CheckRepoExists(containerWorkspacePath) + } + + if repoExists { + // 仓库存在,执行pull + SendProgressToClients("pulling", "仓库已存在,正在更新...", 20) + if useSSH { + err = gitpkg.PullRepoRemote(sshCfg, sshWorkspacePath, progressCallback) + } else { + err = gitpkg.PullRepo(containerWorkspacePath, progressCallback) + } + } else { + // 仓库不存在,执行clone + SendProgressToClients("cloning", "仓库不存在,正在克隆...", 20) + if useSSH { + err = gitpkg.CloneRepoRemote(sshCfg, req.RepoURL, sshWorkspacePath, progressCallback) + } else { + err = gitpkg.CloneRepo(req.RepoURL, containerWorkspacePath, progressCallback) + } + } + + if err != nil { + SendProgressToClients("error", fmt.Sprintf("Git操作失败: %v", err), 0) + json.NewEncoder(w).Encode(ExecuteGitResponse{ + Success: false, + Error: err.Error(), + }) + return + } + + // Git操作成功,发送100%进度 + SendProgressToClients("success", "操作完成,即将跳转...", 100) + + // 直接重定向到code-server,带folder参数 + redirectURL := fmt.Sprintf("%s?folder=%s", cfg.CodeServer.BaseURL, containerWorkspacePath) + + log.Info(fmt.Sprintf("[ExecuteGit] 准备重定向到: %s", redirectURL)) + + // 返回成功响应 + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ExecuteGitResponse{ + Success: true, + Message: "Git操作完成", + RedirectURL: redirectURL, + }) + } +} + +// gitProgressHTML 是git进度页面的HTML模板 +const gitProgressHTML = ` + + + + + 准备工作区 - Loading... + + + +
+

🚀 准备工作区

+ + +
+ + +
+ + +
+
+
+
+
+ + 正在连接... +
+
+
等待服务器响应...
+
+
+ + +
+
+
+
工作区准备完成!
+
即将跳转到 Code Server...
+
3
+ +
+
+
+ + + +` diff --git a/internal/http/handlers/ws.go b/internal/http/handlers/ws.go new file mode 100644 index 0000000..2cb6496 --- /dev/null +++ b/internal/http/handlers/ws.go @@ -0,0 +1,149 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + // 允许所有来源(生产环境应该做更严格的检查) + return true + }, +} + +// ProgressMessage WebSocket进度消息结构 +type ProgressMessage struct { + Status string `json:"status"` // 状态: checking, cloning, pulling, success, error + Message string `json:"message"` // 消息文本 + Percent int `json:"percent"` // 进度百分比 (0-100) +} + +// ProgressHub 进度广播中心 +type ProgressHub struct { + clients map[string]*websocket.Conn + clientsMux sync.RWMutex + broadcast chan ProgressMessage +} + +var ( + hub *ProgressHub + hubOnce sync.Once +) + +// GetProgressHub 获取进度广播中心单例 +func GetProgressHub() *ProgressHub { + hubOnce.Do(func() { + hub = &ProgressHub{ + clients: make(map[string]*websocket.Conn), + broadcast: make(chan ProgressMessage, 10), + } + go hub.run() + }) + return hub +} + +// run 运行广播循环 +func (h *ProgressHub) run() { + for msg := range h.broadcast { + h.clientsMux.RLock() + for clientID, conn := range h.clients { + err := conn.WriteJSON(msg) + if err != nil { + // 写入失败,移除客户端 + conn.Close() + h.clientsMux.RUnlock() + h.removeClient(clientID) + h.clientsMux.RLock() + } + } + h.clientsMux.RUnlock() + } +} + +// addClient 添加客户端 +func (h *ProgressHub) addClient(clientID string, conn *websocket.Conn) { + h.clientsMux.Lock() + defer h.clientsMux.Unlock() + h.clients[clientID] = conn +} + +// removeClient 移除客户端 +func (h *ProgressHub) removeClient(clientID string) { + h.clientsMux.Lock() + defer h.clientsMux.Unlock() + delete(h.clients, clientID) +} + +// SendProgress 发送进度消息 +func (h *ProgressHub) SendProgress(msg ProgressMessage) { + select { + case h.broadcast <- msg: + case <-time.After(1 * time.Second): + // 超时,丢弃消息 + } +} + +// GitProgressWS WebSocket处理器 +func GitProgressWS() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // 升级HTTP连接到WebSocket + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + http.Error(w, "Failed to upgrade to WebSocket", http.StatusInternalServerError) + return + } + defer conn.Close() + + // 生成客户端ID + clientID := r.RemoteAddr + "-" + time.Now().Format("20060102150405") + + // 添加到hub + hub := GetProgressHub() + hub.addClient(clientID, conn) + defer hub.removeClient(clientID) + + // 发送欢迎消息 + welcomeMsg := ProgressMessage{ + Status: "connected", + Message: "WebSocket连接已建立", + Percent: 0, + } + if err := conn.WriteJSON(welcomeMsg); err != nil { + return + } + + // 保持连接,读取客户端消息(主要用于心跳) + for { + _, _, err := conn.ReadMessage() + if err != nil { + // 连接断开 + break + } + } + } +} + +// SendProgressToClients 辅助函数:向所有客户端发送进度 +func SendProgressToClients(status, message string, percent int) { + hub := GetProgressHub() + hub.SendProgress(ProgressMessage{ + Status: status, + Message: message, + Percent: percent, + }) +} + +// SendProgressJSON 辅助函数:发送JSON格式的进度消息 +func SendProgressJSON(w http.ResponseWriter, status, message string, percent int) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(ProgressMessage{ + Status: status, + Message: message, + Percent: percent, + }) +} diff --git a/internal/http/middleware/session.go b/internal/http/middleware/session.go new file mode 100644 index 0000000..0f9207d --- /dev/null +++ b/internal/http/middleware/session.go @@ -0,0 +1,42 @@ +package middleware + +import ( + "net/http" + + "github.com/gorilla/sessions" +) + +var store = sessions.NewCookieStore([]byte("123456")) + +func GetSession(r *http.Request) (*sessions.Session, error) { + return store.Get(r, "transit") +} + +// RequireLogin 登录验证中间件 +// 检查用户是否已登录,未登录则重定向到OAuth登录页面 +func RequireLogin(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, err := GetSession(r) + if err != nil { + // Session读取失败,保存原始URL后重定向到登录 + session, _ = store.New(r, "transit") + session.Values["redirect_after_login"] = r.URL.RequestURI() + session.Save(r, w) + http.Redirect(w, r, "/oauth/gitea", http.StatusFound) + return + } + + // 检查session中是否有用户信息 + userID := session.Values["uid"] + if userID == nil { + // 未登录,保存原始URL后重定向到OAuth登录 + session.Values["redirect_after_login"] = r.URL.RequestURI() + session.Save(r, w) + http.Redirect(w, r, "/oauth/gitea", http.StatusFound) + return + } + + // 已登录,继续处理请求 + next.ServeHTTP(w, r) + }) +} diff --git a/internal/http/router.go b/internal/http/router.go new file mode 100644 index 0000000..85f4d09 --- /dev/null +++ b/internal/http/router.go @@ -0,0 +1,45 @@ +package http + +import ( + "cs-bridge/internal/config" + "cs-bridge/internal/http/handlers" + "cs-bridge/internal/http/middleware" + "cs-bridge/internal/oauth" + "cs-bridge/pkg/logger" + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" +) + +type Router struct { + r http.Handler + cfg config.Server +} + +func NewRouter(cfg *config.Config, mgr *oauth.Manager) *Router { + r := chi.NewRouter() + + r.Get("/health", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("ok")) + }) + + r.Get("/", handlers.Entry(&cfg.CodeServer)) + r.Get("/oauth/{provider}", handlers.OauthLogin(mgr)) + r.Get("/oauth/{provider}/callback", handlers.OauthCallBack(mgr)) + + // Git准备相关路由 - 需要登录 + r.With(middleware.RequireLogin).Get("/tz", handlers.GitProgressPage()) + r.Get("/ws/git-progress", handlers.GitProgressWS()) + r.Post("/api/git/execute", handlers.ExecuteGitOperations(cfg)) + + return &Router{ + r: r, + cfg: cfg.Server, + } +} + +func (r *Router) ListenAndServe() { + logger.GetLogger().Info(fmt.Sprintf("server is running as: http://localhost:%d", r.cfg.Port)) + http.ListenAndServe(fmt.Sprintf(":%d", r.cfg.Port), r.r) +} diff --git a/internal/oauth/gitea.go b/internal/oauth/gitea.go new file mode 100644 index 0000000..ca16e3f --- /dev/null +++ b/internal/oauth/gitea.go @@ -0,0 +1,111 @@ +package oauth + +import ( + "cs-bridge/internal/auth" + "cs-bridge/internal/config" + "cs-bridge/pkg/httpclient" + "encoding/json" + "fmt" + "net/url" + "time" +) + +type Gitea struct { + cfg config.Provider + redirectURI string +} + +// Name implements [Provider]. +func (g Gitea) Name() string { + return "gitea" +} + +// AuthURL implements [Provider]. +func (g Gitea) AuthURL(state string) (string, error) { + u, _ := url.Parse(g.cfg.BaseURL + g.cfg.AuthorizeURL) + q := u.Query() + q.Set("client_id", g.cfg.ClientID) + q.Set("redirect_uri", g.redirectURI) + q.Set("response_type", "code") + q.Set("scope", "read:user") + q.Set("state", state) + u.RawQuery = q.Encode() + return u.String(), nil +} + +// Exchange implements [Provider]. +func (g Gitea) Exchange(code string) (string, error) { + resp, err := httpclient.Default.R(). + SetHeader("Accept", "application/json"). + SetFormData(map[string]string{ + "client_id": g.cfg.ClientID, + "client_secret": g.cfg.ClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": g.redirectURI, + }). + Post(fmt.Sprintf("%s%s", g.cfg.BaseURL, g.cfg.TokenURL)) + if err != nil { + return "", err + } + var out struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn string `json:"expires_in"` + RefreshToken string `json:"refresh_token"` + } + json.NewDecoder(resp.Body).Decode(&out) + + return out.AccessToken, nil +} + +// UserInfo implements [Provider]. +func (g Gitea) UserInfo(token string) (auth.Identify, error) { + resp, err := httpclient.Default.R(). + SetHeader("Authorization", fmt.Sprintf("bearer %s", token)). + Get(fmt.Sprintf("%s%s", g.cfg.BaseURL, g.cfg.UserURL)) + if err != nil { + return auth.Identify{}, err + } + var raw raw + json.NewDecoder(resp.Body).Decode(&raw) + return auth.Identify{ + Provider: g.Name(), + UserId: fmt.Sprint(raw.ID), + Username: raw.Login, + Avatar: raw.AvatarURL, + }, nil +} + +func NewGitea(cfg config.Provider, redirectURI string) Gitea { + return Gitea{ + cfg: cfg, + redirectURI: redirectURI, + } +} + +type raw struct { + ID int `json:"id"` + Login string `json:"login"` + LoginName string `json:"login_name"` + SourceID int `json:"source_id"` + FullName string `json:"full_name"` + Email string `json:"email"` + AvatarURL string `json:"avatar_url"` + HTMLURL string `json:"html_url"` + Language string `json:"language"` + IsAdmin bool `json:"is_admin"` + LastLogin time.Time `json:"last_login"` + Created time.Time `json:"created"` + Restricted bool `json:"restricted"` + Active bool `json:"active"` + ProhibitLogin bool `json:"prohibit_login"` + Location string `json:"location"` + Website string `json:"website"` + Description string `json:"description"` + Visibility string `json:"visibility"` + FollowersCount int `json:"followers_count"` + FollowingCount int `json:"following_count"` + StarredReposCount int `json:"starred_repos_count"` + Username string `json:"username"` +} diff --git a/internal/oauth/lua_provider.go b/internal/oauth/lua_provider.go new file mode 100644 index 0000000..5435076 --- /dev/null +++ b/internal/oauth/lua_provider.go @@ -0,0 +1,86 @@ +package oauth + +import ( + "cs-bridge/internal/auth" + "cs-bridge/pkg/lua" + "fmt" + + lua2 "github.com/yuin/gopher-lua" +) + +type LuaProvider struct { + model ProviderModel + engine *lua.Engine + redirectURI string +} + +// NewLuaProvider creates a new Lua-based provider +func NewLuaProvider(model ProviderModel, redirectURI, scriptPath string) (*LuaProvider, error) { + engine := lua.New() + + // Register all API modules (http, json, log, util) + engine.RegisterAPI() + + // Load the Lua script + if err := engine.LoadFile(scriptPath); err != nil { + engine.Close() + return nil, err + } + + return &LuaProvider{ + model: model, + engine: engine, + redirectURI: redirectURI, + }, nil +} + +// Close cleans up the Lua engine +func (p *LuaProvider) Close() { + p.engine.Close() +} + +// Name implements [Provider]. +func (p *LuaProvider) Name() string { + return p.model.Name +} + +// luaConfig constructs the config table passed to Lua functions +func (p *LuaProvider) luaConfig() *lua2.LTable { + L := p.engine.L + config := L.NewTable() + config.RawSetString("client_id", lua2.LString(p.model.ClientID)) + config.RawSetString("client_secret", lua2.LString(p.model.ClientSecret)) + config.RawSetString("base_url", lua2.LString(p.model.BaseURL)) + config.RawSetString("authorize_url", lua2.LString(p.model.AuthorizeURL)) + config.RawSetString("token_url", lua2.LString(p.model.TokenURL)) + config.RawSetString("user_url", lua2.LString(p.model.UserURL)) + config.RawSetString("redirect_uri", lua2.LString(p.redirectURI)) + return config +} + +// AuthURL implements [Provider]. +func (p *LuaProvider) AuthURL(state string) (string, error) { + cfg := p.luaConfig() + return p.engine.CallString("auth_url", cfg, lua2.LString(state)) +} + +// Exchange implements [Provider]. +func (p *LuaProvider) Exchange(code string) (string, error) { + cfg := p.luaConfig() + return p.engine.CallString("exchange", cfg, lua2.LString(code)) +} + +// UserInfo implements [Provider]. +func (p *LuaProvider) UserInfo(token string) (auth.Identify, error) { + cfg := p.luaConfig() + + var userInfo auth.Identify + if err := p.engine.CallStruct("user_info", &userInfo, cfg, lua2.LString(token)); err != nil { + return auth.Identify{}, fmt.Errorf("failed to get user info: %w", err) + } + + // Set the provider name + userInfo.Provider = p.model.Name + + return userInfo, nil +} diff --git a/internal/oauth/manager.go b/internal/oauth/manager.go new file mode 100644 index 0000000..3b2d9a0 --- /dev/null +++ b/internal/oauth/manager.go @@ -0,0 +1,25 @@ +package oauth + +import "fmt" + +type Manager struct { + provider map[string]Provider +} + +func NewManager() *Manager { + return &Manager{ + provider: map[string]Provider{}, + } +} + +func (m *Manager) Register(p Provider) { + m.provider[p.Name()] = p +} + +func (m *Manager) Get(name string) (Provider, error) { + p, ok := m.provider[name] + if !ok { + return nil, fmt.Errorf("oauth provider not found: %s", name) + } + return p, nil +} diff --git a/internal/oauth/model.go b/internal/oauth/model.go new file mode 100644 index 0000000..29184e1 --- /dev/null +++ b/internal/oauth/model.go @@ -0,0 +1,22 @@ +package oauth + +import "time" + +type ProviderModel struct { + ID uint `gorm:"primaryKey"` + Name string `gorm:"uniqueIndex"` + Type string + BaseURL string + ClientID string + ClientSecret string + + AuthorizeURL string + TokenURL string + UserURL string + + LuaScript string + Enabled bool + + CreatedAt time.Time + UpdatedAt time.Time +} diff --git a/internal/oauth/provider.go b/internal/oauth/provider.go new file mode 100644 index 0000000..b5514d2 --- /dev/null +++ b/internal/oauth/provider.go @@ -0,0 +1,16 @@ +package oauth + +import "cs-bridge/internal/auth" + +type UserInfo struct { + ID string `json:"id"` + Username string `json:"username"` + Avatar string `json:"avatar_url"` +} + +type Provider interface { + Name() string + AuthURL(state string) (string, error) + Exchange(code string) (string, error) + UserInfo(token string) (auth.Identify, error) +} diff --git a/internal/oauth/state.go b/internal/oauth/state.go new file mode 100644 index 0000000..15992c2 --- /dev/null +++ b/internal/oauth/state.go @@ -0,0 +1,12 @@ +package oauth + +import ( + "crypto/rand" + "encoding/hex" +) + +func NewState() string { + b := make([]byte, 16) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} diff --git a/pkg/git/git.go b/pkg/git/git.go new file mode 100644 index 0000000..a227fe4 --- /dev/null +++ b/pkg/git/git.go @@ -0,0 +1,238 @@ +package git + +import ( + "cs-bridge/pkg/logger" + "fmt" + "os" + "os/exec" + "path" + "path/filepath" + "strings" +) + +// ProgressCallback 进度回调函数类型 +// 用于向外部报告git操作的进度 +type ProgressCallback func(message string, percent int) + +// SSHConfig SSH连接配置 +type SSHConfig struct { + Host string // SSH服务器地址 + Port int // SSH端口 + User string // SSH用户名 + KeyPath string // SSH私钥路径 +} + +// CheckRepoExists 检查指定路径是否存在git仓库 +// path: 要检查的路径 +// 返回true表示存在.git目录,false表示不存在 +func CheckRepoExists(path string) bool { + log := logger.GetLogger() + gitDir := filepath.Join(path, ".git") + info, err := os.Stat(gitDir) + if err != nil { + log.Debug(fmt.Sprintf("[Git] 本地仓库不存在: %s", path)) + return false + } + exists := info.IsDir() + log.Info(fmt.Sprintf("[Git] 本地仓库检查 - Path: %s, Exists: %v", path, exists)) + return exists +} + +// CheckRepoExistsRemote 通过SSH检查远程服务器上是否存在git仓库 +func CheckRepoExistsRemote(sshCfg SSHConfig, path string) bool { + log := logger.GetLogger() + log.Info(fmt.Sprintf("[Git] SSH检查远程仓库 - Host: %s, Path: %s", sshCfg.Host, path)) + + cmd := buildSSHCommand(sshCfg, fmt.Sprintf("test -d '%s/.git' && echo 'exists' || echo 'not_exists'", path)) + output, err := cmd.CombinedOutput() + if err != nil { + log.Error(fmt.Sprintf("[Git] SSH检查失败: %v, Output: %s", err, string(output))) + return false + } + exists := strings.TrimSpace(string(output)) == "exists" + log.Info(fmt.Sprintf("[Git] SSH远程仓库检查结果 - Exists: %v", exists)) + return exists +} + +// CloneRepo 克隆git仓库(本地执行) +func CloneRepo(repoURL, destPath string, callback ProgressCallback) error { + log := logger.GetLogger() + log.Info(fmt.Sprintf("[Git] 开始本地克隆 - URL: %s, Dest: %s", repoURL, destPath)) + + // 确保父目录存在 + parentDir := filepath.Dir(destPath) + if err := os.MkdirAll(parentDir, 0755); err != nil { + log.Error(fmt.Sprintf("[Git] 创建父目录失败: %v", err)) + return fmt.Errorf("failed to create parent directory: %w", err) + } + + // 如果目标目录存在但不是git仓库,删除它 + if _, err := os.Stat(destPath); err == nil { + if !CheckRepoExists(destPath) { + log.Info(fmt.Sprintf("[Git] 清理非git目录: %s", destPath)) + if err := os.RemoveAll(destPath); err != nil { + log.Error(fmt.Sprintf("[Git] 清理目录失败: %v", err)) + return fmt.Errorf("failed to clean destination directory: %w", err) + } + } + } + + if callback != nil { + callback("正在克隆仓库...", 10) + } + + // 执行git clone命令 + cmd := exec.Command("git", "clone", repoURL, destPath) + output, err := cmd.CombinedOutput() + if err != nil { + log.Error(fmt.Sprintf("[Git] 克隆失败: %v, Output: %s", err, string(output))) + return fmt.Errorf("git clone failed: %w, output: %s", err, string(output)) + } + + log.Info(fmt.Sprintf("[Git] 克隆成功: %s", destPath)) + if callback != nil { + callback("仓库克隆完成", 100) + } + return nil +} + +// CloneRepoRemote 通过SSH在远程服务器上克隆git仓库 +func CloneRepoRemote(sshCfg SSHConfig, repoURL, destPath string, callback ProgressCallback) error { + log := logger.GetLogger() + log.Info(fmt.Sprintf("[Git] 开始SSH远程克隆 - Host: %s, URL: %s, Dest: %s", sshCfg.Host, repoURL, destPath)) + + if callback != nil { + callback("正在通过SSH连接服务器...", 5) + } + + // 确保父目录存在 + parentDir := path.Dir(destPath) + log.Debug(fmt.Sprintf("[Git] 创建远程父目录: %s", parentDir)) + mkdirCmd := buildSSHCommand(sshCfg, fmt.Sprintf("mkdir -p '%s'", parentDir)) + if output, err := mkdirCmd.CombinedOutput(); err != nil { + log.Error(fmt.Sprintf("[Git] SSH创建目录失败: %v, Output: %s", err, string(output))) + return fmt.Errorf("failed to create parent directory: %w, output: %s", err, string(output)) + } + + // 如果目标目录存在但不是git仓库,删除它 + log.Debug(fmt.Sprintf("[Git] 检查并清理远程目录: %s", destPath)) + cleanCmd := buildSSHCommand(sshCfg, fmt.Sprintf( + "if [ -d '%s' ] && [ ! -d '%s/.git' ]; then rm -rf '%s'; fi", + destPath, destPath, destPath, + )) + if output, err := cleanCmd.CombinedOutput(); err != nil { + log.Error(fmt.Sprintf("[Git] SSH清理目录失败: %v, Output: %s", err, string(output))) + return fmt.Errorf("failed to clean destination: %w, output: %s", err, string(output)) + } + + if callback != nil { + callback("正在克隆仓库...", 20) + } + + // 执行git clone + log.Info(fmt.Sprintf("[Git] 执行SSH git clone: %s -> %s", repoURL, destPath)) + cloneCmd := buildSSHCommand(sshCfg, fmt.Sprintf("git clone '%s' '%s'", repoURL, destPath)) + output, err := cloneCmd.CombinedOutput() + if err != nil { + log.Error(fmt.Sprintf("[Git] SSH克隆失败: %v, Output: %s", err, string(output))) + return fmt.Errorf("git clone failed: %w, output: %s", err, string(output)) + } + + log.Info(fmt.Sprintf("[Git] SSH克隆成功: %s", destPath)) + if callback != nil { + callback("仓库克隆完成", 100) + } + return nil +} + +// PullRepo 更新git仓库(本地执行) +func PullRepo(repoPath string, callback ProgressCallback) error { + log := logger.GetLogger() + log.Info(fmt.Sprintf("[Git] 开始本地更新仓库: %s", repoPath)) + + if !CheckRepoExists(repoPath) { + return fmt.Errorf("not a git repository: %s", repoPath) + } + + if callback != nil { + callback("正在更新仓库...", 10) + } + + cmd := exec.Command("git", "pull") + cmd.Dir = repoPath + output, err := cmd.CombinedOutput() + if err != nil { + log.Error(fmt.Sprintf("[Git] 更新失败: %v, Output: %s", err, string(output))) + return fmt.Errorf("git pull failed: %w, output: %s", err, string(output)) + } + + log.Info(fmt.Sprintf("[Git] 更新成功: %s", strings.TrimSpace(string(output)))) + if callback != nil { + callback(fmt.Sprintf("仓库更新完成: %s", strings.TrimSpace(string(output))), 100) + } + return nil +} + +// PullRepoRemote 通过SSH在远程服务器上更新git仓库 +func PullRepoRemote(sshCfg SSHConfig, repoPath string, callback ProgressCallback) error { + log := logger.GetLogger() + log.Info(fmt.Sprintf("[Git] 开始SSH远程更新仓库 - Host: %s, Path: %s", sshCfg.Host, repoPath)) + + if callback != nil { + callback("正在通过SSH连接服务器...", 5) + } + + if callback != nil { + callback("正在更新仓库...", 20) + } + + pullCmd := buildSSHCommand(sshCfg, fmt.Sprintf("cd '%s' && git pull", repoPath)) + output, err := pullCmd.CombinedOutput() + if err != nil { + log.Error(fmt.Sprintf("[Git] SSH更新失败: %v, Output: %s", err, string(output))) + return fmt.Errorf("git pull failed: %w, output: %s", err, string(output)) + } + + log.Info(fmt.Sprintf("[Git] SSH更新成功: %s", strings.TrimSpace(string(output)))) + if callback != nil { + callback(fmt.Sprintf("仓库更新完成: %s", strings.TrimSpace(string(output))), 100) + } + return nil +} + +// GetRepoName 从仓库URL中提取仓库名称 +func GetRepoName(repoURL string) string { + base := filepath.Base(repoURL) + if len(base) > 4 && base[len(base)-4:] == ".git" { + return base[:len(base)-4] + } + return base +} + +// buildSSHCommand 构建SSH命令 +func buildSSHCommand(sshCfg SSHConfig, remoteCmd string) *exec.Cmd { + log := logger.GetLogger() + + // SSH密钥路径需要转换为Unix风格(SSH命令总是在Linux上执行) + // 即使在Windows上编译,SSH密钥路径传给ssh命令时也必须用正斜杠 + keyPath := filepath.ToSlash(sshCfg.KeyPath) + + args := []string{ + "-o", "StrictHostKeyChecking=no", + "-o", "BatchMode=yes", + } + + if sshCfg.Port != 0 && sshCfg.Port != 22 { + args = append(args, "-p", fmt.Sprintf("%d", sshCfg.Port)) + } + + if keyPath != "" { // Use the converted keyPath + args = append(args, "-i", keyPath) + } + + args = append(args, fmt.Sprintf("%s@%s", sshCfg.User, sshCfg.Host), remoteCmd) + + log.Debug(fmt.Sprintf("[Git] SSH命令: ssh %s", strings.Join(args, " "))) + + return exec.Command("ssh", args...) +} diff --git a/pkg/httpclient/resty.go b/pkg/httpclient/resty.go new file mode 100644 index 0000000..df74097 --- /dev/null +++ b/pkg/httpclient/resty.go @@ -0,0 +1,17 @@ +package httpclient + +import ( + "time" + + "resty.dev/v3" +) + +var Default *resty.Client + +func Init() { + c := resty.New() + // TODO: 将代理配置移到config中 + // c.SetProxy("http://127.0.0.1:9000") + c.SetTimeout(10 * time.Second) + Default = c +} diff --git a/pkg/logger/func.go b/pkg/logger/func.go new file mode 100644 index 0000000..a81c291 --- /dev/null +++ b/pkg/logger/func.go @@ -0,0 +1,27 @@ +package logger + +import "go.uber.org/zap" + +func GetLogger() *zap.Logger { + return logger +} + +func Debug(msg string, fields ...zap.Field) { + logger.Debug(msg, fields...) +} + +func Info(msg string, fields ...zap.Field) { + logger.Info(msg, fields...) +} + +func Warn(msg string, fields ...zap.Field) { + logger.Warn(msg, fields...) +} + +func Error(msg string, fields ...zap.Field) { + logger.Error(msg, fields...) +} + +func Sync() { + _ = logger.Sync() +} diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go new file mode 100644 index 0000000..e423174 --- /dev/null +++ b/pkg/logger/logger.go @@ -0,0 +1,87 @@ +package logger + +import ( + "cs-bridge/internal/config" + "cs-bridge/internal/consts" + "os" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" +) + +var logger *zap.Logger + +func Init(log config.Log) { + var zapLevel zapcore.Level + + // 日志等级解析 + switch log.Level { + case "debug": + zapLevel = zap.DebugLevel + case "info": + zapLevel = zap.InfoLevel + case "warning": + zapLevel = zap.WarnLevel + case "error": + zapLevel = zap.ErrorLevel + default: + zapLevel = zap.InfoLevel + } + + // lumberjack 日志切割配置 + writeSyncer := zapcore.AddSync(&lumberjack.Logger{ + Filename: log.Filepath, + MaxSize: log.MaxSizeMB, + MaxBackups: log.Backups, + MaxAge: log.MaxAgeDay, + Compress: log.Compress, + }) + + // 日志编码格式 + encoderConfigColor := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + MessageKey: "msg", + StacktraceKey: "Stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalColorLevelEncoder, // 彩色等级输出(终端) + // EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: timeEncoder, + EncodeDuration: zapcore.SecondsDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + encoderConfig := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + MessageKey: "msg", + StacktraceKey: "Stacktrace", + LineEnding: zapcore.DefaultLineEnding, + // EncodeLevel: zapcore.CapitalColorLevelEncoder, // 彩色等级输出(终端) + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: timeEncoder, + EncodeDuration: zapcore.SecondsDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + encoderConsole := zapcore.NewConsoleEncoder(encoderConfigColor) + encoderJson := zapcore.NewJSONEncoder(encoderConfig) + + core := zapcore.NewTee( + zapcore.NewCore(encoderJson, writeSyncer, zapLevel), + zapcore.NewCore(encoderConsole, zapcore.AddSync(os.Stdout), zapLevel)) + + logger = zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1)) + zap.ReplaceGlobals(logger) +} + +// 时间格式 +func timeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) { + enc.AppendString(t.Format(consts.TimeFormatDateTime)) +} diff --git a/pkg/lua/api.go b/pkg/lua/api.go new file mode 100644 index 0000000..8558e8a --- /dev/null +++ b/pkg/lua/api.go @@ -0,0 +1,393 @@ +package lua + +import ( + "cs-bridge/pkg/httpclient" + "cs-bridge/pkg/logger" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/url" + + lua "github.com/yuin/gopher-lua" + "resty.dev/v3" +) + +// RegisterHTTPModule registers the http module in Lua +func RegisterHTTPModule(L *lua.LState) { + httpMod := L.NewTable() + + // http.get(url, headers) + httpMod.RawSetString("get", L.NewFunction(luaHTTPGet)) + + // http.post(url, data, headers) + httpMod.RawSetString("post", L.NewFunction(luaHTTPPost)) + + // http.put(url, data, headers) + httpMod.RawSetString("put", L.NewFunction(luaHTTPPut)) + + // http.delete(url, headers) + httpMod.RawSetString("delete", L.NewFunction(luaHTTPDelete)) + + L.SetGlobal("http", httpMod) +} + +// luaHTTPGet implements http.get(url, headers) +func luaHTTPGet(L *lua.LState) int { + url := L.CheckString(1) + headers := L.OptTable(2, nil) + + req := httpclient.Default.R() + + // Set headers if provided + if headers != nil { + headers.ForEach(func(key, value lua.LValue) { + req.SetHeader(key.String(), value.String()) + }) + } + + resp, err := req.Get(url) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + result := parseHTTPResponse(L, resp) + L.Push(result) + return 1 +} + +// luaHTTPPost implements http.post(url, data, headers) +func luaHTTPPost(L *lua.LState) int { + url := L.CheckString(1) + data := L.CheckTable(2) + headers := L.OptTable(3, nil) + + req := httpclient.Default.R() + + // Convert Lua table to map for request body + bodyMap := luaTableToMap(data) + req.SetBody(bodyMap) + req.SetHeader("Content-Type", "application/json") + + // Set additional headers if provided + if headers != nil { + headers.ForEach(func(key, value lua.LValue) { + req.SetHeader(key.String(), value.String()) + }) + } + + resp, err := req.Post(url) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + result := parseHTTPResponse(L, resp) + L.Push(result) + return 1 +} + +// luaHTTPPut implements http.put(url, data, headers) +func luaHTTPPut(L *lua.LState) int { + url := L.CheckString(1) + data := L.CheckTable(2) + headers := L.OptTable(3, nil) + + req := httpclient.Default.R() + + bodyMap := luaTableToMap(data) + req.SetBody(bodyMap) + req.SetHeader("Content-Type", "application/json") + + if headers != nil { + headers.ForEach(func(key, value lua.LValue) { + req.SetHeader(key.String(), value.String()) + }) + } + + resp, err := req.Put(url) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + result := parseHTTPResponse(L, resp) + L.Push(result) + return 1 +} + +// luaHTTPDelete implements http.delete(url, headers) +func luaHTTPDelete(L *lua.LState) int { + url := L.CheckString(1) + headers := L.OptTable(2, nil) + + req := httpclient.Default.R() + + if headers != nil { + headers.ForEach(func(key, value lua.LValue) { + req.SetHeader(key.String(), value.String()) + }) + } + + resp, err := req.Delete(url) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + result := parseHTTPResponse(L, resp) + L.Push(result) + return 1 +} + +// parseHTTPResponse parses HTTP response and returns Lua table +func parseHTTPResponse(L *lua.LState, resp *resty.Response) lua.LValue { + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + logger.GetLogger().Error(fmt.Sprintf("读取响应体失败: %v", err)) + // Return error table + result := L.NewTable() + result.RawSetString("error", lua.LString(err.Error())) + result.RawSetString("status", lua.LNumber(resp.StatusCode())) + return result + } + defer resp.Body.Close() + + // Log response for debugging + logger.GetLogger().Debug(fmt.Sprintf("HTTP响应 [%d]: %s", resp.StatusCode(), string(body))) + + // Try to parse as JSON first + var jsonData any + if err := json.Unmarshal(body, &jsonData); err == nil { + logger.GetLogger().Debug(fmt.Sprintf("JSON解析成功,类型: %T, 值: %+v", jsonData, jsonData)) + result := GoToLua(L, jsonData) + logger.GetLogger().Debug(fmt.Sprintf("转换为Lua后类型: %s", result.Type().String())) + + // Try to access a test field if it's a table + if tbl, ok := result.(*lua.LTable); ok { + testVal := tbl.RawGetString("access_token") + logger.GetLogger().Debug(fmt.Sprintf("测试访问access_token: %s (type: %s)", testVal.String(), testVal.Type().String())) + } + + return result + } + + // If not JSON, return as string in a table + logger.GetLogger().Debug(fmt.Sprintf("响应不是有效JSON: %v", err)) + result := L.NewTable() + result.RawSetString("body", lua.LString(string(body))) + result.RawSetString("status", lua.LNumber(resp.StatusCode())) + return result +} + +// luaTableToMap converts a Lua table to a Go map +func luaTableToMap(t *lua.LTable) map[string]any { + result := make(map[string]any) + t.ForEach(func(key, value lua.LValue) { + result[key.String()] = luaValueToGo(value) + }) + return result +} + +// luaValueToGo converts a Lua value to a Go value +func luaValueToGo(lv lua.LValue) any { + switch v := lv.(type) { + case *lua.LNilType: + return nil + case lua.LBool: + return bool(v) + case lua.LNumber: + return float64(v) + case lua.LString: + return string(v) + case *lua.LTable: + // Check if it's an array or map + if v.Len() > 0 { + // Array + arr := make([]any, 0, v.Len()) + for i := 1; i <= v.Len(); i++ { + arr = append(arr, luaValueToGo(v.RawGetInt(i))) + } + return arr + } + // Map + m := make(map[string]any) + v.ForEach(func(key, value lua.LValue) { + m[key.String()] = luaValueToGo(value) + }) + return m + default: + return nil + } +} + +// RegisterJSONModule registers the json module in Lua +func RegisterJSONModule(L *lua.LState) { + jsonMod := L.NewTable() + + // json.encode(table) + jsonMod.RawSetString("encode", L.NewFunction(luaJSONEncode)) + + // json.decode(string) + jsonMod.RawSetString("decode", L.NewFunction(luaJSONDecode)) + + L.SetGlobal("json", jsonMod) +} + +// luaJSONEncode implements json.encode(value) +func luaJSONEncode(L *lua.LState) int { + value := L.CheckAny(1) + + goValue := luaValueToGo(value) + jsonBytes, err := json.Marshal(goValue) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + L.Push(lua.LString(string(jsonBytes))) + return 1 +} + +// luaJSONDecode implements json.decode(jsonString) +func luaJSONDecode(L *lua.LState) int { + jsonStr := L.CheckString(1) + + var data any + if err := json.Unmarshal([]byte(jsonStr), &data); err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + L.Push(GoToLua(L, data)) + return 1 +} + +// RegisterLogModule registers the log module in Lua +func RegisterLogModule(L *lua.LState) { + logMod := L.NewTable() + + logMod.RawSetString("debug", L.NewFunction(luaLogDebug)) + logMod.RawSetString("info", L.NewFunction(luaLogInfo)) + logMod.RawSetString("warn", L.NewFunction(luaLogWarn)) + logMod.RawSetString("error", L.NewFunction(luaLogError)) + + L.SetGlobal("log", logMod) +} + +// luaLogDebug implements log.debug(message, ...) +func luaLogDebug(L *lua.LState) int { + msg := formatLogMessage(L) + logger.GetLogger().Debug(msg) + return 0 +} + +// luaLogInfo implements log.info(message, ...) +func luaLogInfo(L *lua.LState) int { + msg := formatLogMessage(L) + logger.GetLogger().Info(msg) + return 0 +} + +// luaLogWarn implements log.warn(message, ...) +func luaLogWarn(L *lua.LState) int { + msg := formatLogMessage(L) + logger.GetLogger().Warn(msg) + return 0 +} + +// luaLogError implements log.error(message, ...) +func luaLogError(L *lua.LState) int { + msg := formatLogMessage(L) + logger.GetLogger().Error(msg) + return 0 +} + +// formatLogMessage formats log message from Lua arguments +func formatLogMessage(L *lua.LState) string { + n := L.GetTop() + if n == 0 { + return "" + } + + if n == 1 { + return L.CheckString(1) + } + + // Format string with arguments + format := L.CheckString(1) + args := make([]any, n-1) + for i := 2; i <= n; i++ { + args[i-2] = luaValueToGo(L.Get(i)) + } + return fmt.Sprintf(format, args...) +} + +// RegisterUtilModule registers utility functions in Lua +func RegisterUtilModule(L *lua.LState) { + utilMod := L.NewTable() + + // base64.encode(string) + utilMod.RawSetString("base64_encode", L.NewFunction(luaBase64Encode)) + + // base64.decode(string) + utilMod.RawSetString("base64_decode", L.NewFunction(luaBase64Decode)) + + // url.encode(string) + utilMod.RawSetString("url_encode", L.NewFunction(luaURLEncode)) + + // url.decode(string) + utilMod.RawSetString("url_decode", L.NewFunction(luaURLDecode)) + + L.SetGlobal("util", utilMod) +} + +// luaBase64Encode implements base64_encode(str) +func luaBase64Encode(L *lua.LState) int { + str := L.CheckString(1) + encoded := base64.StdEncoding.EncodeToString([]byte(str)) + L.Push(lua.LString(encoded)) + return 1 +} + +// luaBase64Decode implements base64_decode(str) +func luaBase64Decode(L *lua.LState) int { + str := L.CheckString(1) + decoded, err := base64.StdEncoding.DecodeString(str) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + L.Push(lua.LString(string(decoded))) + return 1 +} + +// luaURLEncode implements url_encode(str) +func luaURLEncode(L *lua.LState) int { + str := L.CheckString(1) + encoded := url.QueryEscape(str) + L.Push(lua.LString(encoded)) + return 1 +} + +// luaURLDecode implements url_decode(str) +func luaURLDecode(L *lua.LState) int { + str := L.CheckString(1) + decoded, err := url.QueryUnescape(str) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + L.Push(lua.LString(decoded)) + return 1 +} diff --git a/pkg/lua/engine.go b/pkg/lua/engine.go new file mode 100644 index 0000000..6dd599c --- /dev/null +++ b/pkg/lua/engine.go @@ -0,0 +1,307 @@ +package lua + +import ( + "errors" + "fmt" + "reflect" + + lua "github.com/yuin/gopher-lua" +) + +type Engine struct { + L *lua.LState +} + +func New() *Engine { + L := lua.NewState() + return &Engine{L: L} +} + +func (e *Engine) Close() { + e.L.Close() +} + +// RegisterAPI registers all Lua API modules (http, json, log, etc.) +func (e *Engine) RegisterAPI() { + RegisterHTTPModule(e.L) + RegisterJSONModule(e.L) + RegisterLogModule(e.L) + RegisterUtilModule(e.L) +} + +func (e *Engine) LoadFile(path string) error { + return e.L.DoFile(path) +} + +func (e *Engine) CallString(fn string, args ...lua.LValue) (string, error) { + L := e.L + + if err := L.CallByParam(lua.P{ + Fn: L.GetGlobal(fn), + NRet: 1, + Protect: true, + }, args...); err != nil { + return "", err + } + + ret := L.Get(-1) + L.Pop(1) + return ret.String(), nil +} + +func (e *Engine) CallStruct(fn string, out any, args ...lua.LValue) error { + L := e.L + + f := L.GetGlobal(fn) + if f.Type() != lua.LTFunction { + return fmt.Errorf("lua function %s not found", fn) + } + + if err := L.CallByParam(lua.P{ + Fn: f, + NRet: 1, + Protect: true, + }, args...); err != nil { + return err + } + + ret := L.Get(-1) + L.Pop(1) + + table, ok := ret.(*lua.LTable) + if !ok { + return fmt.Errorf("lua function %s must return table", fn) + } + + return luaTableToStruct(table, out) +} + +func luaTableToStruct(t *lua.LTable, out any) error { + v := reflect.ValueOf(out) + if v.Kind() != reflect.Ptr { + return errors.New("out must be pointer") + } + + v = v.Elem() + + for i := 0; i < v.NumField(); i++ { + field := v.Type().Field(i) + key := field.Tag.Get("lua") + if key == "" { + continue + } + + lv := t.RawGetString(key) + if lv == lua.LNil { + continue + } + + fv := v.Field(i) + if err := setFieldFromLua(fv, lv); err != nil { + return fmt.Errorf("field %s: %w", field.Name, err) + } + } + return nil +} + +// setFieldFromLua sets a reflect.Value field from a lua.LValue +func setFieldFromLua(fv reflect.Value, lv lua.LValue) error { + if !fv.CanSet() { + return errors.New("field cannot be set") + } + + switch fv.Kind() { + case reflect.String: + fv.SetString(lv.String()) + + case reflect.Bool: + fv.SetBool(lua.LVAsBool(lv)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if n, ok := lv.(lua.LNumber); ok { + fv.SetInt(int64(n)) + } + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if n, ok := lv.(lua.LNumber); ok { + fv.SetUint(uint64(n)) + } + + case reflect.Float32, reflect.Float64: + if n, ok := lv.(lua.LNumber); ok { + fv.SetFloat(float64(n)) + } + + case reflect.Slice: + if ltable, ok := lv.(*lua.LTable); ok { + return setSliceFromLuaTable(fv, ltable) + } + + case reflect.Map: + if ltable, ok := lv.(*lua.LTable); ok { + return setMapFromLuaTable(fv, ltable) + } + + case reflect.Struct: + if ltable, ok := lv.(*lua.LTable); ok { + return luaTableToStruct(ltable, fv.Addr().Interface()) + } + + case reflect.Ptr: + if fv.IsNil() { + fv.Set(reflect.New(fv.Type().Elem())) + } + return setFieldFromLua(fv.Elem(), lv) + } + + return nil +} + +// setSliceFromLuaTable converts a Lua table to a Go slice +func setSliceFromLuaTable(fv reflect.Value, t *lua.LTable) error { + length := t.Len() + slice := reflect.MakeSlice(fv.Type(), length, length) + + for i := 1; i <= length; i++ { + lv := t.RawGetInt(i) + if lv == lua.LNil { + continue + } + + elem := slice.Index(i - 1) + if err := setFieldFromLua(elem, lv); err != nil { + return fmt.Errorf("index %d: %w", i, err) + } + } + + fv.Set(slice) + return nil +} + +// setMapFromLuaTable converts a Lua table to a Go map +func setMapFromLuaTable(fv reflect.Value, t *lua.LTable) error { + mapType := fv.Type() + newMap := reflect.MakeMap(mapType) + + var convErr error + t.ForEach(func(key, value lua.LValue) { + if convErr != nil { + return + } + + // Convert key + k := reflect.New(mapType.Key()).Elem() + if err := setFieldFromLua(k, key); err != nil { + convErr = fmt.Errorf("key conversion: %w", err) + return + } + + // Convert value + v := reflect.New(mapType.Elem()).Elem() + if err := setFieldFromLua(v, value); err != nil { + convErr = fmt.Errorf("value conversion: %w", err) + return + } + + newMap.SetMapIndex(k, v) + }) + + if convErr != nil { + return convErr + } + + fv.Set(newMap) + return nil +} + +// GoToLua converts a Go value to a Lua value +func GoToLua(L *lua.LState, v any) lua.LValue { + if v == nil { + return lua.LNil + } + + val := reflect.ValueOf(v) + return goValueToLua(L, val) +} + +// goValueToLua converts a reflect.Value to lua.LValue +func goValueToLua(L *lua.LState, v reflect.Value) lua.LValue { + if !v.IsValid() { + return lua.LNil + } + + // Dereference pointers and interfaces + for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + if v.IsNil() { + return lua.LNil + } + v = v.Elem() + } + + switch v.Kind() { + case reflect.Bool: + return lua.LBool(v.Bool()) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return lua.LNumber(v.Int()) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return lua.LNumber(v.Uint()) + + case reflect.Float32, reflect.Float64: + return lua.LNumber(v.Float()) + + case reflect.String: + return lua.LString(v.String()) + + case reflect.Slice, reflect.Array: + table := L.NewTable() + for i := 0; i < v.Len(); i++ { + table.Append(goValueToLua(L, v.Index(i))) + } + return table + + case reflect.Map: + table := L.NewTable() + iter := v.MapRange() + for iter.Next() { + key := iter.Key() + val := goValueToLua(L, iter.Value()) + + // For string keys, use RawSetString for better Lua compatibility + if key.Kind() == reflect.String { + table.RawSetString(key.String(), val) + } else { + // For non-string keys, convert and use RawSet + luaKey := goValueToLua(L, key) + table.RawSet(luaKey, val) + } + } + return table + + case reflect.Struct: + table := L.NewTable() + t := v.Type() + for i := 0; i < v.NumField(); i++ { + field := t.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Get lua tag or use field name + luaKey := field.Tag.Get("lua") + if luaKey == "" { + luaKey = field.Name + } + + fieldValue := goValueToLua(L, v.Field(i)) + table.RawSetString(luaKey, fieldValue) + } + return table + + default: + return lua.LNil + } +} diff --git a/scripts/oauth/gitea.lua b/scripts/oauth/gitea.lua new file mode 100644 index 0000000..65eabb0 --- /dev/null +++ b/scripts/oauth/gitea.lua @@ -0,0 +1,94 @@ +-- Gitea OAuth Provider +-- 使用完整的Lua API功能 + +-- 1. 生成授权URL +function auth_url(cfg, state) + log.info("Gitea: 生成授权URL, state=%s", state) + + -- 构建查询参数 + local params = { + "client_id=" .. util.url_encode(cfg.client_id), + "redirect_uri=" .. util.url_encode(cfg.redirect_uri), + "response_type=code", + "scope=read:user", + "state=" .. util.url_encode(state) + } + + -- 拼接URL + local url = cfg.base_url .. cfg.authorize_url .. "?" .. table.concat(params, "&") + log.debug("Gitea: 授权URL生成完成") + + return url +end + +-- 2. 交换授权码获取访问令牌 +function exchange(cfg, code) + log.info("Gitea: 开始交换授权码获取访问令牌") + + -- 准备请求数据 + local data = { + client_id = cfg.client_id, + client_secret = cfg.client_secret, + code = code, + grant_type = "authorization_code", + redirect_uri = cfg.redirect_uri + } + + local url = cfg.base_url .. cfg.token_url + log.debug("Gitea: 请求token URL: %s", url) + + -- 发送POST请求 + local resp, err = http.post(url, data, { + Accept = "application/json" + }) + + if err then + log.error("Gitea: 获取token失败: %s", err) + return nil, err + end + + -- 检查响应 + if not resp.access_token then + log.error("Gitea: 响应中没有access_token") + log.debug("Gitea: 响应内容: %s", json.encode(resp)) + return "" + end + + log.info("Gitea: 成功获取访问令牌") + return resp.access_token +end + +-- 3. 获取用户信息 +function user_info(cfg, token) + log.info("Gitea: 开始获取用户信息") + + local url = cfg.base_url .. cfg.user_url + log.debug("Gitea: 请求用户信息URL: %s", url) + + -- 发送GET请求 + local resp, err = http.get(url, { + Authorization = "Bearer " .. token, + Accept = "application/json" + }) + + if err then + log.error("Gitea: 获取用户信息失败: %s", err) + return nil, err + end + + -- 验证响应 + if not resp.id then + log.error("Gitea: 用户信息响应格式错误") + log.debug("Gitea: 响应内容: %s", json.encode(resp)) + return {} + end + + log.info("Gitea: 成功获取用户信息, 用户名=%s", resp.login or "unknown") + + -- 返回标准化的用户信息 + return { + uid = tostring(resp.id), + username = resp.login, + avatar = resp.avatar_url + } +end