feat(*): go 后端项目脚手架
This commit is contained in:
18
pkg/common/rand.go
Normal file
18
pkg/common/rand.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Rand 生成指定长度的随机字符串,字符集为 [0-9a-zA-Z]
|
||||
func Rand(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[seededRand.Intn(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
35
pkg/common/response.go
Normal file
35
pkg/common/response.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Response[T any] struct {
|
||||
Code int `json:"code" example:"0"`
|
||||
Msg string `json:"msg" example:"success"`
|
||||
Data T `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func Succ(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response[interface{}]{
|
||||
Code: 0,
|
||||
Msg: "success",
|
||||
Data: data,
|
||||
})
|
||||
}
|
||||
|
||||
func Ok(c *gin.Context, msg string) {
|
||||
c.JSON(http.StatusOK, Response[interface{}]{
|
||||
Code: 0,
|
||||
Msg: msg,
|
||||
})
|
||||
}
|
||||
|
||||
func Fail(c *gin.Context, msg string) {
|
||||
c.JSON(http.StatusBadRequest, Response[interface{}]{
|
||||
Code: -1,
|
||||
Msg: msg,
|
||||
})
|
||||
}
|
||||
10
pkg/common/sha1.go
Normal file
10
pkg/common/sha1.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func Sha1(pwd string) string {
|
||||
return hex.EncodeToString(sha1.New().Sum([]byte(pwd)))
|
||||
}
|
||||
130
pkg/config/config.go
Normal file
130
pkg/config/config.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/go-yaml/yaml"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/zhilv666/navsite/pkg/consts"
|
||||
)
|
||||
|
||||
var (
|
||||
globalViper *viper.Viper
|
||||
globalViperOnce sync.Once
|
||||
configMu sync.RWMutex
|
||||
cachedConfig *Config
|
||||
)
|
||||
|
||||
// NewConfig 初始化配置
|
||||
func NewConfig() *Config {
|
||||
globalViperOnce.Do(func() {
|
||||
globalViper = viper.New()
|
||||
})
|
||||
|
||||
configPath := loadConfigFile()
|
||||
if configPath == "" {
|
||||
log.Println("⚠️ 未找到配置文件,使用默认配置并导出为 config.yaml")
|
||||
cachedConfig = DefaultConfig()
|
||||
exportConfig(cachedConfig, "yaml")
|
||||
return cachedConfig
|
||||
}
|
||||
|
||||
conf, err := parseConfig()
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 解析配置失败,使用默认配置并导出为 config.yaml: %v", err)
|
||||
conf = DefaultConfig()
|
||||
exportConfig(conf, "yaml")
|
||||
}
|
||||
|
||||
cachedConfig = conf
|
||||
return conf
|
||||
}
|
||||
|
||||
// loadConfigFile 加载配置文件
|
||||
func loadConfigFile() string {
|
||||
files := []string{
|
||||
fmt.Sprintf("%s.yaml", consts.ConfigName),
|
||||
fmt.Sprintf("%s.yml", consts.ConfigName),
|
||||
fmt.Sprintf("%s.json", consts.ConfigName),
|
||||
}
|
||||
|
||||
for _, file := range files {
|
||||
if _, err := os.Stat(file); err == nil {
|
||||
ext := filepath.Ext(file)[1:]
|
||||
globalViper.SetConfigFile(file)
|
||||
globalViper.SetConfigType(ext)
|
||||
globalViper.AddConfigPath(".")
|
||||
|
||||
if err := globalViper.ReadInConfig(); err != nil {
|
||||
log.Printf("⚠️ 读取配置文件 %s 失败: %v", file, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("✅ 成功加载配置文件: %s", file)
|
||||
return file
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parseConfig 将配置解析为结构体
|
||||
func parseConfig() (*Config, error) {
|
||||
configMu.Lock()
|
||||
defer configMu.Unlock()
|
||||
|
||||
var conf Config
|
||||
if err := globalViper.Unmarshal(&conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &conf, nil
|
||||
}
|
||||
|
||||
// exportConfig 导出配置文件(默认保存为 config.yaml)
|
||||
func exportConfig(conf *Config, format string) {
|
||||
var (
|
||||
data []byte
|
||||
err error
|
||||
)
|
||||
|
||||
switch format {
|
||||
case "json":
|
||||
data, err = json.MarshalIndent(conf, "", " ")
|
||||
case "yaml", "yml", "":
|
||||
format = "yaml"
|
||||
data, err = yaml.Marshal(conf)
|
||||
default:
|
||||
log.Printf("⚠️ 不支持的导出格式: %s", format)
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Printf("⚠️ 导出 %s 格式配置失败: %v", format, err)
|
||||
return
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("%s.%s", consts.ConfigName, format)
|
||||
if err := os.WriteFile(filename, data, 0644); err != nil {
|
||||
log.Printf("⚠️ 写入配置文件 %s 失败: %v", filename, err)
|
||||
} else {
|
||||
log.Printf("✅ 已导出配置文件: %s", filename)
|
||||
}
|
||||
|
||||
log.Printf("📋 当前 %s 配置:\n%s", format, data)
|
||||
}
|
||||
|
||||
// GetConfig 获取当前配置(线程安全)
|
||||
func GetConfig() *Config {
|
||||
configMu.RLock()
|
||||
defer configMu.RUnlock()
|
||||
|
||||
if cachedConfig == nil {
|
||||
return DefaultConfig()
|
||||
}
|
||||
conf := *cachedConfig
|
||||
return &conf
|
||||
}
|
||||
76
pkg/config/types.go
Normal file
76
pkg/config/types.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/zhilv666/navsite/pkg/common"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Debug bool `mapstructure:"debug"`
|
||||
Port int `mapstructure:"port"`
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Level string `mapstructure:"level"`
|
||||
Filepath string `mapstructure:"filepath"`
|
||||
MaxSizeMB int `mapstructure:"max_size_mb"` // 单个日志文件最大(MB)
|
||||
MaxAgeDay int `mapstructure:"max_age_day"` // 日志文件最大保存天数
|
||||
Backups int `mapstructure:"backups"` // 保留的旧文件个数
|
||||
Compress bool `mapstructure:"compress"` // 是否压缩
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
Driver string `mapstructure:"driver"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
DbName string `mapstructure:"db_name"`
|
||||
SqliteDbPath string `mapstructure:"sqlite_db_path"`
|
||||
}
|
||||
|
||||
type JWT struct {
|
||||
SecretKey string `mapstructure:"secret_key"`
|
||||
ExpireDurationHour time.Duration `mapstructure:"expire_duration_hour"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Server Server `mapstructure:"server"`
|
||||
Log Log `mapstructure:"log"`
|
||||
Database Database `mapstructure:"database"`
|
||||
JWT JWT `mapstructure:"jwt"`
|
||||
}
|
||||
|
||||
func DefaultConfig() *Config {
|
||||
logPath := filepath.Join("data", "log.log")
|
||||
dbPath := filepath.Join("data", "sqlite.db")
|
||||
return &Config{
|
||||
Server: Server{
|
||||
Debug: true,
|
||||
Port: 8080,
|
||||
},
|
||||
Log: Log{
|
||||
Level: "debug",
|
||||
Filepath: logPath,
|
||||
MaxSizeMB: 10,
|
||||
MaxAgeDay: 7,
|
||||
Backups: 3,
|
||||
Compress: true,
|
||||
},
|
||||
Database: Database{
|
||||
Driver: "sqlite",
|
||||
User: "",
|
||||
Password: "",
|
||||
Host: "",
|
||||
Port: 0,
|
||||
DbName: "",
|
||||
SqliteDbPath: dbPath,
|
||||
},
|
||||
JWT: JWT{
|
||||
SecretKey: common.Rand(16),
|
||||
ExpireDurationHour: 24,
|
||||
},
|
||||
}
|
||||
}
|
||||
29
pkg/consts/consts.go
Normal file
29
pkg/consts/consts.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package consts
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// 通用系统级常量
|
||||
const (
|
||||
AppName = "NavSite"
|
||||
AppAuthor = "zhilv666"
|
||||
AppVersion = "1.0.0"
|
||||
)
|
||||
|
||||
// 默认服务端口
|
||||
const DefaultPort = 8080
|
||||
|
||||
// 通用时间格式
|
||||
const (
|
||||
TimeFormatDate = "2006-01-02"
|
||||
TimeFormatDateTime = "2006-01-02 15:04:05"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fmt.Printf("🚀 启动 %s v%s by %s\n",
|
||||
AppName,
|
||||
AppVersion,
|
||||
AppAuthor,
|
||||
)
|
||||
}
|
||||
6
pkg/consts/paths.go
Normal file
6
pkg/consts/paths.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package consts
|
||||
|
||||
// 项目路径与文件定义
|
||||
var (
|
||||
ConfigName = "config"
|
||||
)
|
||||
6
pkg/consts/version.go
Normal file
6
pkg/consts/version.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package consts
|
||||
|
||||
var (
|
||||
BuildTime string
|
||||
GitCommit string
|
||||
)
|
||||
25
pkg/db/connect.go
Normal file
25
pkg/db/connect.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/driver/sqlserver"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func connectMySQL(dsn string) gorm.Dialector {
|
||||
return mysql.Open(dsn)
|
||||
}
|
||||
|
||||
func connectSQLite(filepath string) gorm.Dialector {
|
||||
return sqlite.Open(filepath + "?cache=shared&_fk=1&_driver=modernc.org/sqlite")
|
||||
}
|
||||
|
||||
func connectPostgres(dsn string) gorm.Dialector {
|
||||
return postgres.Open(dsn)
|
||||
}
|
||||
|
||||
func connectSQLServer(dsn string) gorm.Dialector {
|
||||
return sqlserver.Open(dsn)
|
||||
}
|
||||
64
pkg/db/db.go
Normal file
64
pkg/db/db.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/zhilv666/navsite/pkg/config"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var db *gorm.DB
|
||||
|
||||
func init() {
|
||||
if _, err := os.Stat("./data"); os.IsNotExist(err) {
|
||||
_ = os.Mkdir("./data", os.ModePerm)
|
||||
}
|
||||
}
|
||||
|
||||
func GetDB() *gorm.DB {
|
||||
if db == nil {
|
||||
panic("数据库未初始化,请先调用 db.InitDB()")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func InitDB(cfg config.Database, log *zap.Logger) {
|
||||
go func() {
|
||||
if err := recover(); err != nil {
|
||||
panic(fmt.Sprintf("Init DB Error: %v", err))
|
||||
}
|
||||
}()
|
||||
|
||||
var dsn string
|
||||
var dialector gorm.Dialector
|
||||
|
||||
switch cfg.Driver {
|
||||
case "mysql":
|
||||
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DbName)
|
||||
dialector = connectMySQL(dsn)
|
||||
case "postgres":
|
||||
dsn = fmt.Sprintf("host=%s port=%d user=%s dbname=%s password=%s sslmode=disable",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.DbName, cfg.Password)
|
||||
dialector = connectPostgres(dsn)
|
||||
case "sqlserver":
|
||||
dsn = fmt.Sprintf("sqlserver://%s:%s@%s:%d?database=%s",
|
||||
cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DbName)
|
||||
dialector = connectSQLServer(dsn)
|
||||
case "sqlite", "sqlite3":
|
||||
fmt.Println(cfg.SqliteDbPath)
|
||||
dialector = connectSQLite(cfg.SqliteDbPath)
|
||||
default:
|
||||
panic(fmt.Errorf("unsupport database driver: %s", cfg.Driver))
|
||||
}
|
||||
var err error
|
||||
db, err = gorm.Open(dialector, &gorm.Config{
|
||||
Logger: NewZapGormLogger(log).LogMode(logger.Info),
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
81
pkg/db/gorm_logger.go
Normal file
81
pkg/db/gorm_logger.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
type ZapGormLogger struct {
|
||||
ZapLogger *zap.Logger
|
||||
LogLevel logger.LogLevel
|
||||
SlowThreshold time.Duration
|
||||
}
|
||||
|
||||
func NewZapGormLogger(zapLogger *zap.Logger) *ZapGormLogger {
|
||||
return &ZapGormLogger{
|
||||
ZapLogger: zapLogger,
|
||||
LogLevel: logger.Info,
|
||||
SlowThreshold: time.Second, // 1s 慢查询阈值
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ZapGormLogger) LogMode(level logger.LogLevel) logger.Interface {
|
||||
newlogger := *l
|
||||
newlogger.LogLevel = level
|
||||
return &newlogger
|
||||
}
|
||||
|
||||
func (l *ZapGormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= logger.Info {
|
||||
l.ZapLogger.Sugar().Infof(msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ZapGormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= logger.Warn {
|
||||
l.ZapLogger.Sugar().Warnf(msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ZapGormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
||||
if l.LogLevel >= logger.Error {
|
||||
l.ZapLogger.Sugar().Errorf(msg, data...)
|
||||
}
|
||||
}
|
||||
|
||||
func (l *ZapGormLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
|
||||
if l.LogLevel <= logger.Silent {
|
||||
return
|
||||
}
|
||||
elapsed := time.Since(begin)
|
||||
sql, rows := fc()
|
||||
|
||||
switch {
|
||||
case err != nil && l.LogLevel >= logger.Error:
|
||||
l.ZapLogger.Error("SQL Error",
|
||||
zap.Error(err),
|
||||
zap.String("sql", sql),
|
||||
zap.Int64("rows", rows),
|
||||
zap.Duration("elapsed", elapsed),
|
||||
zap.String("file", utils.FileWithLineNum()),
|
||||
)
|
||||
case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= logger.Warn:
|
||||
l.ZapLogger.Warn("Slow SQL",
|
||||
zap.Duration("elapsed", elapsed),
|
||||
zap.String("sql", sql),
|
||||
zap.Int64("rows", rows),
|
||||
zap.String("file", utils.FileWithLineNum()),
|
||||
)
|
||||
case l.LogLevel >= logger.Info:
|
||||
l.ZapLogger.Info("SQL",
|
||||
zap.String("sql", sql),
|
||||
zap.Int64("rows", rows),
|
||||
zap.Duration("elapsed", elapsed),
|
||||
zap.String("file", utils.FileWithLineNum()),
|
||||
)
|
||||
}
|
||||
}
|
||||
27
pkg/logger/func.go
Normal file
27
pkg/logger/func.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package logger
|
||||
|
||||
import "go.uber.org/zap"
|
||||
|
||||
func GetLogger() *zap.Logger {
|
||||
return logger
|
||||
}
|
||||
|
||||
func Debug(msg string, fields ...zap.Field) {
|
||||
logger.Debug(msg, fields...)
|
||||
}
|
||||
|
||||
func Info(msg string, fields ...zap.Field) {
|
||||
logger.Info(msg, fields...)
|
||||
}
|
||||
|
||||
func Warn(msg string, fields ...zap.Field) {
|
||||
logger.Warn(msg, fields...)
|
||||
}
|
||||
|
||||
func Error(msg string, fields ...zap.Field) {
|
||||
logger.Error(msg, fields...)
|
||||
}
|
||||
|
||||
func Sync() {
|
||||
_ = logger.Sync()
|
||||
}
|
||||
87
pkg/logger/logger.go
Normal file
87
pkg/logger/logger.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/zhilv666/navsite/pkg/config"
|
||||
"github.com/zhilv666/navsite/pkg/consts"
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
var logger *zap.Logger
|
||||
|
||||
func Init(log config.Log) {
|
||||
var zapLevel zapcore.Level
|
||||
|
||||
// 日志等级解析
|
||||
switch log.Level {
|
||||
case "debug":
|
||||
zapLevel = zap.DebugLevel
|
||||
case "info":
|
||||
zapLevel = zap.InfoLevel
|
||||
case "warning":
|
||||
zapLevel = zap.WarnLevel
|
||||
case "error":
|
||||
zapLevel = zap.ErrorLevel
|
||||
default:
|
||||
zapLevel = zap.InfoLevel
|
||||
}
|
||||
|
||||
// lumberjack 日志切割配置
|
||||
writeSyncer := zapcore.AddSync(&lumberjack.Logger{
|
||||
Filename: log.Filepath,
|
||||
MaxSize: log.MaxSizeMB,
|
||||
MaxBackups: log.Backups,
|
||||
MaxAge: log.MaxAgeDay,
|
||||
Compress: log.Compress,
|
||||
})
|
||||
|
||||
// 日志编码格式
|
||||
encoderConfigColor := zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "Stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.CapitalColorLevelEncoder, // 彩色等级输出(终端)
|
||||
// EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
EncodeTime: timeEncoder,
|
||||
EncodeDuration: zapcore.SecondsDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
|
||||
encoderConfig := zapcore.EncoderConfig{
|
||||
TimeKey: "time",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "Stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
// EncodeLevel: zapcore.CapitalColorLevelEncoder, // 彩色等级输出(终端)
|
||||
EncodeLevel: zapcore.CapitalLevelEncoder,
|
||||
EncodeTime: timeEncoder,
|
||||
EncodeDuration: zapcore.SecondsDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
|
||||
encoderConsole := zapcore.NewConsoleEncoder(encoderConfigColor)
|
||||
encoderJson := zapcore.NewJSONEncoder(encoderConfig)
|
||||
|
||||
core := zapcore.NewTee(
|
||||
zapcore.NewCore(encoderJson, writeSyncer, zapLevel),
|
||||
zapcore.NewCore(encoderConsole, zapcore.AddSync(os.Stdout), zapLevel))
|
||||
|
||||
logger = zap.New(core, zap.AddCaller(), zap.AddCallerSkip(1))
|
||||
zap.ReplaceGlobals(logger)
|
||||
}
|
||||
|
||||
// 时间格式
|
||||
func timeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
|
||||
enc.AppendString(t.Format(consts.TimeFormatDateTime))
|
||||
}
|
||||
81
pkg/utils/jwt.go
Normal file
81
pkg/utils/jwt.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/zhilv666/navsite/pkg/config"
|
||||
"github.com/zhilv666/navsite/pkg/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type MyCustomClaims struct {
|
||||
ID uint `json:"id"`
|
||||
Email string `json:"Email"`
|
||||
SsoID string `json:"sso_id"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateJwtToken 生成 JWT Token
|
||||
func GenerateJwtToken(id uint, email, sso_id string) string {
|
||||
jwtConf := config.GetConfig().JWT
|
||||
claims := MyCustomClaims{
|
||||
id,
|
||||
email,
|
||||
sso_id,
|
||||
jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(jwtConf.ExpireDurationHour * time.Hour)),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
ss, err := token.SignedString([]byte(jwtConf.SecretKey))
|
||||
if err != nil {
|
||||
logger.Error("generate jwt token error", zap.String("jwt", err.Error()))
|
||||
}
|
||||
return ss
|
||||
}
|
||||
|
||||
// ParseJwtToken 解析 JWT 并返回自定义 Claims,同时返回错误信息
|
||||
func ParseJwtToken(tokenString string) (*MyCustomClaims, error) {
|
||||
jwtConf := config.GetConfig().JWT
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (any, error) {
|
||||
return []byte(jwtConf.SecretKey), nil
|
||||
}, jwt.WithLeeway(5*time.Second))
|
||||
|
||||
if err != nil {
|
||||
logger.Error("parse jwt token error", zap.String("jwt", err.Error()))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*MyCustomClaims)
|
||||
if !ok || !token.Valid {
|
||||
logger.Error("invalid jwt claims")
|
||||
return nil, errors.New("jwt token invalid")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// ValidJwtToken 验证 JWT Token
|
||||
func ValidJwtToken(tokenString string) bool {
|
||||
jwtConf := config.GetConfig().JWT
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (any, error) {
|
||||
return []byte(jwtConf.SecretKey), nil
|
||||
}, jwt.WithLeeway(5*time.Second))
|
||||
|
||||
if err != nil {
|
||||
logger.Error("parse jwt token error", zap.String("jwt", err.Error()))
|
||||
return false
|
||||
}
|
||||
|
||||
_, ok := token.Claims.(*MyCustomClaims)
|
||||
if !ok || !token.Valid {
|
||||
logger.Error("invalid jwt claims")
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
25
pkg/utils/utils_test.go
Normal file
25
pkg/utils/utils_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zhilv666/navsite/pkg/config"
|
||||
"github.com/zhilv666/navsite/pkg/logger"
|
||||
"github.com/zhilv666/navsite/pkg/utils"
|
||||
)
|
||||
|
||||
func TestJwt(t *testing.T) {
|
||||
conf := config.NewConfig()
|
||||
logger.Init(conf.Log)
|
||||
toekn := utils.GenerateJwtToken(11, "zhilv666@qq.com", "")
|
||||
t.Log("jwt token: ", toekn)
|
||||
|
||||
claimas, err := utils.ParseJwtToken("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6MTEsIkVtYWlsIjoiemhpbHY2NjZAcXEuY29tIiwic3NvX2lkIjoiIiwiZXhwIjoxNzYyMDkxNzYyfQ.bQeIyXvkOExxD4DAy5Eyjgwj9FbjE-AO6FCLF-YFGVA")
|
||||
if err != nil {
|
||||
t.Log("parse jwt token err")
|
||||
return
|
||||
}
|
||||
t.Log(claimas)
|
||||
|
||||
t.Log(utils.ValidJwtToken("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpZCI6MTEsIkVtYWlsIjoiemhpbHY2NjZAcXEuY29tIiwic3NvX2lkIjoiIiwiZXhwIjoxNzYyMDkyNTIxfQ.QynKGZmUSOXGgVVsqf-IMYBb11UPC6DT56p1UaNgHC0"))
|
||||
}
|
||||
Reference in New Issue
Block a user