【Go】基于GoFiber从零开始搭建一个GoWeb后台管理系统(二)日志输出中间件、校验token中间件、配置路由、基础工具函数。

2023-12-13 10:06:55

上一篇:【Go】基于GoFiber从零开始搭建一个GoWeb后台管理系统(一)搭建项目

在这里插入图片描述

上一篇我们搭好了项目框架,那接下来就可以开始写业务了,然后在写业务之前,我们还需要考虑几个问题:

  1. 首先一个就是日志问题,控制台日志输出、日志输出到文件;
  2. 第二个就是token校验,每次请求服务器时,在执行接口业务之前,需要先校验请求头的token是否有效;
  3. 第三个就是请求成功后,返回回去的参数需要统一。

日志

我们先来说说日志,我们需要控制全局日志输出,输出打印时间、请求的接口、错误信息、SQL语句等这些信息。除了将日志输出到控制台之外,我们还需要将日志输出到文件,按天生成,像这样:

在这里插入图片描述

全局日志我们可以用中间件来实现,中间件里输出的日志就是请求的接口、接口执行报错的错误信息等。

除了在中间件里拦截一些信息作为日志输出外,我们还可以自定义日志格式化输出。比如说你在代码的哪个位置想输出一句话,用fmt.Print() 也是可以直接打印的,但是我想要进行格式输出打印时间、具体的文件、具体的某一行,这就需要我们去自定义实现输出格式了。

还有就是除了上面两个日志输出,我操作数据库的时候,我也希望将具体的sql打印出来。

代码

日志输出中间件

middleware.go

// 统一的日志格式化输出中间件
func LoggerPrint() fiber.Handler {
	return func(c *fiber.Ctx) error {
		start := time.Now()
		// 处理请求
		err := c.Next()
		var logMessage string
		if err != nil {
			// 记录日志
			logMessage = fmt.Sprintf("[%s] %s %s - %s ==> [Error] %s\n", start.Format("2006-01-02 15:04:05"), c.Method(), c.Path(), time.Since(start), err.Error())
		} else {
			// 记录日志
			logMessage = fmt.Sprintf("[%s] %s %s - %s\n", start.Format("2006-01-02 15:04:05"), c.Method(), c.Path(), time.Since(start))
		}
		// 输出到控制台
		fmt.Print(logMessage)
		// 输出到文件
		filename := "logs/" + time.Now().Format("2006-01-02") + ".log"
		file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
		if err != nil {
			mylog.Error("日志文件的打开错误 : " + err.Error())
			return err
		}
		defer file.Close()
		if _, err := file.WriteString(logMessage); err != nil {
			mylog.Error("写入日志文件错误 : " + err.Error())
		}
		return err
	}
}

自定义格式日志输出

log.go

package mylog

import (
	"fmt"
	"gorm.io/gorm/logger"
	"os"
	"path/filepath"
	"runtime"
	"strings"
	"time"
)

func Info(msg string) {
	logOut(msg, "Info")
}

func Debug(msg string) {
	logOut(msg, "Debug")
}

func Error(msg string) {
	logOut(msg, "Error")
}

// 日志格式化输出
func LogOut(msg string) {
	// 替换掉彩色打印符号
	msg = strings.ReplaceAll(msg, logger.Reset, "")
	msg = strings.ReplaceAll(msg, logger.Red, "")
	msg = strings.ReplaceAll(msg, logger.Green, "")
	msg = strings.ReplaceAll(msg, logger.Yellow, "")
	msg = strings.ReplaceAll(msg, logger.Blue, "")
	msg = strings.ReplaceAll(msg, logger.Magenta, "")
	msg = strings.ReplaceAll(msg, logger.Cyan, "")
	msg = strings.ReplaceAll(msg, logger.White, "")
	msg = strings.ReplaceAll(msg, logger.BlueBold, "")
	msg = strings.ReplaceAll(msg, logger.MagentaBold, "")
	msg = strings.ReplaceAll(msg, logger.RedBold, "")
	msg = strings.ReplaceAll(msg, logger.YellowBold, "")
	logOutFile(msg) // 输出到文件
}

// 日志格式化输出
func logOut(msg string, level string) {
	start := time.Now()
	// 获取调用的文件和行号
	_, file, line, _ := runtime.Caller(2)
	file = filepath.Base(file)
	// 使用日志包记录日志,并包括级别、文件名和行号
	logMsg := fmt.Sprintf("[%s] - %s ==> [%s:%d] [%s] %s\n",
		start.Format("2006-01-02 15:04:05"),
		time.Since(start),
		file,
		line,
		level,
		msg,
	)
	fmt.Print(logMsg)  // 打印到控制台
	logOutFile(logMsg) // 输出到文件
}

// 日志输出到文件
func logOutFile(msg string) {
	// 输出到文件
	filename := "logs/" + time.Now().Format("2006-01-02") + ".log"
	file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
	if err != nil {
		fmt.Println("日志文件的打开错误 :", err)
	}
	defer file.Close()
	if _, err := file.WriteString(msg); err != nil {
		fmt.Println("写入日志文件错误 :", err)
	}
}

gorm SQL日志输出到文件

控制 Gorm 的 SQL 日志输出,在 application.go 控制就行,然后我上一篇是有把 application.go 的代码写出来的,这里直接在上一篇的基础上把下面的代码复制过去就行。

application.go

type Writer struct{}

// 自定义的sql日志输出(到文件)
func (w Writer) Printf(format string, args ...interface{}) {
	msg := fmt.Sprintf(format, args...) + "\n"
	fmt.Println(msg)  // 打印到控制台
	mylog.LogOut(msg) // 输出到文件
}

func LoadMysql() {
	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=%s",
		Config.Get("database.username"),
		Config.Get("database.password"),
		Config.Get("database.host"),
		Config.GetInt("database.port"),
		Config.Get("database.database"),
		Config.Get("database.timeout"))
	// 设置操作数据库的日志输出到文件
	mylogger := logger.New(
		Writer{},
		logger.Config{
			SlowThreshold: time.Second, // 慢 SQL 阈值
			LogLevel:      logger.Info, // Log level
			Colorful:      true,        // 允许彩色打印
		},
	)
	db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
		// 跳过默认事务:为了确保数据一致性,GORM 会在事务里执行写入操作(创建、更新、删除)。如果没有这方面的要求,您可以在初始化时禁用它,这样可以获得60%的性能提升
		//SkipDefaultTransaction: true,
		//Logger: logger.Default.LogMode(logger.Info), // sql全局日志
		Logger: mylogger, // 使用自定义的日志输出
		NamingStrategy: schema.NamingStrategy{
			//TablePrefix:   "sys_",  // 表名前缀
			SingularTable: true, // 单数表名
			//NoLowerCase:   false, // 关闭小写转换
		},
	})
	if err != nil {
		fmt.Println("无法连接到MySQL :", err)
	}
	DB = db
}

// 省略其他代码......

你别说日志输出这一块当时纠结了我挺久的。。。

token校验中间件

校验token和请求的IP。我这个系统的token不是用的jwt生成的,就是一串随机字符串生成的。

middleware.go

// token校验
func CheckToken(c *fiber.Ctx) error {
	parsedIP := c.IP() // 获取用户请求的ip
	// 校验用户 IP 是否在白名单中
	if !IsIPInWhitelist(parsedIP) {
		return c.Status(http.StatusOK).JSON(config.Error("非法访问"))
	}
	path := c.Path() // 这里获取的是接口的完整路径,如:/sys/login、/sys/user/list
	// 排除指定接口,不校验token
	if path == "/sys/login" || path == "/sys/getKey" || path == "/sys/getCode" {
		return c.Next()
	}
	// 获取请求头中的 Token
	token := c.Get(config.TokenHeader)
	// 如果 Token 为空,返回未授权状态
	if token == "" {
		return c.Status(http.StatusOK).JSON(config.ErrorCode(1003, "用户未登录"))
	}
	// 校验携带的token在redis中是否存在
	val := config.RedisConn.HExists(config.CachePrefix+token, "token").Val()
	if !val {
		return c.Status(http.StatusOK).JSON(config.ErrorCode(1003, "用户未登录"))
	}
	// 刷新有效期
	v := model.GetCreateTime(token) // 获取token的创建时间
	//判断token的创建时间是否大于2小时,如果是的话则需要刷新token
	s := time.Now().Unix() - v
	hour := s / 1000 / (60 * 60)
	if hour > 2 {
		// TODO 获取当前用户信息,重新登录,生成新的token,将新token设置到响应头中
		user := model.GetLoginUser(token)
		// TODO 这里重新登录(会把旧登录删除),生成新的token
		splits := strings.Split(token, "_")
		var newToken string
		if len(splits) > 1 {
			newToken = user.Login(splits[0], config.TokenExpire)
		} else {
			newToken = user.Login("", config.TokenExpire)
		}
		if len(newToken) == 0 {
			return c.Status(http.StatusOK).JSON(config.Error("登录失败"))
		}
		token = newToken
		c.Response().Header.Set(config.TokenHeader, newToken) // 将新的token设置到请求头中
	}
	timeOut := model.GetTimeOut(token) // 获取token的过期时间
	if timeOut != -1 {                 // token没过期,过期时间不是-1的时候,每次请求都刷新过期时间
		model.UpdateTimeOut(token, config.TokenExpire)
	}
	return c.Next()
}

// 辅助函数:检查 IP 是否在白名单中
func IsIPInWhitelist(ip string) bool {
	parsedIP := net.ParseIP(ip)
	for _, allowedIP := range config.AuthHost {
		if allowedIP == "*" {
			return true
		}
		if strings.Contains(allowedIP, "*") {
			// 将 * 转换为正则表达式
			regexPattern := strings.ReplaceAll(allowedIP, "*", ".*")
			if match, _ := regexp.MatchString("^"+regexPattern+"$", ip); match {
				return true
			}
		} else {
			// 非通配符的精确匹配
			if parsedIP.Equal(net.ParseIP(allowedIP)) {
				return true
			}
		}
	}
	return false
}

定义baseModel、全局常量、统一返回参数等

这个 base_model.go 文件就是定义一些全局需要用到的 结构体、常量、返回参数 等。这些放在一起找起来、看起来比较方便,不用在各个文件跳来跳去。只要你代码写好注释、规划好哪一块连续放在一起、风格统一,就不会显得杂乱。

base_model.go

package config

import "time"

// ===================================== 公共常量 =====================================
const (
	CachePrefix       = "go-web:login:"                                                  // 缓存前缀
	ERROR_COUNT       = "go-web:errorCount:"                                             // 密码错误次数缓存key
	TokenHeader       = "go-web"                                                         // request请求头属性
	TokenExpire       = time.Second * 1800                                               // token默认有效期(单位秒)
	UNKNOWN_EXCEPTION = "未知异常"                                                           // 全局异常 未知异常
	PARENT_VIEW       = "ParentView"                                                     // ParentView组件标识
	InitPassword      = "123456"                                                         // 初始密码
	RandomCharset     = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" // 随机字符串
	RandomCaptcha     = "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"                               // 验证码字符串
	DATA_SCOPE        = "go-web:dataScope:"                                              // 数据范围缓存
)

// ==================================== 公共model ====================================

// 一个公共的model
type BaseModel struct {
	Id         string    `query:"id" json:"id" form:"id"`
	CreatorId  *string   `json:"creatorId" form:"creatorId"`
	CreateTime time.Time `json:"createTime" form:"createTime"`
	UpdateId   *string   `json:"updateId" form:"updateId"`
	UpdateTime *string   `json:"updateTime" form:"updateTime"`
	Token      string    `gorm:"-" json:"token" form:"token"` // token
}

// 统一的返回参数格式
type Result struct {
	Code    int    `json:"code"`    // 统一的返回码,0 成功 -1 失败
	Message string `json:"message"` // 统一的返回信息
	Data    any    `json:"data"`    // 统一的返回数据
}

// 统一的树形结构格式
/*type TreeVo struct {
	Id       string   `json:"id"`       // 统一的返回码,0 成功 -1 失败
	Label    string   `json:"label"`    // 统一的返回信息
	ParentId string   `json:"parentId"` // 统一的返回数据
	Children []TreeVo `json:"children"` // 子级数据
}*/

// 将list转为统一的树形结构格式
/*func ConvertToTreeVo(list []interface{}) []TreeVo {
	result := []TreeVo{}
	for _, t := range list {
		item, _ := t.(map[string]interface{})
		id, _ := item["id"].(string)
		parentId, _ := item["parentId"].(string)
		label, _ := item["name"].(string)
		children, _ := item["children"].([]interface{})
		tree := TreeVo{Id: id, ParentId: parentId, Label: label}
		tree.Children = ConvertToTreeVo(children)
		result = append(result, tree)
	}
	return result
}*/

// 请求成功的默认返回
func Success(obj any) Result {
	return Result{0, "ok", obj}
}

// 请求失败的默认返回,code默认为-1
func Error(message string) Result {
	return ErrorCode(-1, message)
}

// 请求失败的默认返回
func ErrorCode(code int, message string) Result {
	return Result{code, message, nil}
}

// 分页结构体封装
type PageInfo struct {
	List  any   `json:"list"`  // 返回结果
	Total int64 `json:"total"` // 返回总数
}

路由

在go中,我们写的controller接口是需要全部手动注册到路由中的,在路由中写具体的接口名,在controller中写接口的具体实现函数或方法。

然后在注册路由之前,我们需要先注册中间件。

router.go

package router

import (
	"fmt"
	"github.com/gofiber/fiber/v2"
	"go-web2/app/common/config"
	"go-web2/app/common/middleware"
	api "go-web2/app/controller/sys" // 这种写法是,如果包名重复了,我们可以在前面定义这个包名用于区分重复的包
)

// 初始化路由
func InitRouter() *fiber.App {
	app := fiber.New()
	// 初始化yml配置
	_, err := config.InitConfig()
	if err != nil {
		panic(fmt.Errorf("加载yml配置文件错误: %s \n", err))
	}
	// 中间件
	app.Use(middleware.LoggerPrint) // 使用日志中间件
	app.Use(middleware.SetHeader)     // 设置统一的请求头
	app.Use(middleware.CheckToken)    // 应用 token 校验中间件到需要验证的路由
	//app.common.config.StartScheduledTask() // 调用定时任务
	// 注册路由
	loginRouter(app) // 登录
	logRouter(app)   // 日志管理
	safeRouter(app)  // 安全设置
	userRouter(app)  // 用户管理
	deptRouter(app)  // 部门管理
	roleRouter(app)  // 角色管理
	menuRouter(app)  // 菜单管理
	dictRouter(app)  // 字典管理
	return app
}

// 登录路由
func loginRouter(app *fiber.App) {
	controller := api.LoginController{}
	login := app.Group("/sys")
	{
		login.Get("/getKey", controller.GetKey)    // 获取RSA公钥
		login.Get("/getCode", controller.GetCode)  // 获取验证码
		login.Post("/login", controller.Login)     // 用户登录
		login.Delete("/logout", controller.Logout) // 用户退出
	}
}

// 日志管理路由
func logRouter(app *fiber.App) {
	log := app.Group("/sys/log")
	{
		log.Get("/list", api.LogController{}.GetPage) // 日志列表
	}
}

// 安全设置路由
func safeRouter(app *fiber.App) {
	controller := api.SafeController{}
	safe := app.Group("/sys/safe")
	{
		safe.Get("/getSafeSet", controller.GetSafeSet) // 获取安全设置
		safe.Post("/update", controller.Update)        // 修改安全设置
	}
}

// 用户管理路由
func userRouter(app *fiber.App) {
	controller := api.UserController{}
	user := app.Group("/sys/user")
	{
		user.Get("/getLoginUser", controller.GetLoginUser)      // 获取当前登录的用户
		user.Get("/list", controller.GetPage)                   // 用户列表
		user.Get("/getById/:id", controller.GetById)            // 根据id获取用户
		user.Post("/insert", controller.Insert)                 // 新增用户
		user.Post("/update", controller.Update)                 // 修改用户
		user.Delete("/delete", controller.Delete)               // 删除用户
		user.Post("/updatePassword", controller.UpdatePassword) // 修改密码
		user.Post("/resetPassword", controller.ResetPassword)   // 重置密码
		user.Post("/upload", controller.Upload)                 // 上传头像
	}
}

// 部门管理路由
func deptRouter(app *fiber.App) {
	controller := api.DeptController{}
	dept := app.Group("/sys/dept")
	{
		dept.Get("/list", controller.GetList)             // 部门树列表
		dept.Get("/getById/:id", controller.GetById)      // 根据id获取部门
		dept.Post("/insert", controller.Insert)           // 新增部门
		dept.Post("/update", controller.Update)           // 修改部门
		dept.Delete("/delete/:id", controller.Delete)     // 删除部门
		dept.Get("/deptSelect", controller.GetSelectList) // 部门下拉树列表
	}
}

// 角色管理路由
func roleRouter(app *fiber.App) {
	controller := api.RoleController{}
	role := app.Group("/sys/role")
	{
		role.Get("/list", controller.GetPage)              // 角色列表
		role.Get("/getById/:id", controller.GetById)       // 根据id获取角色
		role.Get("/createRoleCode", controller.CreateCode) // 生成角色编码
		role.Post("/insert", controller.Insert)            // 新增角色
		role.Post("/update", controller.Update)            // 修改角色
		role.Post("/updateState", controller.UpdateState)  // 修改角色状态
		role.Delete("/delete", controller.Delete)          // 删除角色
		role.Get("/roleSelect", controller.GetSelectList)  // 角色下拉框
	}
}

// 菜单管理路由
func menuRouter(app *fiber.App) {
	controller := api.MenuController{}
	menu := app.Group("/sys/menu")
	{
		menu.Get("/list", controller.GetList)                      // 菜单列表
		menu.Get("/getRouters", controller.GetRouters)             // 路由列表
		menu.Get("/getById/:id", controller.GetById)               // 根据id获取菜单
		menu.Get("/roleMenuTree/:roleId", controller.RoleMenuTree) // 获取对应角色菜单列表树
		menu.Post("/insert", controller.Insert)                    // 新增菜单
		menu.Post("/update", controller.Update)                    // 修改菜单
		menu.Delete("/delete/:id", controller.Delete)              // 删除菜单
	}
}

// 字典管理路由
func dictRouter(app *fiber.App) {
	controller := api.DictController{}
	dict := app.Group("/sys/dict")
	{
		dict.Get("/typeList", controller.GetTypeList)         // 获取字段类型列表
		dict.Get("/list", controller.GetPage)                 // 字段项列表分页
		dict.Get("/getById/:id", controller.GetById)          // 根据id获取字段
		dict.Get("/createDictCode", controller.CreateCode)    // 生成字典代码
		dict.Get("/hasDictByName", controller.HasByName)      // 字典名称是否存在
		dict.Get("/hasDictByCode", controller.HasByCode)      // 字典代码是否存在
		dict.Post("/insert", controller.Insert)               // 新增字典
		dict.Post("/update", controller.Update)               // 修改字典
		dict.Delete("/deleteType/:id", controller.DeleteType) // 删除字典类型
		dict.Delete("/delete", controller.Delete)             // 删除字典
		dict.Get("/getByTypeCode", controller.GetByTypeCode)  // 根据字典类型代码获取字典项列表
	}
}

然后main.go就改成下面的写法:

package main

import (
	"fmt"
	"go-web2/app/common/config"
	"go-web2/app/common/util"
	"go-web2/app/router"
)

func main() {
	app := router.InitRouter()
	util.GenerateKeyPair() // 初始化RSA密钥对
	app.Listen(fmt.Sprintf(":%d", config.HTTPPort))
}

最后

middleware.go 中间件完整代码

package middleware

import (
	"fmt"
	"github.com/gofiber/fiber/v2"
	"go-web2/app/common/config"
	"go-web2/app/common/mylog"
	model "go-web2/app/model/sys"
	"net"
	"net/http"
	"os"
	"regexp"
	"strings"
	"time"
)

// ================================================= 中间件合集 =================================================

// 统一的日志格式化输出中间件
func LoggerPrint() fiber.Handler {
	return func(c *fiber.Ctx) error {
		start := time.Now()
		// 处理请求
		err := c.Next()
		var logMessage string
		if err != nil {
			// 记录日志
			logMessage = fmt.Sprintf("[%s] %s %s - %s ==> [Error] %s\n", start.Format("2006-01-02 15:04:05"), c.Method(), c.Path(), time.Since(start), err.Error())
		} else {
			// 记录日志
			logMessage = fmt.Sprintf("[%s] %s %s - %s\n", start.Format("2006-01-02 15:04:05"), c.Method(), c.Path(), time.Since(start))
		}
		// 输出到控制台
		fmt.Print(logMessage)
		// 输出到文件
		filename := "logs/" + time.Now().Format("2006-01-02") + ".log"
		file, err := os.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
		if err != nil {
			mylog.Error("日志文件的打开错误 : " + err.Error())
			return err
		}
		defer file.Close()
		if _, err := file.WriteString(logMessage); err != nil {
			mylog.Error("写入日志文件错误 : " + err.Error())
		}
		return err
	}
}

// 设置请求头
func SetHeader(c *fiber.Ctx) error {
	c.Set("Set-Cookie", "name=value; SameSite=Strict; cookiename=httponlyTest;Path=/;Domain=domainvalue;Max-Age=seconds;Secure;HTTPOnly")
	c.Set("Content-Security-Policy", "default-src 'self'; script-src 'self'; frame-ancestors 'self'; object-src 'none'")
	c.Set("Access-Control-Allow-Credentials", "true")
	c.Set("Referrer-Policy", "no-referrer")
	c.Set("X-XSS-Protection", "1; mode=block") //1; mode=block:启用XSS保护,并在检查到XSS攻击时,停止渲染页面
	c.Set("X-Content-Type-Options", "nosniff") //互联网上的资源有各种类型,通常浏览器会根据响应头的Content-Type字段来分辨它们的类型。通过这个响应头可以禁用浏览器的类型猜测行为
	c.Set("X-Frame-Options", "SAMEORIGIN")     //SAMEORIGIN:不允许被本域以外的页面嵌入
	c.Set("X-DNS-Prefetch-Control", "off")
	c.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
	c.Set("Cache-Control", "no-cache, no-store, must-revalidate")
	c.Set("Pragma", "no-cache")
	c.Set("Expires", "0")
	return c.Next()
}

// token校验
func CheckToken(c *fiber.Ctx) error {
	parsedIP := c.IP() // 获取用户请求的ip
	// 校验用户 IP 是否在白名单中
	if !IsIPInWhitelist(parsedIP) {
		return c.Status(http.StatusOK).JSON(config.Error("非法访问"))
	}
	path := c.Path() // 这里获取的是接口的完整路径,如:/sys/login、/sys/user/list
	// 排除指定接口,不校验token
	if path == "/sys/login" || path == "/sys/getKey" || path == "/sys/getCode" {
		return c.Next()
	}
	// 获取请求头中的 Token
	token := c.Get(config.TokenHeader)
	// 如果 Token 为空,返回未授权状态
	if token == "" {
		return c.Status(http.StatusOK).JSON(config.ErrorCode(1003, "用户未登录"))
	}
	// 校验携带的token在redis中是否存在
	val := config.RedisConn.HExists(config.CachePrefix+token, "token").Val()
	if !val {
		return c.Status(http.StatusOK).JSON(config.ErrorCode(1003, "用户未登录"))
	}
	// 刷新有效期
	v := model.GetCreateTime(token) // 获取token的创建时间
	//判断token的创建时间是否大于2小时,如果是的话则需要刷新token
	s := time.Now().Unix() - v
	hour := s / 1000 / (60 * 60)
	if hour > 2 {
		// TODO 获取当前用户信息,重新登录,生成新的token,将新token设置到响应头中
		user := model.GetLoginUser(token)
		// TODO 这里重新登录(会把旧登录删除),生成新的token
		splits := strings.Split(token, "_")
		var newToken string
		if len(splits) > 1 {
			newToken = user.Login(splits[0], config.TokenExpire)
		} else {
			newToken = user.Login("", config.TokenExpire)
		}
		if len(newToken) == 0 {
			return c.Status(http.StatusOK).JSON(config.Error("登录失败"))
		}
		token = newToken
		c.Response().Header.Set(config.TokenHeader, newToken) // 将新的token设置到请求头中
	}
	timeOut := model.GetTimeOut(token) // 获取token的过期时间
	if timeOut != -1 {                 // token没过期,过期时间不是-1的时候,每次请求都刷新过期时间
		model.UpdateTimeOut(token, config.TokenExpire)
	}
	return c.Next()
}

// 辅助函数:检查 IP 是否在白名单中
func IsIPInWhitelist(ip string) bool {
	parsedIP := net.ParseIP(ip)
	for _, allowedIP := range config.AuthHost {
		if allowedIP == "*" {
			return true
		}
		if strings.Contains(allowedIP, "*") {
			// 将 * 转换为正则表达式
			regexPattern := strings.ReplaceAll(allowedIP, "*", ".*")
			if match, _ := regexp.MatchString("^"+regexPattern+"$", ip); match {
				return true
			}
		} else {
			// 非通配符的精确匹配
			if parsedIP.Equal(net.ParseIP(allowedIP)) {
				return true
			}
		}
	}
	return false
}

基础工具 base_util.go

生成随机token、生成登录验证码、MD5加密、RSA非对称加密解密、盐值加密、中文首字母转大写、保存文件等等。

package util

import (
	"crypto"
	"crypto/md5"
	crand "crypto/rand" // 当出现两个相同的包名但父级包不同时,可以给引用起别名进行区分
	"crypto/rsa"
	"crypto/sha256"
	"crypto/x509"
	"encoding/base64"
	"encoding/hex"
	"encoding/pem"
	"github.com/gofiber/fiber/v2/log"
	"github.com/mojocn/base64Captcha"
	"github.com/mozillazg/go-pinyin"
	"github.com/pkg/errors"
	"go-web2/app/common/config"
	"golang.org/x/crypto/bcrypt"
	"image/color"
	"io"
	mrand "math/rand"
	"mime/multipart"
	"os"
	"path/filepath"
	"strings"
	"time"
)

// 生成随机字符串作为令牌
func GenerateRandomToken(length int) string {
	tokenBytes := make([]byte, length)
	mrand.NewSource(time.Now().UnixNano())
	for i := range tokenBytes {
		tokenBytes[i] = config.RandomCharset[mrand.Intn(len(config.RandomCharset))]
	}
	return string(tokenBytes)
}

// 验证码,第一个参数是验证码的上限个数(最多存多少个验证码),第二个参数是验证码的有效时间
var captchaStore = base64Captcha.NewMemoryStore(500, 1*time.Minute)

// 生成验证码
func GenerateCaptcha(length, width, height int) (lid string, lb64s string) {
	var driver base64Captcha.Driver
	var driverString base64Captcha.DriverString
	// 配置验证码信息
	captchaConfig := base64Captcha.DriverString{
		Height:          height,
		Width:           width,
		NoiseCount:      0,
		ShowLineOptions: 2 | 4,
		Length:          length,
		Source:          config.RandomCaptcha,
		BgColor: &color.RGBA{
			R: 3,
			G: 102,
			B: 214,
			A: 125,
		},
		Fonts: []string{"wqy-microhei.ttc"},
	}
	driverString = captchaConfig
	driver = driverString.ConvertFonts()
	captcha := base64Captcha.NewCaptcha(driver, captchaStore)
	lid, lb64s, _ = captcha.Generate()
	//fmt.Println(lid)
	//fmt.Println(lb64s)
	return
}

// 验证captcha是否正确
func CaptVerify(id string, capt string) bool {
	capt = strings.ToUpper(capt) // 默认将验证码转为大写
	// 第三个参数为true,表示校验完之后,这个验证码删除(不管校验是否通过)
	if captchaStore.Verify(id, capt, true) {
		return true
	} else {
		return false
	}
}

// md5加密
func MD5(v string) string {
	d := []byte(v)
	m := md5.New()
	m.Write(d)
	return hex.EncodeToString(m.Sum(nil))
}

var (
	PrivateKey *rsa.PrivateKey
	PublicKey  rsa.PublicKey
)

// 生成随机密钥对
func GenerateKeyPair() (*rsa.PrivateKey, *rsa.PublicKey) {
	var err any
	PrivateKey, err = rsa.GenerateKey(crand.Reader, 2048) //生成私钥
	if err != nil {
		panic(err)
	}
	PublicKey = PrivateKey.PublicKey //生成公钥
	return PrivateKey, &PublicKey
}

// 获取公钥
func GetPublicKey() string {
	publicKeyBytes := x509.MarshalPKCS1PublicKey(&PublicKey)
	pemBlock := &pem.Block{
		Type:  "PUBLIC KEY",
		Bytes: publicKeyBytes,
	}
	return base64.StdEncoding.EncodeToString(pem.EncodeToMemory(pemBlock))
}

// RSA加密
func RSAEncrypt(str string) string {
	//根据公钥加密
	encryptedBytes, err := rsa.EncryptOAEP(sha256.New(), crand.Reader, &PublicKey, []byte(str), nil)
	if err != nil {
		panic(err)
		return ""
	}
	// 加密后进行base64编码
	encryptBase64 := base64.StdEncoding.EncodeToString(encryptedBytes)
	return encryptBase64
}

// RSA解密
func RSADecrypt(str string) string {
	// base64解码
	decodedBase64, err := base64.StdEncoding.DecodeString(str)
	if err != nil {
		panic(err)
		log.Error(err.Error())
		return ""
	}
	//根据私钥解密
	decryptBytes, err := PrivateKey.Decrypt(nil, decodedBase64, &rsa.OAEPOptions{Hash: crypto.SHA256})
	if err != nil {
		panic(err)
		log.Error(err.Error())
		return ""
	}
	return string(decryptBytes)
}

// 盐值加密(根据明文密码,获取密文)
func GetEncryptedPassword(password string) (string, error) {
	// 加密密码,使用 bcrypt 包当中的 GenerateFromPassword 方法,bcrypt.DefaultCost 代表使用默认加密成本
	encryptPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
	if err != nil {
		panic(err)
		log.Error(err.Error())
		return "", err
	}
	return string(encryptPassword), nil
}

// 判断密码是否正确(根据明文和密文对比)plaintextPassword 明文 encryptedPassword 密文
func AuthenticatePassword(plaintextPassword string, encryptedPassword string) bool {
	// 使用 bcrypt 当中的 CompareHashAndPassword 对比密码是否正确,第一个参数为密文,第二个参数为明文
	err := bcrypt.CompareHashAndPassword([]byte(encryptedPassword), []byte(plaintextPassword))
	// 对比密码是否正确会返回一个异常,按照官方的说法是只要异常是 nil 就证明密码正确
	return err == nil
}

// 判断数组中是否包含指定元素
func IsContain(items interface{}, item interface{}) bool {
	switch items.(type) {
	case []int:
		intArr := items.([]int)
		for _, value := range intArr {
			if value == item.(int) {
				return true
			}
		}
	case []string:
		strArr := items.([]string)
		for _, value := range strArr {
			if value == item.(string) {
				return true
			}
		}
	default:
		return false
	}
	return false
}

// 中文转拼音大写(首字母)
func ConvertToPinyin(text string, p pinyin.Args) string {
	var initials []string
	for _, r := range text {
		if r >= 0x4e00 && r <= 0x9fff { // 判断字符是否为中文字符
			pinyinResult := pinyin.Pinyin(string(r), p)
			initials = append(initials, strings.ToUpper(string(pinyinResult[0][0][0])))
		} else if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9' || r == '_') { // 判断字符是否为英文字母或数字
			initials = append(initials, strings.ToUpper(string(r)))
		}
	}
	return strings.Join(initials, "")
}

// 上传保存文件 form 文件数据、relative 文件保存的相对路径、fileName 文件名称
func SaveFile(form *multipart.Form, relative, fileName string) error {
	// 获取文件数据
	file, err := form.File["file"][0].Open()
	if err != nil {
		err = errors.New("上传文件失败:" + err.Error())
		return err
	}
	defer file.Close()
	content, err := io.ReadAll(file)
	if err != nil {
		err = errors.New("上传文件失败:" + err.Error())
		return err
	}
	// 获取当前项目路径
	currentDir, err := os.Getwd()
	if err != nil {
		err = errors.New("获取项目路径失败:" + err.Error())
		return err
	}
	// 文件上传的绝对路径
	absolute := filepath.Join(currentDir, relative)
	// 创建目录
	err = os.MkdirAll(absolute, os.ModePerm)
	if err != nil {
		err = errors.New("上传文件失败:" + err.Error())
		return err
	}
	// 在当前项目路径中的 /upload/20231208/ 下创建新文件
	newFile, err := os.Create(filepath.Join(absolute, fileName))
	if err != nil {
		err = errors.New("上传文件失败:" + err.Error())
		return err
	}
	defer newFile.Close()
	// 将文件内容写入新文件
	_, err = newFile.Write(content)
	if err != nil {
		err = errors.New("上传文件失败:" + err.Error())
		return err
	}
	return nil
}

ok,以上就是本篇文章的全部内容了,等我更完这个项目的全部文章,我会放出完整代码的地址,欢迎大家多多点赞支持下,最后可以关注我不迷路~

文章来源:https://blog.csdn.net/weixin_43165220/article/details/134923682
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。