@@ -2,18 +2,85 @@ local utils = require('luatest.utils')
2
2
3
3
local export = {}
4
4
5
+ local function arrange_hooks (group , hooks , hooks_type , test_name )
6
+ local active_hooks_name = ' active_' .. hooks_type
7
+ if test_name then
8
+ -- Don't forget about named hooks
9
+ active_hooks_name = active_hooks_name .. ' _' .. test_name
10
+ hooks = hooks [test_name ]
11
+ end
12
+ active_hooks_name = active_hooks_name .. ' _hooks'
13
+
14
+ -- If hooks have been arranged already then there is nothing to do
15
+ if group [active_hooks_name ] then
16
+ return group [active_hooks_name ]
17
+ end
18
+
19
+ local active_hooks = {}
20
+ for _ , entry in ipairs (hooks ) do
21
+ -- entry == {fn, params}
22
+ local params_is_ok = true
23
+ for name , value in pairs (entry [2 ]) do
24
+ if group .params
25
+ and group .params [name ] ~= value then
26
+ params_is_ok = false
27
+ break
28
+ end
29
+ end
30
+
31
+ if params_is_ok then
32
+ table.insert (active_hooks , entry )
33
+ end
34
+ end
35
+
36
+ -- after_* hooks are called in order of the most specialized ones (more fixed params)
37
+ -- before_* hooks, on contrary, are called starting with the most general one
38
+ local comparator
39
+ if string.find (hooks_type , ' after' ) then
40
+ comparator = function (a , b )
41
+ return (utils .table_len (a [2 ])) > (utils .table_len (b [2 ]))
42
+ end
43
+ else
44
+ comparator = function (a , b )
45
+ return (utils .table_len (a [2 ])) < (utils .table_len (b [2 ]))
46
+ end
47
+ end
48
+ table.sort (active_hooks , comparator )
49
+
50
+ group [active_hooks_name ] = {}
51
+ for _ , entry in ipairs (active_hooks ) do
52
+ table.insert (group [active_hooks_name ], entry [1 ])
53
+ end
54
+
55
+ return group [active_hooks_name ]
56
+ end
57
+
5
58
local function define_hooks (object , hooks_type )
6
59
local hooks = {}
7
60
object [hooks_type .. ' _hooks' ] = hooks
8
61
9
- object [hooks_type ] = function (fn )
10
- table.insert (hooks , fn )
62
+ object [hooks_type ] = function (...)
63
+ local params , fn = ...
64
+ if fn == nil then
65
+ fn = params
66
+ params = {}
67
+ end
68
+
69
+ assert (type (params ) == ' table' ,
70
+ string.format (' Params should be table, got %s' , type (params )))
71
+ assert (type (fn ) == ' function' ,
72
+ string.format (' Hook should be function, got %s' , type (fn )))
73
+
74
+ params = params or {}
75
+ table.insert (hooks , {fn , params })
11
76
end
12
77
object [' _original_' .. hooks_type ] = object [hooks_type ] -- for leagacy hooks support
13
78
14
79
object [' run_' .. hooks_type ] = function ()
15
- for _ , fn in ipairs (hooks ) do
16
- fn ()
80
+ local stored_hooks = object [hooks_type .. ' _hooks' ]
81
+ local active_hooks = arrange_hooks (object , stored_hooks , hooks_type )
82
+ for _ , hook in ipairs (active_hooks ) do
83
+ hook (object )
17
84
end
18
85
end
19
86
end
@@ -22,22 +89,48 @@ local function define_named_hooks(object, hooks_type)
22
89
local hooks = {}
23
90
object [hooks_type .. ' _hooks' ] = hooks
24
91
25
- object [hooks_type ] = function (test_name , fn )
92
+ object [hooks_type ] = function (...)
93
+ local test_name , params , fn = ...
94
+ if fn == nil then
95
+ fn = params
96
+ params = {}
97
+ end
98
+
99
+ assert (type (test_name ) == ' string' ,
100
+ string.format (' Test name should be string, got %s' , type (test_name )))
101
+ assert (type (params ) == ' table' ,
102
+ string.format (' Params should be table, got %s' , type (params )))
103
+ assert (type (fn ) == ' function' ,
104
+ string.format (' Hook should be function, got %s' , type (fn )))
105
+
26
106
test_name = object .name .. ' .' .. test_name
107
+ params = params or {}
27
108
if not hooks [test_name ] then
28
109
hooks [test_name ] = {}
29
110
end
30
- table.insert (hooks [test_name ], fn )
111
+ table.insert (hooks [test_name ], { fn , params } )
31
112
end
32
113
33
114
object [' run_' .. hooks_type ] = function (test )
115
+ local stored_hooks = object [hooks_type .. ' _hooks' ]
34
116
local test_name = test .name
35
- if not hooks [test_name ] then
117
+
118
+ -- When parametrized groups are defined named hooks saved by
119
+ -- super group test name. When they are called test name is
120
+ -- specific to the parametrized group. So, it should be
121
+ -- converted back to the super one.
122
+ if object .super_group then
123
+ local test_name_parts , parts_amount = utils .split_test_name (test_name )
124
+ test_name = object .super_group .name .. ' .' .. test_name_parts [parts_amount ]
125
+ end
126
+
127
+ if not stored_hooks [test_name ] then
36
128
return
37
129
end
38
130
39
- for _ , fn in ipairs (hooks [test_name ]) do
40
- fn ()
131
+ local active_hooks = arrange_hooks (object , stored_hooks , hooks_type , test_name )
132
+ for _ , hook in ipairs (active_hooks ) do
133
+ hook (object )
41
134
end
42
135
end
43
136
end
0 commit comments