1
1
local utils = require (' luatest.utils' )
2
+ local comparator = require (' luatest.comparator' )
2
3
3
4
local export = {}
4
5
6
+ local function check_params (required , actual )
7
+ for param_name , param_val in pairs (required ) do
8
+ if not comparator .equals (param_val , actual [param_name ]) then
9
+ return false
10
+ end
11
+ end
12
+
13
+ return true
14
+ end
15
+
5
16
local function define_hooks (object , hooks_type )
6
17
local hooks = {}
7
18
object [hooks_type .. ' _hooks' ] = hooks
8
19
9
- object [hooks_type ] = function (fn )
10
- table.insert (hooks , fn )
20
+ object [hooks_type ] = function (...)
21
+ local params , fn = ...
22
+ if fn == nil then
23
+ fn = params
24
+ params = {}
25
+ end
26
+
27
+ assert (type (params ) == ' table' ,
28
+ string.format (' Params should be table, got %s' , type (params )))
29
+ assert (type (fn ) == ' function' ,
30
+ string.format (' Hook should be function, got %s' , type (fn )))
31
+
32
+ params = params or {}
33
+ table.insert (hooks , {fn , params })
11
34
end
12
35
object [' _original_' .. hooks_type ] = object [hooks_type ] -- for leagacy hooks support
13
36
14
37
object [' run_' .. hooks_type ] = function ()
15
- for _ , fn in ipairs (hooks ) do
16
- fn ()
38
+ local hooks = object [hooks_type .. ' _hooks' ]
39
+ for _ , hook in ipairs (hooks ) do
40
+ if check_params (hook [2 ], object .params ) then
41
+ hook [1 ](object )
42
+ end
17
43
end
18
44
end
19
45
end
@@ -22,22 +48,49 @@ local function define_named_hooks(object, hooks_type)
22
48
local hooks = {}
23
49
object [hooks_type .. ' _hooks' ] = hooks
24
50
25
- object [hooks_type ] = function (test_name , fn )
51
+ object [hooks_type ] = function (...)
52
+ local test_name , params , fn = ...
53
+ if fn == nil then
54
+ fn = params
55
+ params = {}
56
+ end
57
+
58
+ assert (type (test_name ) == ' string' ,
59
+ string.format (' Test name should be string, got %s' , type (test_name )))
60
+ assert (type (params ) == ' table' ,
61
+ string.format (' Params should be table, got %s' , type (params )))
62
+ assert (type (fn ) == ' function' ,
63
+ string.format (' Hook should be function, got %s' , type (fn )))
64
+
26
65
test_name = object .name .. ' .' .. test_name
66
+ params = params or {}
27
67
if not hooks [test_name ] then
28
68
hooks [test_name ] = {}
29
69
end
30
- table.insert (hooks [test_name ], fn )
70
+ table.insert (hooks [test_name ], { fn , params } )
31
71
end
32
72
33
73
object [' run_' .. hooks_type ] = function (test )
74
+ local hooks = object [hooks_type .. ' _hooks' ]
34
75
local test_name = test .name
76
+
77
+ -- When parametrized groups are defined named hooks saved by
78
+ -- super group test name. When they are called test name is
79
+ -- specific to the parametrized group. So, it should be
80
+ -- converted back to the super one.
81
+ if object .super_group then
82
+ local test_name_parts , parts_amount = utils .split_test_name (test_name )
83
+ test_name = object .super_group .name .. ' .' .. test_name_parts [parts_amount ]
84
+ end
85
+
35
86
if not hooks [test_name ] then
36
87
return
37
88
end
38
89
39
- for _ , fn in ipairs (hooks [test_name ]) do
40
- fn ()
90
+ for _ , hook in ipairs (hooks [test_name ]) do
91
+ if check_params (hook [2 ], object .params ) then
92
+ hook [1 ](object )
93
+ end
41
94
end
42
95
end
43
96
end
0 commit comments