Gin
gin框架路由详解
gin框架使用的是定制版本的httprouter,其路由的原理是大量使用公共前缀的树结构,它基本上是一个紧凑的Trie tree(或者只是Radix Tree)。具有公共前缀的节点也共享一个公共父节点。
Radix Tree
基数树(Radix Tree)又称为PAT位树(Patricia Trie or crit bit tree),是一种更节省空间的前缀树(Trie Tree)。对于基数树的每个节点,如果该节点是唯一的子树的话,就和父节点合并。下图为一个基数树示例:
Radix Tree
可以被认为是一棵简洁版的前缀树。我们注册路由的过程就是构造前缀树的过程,具有公共前缀的节点也共享一个公共父节点。假设我们现在注册有以下路由信息:
1
2
3
4
5
6
7
8
9
10
|
r := gin.Default()
r.GET("/", func1)
r.GET("/search/", func2)
r.GET("/support/", func3)
r.GET("/blog/", func4)
r.GET("/blog/:post/", func5)
r.GET("/about-us/", func6)
r.GET("/about-us/team/", func7)
r.GET("/contact/", func8)
|
那么我们会得到一个GET
方法对应的路由树,具体结构如下:
1
2
3
4
5
6
7
8
9
10
11
|
Priority Path Handle
9 \ *<1>
3 ├s nil
2 |├earch\ *<2>
1 |└upport\ *<3>
2 ├blog\ *<4>
1 | └:post nil
1 | └\ *<5>
2 ├about-us\ *<6>
1 | └team\ *<7>
1 └contact\ *<8>
|
上面最右边那一列每个*<数字>
表示Handle处理函数的内存地址(一个指针)。从根节点遍历到叶子节点我们就能得到完整的路由表。
例如:blog/:post
其中:post
只是实际文章名称的占位符(参数)。与hash-maps
不同,这种树结构还允许我们使用像:post
参数这种动态部分,因为我们实际上是根据路由模式进行匹配,而不仅仅是比较哈希值。
由于URL路径具有层次结构,并且只使用有限的一组字符(字节值),所以很可能有许多常见的前缀。这使我们可以很容易地将路由简化为更小的问题。此外,路由器为每种请求方法管理一棵单独的树。一方面,它比在每个节点中都保存一个method-> handle map更加节省空间,它还使我们甚至可以在开始在前缀树中查找之前大大减少路由问题。
为了获得更好的可伸缩性,每个树级别上的子节点都按Priority(优先级)
排序,其中优先级(最左列)就是在子节点(子节点、子子节点等等)中注册的句柄的数量。这样做有两个好处:
- 首先优先匹配被大多数路由路径包含的节点。这样可以让尽可能多的路由快速被定位。
- 类似于成本补偿。最长的路径可以被优先匹配,补偿体现在最长的路径需要花费更长的时间来定位,如果最长路径的节点能被优先匹配(即每次拿子节点都命中),那么路由匹配所花的时间不一定比短路径的路由长。下面展示了节点(每个
-
可以看做一个节点)匹配的路径:从左到右,从上到下。
1
2
3
4
5
6
7
|
├------------
├---------
├-----
├----
├--
├--
└-
|
路由树节点
路由树是由一个个节点构成的,gin框架路由树的节点由node
结构体表示,它有以下字段:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
// tree.go
type node struct {
// 节点路径,比如上面的s,earch,和upport
path string
// 和children字段对应, 保存的是分裂的分支的第一个字符
// 例如search和support, 那么s节点的indices对应的"eu"
// 代表有两个分支, 分支的首字母分别是e和u
indices string
// 儿子节点
children []*node
// 处理函数链条(切片)
handlers HandlersChain
// 优先级,子节点、子子节点等注册的handler数量
priority uint32
// 节点类型,包括static, root, param, catchAll
// static: 静态节点(默认),比如上面的s,earch等节点
// root: 树的根节点
// catchAll: 有*匹配的节点
// param: 参数节点
nType nodeType
// 路径上最大参数个数
maxParams uint8
// 节点是否是参数节点,比如上面的:post
wildChild bool
// 完整路径
fullPath string
}
|
请求方法树
在gin的路由中,每一个HTTP Method
(GET、POST、PUT、DELETE…)都对应了一棵 radix tree
,我们注册路由的时候会调用下面的addRoute
函数:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
// gin.go
func (engine *Engine) addRoute(method, path string, handlers HandlersChain) {
// liwenzhou.com...
// 获取请求方法对应的树
root := engine.trees.get(method)
if root == nil {
// 如果没有就创建一个
root = new(node)
root.fullPath = "/"
engine.trees = append(engine.trees, methodTree{method: method, root: root})
}
root.addRoute(path, handlers)
}
|
从上面的代码中我们可以看到在注册路由的时候都是先根据请求方法获取对应的树,也就是gin框架会为每一个请求方法创建一棵对应的树。只不过需要注意到一个细节是gin框架中保存请求方法对应树关系并不是使用的map而是使用的切片,engine.trees
的类型是methodTrees
,其定义如下:
1
2
3
4
5
6
|
type methodTree struct {
method string
root *node
}
type methodTrees []methodTree // slice
|
而获取请求方法对应树的get方法定义如下:
1
2
3
4
5
6
7
8
|
func (trees methodTrees) get(method string) *node {
for _, tree := range trees {
if tree.method == method {
return tree.root
}
}
return nil
}
|
为什么使用切片而不是map来存储请求方法->树
的结构呢?我猜是出于节省内存的考虑吧,毕竟HTTP请求方法的数量是固定的,而且常用的就那几种,所以即使使用切片存储查询起来效率也足够了。顺着这个思路,我们可以看一下gin框架中engine
的初始化方法中,确实对tress
字段做了一次内存申请:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
func New() *Engine {
debugPrintWARNINGNew()
engine := &Engine{
RouterGroup: RouterGroup{
Handlers: nil,
basePath: "/",
root: true,
},
// liwenzhou.com ...
// 初始化容量为9的切片(HTTP1.1请求方法共9种)
trees: make(methodTrees, 0, 9),
// liwenzhou.com...
}
engine.RouterGroup.engine = engine
engine.pool.New = func() interface{} {
return engine.allocateContext()
}
return engine
}
|
注册路由
注册路由的逻辑主要有addRoute
函数和insertChild
方法。
addRoute
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
|
// tree.go
// addRoute 将具有给定句柄的节点添加到路径中。
// 不是并发安全的
func (n *node) addRoute(path string, handlers HandlersChain) {
fullPath := path
n.priority++
numParams := countParams(path) // 数一下参数个数
// 空树就直接插入当前节点
if len(n.path) == 0 && len(n.children) == 0 {
n.insertChild(numParams, path, fullPath, handlers)
n.nType = root
return
}
parentFullPathIndex := 0
walk:
for {
// 更新当前节点的最大参数个数
if numParams > n.maxParams {
n.maxParams = numParams
}
// 找到最长的通用前缀
// 这也意味着公共前缀不包含“:”"或“*” /
// 因为现有键不能包含这些字符。
i := longestCommonPrefix(path, n.path)
// 分裂边缘(此处分裂的是当前树节点)
// 例如一开始path是search,新加入support,s是他们通用的最长前缀部分
// 那么会将s拿出来作为parent节点,增加earch和upport作为child节点
if i < len(n.path) {
child := node{
path: n.path[i:], // 公共前缀后的部分作为子节点
wildChild: n.wildChild,
indices: n.indices,
children: n.children,
handlers: n.handlers,
priority: n.priority - 1, //子节点优先级-1
fullPath: n.fullPath,
}
// Update maxParams (max of all children)
for _, v := range child.children {
if v.maxParams > child.maxParams {
child.maxParams = v.maxParams
}
}
n.children = []*node{&child}
// []byte for proper unicode char conversion, see #65
n.indices = string([]byte{n.path[i]})
n.path = path[:i]
n.handlers = nil
n.wildChild = false
n.fullPath = fullPath[:parentFullPathIndex+i]
}
// 将新来的节点插入新的parent节点作为子节点
if i < len(path) {
path = path[i:]
if n.wildChild { // 如果是参数节点
parentFullPathIndex += len(n.path)
n = n.children[0]
n.priority++
// Update maxParams of the child node
if numParams > n.maxParams {
n.maxParams = numParams
}
numParams--
// 检查通配符是否匹配
if len(path) >= len(n.path) && n.path == path[:len(n.path)] {
// 检查更长的通配符, 例如 :name and :names
if len(n.path) >= len(path) || path[len(n.path)] == '/' {
continue walk
}
}
pathSeg := path
if n.nType != catchAll {
pathSeg = strings.SplitN(path, "/", 2)[0]
}
prefix := fullPath[:strings.Index(fullPath, pathSeg)] + n.path
panic("'" + pathSeg +
"' in new path '" + fullPath +
"' conflicts with existing wildcard '" + n.path +
"' in existing prefix '" + prefix +
"'")
}
// 取path首字母,用来与indices做比较
c := path[0]
// 处理参数后加斜线情况
if n.nType == param && c == '/' && len(n.children) == 1 {
parentFullPathIndex += len(n.path)
n = n.children[0]
n.priority++
continue walk
}
// 检查路path下一个字节的子节点是否存在
// 比如s的子节点现在是earch和upport,indices为eu
// 如果新加一个路由为super,那么就是和upport有匹配的部分u,将继续分列现在的upport节点
for i, max := 0, len(n.indices); i < max; i++ {
if c == n.indices[i] {
parentFullPathIndex += len(n.path)
i = n.incrementChildPrio(i)
n = n.children[i]
continue walk
}
}
// 否则就插入
if c != ':' && c != '*' {
// []byte for proper unicode char conversion, see #65
// 注意这里是直接拼接第一个字符到n.indices
n.indices += string([]byte{c})
child := &node{
maxParams: numParams,
fullPath: fullPath,
}
// 追加子节点
n.children = append(n.children, child)
n.incrementChildPrio(len(n.indices) - 1)
n = child
}
n.insertChild(numParams, path, fullPath, handlers)
return
}
// 已经注册过的节点
if n.handlers != nil {
panic("handlers are already registered for path '" + fullPath + "'")
}
n.handlers = handlers
return
}
}
|
其实上面的代码很好理解,大家可以参照动画尝试将以下情形代入上面的代码逻辑,体味整个路由树构造的详细过程:
- 第一次注册路由,例如注册search
- 继续注册一条没有公共前缀的路由,例如blog
- 注册一条与先前注册的路由有公共前缀的路由,例如support
insertChild
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
// tree.go
func (n *node) insertChild(numParams uint8, path string, fullPath string, handlers HandlersChain) {
// 找到所有的参数
for numParams > 0 {
// 查找前缀直到第一个通配符
wildcard, i, valid := findWildcard(path)
if i < 0 { // 没有发现通配符
break
}
// 通配符的名称必须包含':' 和 '*'
if !valid {
panic("only one wildcard per path segment is allowed, has: '" +
wildcard + "' in path '" + fullPath + "'")
}
// 检查通配符是否有名称
if len(wildcard) < 2 {
panic("wildcards must be named with a non-empty name in path '" + fullPath + "'")
}
// 检查这个节点是否有已经存在的子节点
// 如果我们在这里插入通配符,这些子节点将无法访问
if len(n.children) > 0 {
panic("wildcard segment '" + wildcard +
"' conflicts with existing children in path '" + fullPath + "'")
}
if wildcard[0] == ':' { // param
if i > 0 {
// 在当前通配符之前插入前缀
n.path = path[:i]
path = path[i:]
}
n.wildChild = true
child := &node{
nType: param,
path: wildcard,
maxParams: numParams,
fullPath: fullPath,
}
n.children = []*node{child}
n = child
n.priority++
numParams--
// 如果路径没有以通配符结束
// 那么将有另一个以'/'开始的非通配符子路径。
if len(wildcard) < len(path) {
path = path[len(wildcard):]
child := &node{
maxParams: numParams,
priority: 1,
fullPath: fullPath,
}
n.children = []*node{child}
n = child // 继续下一轮循环
continue
}
// 否则我们就完成了。将处理函数插入新叶子中
n.handlers = handlers
return
}
// catchAll
if i+len(wildcard) != len(path) || numParams > 1 {
panic("catch-all routes are only allowed at the end of the path in path '" + fullPath + "'")
}
if len(n.path) > 0 && n.path[len(n.path)-1] == '/' {
panic("catch-all conflicts with existing handle for the path segment root in path '" + fullPath + "'")
}
// currently fixed width 1 for '/'
i--
if path[i] != '/' {
panic("no / before catch-all in path '" + fullPath + "'")
}
n.path = path[:i]
// 第一个节点:路径为空的catchAll节点
child := &node{
wildChild: true,
nType: catchAll,
maxParams: 1,
fullPath: fullPath,
}
// 更新父节点的maxParams
if n.maxParams < 1 {
n.maxParams = 1
}
n.children = []*node{child}
n.indices = string('/')
n = child
n.priority++
// 第二个节点:保存变量的节点
child = &node{
path: path[i:],
nType: catchAll,
maxParams: 1,
handlers: handlers,
priority: 1,
fullPath: fullPath,
}
n.children = []*node{child}
return
}
// 如果没有找到通配符,只需插入路径和句柄
n.path = path
n.handlers = handlers
n.fullPath = fullPath
}
|
insertChild
函数是根据path
本身进行分割,将/
分开的部分分别作为节点保存,形成一棵树结构。参数匹配中的:
和*
的区别是,前者是匹配一个字段而后者是匹配后面所有的路径。
路由匹配
我们先来看gin框架处理请求的入口函数ServeHTTP
:
1
2
3
4
5
6
7
8
9
10
11
12
13
|
// gin.go
func (engine *Engine) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// 这里使用了对象池
c := engine.pool.Get().(*Context)
// 这里有一个细节就是Get对象后做初始化
c.writermem.reset(w)
c.Request = req
c.reset()
engine.handleHTTPRequest(c) // 我们要找的处理HTTP请求的函数
engine.pool.Put(c) // 处理完请求后将对象放回池子
}
|
函数很长,这里省略了部分代码,只保留相关逻辑代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
// gin.go
func (engine *Engine) handleHTTPRequest(c *Context) {
// liwenzhou.com...
// 根据请求方法找到对应的路由树
t := engine.trees
for i, tl := 0, len(t); i < tl; i++ {
if t[i].method != httpMethod {
continue
}
root := t[i].root
// 在路由树中根据path查找
value := root.getValue(rPath, c.Params, unescape)
if value.handlers != nil {
c.handlers = value.handlers
c.Params = value.params
c.fullPath = value.fullPath
c.Next() // 执行函数链条
c.writermem.WriteHeaderNow()
return
}
// liwenzhou.com...
c.handlers = engine.allNoRoute
serveError(c, http.StatusNotFound, default404Body)
}
|
路由匹配是由节点的 getValue
方法实现的。getValue
根据给定的路径(键)返回nodeValue
值,保存注册的处理函数和匹配到的路径参数数据。
如果找不到任何处理函数,则会尝试TSR(尾随斜杠重定向)。
代码虽然很长,但还算比较工整。大家可以借助注释看一下路由查找及参数匹配的逻辑。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
// tree.go
type nodeValue struct {
handlers HandlersChain
params Params // []Param
tsr bool
fullPath string
}
// liwenzhou.com...
func (n *node) getValue(path string, po Params, unescape bool) (value nodeValue) {
value.params = po
walk: // Outer loop for walking the tree
for {
prefix := n.path
if path == prefix {
// 我们应该已经到达包含处理函数的节点。
// 检查该节点是否注册有处理函数
if value.handlers = n.handlers; value.handlers != nil {
value.fullPath = n.fullPath
return
}
if path == "/" && n.wildChild && n.nType != root {
value.tsr = true
return
}
// 没有找到处理函数 检查这个路径末尾+/ 是否存在注册函数
indices := n.indices
for i, max := 0, len(indices); i < max; i++ {
if indices[i] == '/' {
n = n.children[i]
value.tsr = (len(n.path) == 1 && n.handlers != nil) ||
(n.nType == catchAll && n.children[0].handlers != nil)
return
}
}
return
}
if len(path) > len(prefix) && path[:len(prefix)] == prefix {
path = path[len(prefix):]
// 如果该节点没有通配符(param或catchAll)子节点
// 我们可以继续查找下一个子节点
if !n.wildChild {
c := path[0]
indices := n.indices
for i, max := 0, len(indices); i < max; i++ {
if c == indices[i] {
n = n.children[i] // 遍历树
continue walk
}
}
// 没找到
// 如果存在一个相同的URL但没有末尾/的叶子节点
// 我们可以建议重定向到那里
value.tsr = path == "/" && n.handlers != nil
return
}
// 根据节点类型处理通配符子节点
n = n.children[0]
switch n.nType {
case param:
// find param end (either '/' or path end)
end := 0
for end < len(path) && path[end] != '/' {
end++
}
// 保存通配符的值
if cap(value.params) < int(n.maxParams) {
value.params = make(Params, 0, n.maxParams)
}
i := len(value.params)
value.params = value.params[:i+1] // 在预先分配的容量内扩展slice
value.params[i].Key = n.path[1:]
val := path[:end]
if unescape {
var err error
if value.params[i].Value, err = url.QueryUnescape(val); err != nil {
value.params[i].Value = val // fallback, in case of error
}
} else {
value.params[i].Value = val
}
// 继续向下查询
if end < len(path) {
if len(n.children) > 0 {
path = path[end:]
n = n.children[0]
continue walk
}
// ... but we can't
value.tsr = len(path) == end+1
return
}
if value.handlers = n.handlers; value.handlers != nil {
value.fullPath = n.fullPath
return
}
if len(n.children) == 1 {
// 没有找到处理函数. 检查此路径末尾加/的路由是否存在注册函数
// 用于 TSR 推荐
n = n.children[0]
value.tsr = n.path == "/" && n.handlers != nil
}
return
case catchAll:
// 保存通配符的值
if cap(value.params) < int(n.maxParams) {
value.params = make(Params, 0, n.maxParams)
}
i := len(value.params)
value.params = value.params[:i+1] // 在预先分配的容量内扩展slice
value.params[i].Key = n.path[2:]
if unescape {
var err error
if value.params[i].Value, err = url.QueryUnescape(path); err != nil {
value.params[i].Value = path // fallback, in case of error
}
} else {
value.params[i].Value = path
}
value.handlers = n.handlers
value.fullPath = n.fullPath
return
default:
panic("invalid node type")
}
}
// 找不到,如果存在一个在当前路径最后添加/的路由
// 我们会建议重定向到那里
value.tsr = (path == "/") ||
(len(prefix) == len(path)+1 && prefix[len(path)] == '/' &&
path == prefix[:len(prefix)-1] && n.handlers != nil)
return
}
}
|
gin框架中间件详解
gin框架涉及中间件相关有4个常用的方法,它们分别是c.Next()
、c.Abort()
、c.Set()
、c.Get()
。
中间件的注册
gin框架中的中间件设计很巧妙,我们可以首先从我们最常用的r := gin.Default()
的Default
函数开始看,它内部构造一个新的engine
之后就通过Use()
函数注册了Logger
中间件和Recovery
中间件:
1
2
3
4
5
6
|
func Default() *Engine {
debugPrintWARNINGDefault()
engine := New()
engine.Use(Logger(), Recovery()) // 默认注册的两个中间件
return engine
}
|
继续往下查看一下Use()
函数的代码:
1
2
3
4
5
6
|
func (engine *Engine) Use(middleware ...HandlerFunc) IRoutes {
engine.RouterGroup.Use(middleware...) // 实际上还是调用的RouterGroup的Use函数
engine.rebuild404Handlers()
engine.rebuild405Handlers()
return engine
}
|
从下方的代码可以看出,注册中间件其实就是将中间件函数追加到group.Handlers
中:
1
2
3
4
|
func (group *RouterGroup) Use(middleware ...HandlerFunc) IRoutes {
group.Handlers = append(group.Handlers, middleware...)
return group.returnObj()
}
|
而我们注册路由时会将对应路由的函数和之前的中间件函数结合到一起:
1
2
3
4
5
6
|
func (group *RouterGroup) handle(httpMethod, relativePath string, handlers HandlersChain) IRoutes {
absolutePath := group.calculateAbsolutePath(relativePath)
handlers = group.combineHandlers(handlers) // 将处理请求的函数与中间件函数结合
group.engine.addRoute(httpMethod, absolutePath, handlers)
return group.returnObj()
}
|
其中结合操作的函数内容如下,注意观察这里是如何实现拼接两个切片得到一个新切片的。
1
2
3
4
5
6
7
8
9
10
11
12
|
const abortIndex int8 = math.MaxInt8 / 2
func (group *RouterGroup) combineHandlers(handlers HandlersChain) HandlersChain {
finalSize := len(group.Handlers) + len(handlers)
if finalSize >= int(abortIndex) { // 这里有一个最大限制
panic("too many handlers")
}
mergedHandlers := make(HandlersChain, finalSize)
copy(mergedHandlers, group.Handlers)
copy(mergedHandlers[len(group.Handlers):], handlers)
return mergedHandlers
}
|
也就是说,我们会将一个路由的中间件函数和处理函数结合到一起组成一条处理函数链条HandlersChain
,而它本质上就是一个由HandlerFunc
组成的切片:
1
|
type HandlersChain []HandlerFunc
|
中间件的执行
我们在上面路由匹配的时候见过如下逻辑:
1
2
3
4
5
6
7
8
9
|
value := root.getValue(rPath, c.Params, unescape)
if value.handlers != nil {
c.handlers = value.handlers
c.Params = value.params
c.fullPath = value.fullPath
c.Next() // 执行函数链条
c.writermem.WriteHeaderNow()
return
}
|
其中c.Next()
就是很关键的一步,它的代码很简单:
1
2
3
4
5
6
7
|
func (c *Context) Next() {
c.index++
for c.index < int8(len(c.handlers)) {
c.handlers[c.index](c)
c.index++
}
}
|
从上面的代码可以看到,这里通过索引遍历HandlersChain
链条,从而实现依次调用该路由的每一个函数(中间件或处理请求的函数)。
我们可以在中间件函数中通过再次调用c.Next()
实现嵌套调用(func1中调用func2;func2中调用func3),
或者通过调用c.Abort()
中断整个调用链条,从当前函数返回。
1
2
3
|
func (c *Context) Abort() {
c.index = abortIndex // 直接将索引置为最大限制值,从而退出循环
}
|
c.Set()/c.Get()
c.Set()
和c.Get()
这两个方法多用于在多个函数之间通过c
传递数据的,比如我们可以在认证中间件中获取当前请求的相关信息(userID等)通过c.Set()
存入c
,然后在后续处理业务逻辑的函数中通过c.Get()
来获取当前请求的用户。c
就像是一根绳子,将该次请求相关的所有的函数都串起来了。
总结
- gin框架路由使用前缀树,路由注册的过程是构造前缀树的过程,路由匹配的过程就是查找前缀树的过程。
- gin框架的中间件函数和处理函数是以切片形式的调用链条存在的,我们可以顺序调用也可以借助
c.Next()
方法实现嵌套调用。
- 借助
c.Set()
和c.Get()
方法我们能够在不同的中间件函数中传递数据。
Gin连接mysql
驱动依赖
go get -u github.com/go-sql-driver/mysql
如何使用
里面的user password可改
注意defer得卸载err的下面,不可以写成这样
为什么不能写在这个位置?
因为如果出错了,我们获得的open会是一个nil,在最后调用他的close,会出现空指针异常。
所以我们应该在判断当err == nil才去关闭数据库。
Open函数实际上只是验证dsn参数是否正确,并不是真正和数据库连接。如果要检查数据源的名字是否有效,应该调用ping方法。
当然一般不可能写在main函数里面,一般应该使模块化的操作。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
package main
import (
"database/sql"
"fmt"
_ "github.com/go-sql-driver/mysql"
)
var db *sql.DB
//这里用了一个匿名参数,我们可以不用在函数中声明err,但是会自己返回这个err参数。非常的优雅。
func initMysql() (err error) {
dsn := "root:abc123456@tcp(127.0.0.1:3306)/go"
db, err = sql.Open("mysql", dsn) //这里就不能用 := 因为db是公共变量
if err != nil {
panic(err)
}
err = db.Ping()
if err != nil {
fmt.Printf("connect fail , err : %v\n", err)
panic(err)
}
fmt.Printf("connect successful")
return
}
func main() {
if err := initMysql();err!=nil{
fmt.Printf("connect fail")
}
defer db.Close()
}
|
两个常用配置
SetMaxOpenCoons() 设置和mysql的最大连接数
SetMaxIdleCoons() 设置连接池的最大空闲连接数
浅浅的研究一下源码
看一看驱动里面的源码
init方法
点进去Regist
此时可以发现是跳到了内置的标准库database里面
在源码中,regist的是map中的值,比如map[“mysql”] = 驱动
看一看open函数里做了什么事情
可以发现我们是调用的openDB函数,返回的一个DB,有一些基本信息
我们的用户名,密码都在c里面
增删改查
queryRow(查询一列)
1
2
3
4
5
6
7
8
9
10
|
func queryRowDemo() {
sql := "select id,name,age from test_user where id = ?"
var u user
err := db.QueryRow(sql, 1).Scan(&u.id, &u.name, &u.age)
if err != nil {
fmt.Printf("scan failed , err : %v", err)
return
}
fmt.Printf("id: %v,name: %v,age: %v", u.id, u.name, u.age)
}
|
这边值得注意的是,在调用了queryRow之后一定要调用scan,不然不会释放连接。
query(查询多行)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
func queryMultiRowDemo() {
sql := "select id,name,age from test_user where id > ? "
query, err := db.Query(sql, 0)
if err != nil {
fmt.Printf("query failed err: %v\n", err)
return
}
defer query.Close()//这里需要关闭
for query.Next() {
var u user
err = query.Scan(&u.id, &u.name, &u.age)
if err != nil {
fmt.Printf("scan failed err: %v\n", err)
return
}
fmt.Printf("id: %d,name: %s,age: %d\n", u.id, u.name, u.age)
}
}
|
为什么在中间需要close?不是在for中调用了scan,scan应该可以帮我们close嘛?因为不一定能够进去for循环。得手动调用close。
exec(插入和更新和删除)
插入
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
func insertRowDemo() {
sql := "insert into test_user(name,age)values(?,?)"
exec, err := db.Exec(sql, "peter", 99)
if err != nil {
fmt.Printf("insert failed err: %v\n", err)
return
}
id, err := exec.LastInsertId() //可以获得新插入的id
if err != nil {
fmt.Printf("get id failed err: %v\n", err)
return
}
fmt.Printf("insert success last id is %d \n", id)
}
|
更新
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
func updateRowDemo() {
sql := "update test_user set name = ? where id = ?"
exec, err := db.Exec(sql, "john", 3)
if err != nil {
fmt.Printf("update failed err: %v\n", err)
return
}
id, err := exec.RowsAffected() //可以影响的行数
if err != nil {
fmt.Printf("get row count failed err: %v\n", err)
return
}
fmt.Printf("update success row count is %d \n", id)
}
|
删除
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
func deleteRowDemo() {
sql := "delete from test_user where id = ?"
exec, err := db.Exec(sql, 3)
if err != nil {
fmt.Printf("delete failed err: %v\n", err)
return
}
id, err := exec.RowsAffected() //可以影响的行数
if err != nil {
fmt.Printf("get row count failed err: %v\n", err)
return
}
fmt.Printf("delete success row count is %d \n", id)
}
|
Mysql预处理与Sql注入
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
func prepareDemo() {
sql := "select id,name,age from test_user where id > ?"
prepare, err := db.Prepare(sql)
if err != nil {
fmt.Printf("prepare failed err: %v\n", err)
}
defer prepare.Close()
query, err := prepare.Query(0)
if err != nil {
fmt.Printf("query failed err: %v\n", err)
}
for query.Next() {
var u user
err = query.Scan(&u.id, &u.name, &u.age)
if err != nil {
fmt.Printf("scan failed err: %v\n", err)
}
fmt.Printf("id: %d,name: %s,age: %d\n", u.id, u.name, u.age)
}
}
|
sql注入
试试sql注入吧
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
func sqlInjection(name string) {
sql := fmt.Sprintf("select id,name,age from test_user where name = %s", name)
fmt.Println(sql)
query, err := db.Query(sql)
if err != nil {
fmt.Printf("scan failed err: %v\n", err)
return
}
for query.Next() {
var u user
err = query.Scan(&u.id, &u.name, &u.age)
if err != nil {
fmt.Printf("scan failed err: %v\n", err)
}
fmt.Printf("id: %d,name: %s,age: %d\n", u.id, u.name, u.age)
}
}
|
Mysql的事务
sqlX 强大的工具
安装
go get github.com/jmoiron/sqlx
连接
1
2
3
4
5
6
7
8
9
10
11
12
13
|
var db1 *sqlx.DB
func initDB() (err error) {
dsn := "root:abc123456@tcp(127.0.0.1:3306)/go"
db1, err = sqlx.Connect("mysql", dsn)
if err != nil {
fmt.Printf("err : %v\n", err)
return
}
db.SetMaxOpenConns(200)
db.SetMaxIdleConns(10)
return
}
|
SQLX的基本使用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
type user1 struct {
Id int `db:"id"`
Name string `db:"name"`
Age int `db:"age"`
}
func queryRowDemo1() {
sql := "select id,name,age from test_user where id = ?"
var u user1
err := db1.Get(&u, sql, 1)
if err != nil {
fmt.Printf("get failed err: %v", err)
return
}
fmt.Printf("id: %v,name: %v,age: %v", u.Id, u.Name, u.Age)
}
|
注意结构体的值不可以写为小写,db1的get会通过反射给对象的属性赋值。如果为小写的话其他的包访问不到。
查询多个
1
2
3
4
5
6
7
8
9
10
11
|
func queryRowMultiDemo1() {
sql := "select id,name,age from test_user where id > ?"
var u []user1
err := db1.Select(&u, sql, 0)
if err != nil {
fmt.Printf("get failed err: %v", err)
return
}
fmt.Printf("users:%v", u)
}
|
增删改
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
// 插入数据
func insertRowDemo() {
sqlStr := "insert into user(name, age) values (?,?)"
ret, err := db.Exec(sqlStr, "沙河小王子", 19)
if err != nil {
fmt.Printf("insert failed, err:%v\n", err)
return
}
theID, err := ret.LastInsertId() // 新插入数据的id
if err != nil {
fmt.Printf("get lastinsert ID failed, err:%v\n", err)
return
}
fmt.Printf("insert success, the id is %d.\n", theID)
}
// 更新数据
func updateRowDemo() {
sqlStr := "update user set age=? where id = ?"
ret, err := db.Exec(sqlStr, 39, 6)
if err != nil {
fmt.Printf("update failed, err:%v\n", err)
return
}
n, err := ret.RowsAffected() // 操作影响的行数
if err != nil {
fmt.Printf("get RowsAffected failed, err:%v\n", err)
return
}
fmt.Printf("update success, affected rows:%d\n", n)
}
// 删除数据
func deleteRowDemo() {
sqlStr := "delete from user where id = ?"
ret, err := db.Exec(sqlStr, 6)
if err != nil {
fmt.Printf("delete failed, err:%v\n", err)
return
}
n, err := ret.RowsAffected() // 操作影响的行数
if err != nil {
fmt.Printf("get RowsAffected failed, err:%v\n", err)
return
}
fmt.Printf("delete success, affected rows:%d\n", n)
}
|
go-redis
1
|
go get -u github.com/go-redis/redis
|
连接redis
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
func initRedis() (err error) {
rdb = redis.NewClient(
&redis.Options{
Addr: "localhost:6379",
Password: "abc123456",
DB: 0,
PoolSize: 100, //连接池大小
})
_, err = rdb.Ping().Result()
if err != nil {
return err
}
return nil
}
func main() {
if err := initRedis(); err != nil {
fmt.Printf("err : %v", err)
return
}
fmt.Printf("connect success\n")
}
|
连接哨兵模式
基本使用
get/set
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
func redisDemo() {
err := rdb.Set("score", 100, 1).Err()
if err != nil {
fmt.Printf("set error err:%v\n", err)
}
result, err := rdb.Get("score").Result()
if err != nil {
fmt.Printf("get error err:%v\n", err)
return
}
fmt.Println(result)
s, err := rdb.Get("token").Result()
//优先判断错误是不是属于redis没有对应的key
if err == redis.Nil {
fmt.Printf("the key is not exist\n")
return
} else if err != nil {
fmt.Printf("err : %v", err)
return
} else {
fmt.Printf("%s\n", s)
}
}
|
hset,hmset,hget
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
func hgetDemo() (err error) {
rdb.HMSet("school", map[string]interface{}{
"name": "scuec",
"addr": "laofang",
"isLike": false,
})
err = rdb.HSet("user", "name", "phm").Err()
err = rdb.HSet("user", "age", "1").Err()
if err != nil {
fmt.Printf("err :%v\n", err)
return
}
//得到全部字段
result, err := rdb.HGetAll("user").Result()
if err == redis.Nil {
fmt.Printf("err: %v\n", err)
return err
} else if err != nil {
fmt.Printf("err: %v\n", err)
return err
} else {
fmt.Printf("result :%v\n", result)
}
val := rdb.HGet("user", "name").Val()
fmt.Printf("val :%v\n", val)
return
}
|
zset
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
|
func zSetDemo() {
key := "language_rank"
languages := []redis.Z{
redis.Z{Score: 100, Member: "Java"},
redis.Z{Score: 99, Member: "Golang"},
redis.Z{Score: 89, Member: "Python"},
redis.Z{Score: 80, Member: "C"},
redis.Z{Score: 0, Member: "C++"},
}
//language切片被打散一一add到zset中
result, err := rdb.ZAdd(key, languages...).Result()
if err != nil {
fmt.Printf("err :%v\n", err)
return
}
fmt.Printf("zadd %d success!", result)
//把Golang的分数+10
f, err := rdb.ZIncrBy(key, 10, "Golang").Result()
if err != nil {
fmt.Printf("err :%v\n", err)
return
}
fmt.Printf("new Score is %d\n", f)
//取最高的三个分数
zs, err := rdb.ZRevRangeWithScores(key, 0, 2).Result()
if err != nil {
fmt.Printf("err :%v\n", err)
return
}
for i, z := range zs {
fmt.Println(i, z.Member, z.Score)
}
//取90 - 100分的
ranges := redis.ZRangeBy{
Min: "90",
Max: "100",
}
strings, err := rdb.ZRevRangeByScoreWithScores(key, ranges).Result()
if err != nil {
fmt.Printf("err :%v\n", err)
return
}
if err != nil {
fmt.Printf("err :%v\n", err)
return
}
for i, z := range strings {
fmt.Println(i, z.Member, z.Score)
}
}
|
pipeline
事务
1
2
3
4
5
6
7
8
9
10
11
12
13
|
// 监视watch_count的值,并在值不变的前提下将其值+1
key := "watch_count"
err = client.Watch(func(tx *redis.Tx) error {
n, err := tx.Get(key).Int()
if err != nil && err != redis.Nil {
return err
}
_, err = tx.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Set(key, n+1, 0)
return nil
})
return err
}, key)
|
Zap日志库
Gologger
Go logger的优劣?
Zap的优点
使用
获得
go get -u go.uber.org/zap
配置zap
Logger
- 通过调用
zap.NewProduction()
/zap.NewDevelopment()
或者zap.Example()
创建一个Logger。
- 上面的每一个函数都将创建一个logger。唯一的区别在于它将记录的信息不同。例如production logger默认记录调用函数信息、日期和时间等。
- 通过Logger调用Info/Error等。
- 默认情况下日志都会打印到应用程序的console界面。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
package main
import (
"go.uber.org/zap"
"net/http"
)
var logger *zap.Logger
func main() {
InitLogg()
defer logger.Sync() //将logger的日志刷到磁盘
simpleHttp("www.google.com")
simpleHttp("http://www.google.com")
}
func InitLogg() {
logger, _ = zap.NewProduction()
}
func simpleHttp(url string) {
get, err := http.Get(url)
if err != nil {
logger.Error("Error fetching url..", zap.String("url", url), zap.Error(err))
} else {
logger.Info("success..", zap.String("statusCode", get.Status), zap.String("url", url))
get.Body.Close()
}
}
|
自定义配置项
将日志写入文件而不是终端
我们要做的第一个更改是把日志写入文件,而不是打印到应用程序控制台。
- 我们将使用
zap.New(…)
方法来手动传递所有配置,而不是使用像zap.NewProduction()
这样的预置方法来创建logger。
1
|
func New(core zapcore.Core, options ...Option) *Logger
|
zapcore.Core
需要三个配置——Encoder
,WriteSyncer
,LogLevel
。
1.Encoder:编码器(如何写入日志)。我们将使用开箱即用的NewJSONEncoder()
,并使用预先设置的ProductionEncoderConfig()
。
1
|
zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig())
|
2.WriterSyncer :指定日志将写到哪里去。我们使用zapcore.AddSync()
函数并且将打开的文件句柄传进去。
1
2
|
file, _ := os.Create("./logger.log")
writeSyncer := zapcore.AddSync(file)
|
3.Log Level:哪种级别的日志将被写入。
我们将修改上述部分中的Logger代码,并重写InitLogger()
方法。其余方法—main()
/SimpleHttpGet()
保持不变。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
func InitLogg() {
writerSync := getLogWriter()
encoding := getEncoder()
core := zapcore.NewCore(encoding, writerSync, zapcore.DebugLevel)
logger = zap.New(core)
}
func getEncoder() zapcore.Encoder {
return zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig())
}
func getLogWriter() zapcore.WriteSyncer {
f, _ := os.Create("./logger.log")
return zapcore.AddSync(f)
}
|
当使用这些修改过的logger配置调用上述部分的main()
函数时,以下输出将打印在文件——test.log
中。
1
2
3
4
|
{"level":"debug","ts":1572160754.994731,"msg":"Trying to hit GET request for www.sogo.com"}
{"level":"error","ts":1572160754.994982,"msg":"Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme \"\""}
{"level":"debug","ts":1572160754.994996,"msg":"Trying to hit GET request for http://www.sogo.com"}
{"level":"info","ts":1572160757.3755069,"msg":"Success! statusCode = 200 OK for URL http://www.sogo.com"}
|
将JSON Encoder更改为普通的Log Encoder
现在,我们希望将编码器从JSON Encoder更改为普通Encoder。为此,我们需要将NewJSONEncoder()
更改为NewConsoleEncoder()
。
1
|
return zapcore.NewConsoleEncoder(zap.NewProductionEncoderConfig())
|
当使用这些修改过的logger配置调用上述部分的main()
函数时,以下输出将打印在文件——test.log
中。
1
2
3
4
|
1.572161051846623e+09 debug Trying to hit GET request for www.sogo.com
1.572161051846828e+09 error Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme ""
1.5721610518468401e+09 debug Trying to hit GET request for http://www.sogo.com
1.572161052068744e+09 info Success! statusCode = 200 OK for URL http://www.sogo.com
|
更改时间编码并添加调用者详细信息
鉴于我们对配置所做的更改,有下面两个问题:
- 时间是以非人类可读的方式展示,例如1.572161051846623e+09
- 调用方函数的详细信息没有显示在日志中
我们要做的第一件事是覆盖默认的ProductionConfig()
,并进行以下更改:
- 修改时间编码器
- 在日志文件中使用大写字母记录日志级别
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
func getEncoder() zapcore.Encoder {
return zapcore.NewConsoleEncoder(
zapcore.EncoderConfig{
TimeKey: "ts",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
FunctionKey: zapcore.OmitKey,
MessageKey: "msg",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.LowercaseLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.SecondsDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
})
}
|
接下来,我们将修改zap logger代码,添加将调用函数信息记录到日志中的功能。为此,我们将在zap.New(..)
函数中添加一个Option
。
1
|
logger := zap.New(core, zap.AddCaller())
|
当使用这些修改过的logger配置调用上述部分的main()
函数时,以下输出将打印在文件——test.log
中。
1
2
3
4
|
2019-10-27T15:33:29.855+0800 DEBUG logic/temp2.go:47 Trying to hit GET request for www.sogo.com
2019-10-27T15:33:29.855+0800 ERROR logic/temp2.go:50 Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme ""
2019-10-27T15:33:29.856+0800 DEBUG logic/temp2.go:47 Trying to hit GET request for http://www.sogo.com
2019-10-27T15:33:30.125+0800 INFO logic/temp2.go:52 Success! statusCode = 200 OK for URL http://www.sogo.com
|
使用Lumberjack进行日志切割归档
这个日志程序中唯一缺少的就是日志切割归档功能。
Zap本身不支持切割归档日志文件
为了添加日志切割归档功能,我们将使用第三方库Lumberjack来实现。
安装
执行下面的命令安装Lumberjack
1
|
go get -u github.com/natefinch/lumberjack
|
zap logger中加入Lumberjack
要在zap中加入Lumberjack支持,我们需要修改WriteSyncer
代码。我们将按照下面的代码修改getLogWriter()
函数:
1
2
3
4
5
6
7
8
9
10
|
func getLogWriter() zapcore.WriteSyncer {
lumberJackLogger := &lumberjack.Logger{
Filename: "./test.log",
MaxSize: 10,
MaxBackups: 5,
MaxAge: 30,
Compress: false,
}
return zapcore.AddSync(lumberJackLogger)
}
|
Lumberjack Logger采用以下属性作为输入:
- Filename: 日志文件的位置
- MaxSize:在进行切割之前,日志文件的最大大小(以MB为单位)
- MaxBackups:保留旧文件的最大个数
- MaxAges:保留旧文件的最大天数
- Compress:是否压缩/归档旧文件
测试所有功能
最终,使用Zap/Lumberjack logger的完整示例代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
package main
import (
"net/http"
"github.com/natefinch/lumberjack"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
var sugarLogger *zap.SugaredLogger
func main() {
InitLogger()
defer sugarLogger.Sync()
simpleHttpGet("www.sogo.com")
simpleHttpGet("http://www.sogo.com")
}
func InitLogger() {
writeSyncer := getLogWriter()
encoder := getEncoder()
core := zapcore.NewCore(encoder, writeSyncer, zapcore.DebugLevel)
logger := zap.New(core, zap.AddCaller())
sugarLogger = logger.Sugar()
}
func getEncoder() zapcore.Encoder {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
return zapcore.NewConsoleEncoder(encoderConfig)
}
func getLogWriter() zapcore.WriteSyncer {
lumberJackLogger := &lumberjack.Logger{
Filename: "./test.log",
MaxSize: 1,
MaxBackups: 5,
MaxAge: 30,
Compress: false,
}
return zapcore.AddSync(lumberJackLogger)
}
func simpleHttpGet(url string) {
sugarLogger.Debugf("Trying to hit GET request for %s", url)
resp, err := http.Get(url)
if err != nil {
sugarLogger.Errorf("Error fetching URL %s : Error = %s", url, err)
} else {
sugarLogger.Infof("Success! statusCode = %s for URL %s", resp.Status, url)
resp.Body.Close()
}
}
|
执行上述代码,下面的内容会输出到文件——test.log中。
1
2
3
4
|
2019-10-27T15:50:32.944+0800 DEBUG logic/temp2.go:48 Trying to hit GET request for www.sogo.com
2019-10-27T15:50:32.944+0800 ERROR logic/temp2.go:51 Error fetching URL www.sogo.com : Error = Get www.sogo.com: unsupported protocol scheme ""
2019-10-27T15:50:32.944+0800 DEBUG logic/temp2.go:48 Trying to hit GET request for http://www.sogo.com
2019-10-27T15:50:33.165+0800 INFO logic/temp2.go:53 Success! statusCode = 200 OK for URL http://www.sogo.com
|
同时,可以在main
函数中循环记录日志,测试日志文件是否会自动切割和归档(日志文件每1MB会切割并且在当前目录下最多保存5个备份)。
至此,我们总结了如何将Zap日志程序集成到Go应用程序项目中。
gin整合zap
go get -u github.com/gin-gonic/gin
首先我们来看一个最简单的gin项目:
1
2
3
4
5
6
7
|
func main() {
r := gin.Default()
r.GET("/hello", func(c *gin.Context) {
c.String("hello liwenzhou.com!")
})
r.Run(
}
|
接下来我们看一下gin.Default()
的源码:
1
2
3
4
5
6
|
func Default() *Engine {
debugPrintWARNINGDefault()
engine := New()
engine.Use(Logger(), Recovery())
return engine
}
|
也就是我们在使用gin.Default()
的同时是用到了gin框架内的两个默认中间件Logger()
和Recovery()
。
其中Logger()
是把gin框架本身的日志输出到标准输出(我们本地开发调试时在终端输出的那些日志就是它的功劳),而Recovery()
是在程序出现panic的时候恢复现场并写入500响应的。
基于zap的中间件
我们可以模仿Logger()
和Recovery()
的实现,使用我们的日志库来接收gin框架默认输出的日志。
这里以zap为例,我们实现两个中间件如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
|
// GinLogger 接收gin框架默认的日志
func GinLogger(logger *zap.Logger) gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
c.Next()
cost := time.Since(start)
logger.Info(path,
zap.Int("status", c.Writer.Status()),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("ip", c.ClientIP()),
zap.String("user-agent", c.Request.UserAgent()),
zap.String("errors", c.Errors.ByType(gin.ErrorTypePrivate).String()),
zap.Duration("cost", cost),
)
}
}
// GinRecovery recover掉项目可能出现的panic
func GinRecovery(logger *zap.Logger, stack bool) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}
httpRequest, _ := httputil.DumpRequest(c.Request, false)
if brokenPipe {
logger.Error(c.Request.URL.Path,
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
// If the connection is dead, we can't write a status to it.
c.Error(err.(error)) // nolint: errcheck
c.Abort()
return
}
if stack {
logger.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
zap.String("stack", string(debug.Stack())),
)
} else {
logger.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
}
c.AbortWithStatus(http.StatusInternalServerError)
}
}()
c.Next()
}
}
|
如果不想自己实现,可以使用github上有别人封装好的https://github.com/gin-contrib/zap。
这样我们就可以在gin框架中使用我们上面定义好的两个中间件来代替gin框架默认的Logger()
和Recovery()
了。
1
2
|
r := gin.New()
r.Use(GinLogger(), GinRecovery())
|
在gin项目中使用zap
最后我们再加入我们项目中常用的日志切割,完整版的logger.go
代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
|
package logger
import (
"gin_zap_demo/config"
"net"
"net/http"
"net/http/httputil"
"os"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/natefinch/lumberjack"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
var lg *zap.Logger
// InitLogger 初始化Logger
func InitLogger(cfg *config.LogConfig) (err error) {
writeSyncer := getLogWriter(cfg.Filename, cfg.MaxSize, cfg.MaxBackups, cfg.MaxAge)
encoder := getEncoder()
var l = new(zapcore.Level)
err = l.UnmarshalText([]byte(cfg.Level))
if err != nil {
return
}
core := zapcore.NewCore(encoder, writeSyncer, l)
lg = zap.New(core, zap.AddCaller())
zap.ReplaceGlobals(lg) // 替换zap包中全局的logger实例,后续在其他包中只需使用zap.L()调用即可
return
}
func getEncoder() zapcore.Encoder {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.TimeKey = "time"
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
encoderConfig.EncodeDuration = zapcore.SecondsDurationEncoder
encoderConfig.EncodeCaller = zapcore.ShortCallerEncoder
return zapcore.NewJSONEncoder(encoderConfig)
}
func getLogWriter(filename string, maxSize, maxBackup, maxAge int) zapcore.WriteSyncer {
lumberJackLogger := &lumberjack.Logger{
Filename: filename,
MaxSize: maxSize,
MaxBackups: maxBackup,
MaxAge: maxAge,
}
return zapcore.AddSync(lumberJackLogger)
}
// GinLogger 接收gin框架默认的日志
func GinLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
c.Next()
cost := time.Since(start)
lg.Info(path,
zap.Int("status", c.Writer.Status()),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("ip", c.ClientIP()),
zap.String("user-agent", c.Request.UserAgent()),
zap.String("errors", c.Errors.ByType(gin.ErrorTypePrivate).String()),
zap.Duration("cost", cost),
)
}
}
// GinRecovery recover掉项目可能出现的panic,并使用zap记录相关日志
func GinRecovery(stack bool) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}
httpRequest, _ := httputil.DumpRequest(c.Request, false)
if brokenPipe {
lg.Error(c.Request.URL.Path,
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
// If the connection is dead, we can't write a status to it.
c.Error(err.(error)) // nolint: errcheck
c.Abort()
return
}
if stack {
lg.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
zap.String("stack", string(debug.Stack())),
)
} else {
lg.Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
}
c.AbortWithStatus(http.StatusInternalServerError)
}
}()
c.Next()
}
}
|
然后定义日志相关配置:
1
2
3
4
5
6
7
|
type LogConfig struct {
Level string `json:"level"`
Filename string `json:"filename"`
MaxSize int `json:"maxsize"`
MaxAge int `json:"max_age"`
MaxBackups int `json:"max_backups"`
}
|
在项目中先从配置文件加载配置信息,再调用logger.InitLogger(config.Conf.LogConfig)
即可完成logger实例的初识化。其中,通过r.Use(logger.GinLogger(), logger.GinRecovery(true))
注册我们的中间件来使用zap接收gin框架自身的日志,在项目中需要的地方通过使用zap.L().Xxx()
方法来记录自定义日志信息。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
package main
import (
"fmt"
"gin_zap_demo/config"
"gin_zap_demo/logger"
"net/http"
"os"
"go.uber.org/zap"
"github.com/gin-gonic/gin"
)
func main() {
// load config from config.json
if len(os.Args) < 1 {
return
}
if err := config.Init(os.Args[1]); err != nil {
panic(err)
}
// init logger
if err := logger.InitLogger(config.Conf.LogConfig); err != nil {
fmt.Printf("init logger failed, err:%v\n", err)
return
}
gin.SetMode(config.Conf.Mode)
r := gin.Default()
// 注册zap相关中间件
r.Use(logger.GinLogger(), logger.GinRecovery(true))
r.GET("/hello", func(c *gin.Context) {
// 假设你有一些数据需要记录到日志中
var (
name = "q1mi"
age = 18
)
// 记录日志并使用zap.Xxx(key, val)记录相关字段
zap.L().Debug("this is hello func", zap.String("user", name), zap.Int("age", age))
c.String(http.StatusOK, "hello liwenzhou.com!")
})
addr := fmt.Sprintf(":%v", config.Conf.Port)
r.Run(addr)
}
|
Viper是适用于Go应用程序的完整配置解决方案。它被设计用于在应用程序中工作,并且可以处理所有类型的配置需求和格式。
Viper
Viper是适用于Go应用程序的完整配置解决方案。它被设计用于在应用程序中工作,并且可以处理所有类型的配置需求和格式。
鉴于viper
库本身的README已经写得十分详细,这里就将其翻译成中文,并在最后附上两个项目中使用viper
的示例代码以供参考。
安装
1
|
go get github.com/spf13/viper
|
什么是Viper?
Viper是适用于Go应用程序(包括Twelve-Factor App
)的完整配置解决方案。它被设计用于在应用程序中工作,并且可以处理所有类型的配置需求和格式。它支持以下特性:
- 设置默认值
- 从
JSON
、TOML
、YAML
、HCL
、envfile
和Java properties
格式的配置文件读取配置信息
- 实时监控和重新读取配置文件(可选)
- 从环境变量中读取
- 从远程配置系统(etcd或Consul)读取并监控配置变化
- 从命令行参数读取配置
- 从buffer读取配置
- 显式配置值
为什么选择Viper?
在构建现代应用程序时,你无需担心配置文件格式;你想要专注于构建出色的软件。Viper的出现就是为了在这方面帮助你的。
Viper能够为你执行下列操作:
- 查找、加载和反序列化
JSON
、TOML
、YAML
、HCL
、INI
、envfile
和Java properties
格式的配置文件。
- 提供一种机制为你的不同配置选项设置默认值。
- 提供一种机制来通过命令行参数覆盖指定选项的值。
- 提供别名系统,以便在不破坏现有代码的情况下轻松重命名参数。
- 当用户提供了与默认值相同的命令行或配置文件时,可以很容易地分辨出它们之间的区别。
Viper会按照下面的优先级。每个项目的优先级都高于它下面的项目:
- 显示调用
Set
设置值
- 命令行参数(flag)
- 环境变量
- 配置文件
- key/value存储
- 默认值
重要: 目前Viper配置的键(Key)是大小写不敏感的。目前正在讨论是否将这一选项设为可选。
把值存入Viper
建立默认值
一个好的配置系统应该支持默认值。键不需要默认值,但如果没有通过配置文件、环境变量、远程配置或命令行标志(flag)设置键,则默认值非常有用。
例如:
1
2
3
|
viper.SetDefault("ContentDir", "content")
viper.SetDefault("LayoutDir", "layouts")
viper.SetDefault("Taxonomies", map[string]string{"tag": "tags", "category": "categories"})
|
读取配置文件
Viper需要最少知道在哪里查找配置文件的配置。Viper支持JSON
、TOML
、YAML
、HCL
、envfile
和Java properties
格式的配置文件。Viper可以搜索多个路径,但目前单个Viper实例只支持单个配置文件。Viper不默认任何配置搜索路径,将默认决策留给应用程序。
下面是一个如何使用Viper搜索和读取配置文件的示例。不需要任何特定的路径,但是至少应该提供一个配置文件预期出现的路径。
1
2
3
4
5
6
7
8
9
10
|
viper.SetConfigFile("./config.yaml") // 指定配置文件路径
viper.SetConfigName("config") // 配置文件名称(无扩展名)
viper.SetConfigType("yaml") // 如果配置文件的名称中没有扩展名,则需要配置此项
viper.AddConfigPath("/etc/appname/") // 查找配置文件所在的路径
viper.AddConfigPath("$HOME/.appname") // 多次调用以添加多个搜索路径
viper.AddConfigPath(".") // 还可以在工作目录中查找配置
err := viper.ReadInConfig() // 查找并读取配置文件
if err != nil { // 处理读取配置文件的错误
panic(fmt.Errorf("Fatal error config file: %s \n", err))
}
|
在加载配置文件出错时,你可以像下面这样处理找不到配置文件的特定情况:
1
2
3
4
5
6
7
8
9
|
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
// 配置文件未找到错误;如果需要可以忽略
} else {
// 配置文件被找到,但产生了另外的错误
}
}
// 配置文件找到并成功解析
|
注意[自1.6起]: 你也可以有不带扩展名的文件,并以编程方式指定其格式。对于位于用户$HOME
目录中的配置文件没有任何扩展名,如.bashrc
。
这里补充两个问题供读者解答并自行验证
当你使用如下方式读取配置时,viper会从./conf
目录下查找任何以config
为文件名的配置文件,如果同时存在./conf/config.json
和./conf/config.yaml
两个配置文件的话,viper
会从哪个配置文件加载配置呢?
1
2
|
viper.SetConfigName("config")
viper.AddConfigPath("./conf")
|
在上面两个语句下搭配使用viper.SetConfigType("yaml")
指定配置文件类型可以实现预期的效果吗?
写入配置文件
从配置文件中读取配置文件是有用的,但是有时你想要存储在运行时所做的所有修改。为此,可以使用下面一组命令,每个命令都有自己的用途:
- WriteConfig - 将当前的
viper
配置写入预定义的路径并覆盖(如果存在的话)。如果没有预定义的路径,则报错。
- SafeWriteConfig - 将当前的
viper
配置写入预定义的路径。如果没有预定义的路径,则报错。如果存在,将不会覆盖当前的配置文件。
- WriteConfigAs - 将当前的
viper
配置写入给定的文件路径。将覆盖给定的文件(如果它存在的话)。
- SafeWriteConfigAs - 将当前的
viper
配置写入给定的文件路径。不会覆盖给定的文件(如果它存在的话)。
根据经验,标记为safe
的所有方法都不会覆盖任何文件,而是直接创建(如果不存在),而默认行为是创建或截断。
一个小示例:
1
2
3
4
5
|
viper.WriteConfig() // 将当前配置写入“viper.AddConfigPath()”和“viper.SetConfigName”设置的预定义路径
viper.SafeWriteConfig()
viper.WriteConfigAs("/path/to/my/.config")
viper.SafeWriteConfigAs("/path/to/my/.config") // 因为该配置文件写入过,所以会报错
viper.SafeWriteConfigAs("/path/to/my/.other_config")
|
监控并重新读取配置文件
Viper支持在运行时实时读取配置文件的功能。
需要重新启动服务器以使配置生效的日子已经一去不复返了,viper驱动的应用程序可以在运行时读取配置文件的更新,而不会错过任何消息。
只需告诉viper实例watchConfig。可选地,你可以为Viper提供一个回调函数,以便在每次发生更改时运行。
确保在调用WatchConfig()
之前添加了所有的配置路径。
1
2
3
4
5
|
viper.WatchConfig()
viper.OnConfigChange(func(e fsnotify.Event) {
// 配置文件发生变更之后会调用的回调函数
fmt.Println("Config file changed:", e.Name)
})
|
从io.Reader读取配置
Viper预先定义了许多配置源,如文件、环境变量、标志和远程K/V存储,但你不受其约束。你还可以实现自己所需的配置源并将其提供给viper。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
viper.SetConfigType("yaml") // 或者 viper.SetConfigType("YAML")
// 任何需要将此配置添加到程序中的方法。
var yamlExample = []byte(`
Hacker: true
name: steve
hobbies:
- skateboarding
- snowboarding
- go
clothing:
jacket: leather
trousers: denim
age: 35
eyes : brown
beard: true
`)
viper.ReadConfig(bytes.NewBuffer(yamlExample))
viper.Get("name") // 这里会得到 "steve"
|
覆盖设置
这些可能来自命令行标志,也可能来自你自己的应用程序逻辑。
1
2
|
viper.Set("Verbose", true)
viper.Set("LogFile", LogFile)
|
注册和使用别名
别名允许多个键引用单个值
1
2
3
4
5
6
7
|
viper.RegisterAlias("loud", "Verbose") // 注册别名(此处loud和Verbose建立了别名)
viper.Set("verbose", true) // 结果与下一行相同
viper.Set("loud", true) // 结果与前一行相同
viper.GetBool("loud") // true
viper.GetBool("verbose") // true
|
使用环境变量
Viper完全支持环境变量。这使Twelve-Factor App
开箱即用。有五种方法可以帮助与ENV协作:
AutomaticEnv()
BindEnv(string...) : error
SetEnvPrefix(string)
SetEnvKeyReplacer(string...) *strings.Replacer
AllowEmptyEnv(bool)
使用ENV变量时,务必要意识到Viper将ENV变量视为区分大小写。
Viper提供了一种机制来确保ENV变量是惟一的。通过使用SetEnvPrefix
,你可以告诉Viper在读取环境变量时使用前缀。BindEnv
和AutomaticEnv
都将使用这个前缀。
BindEnv
使用一个或两个参数。第一个参数是键名称,第二个是环境变量的名称。环境变量的名称区分大小写。如果没有提供ENV变量名,那么Viper将自动假设ENV变量与以下格式匹配:前缀+ “_” +键名全部大写。当你显式提供ENV变量名(第二个参数)时,它 不会 自动添加前缀。例如,如果第二个参数是“id”,Viper将查找环境变量“ID”。
在使用ENV变量时,需要注意的一件重要事情是,每次访问该值时都将读取它。Viper在调用BindEnv
时不固定该值。
AutomaticEnv
是一个强大的助手,尤其是与SetEnvPrefix
结合使用时。调用时,Viper会在发出viper.Get
请求时随时检查环境变量。它将应用以下规则。它将检查环境变量的名称是否与键匹配(如果设置了EnvPrefix
)。
SetEnvKeyReplacer
允许你使用strings.Replacer
对象在一定程度上重写 Env 键。如果你希望在Get()
调用中使用-
或者其他什么符号,但是环境变量里使用_
分隔符,那么这个功能是非常有用的。可以在viper_test.go
中找到它的使用示例。
或者,你可以使用带有NewWithOptions
工厂函数的EnvKeyReplacer
。与SetEnvKeyReplacer
不同,它接受StringReplacer
接口,允许你编写自定义字符串替换逻辑。
默认情况下,空环境变量被认为是未设置的,并将返回到下一个配置源。若要将空环境变量视为已设置,请使用AllowEmptyEnv
方法。
Env 示例:
1
2
3
4
5
6
|
SetEnvPrefix("spf") // 将自动转为大写
BindEnv("id")
os.Setenv("SPF_ID", "13") // 通常是在应用程序之外完成的
id := Get("id") // 13
|
使用Flags
Viper 具有绑定到标志的能力。具体来说,Viper支持Cobra库中使用的Pflag
。
与BindEnv
类似,该值不是在调用绑定方法时设置的,而是在访问该方法时设置的。这意味着你可以根据需要尽早进行绑定,即使在init()
函数中也是如此。
对于单个标志,BindPFlag()
方法提供此功能。
例如:
1
2
|
serverCmd.Flags().Int("port", 1138, "Port to run Application server on")
viper.BindPFlag("port", serverCmd.Flags().Lookup("port"))
|
你还可以绑定一组现有的pflags (pflag.FlagSet):
举个例子:
1
2
3
4
5
6
|
pflag.Int("flagname", 1234, "help message for flagname")
pflag.Parse()
viper.BindPFlags(pflag.CommandLine)
i := viper.GetInt("flagname") // 从viper而不是从pflag检索值
|
在 Viper 中使用 pflag 并不阻碍其他包中使用标准库中的 flag 包。pflag 包可以通过导入这些 flags 来处理flag包定义的flags。这是通过调用pflag包提供的便利函数AddGoFlagSet()
来实现的。
例如:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
package main
import (
"flag"
"github.com/spf13/pflag"
)
func main() {
// 使用标准库 "flag" 包
flag.Int("flagname", 1234, "help message for flagname")
pflag.CommandLine.AddGoFlagSet(flag.CommandLine)
pflag.Parse()
viper.BindPFlags(pflag.CommandLine)
i := viper.GetInt("flagname") // 从 viper 检索值
...
}
|
flag接口
如果你不使用Pflag
,Viper 提供了两个Go接口来绑定其他 flag 系统。
FlagValue
表示单个flag。这是一个关于如何实现这个接口的非常简单的例子:
1
2
3
4
5
|
type myFlag struct {}
func (f myFlag) HasChanged() bool { return false }
func (f myFlag) Name() string { return "my-flag-name" }
func (f myFlag) ValueString() string { return "my-flag-value" }
func (f myFlag) ValueType() string { return "string" }
|
一旦你的 flag 实现了这个接口,你可以很方便地告诉Viper绑定它:
1
|
viper.BindFlagValue("my-flag-name", myFlag{})
|
FlagValueSet
代表一组 flags 。这是一个关于如何实现这个接口的非常简单的例子:
1
2
3
4
5
6
7
8
9
|
type myFlagSet struct {
flags []myFlag
}
func (f myFlagSet) VisitAll(fn func(FlagValue)) {
for _, flag := range flags {
fn(flag)
}
}
|
一旦你的flag set实现了这个接口,你就可以很方便地告诉Viper绑定它:
1
2
3
4
|
fSet := myFlagSet{
flags: []myFlag{myFlag{}, myFlag{}},
}
viper.BindFlagValues("my-flags", fSet)
|
远程Key/Value存储支持
在Viper中启用远程支持,需要在代码中匿名导入viper/remote
这个包。
1
|
import _ "github.com/spf13/viper/remote"
|
Viper将读取从Key/Value存储(例如etcd或Consul)中的路径检索到的配置字符串(如JSON
、TOML
、YAML
、HCL
、envfile
和Java properties
格式)。这些值的优先级高于默认值,但是会被从磁盘、flag或环境变量检索到的配置值覆盖。(译注:也就是说Viper加载配置值的优先级为:磁盘上的配置文件>命令行标志位>环境变量>远程Key/Value存储>默认值。)
Viper使用crypt从K/V存储中检索配置,这意味着如果你有正确的gpg密匙,你可以将配置值加密存储并自动解密。加密是可选的。
你可以将远程配置与本地配置结合使用,也可以独立使用。
crypt
有一个命令行助手,你可以使用它将配置放入K/V存储中。crypt
默认使用在http://127.0.0.1:4001的etcd。
1
2
|
$ go get github.com/bketelsen/crypt/bin/crypt
$ crypt set -plaintext /config/hugo.json /Users/hugo/settings/config.json
|
确认值已经设置:
1
|
$ crypt get -plaintext /config/hugo.json
|
有关如何设置加密值或如何使用Consul的示例,请参见crypt
文档。
远程Key/Value存储示例-未加密
etcd
1
2
3
|
viper.AddRemoteProvider("etcd", "http://127.0.0.1:4001","/config/hugo.json")
viper.SetConfigType("json") // 因为在字节流中没有文件扩展名,所以这里需要设置下类型。支持的扩展名有 "json", "toml", "yaml", "yml", "properties", "props", "prop", "env", "dotenv"
err := viper.ReadRemoteConfig()
|
Consul
你需要 Consul Key/Value存储中设置一个Key保存包含所需配置的JSON值。例如,创建一个keyMY_CONSUL_KEY
将下面的值存入Consul key/value 存储:
1
2
3
4
5
6
7
8
9
10
|
{
"port": 8080,
"hostname": "liwenzhou.com"
}
viper.AddRemoteProvider("consul", "localhost:8500", "MY_CONSUL_KEY")
viper.SetConfigType("json") // 需要显示设置成json
err := viper.ReadRemoteConfig()
fmt.Println(viper.Get("port")) // 8080
fmt.Println(viper.Get("hostname")) // liwenzhou.com
|
Firestore
1
2
3
|
viper.AddRemoteProvider("firestore", "google-cloud-project-id", "collection/document")
viper.SetConfigType("json") // 配置的格式: "json", "toml", "yaml", "yml"
err := viper.ReadRemoteConfig()
|
当然,你也可以使用SecureRemoteProvider
。
远程Key/Value存储示例-加密
1
2
3
|
viper.AddSecureRemoteProvider("etcd","http://127.0.0.1:4001","/config/hugo.json","/etc/secrets/mykeyring.gpg")
viper.SetConfigType("json") // 因为在字节流中没有文件扩展名,所以这里需要设置下类型。支持的扩展名有 "json", "toml", "yaml", "yml", "properties", "props", "prop", "env", "dotenv"
err := viper.ReadRemoteConfig()
|
监控etcd中的更改-未加密
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
// 或者你可以创建一个新的viper实例
var runtime_viper = viper.New()
runtime_viper.AddRemoteProvider("etcd", "http://127.0.0.1:4001", "/config/hugo.yml")
runtime_viper.SetConfigType("yaml") // 因为在字节流中没有文件扩展名,所以这里需要设置下类型。支持的扩展名有 "json", "toml", "yaml", "yml", "properties", "props", "prop", "env", "dotenv"
// 第一次从远程读取配置
err := runtime_viper.ReadRemoteConfig()
// 反序列化
runtime_viper.Unmarshal(&runtime_conf)
// 开启一个单独的goroutine一直监控远端的变更
go func(){
for {
time.Sleep(time.Second * 5) // 每次请求后延迟一下
// 目前只测试了etcd支持
err := runtime_viper.WatchRemoteConfig()
if err != nil {
log.Errorf("unable to read remote config: %v", err)
continue
}
// 将新配置反序列化到我们运行时的配置结构体中。你还可以借助channel实现一个通知系统更改的信号
runtime_viper.Unmarshal(&runtime_conf)
}
}()
|
从Viper获取值
在Viper中,有几种方法可以根据值的类型获取值。存在以下功能和方法:
Get(key string) : interface{}
GetBool(key string) : bool
GetFloat64(key string) : float64
GetInt(key string) : int
GetIntSlice(key string) : []int
GetString(key string) : string
GetStringMap(key string) : map[string]interface{}
GetStringMapString(key string) : map[string]string
GetStringSlice(key string) : []string
GetTime(key string) : time.Time
GetDuration(key string) : time.Duration
IsSet(key string) : bool
AllSettings() : map[string]interface{}
需要认识到的一件重要事情是,每一个Get方法在找不到值的时候都会返回零值。为了检查给定的键是否存在,提供了IsSet()
方法。
例如:
1
2
3
4
|
viper.GetString("logfile") // 不区分大小写的设置和获取
if viper.GetBool("verbose") {
fmt.Println("verbose enabled")
}
|
访问嵌套的键
访问器方法也接受深度嵌套键的格式化路径。例如,如果加载下面的JSON文件:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
{
"host": {
"address": "localhost",
"port": 5799
},
"datastore": {
"metric": {
"host": "127.0.0.1",
"port": 3099
},
"warehouse": {
"host": "198.0.0.1",
"port": 2112
}
}
}
|
Viper可以通过传入.
分隔的路径来访问嵌套字段:
1
|
GetString("datastore.metric.host") // (返回 "127.0.0.1")
|
这遵守上面建立的优先规则;搜索路径将遍历其余配置注册表,直到找到为止。(译注:因为Viper支持从多种配置来源,例如磁盘上的配置文件>命令行标志位>环境变量>远程Key/Value存储>默认值,我们在查找一个配置的时候如果在当前配置源中没找到,就会继续从后续的配置源查找,直到找到为止。)
例如,在给定此配置文件的情况下,datastore.metric.host
和datastore.metric.port
均已定义(并且可以被覆盖)。如果另外在默认值中定义了datastore.metric.protocol
,Viper也会找到它。
然而,如果datastore.metric
被直接赋值覆盖(被flag,环境变量,set()
方法等等…),那么datastore.metric
的所有子键都将变为未定义状态,它们被高优先级配置级别“遮蔽”(shadowed)了。
最后,如果存在与分隔的键路径匹配的键,则返回其值。例如:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
|
{
"datastore.metric.host": "0.0.0.0",
"host": {
"address": "localhost",
"port": 5799
},
"datastore": {
"metric": {
"host": "127.0.0.1",
"port": 3099
},
"warehouse": {
"host": "198.0.0.1",
"port": 2112
}
}
}
GetString("datastore.metric.host") // 返回 "0.0.0.0"
|
提取子树
从Viper中提取子树。
例如,viper
实例现在代表了以下配置:
1
2
3
4
5
6
7
|
app:
cache1:
max-items: 100
item-size: 64
cache2:
max-items: 200
item-size: 80
|
执行后:
1
|
subv := viper.Sub("app.cache1")
|
subv
现在就代表:
1
2
|
max-items: 100
item-size: 64
|
假设我们现在有这么一个函数:
1
|
func NewCache(cfg *Viper) *Cache {...}
|
它基于subv
格式的配置信息创建缓存。现在,可以轻松地分别创建这两个缓存,如下所示:
1
2
3
4
5
|
cfg1 := viper.Sub("app.cache1")
cache1 := NewCache(cfg1)
cfg2 := viper.Sub("app.cache2")
cache2 := NewCache(cfg2)
|
反序列化
你还可以选择将所有或特定的值解析到结构体、map等。
有两种方法可以做到这一点:
Unmarshal(rawVal interface{}) : error
UnmarshalKey(key string, rawVal interface{}) : error
举个例子:
1
2
3
4
5
6
7
8
9
10
11
12
|
type config struct {
Port int
Name string
PathMap string `mapstructure:"path_map"`
}
var C config
err := viper.Unmarshal(&C)
if err != nil {
t.Fatalf("unable to decode into struct, %v", err)
}
|
如果你想要解析那些键本身就包含.
(默认的键分隔符)的配置,你需要修改分隔符:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
v := viper.NewWithOptions(viper.KeyDelimiter("::"))
v.SetDefault("chart::values", map[string]interface{}{
"ingress": map[string]interface{}{
"annotations": map[string]interface{}{
"traefik.frontend.rule.type": "PathPrefix",
"traefik.ingress.kubernetes.io/ssl-redirect": "true",
},
},
})
type config struct {
Chart struct{
Values map[string]interface{}
}
}
var C config
v.Unmarshal(&C)
|
Viper还支持解析到嵌入的结构体:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
/*
Example config:
module:
enabled: true
token: 89h3f98hbwf987h3f98wenf89ehf
*/
type config struct {
Module struct {
Enabled bool
moduleConfig `mapstructure:",squash"`
}
}
// moduleConfig could be in a module specific package
type moduleConfig struct {
Token string
}
var C config
err := viper.Unmarshal(&C)
if err != nil {
t.Fatalf("unable to decode into struct, %v", err)
}
|
Viper在后台使用github.com/mitchellh/mapstructure来解析值,其默认情况下使用mapstructure
tag。
注意 当我们需要将viper读取的配置反序列到我们定义的结构体变量中时,一定要使用mapstructure
tag哦!
序列化成字符串
你可能需要将viper中保存的所有设置序列化到一个字符串中,而不是将它们写入到一个文件中。你可以将自己喜欢的格式的序列化器与AllSettings()
返回的配置一起使用。
1
2
3
4
5
6
7
8
9
10
11
12
13
|
import (
yaml "gopkg.in/yaml.v2"
// ...
)
func yamlStringSettings() string {
c := viper.AllSettings()
bs, err := yaml.Marshal(c)
if err != nil {
log.Fatalf("unable to marshal config to YAML: %v", err)
}
return string(bs)
}
|
使用单个还是多个Viper实例?
Viper是开箱即用的。你不需要配置或初始化即可开始使用Viper。由于大多数应用程序都希望使用单个中央存储库管理它们的配置信息,所以viper包提供了这个功能。它类似于单例模式。
在上面的所有示例中,它们都以其单例风格的方法演示了如何使用viper。
使用多个viper实例
你还可以在应用程序中创建许多不同的viper实例。每个都有自己独特的一组配置和值。每个人都可以从不同的配置文件,key value存储区等读取数据。每个都可以从不同的配置文件、键值存储等中读取。viper包支持的所有功能都被镜像为viper实例的方法。
例如:
1
2
3
4
5
6
7
|
x := viper.New()
y := viper.New()
x.SetDefault("ContentDir", "content")
y.SetDefault("ContentDir", "foobar")
//...
|
当使用多个viper实例时,由用户来管理不同的viper实例。
使用Viper示例
假设我们的项目现在有一个./conf/config.yaml
配置文件,内容如下:
1
2
|
port: 8123
version: "v1.2.3"
|
接下来通过示例代码演示两种在项目中使用viper
管理项目配置信息的方式。
直接使用viper管理配置
这里用一个demo演示如何在gin框架搭建的web项目中使用viper
,使用viper加载配置文件中的信息,并在代码中直接使用viper.GetXXX()
方法获取对应的配置值。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
package main
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
)
func main() {
viper.SetConfigFile("./conf/config.yaml") // 指定配置文件路径
err := viper.ReadInConfig() // 读取配置信息
if err != nil { // 读取配置信息失败
panic(fmt.Errorf("Fatal error config file: %s \n", err))
}
// 监控配置文件变化
viper.WatchConfig()
r := gin.Default()
// 访问/version的返回值会随配置文件的变化而变化
r.GET("/version", func(c *gin.Context) {
c.String(http.StatusOK, viper.GetString("version"))
})
if err := r.Run(
fmt.Sprintf(":%d", viper.GetInt("port"))); err != nil {
panic(err)
}
}
|
使用结构体变量保存配置信息
除了上面的用法外,我们还可以在项目中定义与配置文件对应的结构体,viper
加载完配置信息后使用结构体变量保存配置信息。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
|
package main
import (
"fmt"
"net/http"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
)
type Config struct {
Port int `mapstructure:"port"`
Version string `mapstructure:"version"`
}
var Conf = new(Config)
func main() {
viper.SetConfigFile("./conf/config.yaml") // 指定配置文件路径
err := viper.ReadInConfig() // 读取配置信息
if err != nil { // 读取配置信息失败
panic(fmt.Errorf("Fatal error config file: %s \n", err))
}
// 将读取的配置信息保存至全局变量Conf
if err := viper.Unmarshal(Conf); err != nil {
panic(fmt.Errorf("unmarshal conf failed, err:%s \n", err))
}
// 监控配置文件变化
viper.WatchConfig()
// 注意!!!配置文件发生变化后要同步到全局变量Conf
viper.OnConfigChange(func(in fsnotify.Event) {
fmt.Println("夭寿啦~配置文件被人修改啦...")
if err := viper.Unmarshal(Conf); err != nil {
panic(fmt.Errorf("unmarshal conf failed, err:%s \n", err))
}
})
r := gin.Default()
// 访问/version的返回值会随配置文件的变化而变化
r.GET("/version", func(c *gin.Context) {
c.String(http.StatusOK, Conf.Version)
})
if err := r.Run(fmt.Sprintf(":%d", Conf.Port)); err != nil {
panic(err)
}
}
|
优雅的关机
我们编写的Web项目部署之后,经常会因为需要进行配置变更或功能迭代而重启服务,单纯的kill -9 pid
的方式会强制关闭进程,这样就会导致服务端当前正在处理的请求失败,那有没有更优雅的方式来实现关机或重启呢?
阅读本文需要了解一些UNIX系统中信号
的概念,请提前查阅资料预习。
优雅地关机
什么是优雅关机?
优雅关机就是服务端关机命令发出后不是立即关机,而是等待当前还在处理的请求全部处理完毕后再退出程序,是一种对客户端友好的关机方式。而执行Ctrl+C
关闭服务端时,会强制结束进程导致正在访问的请求出现问题。
如何实现优雅关机?
Go 1.8版本之后, http.Server 内置的 Shutdown() 方法就支持优雅地关机,具体示例如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
|
// +build go1.8
package main
import (
"context"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/gin-gonic/gin"
)
func main() {
router := gin.Default()
router.GET("/", func(c *gin.Context) {
time.Sleep(5 * time.Second)
c.String(http.StatusOK, "Welcome Gin Server")
})
srv := &http.Server{
Addr: ":8080",
Handler: router,
}
go func() {
// 开启一个goroutine启动服务
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("listen: %s\n", err)
}
}()
// 等待中断信号来优雅地关闭服务器,为关闭服务器操作设置一个5秒的超时
quit := make(chan os.Signal, 1) // 创建一个接收信号的通道
// kill 默认会发送 syscall.SIGTERM 信号
// kill -2 发送 syscall.SIGINT 信号,我们常用的Ctrl+C就是触发系统SIGINT信号
// kill -9 发送 syscall.SIGKILL 信号,但是不能被捕获,所以不需要添加它
// signal.Notify把收到的 syscall.SIGINT或syscall.SIGTERM 信号转发给quit
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) // 此处不会阻塞
<-quit // 阻塞在此,当接收到上述两种信号时才会往下执行
log.Println("Shutdown Server ...")
// 创建一个5秒超时的context
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 5秒内优雅关闭服务(将未处理完的请求处理完再关闭服务),超过5秒就超时退出
if err := srv.Shutdown(ctx); err != nil {
log.Fatal("Server Shutdown: ", err)
}
log.Println("Server exiting")
}
|
如何验证优雅关机的效果呢?
上面的代码运行后会在本地的8080
端口开启一个web服务,它只注册了一条路由/
,后端服务会先sleep 5秒钟然后才返回响应信息。
我们按下Ctrl+C
时会发送syscall.SIGINT
来通知程序优雅关机,具体做法如下:
- 打开终端,编译并执行上面的代码
- 打开一个浏览器,访问
127.0.0.1:8080/
,此时浏览器白屏等待服务端返回响应。
- 在终端迅速执行
Ctrl+C
命令给程序发送syscall.SIGINT
信号
- 此时程序并不立即退出而是等我们第2步的响应返回之后再退出,从而实现优雅关机。
优雅地重启
优雅关机实现了,那么该如何实现优雅重启呢?
我们可以使用 fvbock/endless 来替换默认的 ListenAndServe
启动服务来实现, 示例代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
package main
import (
"log"
"net/http"
"time"
"github.com/fvbock/endless"
"github.com/gin-gonic/gin"
)
func main() {
router := gin.Default()
router.GET("/", func(c *gin.Context) {
time.Sleep(5 * time.Second)
c.String(http.StatusOK, "hello gin!")
})
// 默认endless服务器会监听下列信号:
// syscall.SIGHUP,syscall.SIGUSR1,syscall.SIGUSR2,syscall.SIGINT,syscall.SIGTERM和syscall.SIGTSTP
// 接收到 SIGHUP 信号将触发`fork/restart` 实现优雅重启(kill -1 pid会发送SIGHUP信号)
// 接收到 syscall.SIGINT或syscall.SIGTERM 信号将触发优雅关机
// 接收到 SIGUSR2 信号将触发HammerTime
// SIGUSR1 和 SIGTSTP 被用来触发一些用户自定义的hook函数
if err := endless.ListenAndServe(":8080", router); err!=nil{
log.Fatalf("listen: %s\n", err)
}
log.Println("Server exiting")
}
|
如何验证优雅重启的效果呢?
我们通过执行kill -1 pid
命令发送syscall.SIGINT
来通知程序优雅重启,具体做法如下:
- 打开终端,
go build -o graceful_restart
编译并执行./graceful_restart
,终端输出当前pid(假设为43682)
- 将代码中处理请求函数返回的
hello gin!
修改为hello q1mi!
,再次编译go build -o graceful_restart
- 打开一个浏览器,访问
127.0.0.1:8080/
,此时浏览器白屏等待服务端返回响应。
- 在终端迅速执行
kill -1 43682
命令给程序发送syscall.SIGHUP
信号
- 等第3步浏览器收到响应信息
hello gin!
后再次访问127.0.0.1:8080/
会收到hello q1mi!
的响应。
- 在不影响当前未处理完请求的同时完成了程序代码的替换,实现了优雅重启。
但是需要注意的是,此时程序的PID变化了,因为endless
是通过fork
子进程处理新请求,待原进程处理完当前请求后再退出的方式实现优雅重启的。所以当你的项目是使用类似supervisor
的软件管理进程时就不适用这种方式了。
搭建一个通用的脚手架工具
v1.0
mysql.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
package mysql
import (
"fmt"
"github.com/spf13/viper"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
)
var db *sqlx.DB
func Init() (err error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True",
viper.GetString("mysql.user"),
viper.GetString("mysql.password"),
viper.GetString("mysql.host"),
viper.GetInt("mysql.port"),
viper.GetString("mysql.dbName"),
)
fmt.Println(dsn)
db, err = sqlx.Connect("mysql", dsn)
if err != nil {
fmt.Printf("mysql connect error,err : %v\n", err)
return
}
db.SetMaxOpenConns(viper.GetInt("mysql.maxOpen"))
db.SetMaxIdleConns(viper.GetInt("mysql.maxIdle"))
return
}
func Close() {
db.Close()
}
|
redis.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
package redis
import (
"fmt"
"github.com/go-redis/redis"
"github.com/spf13/viper"
)
var rdb *redis.Client
func Init() (err error) {
rdb = redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%d", viper.GetString("redis.host"), viper.GetInt("redis.port")),
Password: viper.GetString("redis.password"),
DB: viper.GetInt("redis.db"),
PoolSize: viper.GetInt("redis.poolSize"), //连接池大小
})
_, err = rdb.Ping().Result()
if err != nil {
fmt.Printf("redis ping error,err: %v", err)
return
}
return
}
func Close() {
rdb.Close()
}
|
log.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
|
package logger
import (
"net"
"net/http"
"net/http/httputil"
"os"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/natefinch/lumberjack"
"github.com/spf13/viper"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
func Init() (err error) {
writer := getLoggerWriter(
viper.GetString("log.logName"),
viper.GetInt("log.max-size"),
viper.GetInt("log.max-age"),
viper.GetInt("log.max-backup"))
encoder := getEncoder()
var l = new(zapcore.Level)
err = l.UnmarshalText([]byte(viper.GetString("log.level")))
if err != nil {
return
}
core := zapcore.NewCore(encoder, writer, l)
logger := zap.New(core, zap.AddCaller())
//替换zap库中的全局log对象
zap.ReplaceGlobals(logger)
return
}
func getEncoder() zapcore.Encoder {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
return zapcore.NewConsoleEncoder(encoderConfig)
}
func getLoggerWriter(filePath string, maxSize int, maxAge int, maxBackUp int) zapcore.WriteSyncer {
cfg := lumberjack.Logger{
Filename: filePath,
MaxSize: maxSize,
MaxBackups: maxBackUp,
MaxAge: maxAge,
Compress: false,
}
return zapcore.AddSync(&cfg)
}
//接受gin框架的默认日志
func GinLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
query := c.Request.URL.RawQuery
c.Next() //执行后面的中间件,然后计算cost
cost := time.Since(start)
zap.L().Info(path,
zap.Int("status", c.Writer.Status()),
zap.String("method", c.Request.Method),
zap.String("path", path),
zap.String("query", query),
zap.String("ip", c.ClientIP()),
zap.String("user-agent", c.Request.UserAgent()),
zap.String("errors", c.Errors.ByType(gin.ErrorTypePrivate).String()),
zap.Duration("cost", cost))
}
}
// GinRecovery recover掉项目可能出现的panic,并使用zap记录相关日志
func GinRecovery(stack bool) gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
// Check for a broken connection, as it is not really a
// condition that warrants a panic stack trace.
var brokenPipe bool
if ne, ok := err.(*net.OpError); ok {
if se, ok := ne.Err.(*os.SyscallError); ok {
if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") {
brokenPipe = true
}
}
}
httpRequest, _ := httputil.DumpRequest(c.Request, false)
if brokenPipe {
zap.L().Error(c.Request.URL.Path,
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
// If the connection is dead, we can't write a status to it.
c.Error(err.(error)) // nolint: errcheck
c.Abort()
return
}
if stack {
zap.L().Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
zap.String("stack", string(debug.Stack())),
)
} else {
zap.L().Error("[Recovery from panic]",
zap.Any("error", err),
zap.String("request", string(httpRequest)),
)
}
c.AbortWithStatus(http.StatusInternalServerError)
}
}()
c.Next()
}
}
|
route.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
package route
import (
"comman_web_structure/logger"
"net/http"
"github.com/gin-gonic/gin"
)
func SetUp() *gin.Engine {
r := gin.New()
r.Use(logger.GinLogger(), logger.GinRecovery(true))
r.GET("/", func(context *gin.Context) {
context.String(http.StatusOK, "hello")
})
return r
}
|
setting.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
package settings
import (
"fmt"
"github.com/fsnotify/fsnotify"
"github.com/spf13/viper"
)
func Init() (err error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./settings")
err = viper.ReadInConfig()
if err != nil {
fmt.Printf("read error,err:%v\n", err)
return
}
viper.WatchConfig()
viper.OnConfigChange(func(in fsnotify.Event) {
fmt.Printf("配置文件修改了..")
})
return
}
|
v2.0(将配置信息写道结构体)
优化后的settings
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
|
var Conf = new(AppConfig)
type AppConfig struct {
Name string `mapstructure:"name"`
Port string `mapstructure:"port"`
Mode string `mapstructure:"mode"`
*MysqlConfig `mapstructure:"mysql"`
*RedisConfig `mapstructure:"redis"`
*LoggerConfig `mapstructure:"log"`
}
type MysqlConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DbName string `mapstructure:"dbName"`
MaxOpen int `mapstructure:"maxOpen"`
MaxIdle int `mapstructure:"maxIdle"`
}
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
Db int `mapstructure:"db"`
PoolSize int `mapstructure:"poolSize"`
}
type LoggerConfig struct {
Level string `mapstructure:"level"`
LogName string `mapstructure:"logName"`
MaxSize int `mapstructure:"max-size"`
MaxAge int `mapstructure:"max-age"`
MaxBackup int `mapstructure:"max-backup"`
}
func Init() (err error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./settings")
err = viper.ReadInConfig()
if err != nil {
fmt.Printf("read error,err:%v\n", err)
return
}
if err = viper.Unmarshal(Conf); err != nil {
fmt.Printf("unMarshal error, err : %v\n", err)
}
viper.WatchConfig()
viper.OnConfigChange(func(in fsnotify.Event) {
fmt.Printf("配置文件修改了..")
})
return
}
|
v3.0(用命令行指定文件)
用系统内置的os
1
2
3
4
5
6
7
8
9
|
if len(os.Args) < 2 {
fmt.Println("need config file..")
return
}
//1.加载配置 通过args
if err := settings.Init(os.Args[1]); err != nil {
fmt.Printf("init settings error,err:%v", err)
return
}
|
用第三方库flag
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
|
func main() {
//定义命令行参数方式1
var name string
var age int
var married bool
var delay time.Duration
flag.StringVar(&name, "name", "张三", "姓名")
flag.IntVar(&age, "age", 18, "年龄")
flag.BoolVar(&married, "married", false, "婚否")
flag.DurationVar(&delay, "d", 0, "延迟的时间间隔")
//解析命令行参数
flag.Parse()
fmt.Println(name, age, married, delay)
//返回命令行参数后的其他参数
fmt.Println(flag.Args())
//返回命令行参数后的其他参数个数
fmt.Println(flag.NArg())
//返回使用的命令行参数个数
fmt.Println(flag.NFlag())
}
|
项目搭建
分布式ID
雪花算法
如何使用
1
|
go get github.com/bwmarrin/snowflake
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
func Init(startTime string, machineId int64) (err error) {
var t time.Time
//先将当前时间格式转换
t, err = time.Parse("2006-01-02", startTime)
if err != nil {
return
}
snowflake.Epoch = t.UnixNano() / 100000
node, err = snowflake.NewNode(machineId)
return
}
func getId() int64 {
return node.Generate().Int64()
}
|
注册逻辑
在controller层,一般做参数校验已经通过service进行业务处理最后返回。
普通的参数判断
1
2
3
4
5
6
|
if len(par.Username) == 0 || len(par.Password) == 0 || len(par.RePassword) == 0 || par.RePassword != par.Password {
zap.L().Error("Signup with invalid param")
context.JSON(http.StatusOK, gin.H{
"msg": "请求参数有误!",
})
}
|
用第三方库进行参数校验
在web开发中一个不可避免的环节就是对请求参数进行校验,通常我们会在代码中定义与请求参数相对应的模型(结构体),借助模型绑定快捷地解析请求中的参数,例如 gin 框架中的Bind
和ShouldBind
系列方法。本文就以 gin 框架的请求参数校验为例,介绍一些validator
库的实用技巧。
gin框架使用github.com/go-playground/validator进行参数校验,目前已经支持github.com/go-playground/validator/v10
了,我们需要在定义结构体时使用 binding
tag标识相关校验规则,可以查看validator文档查看支持的所有 tag。
用一个简单的binding即可校验
validator返回错误信息的国际化
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
package controller
import (
"fmt"
"reflect"
"strings"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/locales/en"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
enTranslations "github.com/go-playground/validator/v10/translations/en"
zhTranslations "github.com/go-playground/validator/v10/translations/zh"
)
// 定义一个全局翻译器T
var trans ut.Translator
// InitTrans 初始化翻译器
func InitTrans(locale string) (err error) {
// 修改gin框架中的Validator引擎属性,实现自定制
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
// 注册一个获取json tag的自定义方法
v.RegisterTagNameFunc(func(fld reflect.StructField) string {
name := strings.SplitN(fld.Tag.Get("json"), ",", 2)[0]
if name == "-" {
return ""
}
return name
})
zhT := zh.New() // 中文翻译器
enT := en.New() // 英文翻译器
// 第一个参数是备用(fallback)的语言环境
// 后面的参数是应该支持的语言环境(支持多个)
// uni := ut.New(zhT, zhT) 也是可以的
uni := ut.New(enT, zhT, enT)
// locale 通常取决于 http 请求头的 'Accept-Language'
var ok bool
// 也可以使用 uni.FindTranslator(...) 传入多个locale进行查找
trans, ok = uni.GetTranslator(locale)
if !ok {
return fmt.Errorf("uni.GetTranslator(%s) failed", locale)
}
// 注册翻译器
switch locale {
case "en":
err = enTranslations.RegisterDefaultTranslations(v, trans)
case "zh":
err = zhTranslations.RegisterDefaultTranslations(v, trans)
default:
err = enTranslations.RegisterDefaultTranslations(v, trans)
}
return
}
return
}
//去除提示信息的结构体
func removeTopStruct(fields map[string]string) map[string]string {
res := map[string]string{}
for field, err := range fields {
res[field[strings.Index(field, ".")+1:]] = err
}
return res
}
|
日志输出到控制台和文件(dev环境下方便调试)
1
2
3
4
5
6
7
8
9
10
11
|
if settings.Conf.Mode == "dev" {
consoleEncoder := zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig())
//通过tee,得到了两个输出
core = zapcore.NewTee(
zapcore.NewCore(encoder, writer, l),
//向终端输出
zapcore.NewCore(consoleEncoder, zapcore.Lock(os.Stdout), zapcore.DebugLevel),
)
} else {
core = zapcore.NewCore(encoder, writer, l)
}
|
通过配置文件的mode,获得当前的开发环境,然后如果是dev,就通过tee得到两个core,一个输出到文件,一个输出到控制台,方便调试。
封装返回的json
我截取一段 现在的代码,大家可以发现返回的时候都要调用context.json很麻烦。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
func LoginUser(context *gin.Context) {
var p = new(model.LoginParam)
if err := context.ShouldBindJSON(p); err != nil {
errors, ok := err.(validator.ValidationErrors)
zap.L().Error("login with invalid json", zap.Error(err))
if ok {
context.JSON(http.StatusOK, gin.H{
"msg": removeTopStruct(errors.Translate(trans)),
})
} else {
context.JSON(http.StatusOK, gin.H{
"msg": err.Error(),
})
}
return
}
if err := service.Login(p); err != nil {
zap.L().Error("login error", zap.Error(err))
context.JSON(http.StatusOK, gin.H{
"msg": "用户名或者密码错误",
})
return
}
context.JSON(http.StatusOK, gin.H{
"msg": "登陆成功",
})
}
|
我们可以将返回响应封装为一个函数。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
package controller
import (
"net/http"
"github.com/gin-gonic/gin"
)
type Response struct {
Code ResCode `json:"code"`
Msg interface{} `json:"msg"`
Data interface{} `json:"data"`
}
func ResponseError(code ResCode, c *gin.Context) {
c.JSON(http.StatusOK, &Response{
Code: code,
Msg: Msg(code),
Data: nil,
})
}
func ResponseErrorWithMsg(code ResCode, c *gin.Context, msg interface{}) {
c.JSON(http.StatusOK, &Response{
Code: code,
Msg: msg,
Data: nil,
})
}
func ResponseSuccess(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, &Response{
Code: CodeSuccess,
Msg: Msg(CodeSuccess),
Data: data,
})
}
|
对应的Code类型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
|
package controller
type ResCode int64
const (
CodeSuccess ResCode = 1000 + iota
CodeInvalidParam
CodeUserExist
CodeUserNotExist
CodeInvalidPassword
CodeServeBusy
)
var codeMap = map[ResCode]string{
CodeSuccess: "请求成功",
CodeInvalidParam: "请求参数异常",
CodeUserExist: "用户已存在",
CodeUserNotExist: "用户不存在",
CodeInvalidPassword: "密码错误",
CodeServeBusy: "服务器繁忙",
}
func Msg(code ResCode) string {
s, ok := codeMap[code]
if !ok {
return codeMap[CodeServeBusy]
} else {
return s
}
}
|
将errors.New封装成对象,方便controller判断
这些错误我们的controller不好判断,最好把它换成下图这样
改造后,变得优雅了许多
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
|
package controller
import (
"comman_web_structure/dao/mysql"
"comman_web_structure/model"
"comman_web_structure/service"
"errors"
"github.com/go-playground/validator/v10"
"go.uber.org/zap"
"github.com/gin-gonic/gin"
)
func SingUpUser(context *gin.Context) {
var par model.SignUpParam
if err := context.ShouldBindJSON(&par); err != nil {
zap.L().Error("Signup with invalid json", zap.Error(err))
errors, ok := err.(validator.ValidationErrors)
if !ok {
ResponseError(CodeInvalidParam, context)
return
} else {
ResponseErrorWithMsg(CodeInvalidParam, context, removeTopStruct(errors.Translate(trans)))
}
return
}
if err := service.SignUp(&par); err != nil {
zap.L().Error("Signup error", zap.Error(err))
if errors.Is(err, mysql.ErrorUserExist) {
ResponseError(CodeUserExist, context)
return
}
ResponseError(CodeServeBusy, context)
return
}
ResponseSuccess(context, nil)
}
func LoginUser(context *gin.Context) {
var p = new(model.LoginParam)
if err := context.ShouldBindJSON(p); err != nil {
errors, ok := err.(validator.ValidationErrors)
zap.L().Error("login with invalid json", zap.Error(err))
if ok {
ResponseErrorWithMsg(CodeInvalidParam, context, removeTopStruct(errors.Translate(trans)))
} else {
ResponseError(CodeInvalidParam, context)
}
return
}
if err := service.Login(p); err != nil {
zap.L().Error("login error", zap.Error(err))
if errors.Is(err, mysql.ErrorUserNotExist) {
ResponseError(CodeUserNotExist, context)
} else if errors.Is(err, mysql.ErrorUserPasswordWrong) {
ResponseError(CodeInvalidPassword, context)
} else {
ResponseError(CodeServeBusy, context)
}
return
}
ResponseSuccess(context,nil)
}
|
利用Jwt的token进行登录
一般还会有另一种中方案:Cookie + Session,但是会有一点风险,因为会有安全的危险,假如有一个钓鱼网站得到了你的cookie,就会发生安全问题。并且session是需要存储在服务器端的。
1
|
go get -u github.com/dgrijalva/jwt-go
|
通过token进行校验
校验方法我们会放在中间件中,若某个接口需要登陆,可以直接将中间键的方法放到路由里面即可。
方法大概是取出头部的token,对token校验,若出错则返回错误,否则会将用户的id放到context中。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
|
func JWTAuthMiddleware() func(c *gin.Context) {
return func(c *gin.Context) {
// 客户端携带Token有三种方式 1.放在请求头 2.放在请求体 3.放在URI
// 这里假设Token放在Header的Authorization中,并使用Bearer开头
// 这里的具体实现方式要依据你的实际业务情况决定
authHeader := c.Request.Header.Get("Authorization")
if authHeader == "" {
c.JSON(http.StatusOK, gin.H{
"code": 2003,
"msg": "请求头中auth为空",
})
c.Abort()
return
}
// 按空格分割
parts := strings.SplitN(authHeader, " ", 2)
if !(len(parts) == 2 && parts[0] == "Bearer") {
controller.ResponseError(controller.CodeEmptyRequestHeader, c)
c.Abort()
return
}
// parts[1]是获取到的tokenString,我们使用之前定义好的解析JWT的函数来解析它
mc, err := jwt.ParseToken(parts[1])
if err != nil {
controller.ResponseError(controller.CodeInvalidAuth, c)
c.Abort()
return
}
// 将当前请求的username信息保存到请求的上下文c上
c.Set(ContextUserName, mc.Username)
c.Set(ContextUserID, mc.UserId)
c.Next() // 后续的处理函数可以用过c.Get("username")来获取当前请求的用户信息
}
}
|
但是现在我去go build,他却告诉我有循环引用,啥情况?
首先分析一下,下面的日志告诉我们是controller和middle发生循环引用了。
request是在controller层的,他用于解析出request头部中头部所带的payload,再看看鉴权方法是存在与middle包的。
明确了这两点来看看代码
(middle->controller)
(controller->middle)
解决方法很简单,我们不要让controller引用middle的id就可以了,把id放在controller层
现在就变成了controlle不引用,middle->controller,不存在循环引用。
结果:
refresh token
以前在java项目也用过,双令牌,即时间长的令牌如果存在,会给用户一个新的access token,如果不存在那只能重新登录,然后去申请双令牌了。
如何限制一台电脑只能登录一个账号
首先,我们知道当我们发布token的时候,一定一定,账号登录是没有问题的,那么我们如何来保存token的唯一性呢?(一个账号只能登录一个设备,换句话说同一时间只能存在一个有效token)
我们可以将token存到redis,并且存储userId和token对应,由于redis的string,key是一定要唯一的,我们可以拿userId为key,token为val,在调用接口进行认证的时候,先取出token中的userId,再去redis里面取出token,进行对比,如果token不一致说明该账号有两个不同机器正在登陆,以redis为准即可。
为go项目编写make file
| Golang
|总阅读量:11560次
借助Makefile
我们在编译过程中不再需要每次手动输入编译的命令和编译的参数,可以极大简化项目编译过程。
make介绍
make
是一个构建自动化工具,会在当前目录下寻找Makefile
或makefile
文件。如果存在相应的文件,它就会依据其中定义好的规则完成构建任务。
Makefile介绍
我们可以把Makefile
简单理解为它定义了一个项目文件的编译规则。借助Makefile
我们在编译过程中不再需要每次手动输入编译的命令和编译的参数,可以极大简化项目编译过程。同时使用Makefile
也可以在项目中确定具体的编译规则和流程,很多开源项目中都会定义Makefile
文件。
本文不会详细介绍Makefile
的各种规则,只会给出Go项目中常用的Makefile
示例。关于Makefile
的详细内容推荐阅读Makefile教程。
规则概述
Makefile
由多条规则组成,每条规则主要由两个部分组成,分别是依赖的关系和执行的命令。
其结构如下所示:
1
2
3
4
|
[target] ... : [prerequisites] ...
<tab>[command]
...
...
|
其中:
- targets:规则的目标
- prerequisites:可选的要生成 targets 需要的文件或者是目标。
- command:make 需要执行的命令(任意的 shell 命令)。可以有多条命令,每一条命令占一行。
举个例子:
1
2
|
build:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o xx
|
示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
.PHONY: all build run gotool clean help
BINARY="bluebell"
all: gotool build
build:
CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o ${BINARY}
run:
@go run ./
gotool:
go fmt ./
go vet ./
clean:
@if [ -f ${BINARY} ] ; then rm ${BINARY} ; fi
help:
@echo "make - 格式化 Go 代码, 并编译生成二进制文件"
@echo "make build - 编译 Go 代码, 生成二进制文件"
@echo "make run - 直接运行 Go 代码"
@echo "make clean - 移除二进制文件和 vim swap files"
@echo "make gotool - 运行 Go 工具 'fmt' and 'vet'"
|
其中:
BINARY="bluebell"
是定义变量。
.PHONY
用来定义伪目标。不创建目标文件,而是去执行这个目标下面的命令。
air(大大提高开发效率)
go get -u github.com/cosmtrek/air
Go内存对齐
这两个结构体虽然属性一样,但是最后new出来一个对象,占用内存是不一样的,因为内存对齐。
1
2
3
4
5
6
7
8
9
10
11
|
type s1 struct{
i1 int8
i2 int64
i3 int32
}
type s1 struct{
i1 int8
i3 int32
i2 int64
}
|
对查询结果的优化
当我第一次编写接口,发现有一个问题,当我返回帖子的细节的时候,由于数据库及采用的外键关联,所以返回的数据并不是作者名字,而是作者ID,这样非常不友好,我想把关联信息也一并返回。
遇到的问题:返回整个结果的时候发生空指针。。(粗心了)
可以看到p是没有初始化的,所以我们new一下就好了。查出来就没啥毛病
这里会发现,数据其实有点挤在一起,看着不舒服,我们想把社区的放一边,作者放一边,怎么办呢?
改造一下
舒服多了
分页查询
从前端后的页码和size即可。
1
2
3
4
5
6
|
func GetPostList(offset int64, size int64) (list []*model.Post, err error) {
list = make([]*model.Post, 0, size)
sqlStr := "select post_id,author_id,community_id,title,content,create_time from post limit ?,?"
err = db.Select(&list, sqlStr, offset, size)
return list, err
}
|
后端用select查询数据库封装到切片里面返回就好了。
其实我们可能会有很多需要分页的项目,所以我们为了避免重复书写代码,可以将获得页码的代码给抽出来。放到Request.go中
解决id发送给前端失真的问题
由于go后端这边的id用的为int64类型,数据大小在2^-63 + 1~2^63 - 1,但是JS的Number数据类型的数据大小会在2^-53 + 1~2^53 - 1.会导致数据精度有问题。
我们可以将这个值变成string类型传输即可。
但是由于前端会传给我们后端一个字符串,如果结构体仍然用以前的结构体,如下
1
2
3
4
|
type Community struct {
ID int64 `json:"community_id"`
Name string `json:"community_name"`
}
|
是无法调用Unmarshal的,因为前端发的是string,go这边用的int64,我们稍微修改一下就可以了
1
2
3
4
|
type Community struct {
ID int64 `json:"community_id,string"`
Name string `json:"community_name"`
}
|
更多golang json的用法:https://www.liwenzhou.com/posts/Go/json_tricks_in_go/
投票功能
我们在写投票功能之前,首先明确我们想实现什么功能,首先得记录每个用户投的是赞成票还是反对票,其次我们想实现一个关于时间和,热度的排行榜。
我们可以利用redis中的Zset来实现,对于记录用户我们可以为每一个帖子做一个zset,key是某一个key+帖子ID,里面存储的是userId对应是否投票(-1:投反对,1投赞成)
对于排行榜就很简单,直接对应的帖子给上响应的权值即可。
我们如何计算热度呢,我们这里比较简单,赞成票就+432,反对就-432。
将可能发生的情况穷举出来:
db:1 request:1 0 abs = 0
db:1 request:0 -432 abs = 1
db:1 request:-1 -432 * 2 abs = 2
db:0 request:1 432 abs = 1
db:0 request:0 0 abs = 0
db:0 request:-1 -432 abs = 1
db:-1 request:1 + 432* 2 abs = 2
db:-1 request:-1 0 abs = 0
db:-1 request:0 + 432 abs = 0
我们可以通过abs来判断当前应该改变多少积分
可以看到当abs=0积分不变,abs=1积分变化432(+或者-),当abs=2积分变化2*432(+或者-)
又可以发现当request > 0 就是+,否则为-
现在逻辑就很清晰。
那么问题来了,一开始我们好像并没有排行榜所需的那两个zset,我们需要创建zset
那再看,其实这应该是一个原子性的操作,不能一个成功一个失败,所以应该放到事务里面,上面曾讲过事务pipeline,太完美了。
会看到dao层的投票功能,里面会对用户投票记录和分数记录,这两个也是应该是原子性的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
func VoteForPost(userId, postId string, flag float64) (err error) {
score, err := rdb.ZScore(getKey(TimeZSet), postId).Result()
if err != nil {
return err
}
if float64(time.Now().Unix())-score > weekUnix {
return ErrorTimeExpire
}
//得到该用户在redis中是否投票过
score = rdb.ZScore(getKey(FlagZSet+postId), userId).Val()
dif := math.Abs(score - flag)
pipeline := rdb.TxPipeline()
if score < flag {
pipeline.IncrBy(getKey(ScoreZSet), int64(dif*1*perFlagScore))
} else {
pipeline.IncrBy(getKey(ScoreZSet), int64(dif*-1*perFlagScore))
}
//更新记录用户投票的zset
if flag == 0 {
pipeline.ZRem(getKey(FlagZSet+postId), userId)
} else {
pipeline.ZAdd(getKey(FlagZSet+postId), redis.Z{
Score: flag,
Member: userId,
})
}
_, err = pipeline.Exec()
return
}
|
swagger接口文档
如何得到:
go get -u github.com/swaggo/swag/cmd/swag
想要使用gin-swagger
为你的代码自动生成接口文档,一般需要下面三个步骤:
- 按照swagger要求给接口代码添加声明式注释,具体参照声明式注释格式。
- 使用swag工具扫描代码自动生成API接口文档数据
- 使用gin-swagger渲染在线接口文档页面
第一步:添加注释
在程序入口main函数上以注释的方式写下项目相关介绍信息。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
package main
// @title 这里写标题
// @version 1.0
// @description 这里写描述信息
// @termsOfService http://swagger.io/terms/
// @contact.name 这里写联系人信息
// @contact.url http://www.swagger.io/support
// @contact.email support@swagger.io
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @host 这里写接口服务的host
// @BasePath 这里写base path
func main() {
r := gin.New()
// liwenzhou.com ...
r.Run()
}
|
在你代码中处理请求的接口函数(通常位于controller层)按如下方式写上注释:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
|
// GetPostListHandler2 升级版帖子列表接口
// @Summary 升级版帖子列表接口
// @Description 可按社区按时间或分数排序查询帖子列表接口
// @Tags 帖子相关接口
// @Accept application/json
// @Produce application/json
// @Param Authorization header string false "Bearer 用户令牌"
// @Param object query models.ParamPostList false "查询参数"
// @Security ApiKeyAuth
// @Success 200 {object} _ResponsePostList
// @Router /posts2 [get]
func GetPostListHandler2(c *gin.Context) {
// GET请求参数(query string):/api/v1/posts2?page=1&size=10&order=time
// 初始化结构体时指定初始参数
p := &models.ParamPostList{
Page: 1,
Size: 10,
Order: models.OrderTime,
}
if err := c.ShouldBindQuery(p); err != nil {
zap.L().Error("GetPostListHandler2 with invalid params", zap.Error(err))
ResponseError(c, CodeInvalidParam)
return
}
data, err := logic.GetPostListNew(p)
// 获取数据
if err != nil {
zap.L().Error("logic.GetPostList() failed", zap.Error(err))
ResponseError(c, CodeServerBusy)
return
}
ResponseSuccess(c, data)
// 返回响应
}
|
上面注释中参数类型使用了object
,models.ParamPostList
具体定义如下:
1
2
3
4
5
6
7
8
9
|
// bluebell/models/params.go
// ParamPostList 获取帖子列表query string参数
type ParamPostList struct {
CommunityID int64 `json:"community_id" form:"community_id"` // 可以为空
Page int64 `json:"page" form:"page" example:"1"` // 页码
Size int64 `json:"size" form:"size" example:"10"` // 每页数据量
Order string `json:"order" form:"order" example:"score"` // 排序依据
}
|
响应数据类型也使用的object
,我个人习惯在controller层专门定义一个docs_models.go
文件来存储文档中使用的响应数据model。
1
2
3
4
5
6
7
8
|
// bluebell/controller/docs_models.go
// _ResponsePostList 帖子列表接口响应数据
type _ResponsePostList struct {
Code ResCode `json:"code"` // 业务响应状态码
Message string `json:"message"` // 提示信息
Data []*models.ApiPostDetail `json:"data"` // 数据
}
|
第二步:生成接口文档数据
编写完注释后,使用以下命令安装swag工具:
1
|
go get -u github.com/swaggo/swag/cmd/swag
|
在项目根目录执行以下命令,使用swag工具生成接口文档数据。
执行完上述命令后,如果你写的注释格式没问题,此时你的项目根目录下会多出一个docs
文件夹。
1
2
3
4
|
./docs
├── docs.go
├── swagger.json
└── swagger.yaml
|
第三步:引入gin-swagger渲染文档数据
然后在项目代码中注册路由的地方按如下方式引入gin-swagger
相关内容:
1
2
3
4
5
6
7
8
9
10
|
import (
// liwenzhou.com ...
_ "bluebell/docs" // 千万不要忘了导入把你上一步生成的docs
gs "github.com/swaggo/gin-swagger"
"github.com/swaggo/gin-swagger/swaggerFiles"
"github.com/gin-gonic/gin"
)
|
注册swagger api相关路由
1
|
r.GET("/swagger/*any", gs.WrapHandler(swaggerFiles.Handler))
|
把你的项目程序运行起来,打开浏览器访问http://localhost:8080/swagger/index.html就能看到Swagger 2.0 Api文档了。
gin-swagger
同时还提供了DisablingWrapHandler
函数,方便我们通过设置某些环境变量来禁用Swagger。例如:
1
|
r.GET("/swagger/*any", gs.DisablingWrapHandler(swaggerFiles.Handler, "NAME_OF_ENV_VARIABLE"))
|
此时如果将环境变量NAME_OF_ENV_VARIABLE
设置为任意值,则/swagger/*any
将返回404响应,就像未指定路由时一样。
这是Go语言单元测试从零到溜系列教程的第0篇,主要讲解在Go语言中如何编写单元测试以及介绍了表格驱动测试、回归测试和单元测试中常用的断言工具。
Go语言单元测试从零到溜系列共7篇,本文是第0篇,介绍了Go语言单元测试的基础内容。本篇部分内容基于我之前写过的那篇《Go语言基础之单元测试》,内容略有删改。特别是由于篇幅限制移除了基准测试相关内容,想了解基准测试/性能测试的同学可以点击上文链接查看。
《Go单测从零到溜系列》的示例代码已上传至Github,点击👉🏻https://github.com/Q1mi/golang-unit-test-demo 查看完整源代码。
Go语言单元测试
go test工具
Go语言中的测试依赖go test
命令。编写测试代码和编写普通的Go代码过程是类似的,并不需要学习新的语法、规则或工具。
go test命令是一个按照一定约定和组织的测试代码的驱动程序。在包目录内,所有以_test.go
为后缀名的源代码文件都是go test
测试的一部分,不会被go build
编译到最终的可执行文件中。
在*_test.go
文件中有三种类型的函数,单元测试函数、基准测试函数和示例函数。
类型 |
格式 |
作用 |
测试函数 |
函数名前缀为Test |
测试程序的一些逻辑行为是否正确 |
基准函数 |
函数名前缀为Benchmark |
测试函数的性能 |
示例函数 |
函数名前缀为Example |
为文档提供示例文档 |
go test
命令会遍历所有的*_test.go
文件中符合上述命名规则的函数,然后生成一个临时的main包用于调用相应的测试函数,然后构建并运行、报告测试结果,最后清理测试中生成的临时文件。
单元测试函数
格式
每个测试函数必须导入testing
包,测试函数的基本格式(签名)如下:
1
2
3
|
func TestName(t *testing.T){
// ...
}
|
测试函数的名字必须以Test
开头,可选的后缀名必须以大写字母开头,举几个例子:
1
2
3
|
func TestAdd(t *testing.T){ ... }
func TestSum(t *testing.T){ ... }
func TestLog(t *testing.T){ ... }
|
其中参数t
用于报告测试失败和附加的日志信息。 testing.T
的拥有的方法如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
func (c *T) Cleanup(func())
func (c *T) Error(args ...interface{})
func (c *T) Errorf(format string, args ...interface{})
func (c *T) Fail()
func (c *T) FailNow()
func (c *T) Failed() bool
func (c *T) Fatal(args ...interface{})
func (c *T) Fatalf(format string, args ...interface{})
func (c *T) Helper()
func (c *T) Log(args ...interface{})
func (c *T) Logf(format string, args ...interface{})
func (c *T) Name() string
func (c *T) Skip(args ...interface{})
func (c *T) SkipNow()
func (c *T) Skipf(format string, args ...interface{})
func (c *T) Skipped() bool
func (c *T) TempDir() string
|
单元测试示例
就像细胞是构成我们身体的基本单位,一个软件程序也是由很多单元组件构成的。单元组件可以是函数、结构体、方法和最终用户可能依赖的任意东西。总之我们需要确保这些组件是能够正常运行的。单元测试是一些利用各种方法测试单元组件的程序,它会将结果与预期输出进行比较。
接下来,我们在base_demo
包中定义了一个Split
函数,具体实现如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
|
// base_demo/split.go
package base_demo
import "strings"
// Split 把字符串s按照给定的分隔符sep进行分割返回字符串切片
func Split(s, sep string) (result []string) {
i := strings.Index(s, sep)
for i > -1 {
result = append(result, s[:i])
s = s[i+1:]
i = strings.Index(s, sep)
}
result = append(result, s)
return
}
|
在当前目录下,我们创建一个split_test.go
的测试文件,并定义一个测试函数如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
// split/split_test.go
package split
import (
"reflect"
"testing"
)
func TestSplit(t *testing.T) { // 测试函数名必须以Test开头,必须接收一个*testing.T类型参数
got := Split("a:b:c", ":") // 程序输出的结果
want := []string{"a", "b", "c"} // 期望的结果
if !reflect.DeepEqual(want, got) { // 因为slice不能比较直接,借助反射包中的方法比较
t.Errorf("expected:%v, got:%v", want, got) // 测试失败输出错误提示
}
}
|
此时split
这个包中的文件如下:
1
2
3
4
|
❯ ls -l
total 16
-rw-r--r-- 1 liwenzhou staff 408 4 29 15:50 split.go
-rw-r--r-- 1 liwenzhou staff 466 4 29 16:04 split_test.go
|
在当前路径下执行go test
命令,可以看到输出结果如下:
1
2
3
|
❯ go test
PASS
ok golang-unit-test-demo/base_demo 0.005s
|
go test -v
一个测试用例有点单薄,我们再编写一个测试使用多个字符切割字符串的例子,在split_test.go
中添加如下测试函数:
1
2
3
4
5
6
7
|
func TestSplitWithComplexSep(t *testing.T) {
got := Split("abcd", "bc")
want := []string{"a", "d"}
if !reflect.DeepEqual(want, got) {
t.Errorf("expected:%v, got:%v", want, got)
}
}
|
现在我们有多个测试用例了,为了能更好的在输出结果中看到每个测试用例的执行情况,我们可以为go test
命令添加-v
参数,让它输出完整的测试结果。
1
2
3
4
5
6
7
8
9
|
❯ go test -v
=== RUN TestSplit
--- PASS: TestSplit (0.00s)
=== RUN TestSplitWithComplexSep
split_test.go:20: expected:[a d], got:[a cd]
--- FAIL: TestSplitWithComplexSep (0.00s)
FAIL
exit status 1
FAIL golang-unit-test-demo/base_demo 0.009s
|
从上面的输出结果我们能清楚的看到是TestSplitWithComplexSep
这个测试用例没有测试通过。
go test -run
单元测试的结果表明split
函数的实现并不可靠,没有考虑到传入的sep参数是多个字符的情况,下面我们来修复下这个Bug:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
|
package base_demo
import "strings"
// Split 把字符串s按照给定的分隔符sep进行分割返回字符串切片
func Split(s, sep string) (result []string) {
i := strings.Index(s, sep)
for i > -1 {
result = append(result, s[:i])
s = s[i+len(sep):] // 这里使用len(sep)获取sep的长度
i = strings.Index(s, sep)
}
result = append(result, s)
return
}
|
在执行go test
命令的时候可以添加-run
参数,它对应一个正则表达式,只有函数名匹配上的测试函数才会被go test
命令执行。
例如通过给go test
添加-run=Sep
参数来告诉它本次测试只运行TestSplitWithComplexSep
这个测试用例:
1
2
3
4
5
|
❯ go test -run=Sep -v
=== RUN TestSplitWithComplexSep
--- PASS: TestSplitWithComplexSep (0.00s)
PASS
ok golang-unit-test-demo/base_demo 0.010s
|
最终的测试结果表情我们成功修复了之前的Bug。
回归测试
我们修改了代码之后仅仅执行那些失败的测试用例或新引入的测试用例是错误且危险的,正确的做法应该是完整运行所有的测试用例,保证不会因为修改代码而引入新的问题。
1
2
3
4
5
6
7
|
❯ go test -v
=== RUN TestSplit
--- PASS: TestSplit (0.00s)
=== RUN TestSplitWithComplexSep
--- PASS: TestSplitWithComplexSep (0.00s)
PASS
ok golang-unit-test-demo/base_demo 0.011s
|
测试结果表明我们的单元测试全部通过。
通过这个示例我们可以看到,有了单元测试就能够在代码改动后快速进行回归测试,极大地提高开发效率并保证代码的质量。
跳过某些测试用例
为了节省时间支持在单元测试时跳过某些耗时的测试用例。
1
2
3
4
5
6
|
func TestTimeConsuming(t *testing.T) {
if testing.Short() {
t.Skip("short模式下会跳过该测试用例")
}
...
}
|
当执行go test -short
时就不会执行上面的TestTimeConsuming
测试用例。
子测试
在上面的示例中我们为每一个测试数据编写了一个测试函数,而通常单元测试中需要多组测试数据保证测试的效果。Go1.7+中新增了子测试,支持在测试函数中使用t.Run
执行一组测试用例,这样就不需要为不同的测试数据定义多个测试函数了。
1
2
3
4
5
|
func TestXXX(t *testing.T){
t.Run("case1", func(t *testing.T){...})
t.Run("case2", func(t *testing.T){...})
t.Run("case3", func(t *testing.T){...})
}
|
表格驱动测试
介绍
表格驱动测试不是工具、包或其他任何东西,它只是编写更清晰测试的一种方式和视角。
编写好的测试并非易事,但在许多情况下,表格驱动测试可以涵盖很多方面:表格里的每一个条目都是一个完整的测试用例,包含输入和预期结果,有时还包含测试名称等附加信息,以使测试输出易于阅读。
使用表格驱动测试能够很方便的维护多个测试用例,避免在编写单元测试时频繁的复制粘贴。
表格驱动测试的步骤通常是定义一个测试用例表格,然后遍历表格,并使用t.Run
对每个条目执行必要的测试。
示例
官方标准库中有很多表格驱动测试的示例,例如fmt包中便有如下测试代码:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
var flagtests = []struct {
in string
out string
}{
{"%a", "[%a]"},
{"%-a", "[%-a]"},
{"%+a", "[%+a]"},
{"%#a", "[%#a]"},
{"% a", "[% a]"},
{"%0a", "[%0a]"},
{"%1.2a", "[%1.2a]"},
{"%-1.2a", "[%-1.2a]"},
{"%+1.2a", "[%+1.2a]"},
{"%-+1.2a", "[%+-1.2a]"},
{"%-+1.2abc", "[%+-1.2a]bc"},
{"%-1.2abc", "[%-1.2a]bc"},
}
func TestFlagParser(t *testing.T) {
var flagprinter flagPrinter
for _, tt := range flagtests {
t.Run(tt.in, func(t *testing.T) {
s := Sprintf(tt.in, &flagprinter)
if s != tt.out {
t.Errorf("got %q, want %q", s, tt.out)
}
})
}
}
|
通常表格是匿名结构体切片,可以定义结构体或使用已经存在的结构进行结构体数组声明。name属性用来描述特定的测试用例。
接下来让我们试着自己编写表格驱动测试:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
func TestSplitAll(t *testing.T) {
// 定义测试表格
// 这里使用匿名结构体定义了若干个测试用例
// 并且为每个测试用例设置了一个名称
tests := []struct {
name string
input string
sep string
want []string
}{
{"base case", "a:b:c", ":", []string{"a", "b", "c"}},
{"wrong sep", "a:b:c", ",", []string{"a:b:c"}},
{"more sep", "abcd", "bc", []string{"a", "d"}},
{"leading sep", "沙河有沙又有河", "沙", []string{"", "河有", "又有河"}},
}
// 遍历测试用例
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { // 使用t.Run()执行子测试
got := Split(tt.input, tt.sep)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expected:%#v, got:%#v", tt.want, got)
}
})
}
}
|
在终端执行go test -v
,会得到如下测试输出结果:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
|
❯ go test -v
=== RUN TestSplit
--- PASS: TestSplit (0.00s)
=== RUN TestSplitWithComplexSep
--- PASS: TestSplitWithComplexSep (0.00s)
=== RUN TestSplitAll
=== RUN TestSplitAll/base_case
=== RUN TestSplitAll/wrong_sep
=== RUN TestSplitAll/more_sep
=== RUN TestSplitAll/leading_sep
--- PASS: TestSplitAll (0.00s)
--- PASS: TestSplitAll/base_case (0.00s)
--- PASS: TestSplitAll/wrong_sep (0.00s)
--- PASS: TestSplitAll/more_sep (0.00s)
--- PASS: TestSplitAll/leading_sep (0.00s)
PASS
ok golang-unit-test-demo/base_demo 0.010s
|
并行测试
表格驱动测试中通常会定义比较多的测试用例,而Go语言又天生支持并发,所以很容易发挥自身并发优势将表格驱动测试并行化。 想要在单元测试过程中使用并行测试,可以像下面的代码示例中那样通过添加t.Parallel()
来实现。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
|
func TestSplitAll(t *testing.T) {
t.Parallel() // 将 TLog 标记为能够与其他测试并行运行
// 定义测试表格
// 这里使用匿名结构体定义了若干个测试用例
// 并且为每个测试用例设置了一个名称
tests := []struct {
name string
input string
sep string
want []string
}{
{"base case", "a:b:c", ":", []string{"a", "b", "c"}},
{"wrong sep", "a:b:c", ",", []string{"a:b:c"}},
{"more sep", "abcd", "bc", []string{"a", "d"}},
{"leading sep", "沙河有沙又有河", "沙", []string{"", "河有", "又有河"}},
}
// 遍历测试用例
for _, tt := range tests {
tt := tt // 注意这里重新声明tt变量(避免多个goroutine中使用了相同的变量)
t.Run(tt.name, func(t *testing.T) { // 使用t.Run()执行子测试
t.Parallel() // 将每个测试用例标记为能够彼此并行运行
got := Split(tt.input, tt.sep)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expected:%#v, got:%#v", tt.want, got)
}
})
}
}
|
这样我们执行go test -v
的时候就会看到每个测试用例并不是按照我们定义的顺序执行,而是互相并行了。
使用工具生成测试代码
社区里有很多自动生成表格驱动测试函数的工具,比如gotests等,很多编辑器如Goland也支持快速生成测试文件。这里简单演示一下gotests
的使用。
安装
1
|
go get -u github.com/cweill/gotests/...
|
执行
1
|
gotests -all -w split.go
|
上面的命令表示,为split.go
文件的所有函数生成测试代码至split_test.go
文件(目录下如果事先存在这个文件就不再生成)。
生成的测试代码大致如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
|
package base_demo
import (
"reflect"
"testing"
)
func TestSplit(t *testing.T) {
type args struct {
s string
sep string
}
tests := []struct {
name string
args args
wantResult []string
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if gotResult := Split(tt.args.s, tt.args.sep); !reflect.DeepEqual(gotResult, tt.wantResult) {
t.Errorf("Split() = %v, want %v", gotResult, tt.wantResult)
}
})
}
}
|
代码格式与我们上面的类似,只需要在TODO位置添加我们的测试逻辑就可以了。
测试覆盖率
测试覆盖率是指代码被测试套件覆盖的百分比。通常我们使用的都是语句的覆盖率,也就是在测试中至少被运行一次的代码占总代码的比例。在公司内部一般会要求测试覆盖率达到80%左右。
Go提供内置功能来检查你的代码覆盖率,即使用go test -cover
来查看测试覆盖率。
1
2
3
4
|
❯ go test -cover
PASS
coverage: 100.0% of statements
ok golang-unit-test-demo/base_demo 0.009s
|
从上面的结果可以看到我们的测试用例覆盖了100%的代码。
Go还提供了一个额外的-coverprofile
参数,用来将覆盖率相关的记录信息输出到一个文件。例如:
1
2
3
4
|
❯ go test -cover -coverprofile=c.out
PASS
coverage: 100.0% of statements
ok golang-unit-test-demo/base_demo 0.009s
|
上面的命令会将覆盖率相关的信息输出到当前文件夹下面的c.out
文件中。
1
2
3
4
5
|
❯ tree .
.
├── c.out
├── split.go
└── split_test.go
|
然后我们执行go tool cover -html=c.out
,使用cover
工具来处理生成的记录信息,该命令会打开本地的浏览器窗口生成一个HTML报告。上图中每个用绿色标记的语句块表示被覆盖了,而红色的表示没有被覆盖。
testify/assert
testify是一个社区非常流行的Go单元测试工具包,其中使用最多的功能就是它提供的断言工具——testify/assert
或testify/require
。
安装
1
|
go get github.com/stretchr/testify
|
使用示例
我们在写单元测试的时候,通常需要使用断言来校验测试结果,但是由于Go语言官方没有提供断言,所以我们会写出很多的if...else...
语句。而testify/assert
为我们提供了很多常用的断言函数,并且能够输出友好、易于阅读的错误描述信息。
比如我们之前在TestSplit
测试函数中就使用了reflect.DeepEqual
来判断期望结果与实际结果是否一致。
1
2
3
4
5
6
|
t.Run(tt.name, func(t *testing.T) { // 使用t.Run()执行子测试
got := Split(tt.input, tt.sep)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("expected:%#v, got:%#v", tt.want, got)
}
})
|
使用testify/assert
之后就能将上述判断过程简化如下:
1
2
3
4
|
t.Run(tt.name, func(t *testing.T) { // 使用t.Run()执行子测试
got := Split(tt.input, tt.sep)
assert.Equal(t, got, tt.want) // 使用assert提供的断言函数
})
|
当我们有多个断言语句时,还可以使用assert := assert.New(t)
创建一个assert对象,它拥有前面所有的断言方法,只是不需要再传入Testing.T
参数了。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
func TestSomething(t *testing.T) {
assert := assert.New(t)
// assert equality
assert.Equal(123, 123, "they should be equal")
// assert inequality
assert.NotEqual(123, 456, "they should not be equal")
// assert for nil (good for errors)
assert.Nil(object)
// assert for not nil (good when you expect something)
if assert.NotNil(object) {
// now we know that object isn't nil, we are safe to make
// further assertions without causing any errors
assert.Equal("Something", object.Value)
}
}
|
testify/assert
提供了非常多的断言函数,这里没办法一一列举出来,大家可以查看官方文档了解。
testify/require
拥有testify/assert
所有断言函数,它们的唯一区别就是——testify/require
遇到失败的用例会立即终止本次测试。
此外,testify
包还提供了mock、http等其他测试工具,篇幅所限这里就不详细介绍了,有兴趣的同学可以自己了解一下。
常用限流策略——漏桶与令牌桶介绍
发布于2020/09/13 ,更新于2020/09/13 18:41:02
| Golang
|总阅读量:8321次
限流又称为流量控制(流控),通常是指限制到达系统的并发请求数,本文列举了常见的限流策略,并以gin框架为例演示了如何为项目添加限流组件。
限流
限流又称为流量控制(流控),通常是指限制到达系统的并发请求数。
我们生活中也会经常遇到限流的场景,比如:某景区限制每日进入景区的游客数量为8万人;沙河地铁站早高峰通过站外排队逐一放行的方式限制同一时间进入车站的旅客数量等。
限流虽然会影响部分用户的使用体验,但是却能在一定程度上报障系统的稳定性,不至于崩溃(大家都没了用户体验)。
而互联网上类似需要限流的业务场景也有很多,比如电商系统的秒杀、微博上突发热点新闻、双十一购物节、12306抢票等等。这些场景下的用户请求量通常会激增,远远超过平时正常的请求量,此时如果不加任何限制很容易就会将后端服务打垮,影响服务的稳定性。
此外,一些厂商公开的API服务通常也会限制用户的请求次数,比如百度地图开放平台等会根据用户的付费情况来限制用户的请求数等。
常用的限流策略
漏桶
漏桶法限流很好理解,假设我们有一个水桶按固定的速率向下方滴落一滴水,无论有多少请求,请求的速率有多大,都按照固定的速率流出,对应到系统中就是按照固定的速率处理请求。
漏桶法的关键点在于漏桶始终按照固定的速率运行,但是它并不能很好的处理有大量突发请求的场景,毕竟在某些场景下我们可能需要提高系统的处理效率,而不是一味的按照固定速率处理请求。
关于漏桶的实现,uber团队有一个开源的github.com/uber-go/ratelimit库。 这个库的使用方法比较简单,Take()
方法会返回漏桶下一次滴水的时间。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
|
import (
"fmt"
"time"
"go.uber.org/ratelimit"
)
func main() {
rl := ratelimit.New(100) // per second
prev := time.Now()
for i := 0; i < 10; i++ {
now := rl.Take()
fmt.Println(i, now.Sub(prev))
prev = now
}
// Output:
// 0 0
// 1 10ms
// 2 10ms
// 3 10ms
// 4 10ms
// 5 10ms
// 6 10ms
// 7 10ms
// 8 10ms
// 9 10ms
}
|
它的源码实现也比较简单,这里大致说一下关键的地方,有兴趣的同学可以自己去看一下完整的源码。
限制器是一个接口类型,其要求实现一个Take()
方法:
1
2
3
4
|
type Limiter interface {
// Take方法应该阻塞已确保满足 RPS
Take() time.Time
}
|
实现限制器接口的结构体定义如下,这里可以重点留意下maxSlack
字段,它在后面的Take()
方法中的处理。
1
2
3
4
5
6
7
8
|
type limiter struct {
sync.Mutex // 锁
last time.Time // 上一次的时刻
sleepFor time.Duration // 需要等待的时间
perRequest time.Duration // 每次的时间间隔
maxSlack time.Duration // 最大的富余量
clock Clock // 时钟
}
|
limiter
结构体实现Limiter
接口的Take()
方法内容如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
|
// Take 会阻塞确保两次请求之间的时间走完
// Take 调用平均数为 time.Second/rate.
func (t *limiter) Take() time.Time {
t.Lock()
defer t.Unlock()
now := t.clock.Now()
// 如果是第一次请求就直接放行
if t.last.IsZero() {
t.last = now
return t.last
}
// sleepFor 根据 perRequest 和上一次请求的时刻计算应该sleep的时间
// 由于每次请求间隔的时间可能会超过perRequest, 所以这个数字可能为负数,并在多个请求之间累加
t.sleepFor += t.perRequest - now.Sub(t.last)
// 我们不应该让sleepFor负的太多,因为这意味着一个服务在短时间内慢了很多随后会得到更高的RPS。
if t.sleepFor < t.maxSlack {
t.sleepFor = t.maxSlack
}
// 如果 sleepFor 是正值那么就 sleep
if t.sleepFor > 0 {
t.clock.Sleep(t.sleepFor)
t.last = now.Add(t.sleepFor)
t.sleepFor = 0
} else {
t.last = now
}
return t.last
}
|
上面的代码根据记录每次请求的间隔时间和上一次请求的时刻来计算当次请求需要阻塞的时间——sleepFor
,这里需要留意的是sleepFor
的值可能为负,在经过间隔时间长的两次访问之后会导致随后大量的请求被放行,所以代码中针对这个场景有专门的优化处理。创建限制器的New()
函数中会为maxSlack
设置初始值,也可以通过WithoutSlack
这个Option取消这个默认值。
1
2
3
4
5
6
7
8
9
10
11
12
13
|
func New(rate int, opts ...Option) Limiter {
l := &limiter{
perRequest: time.Second / time.Duration(rate),
maxSlack: -10 * time.Second / time.Duration(rate),
}
for _, opt := range opts {
opt(l)
}
if l.clock == nil {
l.clock = clock.New()
}
return l
}
|
令牌桶
令牌桶其实和漏桶的原理类似,令牌桶按固定的速率往桶里放入令牌,并且只要能从桶里取出令牌就能通过,令牌桶支持突发流量的快速处理。
对于从桶里取不到令牌的场景,我们可以选择等待也可以直接拒绝并返回。
对于令牌桶的Go语言实现,大家可以参照github.com/juju/ratelimit库。这个库支持多种令牌桶模式,并且使用起来也比较简单。
创建令牌桶的方法:
1
2
3
4
5
6
7
|
// 创建指定填充速率和容量大小的令牌桶
func NewBucket(fillInterval time.Duration, capacity int64) *Bucket
// 创建指定填充速率、容量大小和每次填充的令牌数的令牌桶
func NewBucketWithQuantum(fillInterval time.Duration, capacity, quantum int64) *Bucket
// 创建填充速度为指定速率和容量大小的令牌桶
// NewBucketWithRate(0.1, 200) 表示每秒填充20个令牌
func NewBucketWithRate(rate float64, capacity int64) *Bucket
|
取出令牌的方法如下:
1
2
3
4
5
6
7
8
9
10
|
// 取token(非阻塞)
func (tb *Bucket) Take(count int64) time.Duration
func (tb *Bucket) TakeAvailable(count int64) int64
// 最多等maxWait时间取token
func (tb *Bucket) TakeMaxDuration(count int64, maxWait time.Duration) (time.Duration, bool)
// 取token(阻塞)
func (tb *Bucket) Wait(count int64)
func (tb *Bucket) WaitMaxDuration(count int64, maxWait time.Duration) bool
|
虽说是令牌桶,但是我们没有必要真的去生成令牌放到桶里,我们只需要每次来取令牌的时候计算一下,当前是否有足够的令牌就可以了,具体的计算方式可以总结为下面的公式:
1
|
当前令牌数 = 上一次剩余的令牌数 + (本次取令牌的时刻-上一次取令牌的时刻)/放置令牌的时间间隔 * 每次放置的令牌数
|
github.com/juju/ratelimit这个库中关于令牌数计算的源代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
func (tb *Bucket) currentTick(now time.Time) int64 {
return int64(now.Sub(tb.startTime) / tb.fillInterval)
}
func (tb *Bucket) adjustavailableTokens(tick int64) {
if tb.availableTokens >= tb.capacity {
return
}
tb.availableTokens += (tick - tb.latestTick) * tb.quantum
if tb.availableTokens > tb.capacity {
tb.availableTokens = tb.capacity
}
tb.latestTick = tick
return
}
|
获取令牌的TakeAvailable()
函数关键部分的源代码如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
func (tb *Bucket) takeAvailable(now time.Time, count int64) int64 {
if count <= 0 {
return 0
}
tb.adjustavailableTokens(tb.currentTick(now))
if tb.availableTokens <= 0 {
return 0
}
if count > tb.availableTokens {
count = tb.availableTokens
}
tb.availableTokens -= count
return count
}
|
大家从代码中也可以看到其实令牌桶的实现并没有很复杂。
gin框架中使用限流中间件
在gin框架构建的项目中,我们可以将限流组件定义成中间件。
这里使用令牌桶作为限流策略,编写一个限流中间件如下:
1
2
3
4
5
6
7
8
9
10
11
12
|
func RateLimitMiddleware(fillInterval time.Duration, cap int64) func(c *gin.Context) {
bucket := ratelimit.NewBucket(fillInterval, cap)
return func(c *gin.Context) {
// 如果取不到令牌就中断本次请求返回 rate limit...
if bucket.TakeAvailable(1) < 1 {
c.String(http.StatusOK, "rate limit...")
c.Abort()
return
}
c.Next()
}
}
|
对于该限流中间件的注册位置,我们可以按照不同的限流策略将其注册到不同的位置,例如:
- 如果要对全站限流就可以注册成全局的中间件。
- 如果是某一组路由需要限流,那么就只需将该限流中间件注册到对应的路由组即可。
对于漏桶,会维护一个稳定的访问速率,通过此次访问和上一次访问的时间差,判断该时间差是否大于访问速率,若大于则会sleep若干时间,不然就可以访问。
对于令牌桶,不会真正的实现一个桶,而是会有一个加令牌的速度,例如5个/tick,会记录一个startTime,用now-start就会得到一个时间段,该时间段除以生成令牌的间隔,就会得到当前的tick,用当前的tick减去上一次的tick再乘以速率就是当前令牌桶的令牌数字,和请求的令牌数比较一下就好了。
完