Skip to content

Commit fea8bae

Browse files
committed
Implement parametrization
1 parent bac09a6 commit fea8bae

File tree

5 files changed

+444
-11
lines changed

5 files changed

+444
-11
lines changed

luatest/hooks.lua

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,85 @@ local utils = require('luatest.utils')
22

33
local export = {}
44

5+
local function arrange_hooks(group, hooks, hooks_type, test_name)
6+
local active_hooks_name = 'active_' .. hooks_type
7+
if test_name then
8+
-- Don't forget about named hooks
9+
active_hooks_name = active_hooks_name .. '_' .. test_name
10+
hooks = hooks[test_name]
11+
end
12+
active_hooks_name = active_hooks_name .. '_hooks'
13+
14+
-- If hooks have been arranged already then there is nothing to do
15+
if group[active_hooks_name] then
16+
return group[active_hooks_name]
17+
end
18+
19+
local active_hooks = {}
20+
for _, entry in ipairs(hooks) do
21+
-- entry == {fn, params}
22+
local params_is_ok = true
23+
for name, value in pairs(entry[2]) do
24+
if group.params
25+
and group.params[name] ~= value then
26+
params_is_ok = false
27+
break
28+
end
29+
end
30+
31+
if params_is_ok then
32+
table.insert(active_hooks, entry)
33+
end
34+
end
35+
36+
-- after_* hooks are called in order of the most specialized ones (more fixed params)
37+
-- before_* hooks, on contrary, are called starting with the most general one
38+
local comparator
39+
if string.find(hooks_type, 'after') then
40+
comparator = function(a, b)
41+
return (utils.table_len(a[2])) > (utils.table_len(b[2]))
42+
end
43+
else
44+
comparator = function(a, b)
45+
return (utils.table_len(a[2])) < (utils.table_len(b[2]))
46+
end
47+
end
48+
table.sort(active_hooks, comparator)
49+
50+
group[active_hooks_name] = {}
51+
for _, entry in ipairs(active_hooks) do
52+
table.insert(group[active_hooks_name], entry[1])
53+
end
54+
55+
return group[active_hooks_name]
56+
end
57+
558
local function define_hooks(object, hooks_type)
659
local hooks = {}
760
object[hooks_type .. '_hooks'] = hooks
861

9-
object[hooks_type] = function(fn)
10-
table.insert(hooks, fn)
62+
object[hooks_type] = function(...)
63+
local params, fn = ...
64+
if fn == nil then
65+
fn = params
66+
params = {}
67+
end
68+
69+
assert(type(params) == 'table',
70+
string.format('Params should be table, got %s', type(params)))
71+
assert(type(fn) == 'function',
72+
string.format('Hook should be function, got %s', type(fn)))
73+
74+
params = params or {}
75+
table.insert(hooks, {fn, params})
1176
end
1277
object['_original_' .. hooks_type] = object[hooks_type] -- for leagacy hooks support
1378

1479
object['run_' .. hooks_type] = function()
15-
for _, fn in ipairs(hooks) do
16-
fn()
80+
local stored_hooks = object[hooks_type .. '_hooks']
81+
local active_hooks = arrange_hooks(object, stored_hooks, hooks_type)
82+
for _, hook in ipairs(active_hooks) do
83+
hook(object)
1784
end
1885
end
1986
end
@@ -22,22 +89,48 @@ local function define_named_hooks(object, hooks_type)
2289
local hooks = {}
2390
object[hooks_type .. '_hooks'] = hooks
2491

25-
object[hooks_type] = function(test_name, fn)
92+
object[hooks_type] = function(...)
93+
local test_name, params, fn = ...
94+
if fn == nil then
95+
fn = params
96+
params = {}
97+
end
98+
99+
assert(type(test_name) == 'string',
100+
string.format('Test name should be string, got %s', type(test_name)))
101+
assert(type(params) == 'table',
102+
string.format('Params should be table, got %s', type(params)))
103+
assert(type(fn) == 'function',
104+
string.format('Hook should be function, got %s', type(fn)))
105+
26106
test_name = object.name .. '.' .. test_name
107+
params = params or {}
27108
if not hooks[test_name] then
28109
hooks[test_name] = {}
29110
end
30-
table.insert(hooks[test_name], fn)
111+
table.insert(hooks[test_name], {fn, params})
31112
end
32113

33114
object['run_' .. hooks_type] = function(test)
115+
local stored_hooks = object[hooks_type .. '_hooks']
34116
local test_name = test.name
35-
if not hooks[test_name] then
117+
118+
-- When parametrized groups are defined named hooks saved by
119+
-- super group test name. When they are called test name is
120+
-- specific to the parametrized group. So, it should be
121+
-- converted back to the super one.
122+
if object.super_group then
123+
local test_name_parts, parts_amount = utils.split_test_name(test_name)
124+
test_name = object.super_group.name .. '.' .. test_name_parts[parts_amount]
125+
end
126+
127+
if not stored_hooks[test_name] then
36128
return
37129
end
38130

39-
for _, fn in ipairs(hooks[test_name]) do
40-
fn()
131+
local active_hooks = arrange_hooks(object, stored_hooks, hooks_type, test_name)
132+
for _, hook in ipairs(active_hooks) do
133+
hook(object)
41134
end
42135
end
43136
end

luatest/init.lua

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ luatest.Server = require('luatest.server')
1717

1818
local Group = require('luatest.group')
1919
local hooks = require('luatest.hooks')
20+
local parametrizer = require('luatest.parametrizer')
2021

2122
--- Add before suite hook.
2223
--
@@ -36,13 +37,26 @@ luatest.groups = {}
3637
-- @string[opt] name
3738
-- @return Group object
3839
-- @see luatest.group
39-
function luatest.group(name)
40+
function luatest.group(name, params)
4041
local group = Group:new(name)
4142
name = group.name
4243
if luatest.groups[name] then
4344
error('Test group already exists: ' .. name ..'.')
4445
end
45-
luatest.groups[name] = group
46+
47+
if params then
48+
parametrizer.parametrize(group, params)
49+
end
50+
51+
-- Register all parametrized groups
52+
if group.pgroups then
53+
for _, pgroup in ipairs(group.pgroups) do
54+
luatest.groups[pgroup.name] = pgroup
55+
end
56+
else
57+
luatest.groups[name] = group
58+
end
59+
4660
return group
4761
end
4862

luatest/parametrizer.lua

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
local checks = require('checks')
2+
3+
local Group = require('luatest.group')
4+
5+
local export = {}
6+
7+
local function params_combinations(params)
8+
local combinations = {}
9+
10+
local params_names = {}
11+
for name, _ in pairs(params) do
12+
table.insert(params_names, name)
13+
end
14+
table.sort(params_names)
15+
16+
local function combinator(_params, ind, ...)
17+
if ind < 1 then
18+
local combination = {}
19+
for _, entry in ipairs({...}) do
20+
combination[entry[1]] = entry[2]
21+
end
22+
table.insert(combinations, combination)
23+
else
24+
local name = params_names[ind]
25+
for i=1,#(_params[name]) do combinator(_params, ind - 1, {name, _params[name][i]}, ...) end
26+
end
27+
end
28+
29+
combinator(params, #params_names)
30+
return combinations
31+
end
32+
33+
local function get_group_name(params)
34+
local params_names = {}
35+
for name, _ in pairs(params) do
36+
table.insert(params_names, name)
37+
end
38+
table.sort(params_names)
39+
40+
for ind, param_name in ipairs(params_names) do
41+
params_names[ind] = param_name .. '_' .. params[param_name]
42+
end
43+
return table.concat(params_names, ".")
44+
end
45+
46+
local function redefine_hooks(group, hooks_type)
47+
-- Super group shares its hooks with pgroups
48+
for _, pgroup in ipairs(group.pgroups) do
49+
pgroup[hooks_type .. '_hooks'] = group[hooks_type .. '_hooks']
50+
end
51+
end
52+
53+
local function redefine_pgroups_hooks(group)
54+
redefine_hooks(group, 'before_each')
55+
redefine_hooks(group, 'after_each')
56+
redefine_hooks(group, 'before_all')
57+
redefine_hooks(group, 'after_all')
58+
59+
redefine_hooks(group, 'before_test')
60+
redefine_hooks(group, 'after_test')
61+
end
62+
63+
local function redirect_index(group)
64+
local super_group_mt = table.deepcopy(getmetatable(group))
65+
super_group_mt.__newindex = function(_group, key, value)
66+
for _, pgroup in ipairs(_group.pgroups) do
67+
pgroup[key] = value
68+
end
69+
end
70+
setmetatable(group, super_group_mt)
71+
end
72+
73+
function export.parametrize(object, params)
74+
checks('table', 'table')
75+
-- Validate params' name and values
76+
for parameter_name, parameter_values in pairs(params) do
77+
assert(type(parameter_name) == 'string',
78+
string.format('Parameter name should be string, got %s', type(parameter_name)))
79+
assert(type(parameter_values) == 'table',
80+
string.format('Parameter values should be table, got %s', type(parameter_values)))
81+
end
82+
object.params = params
83+
84+
-- Create a subgroup on every param combination
85+
object.pgroups = {}
86+
local combinations = params_combinations(object.params)
87+
for _, pgroup_params in ipairs(combinations) do
88+
local pgroup_name = get_group_name(pgroup_params)
89+
local pgroup = Group:new(object.name .. '.' .. pgroup_name)
90+
pgroup.params = pgroup_params
91+
92+
pgroup.super_group = object
93+
table.insert(object.pgroups, pgroup)
94+
end
95+
96+
redirect_index(object)
97+
redefine_pgroups_hooks(object)
98+
end
99+
100+
return export

luatest/utils.lua

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,17 @@ function utils.pattern_filter(patterns, expr)
120120
return result
121121
end
122122

123+
function utils.split_test_name(test_name)
124+
local test_name_parts = string.split(test_name, '.')
125+
return test_name_parts, #test_name_parts
126+
end
127+
128+
function utils.table_len(t)
129+
local counter = 0
130+
for _, _ in pairs(t) do
131+
counter = counter + 1
132+
end
133+
return counter
134+
end
135+
123136
return utils

0 commit comments

Comments
 (0)