Skip to content

Commit fa87476

Browse files
committed
Implement parametrization
1 parent bac09a6 commit fa87476

File tree

8 files changed

+547
-11
lines changed

8 files changed

+547
-11
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Unreleased
44

5+
- Add group parametrization
6+
57
## 0.5.4
68

79
- Add `after_test` and `before_test` hooks.

README.rst

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ Define tests.
7373
g.test_example_2 = function() ... end
7474
g.test_example_m = function() ... end
7575
76+
-- Define parametrized groups
77+
local pg = t.group('pgroup', {param_1 = {1, 2}, param_2 = {3, 4}})
78+
pg.test_example_1 = function(cg) ... end
79+
-- type(cg.params.param_1) == 'number'
80+
pg.test_example_n = function(cg) ... end
81+
82+
-- Hooks can be specified for one parameter
83+
pg.before_all(function() ... end)
84+
pg.before_each({param_1 = 1}, function() ... end)
85+
pg.after_each({param_2 = 3}, function() ... end)
86+
pg.after_test('test_example_1', {param_1 = 2, param_2 = 4}, function() ... end)
87+
7688
7789
Run them.
7890

@@ -232,6 +244,64 @@ Capturing output
232244
By default runner captures all stdout/stderr output and shows it only for failed tests.
233245
Capturing can be disabled with ``-c`` flag.
234246

247+
.. _parametrization:
248+
249+
---------------------------------
250+
Parametrization
251+
---------------------------------
252+
253+
Test group can be parametrized.
254+
255+
.. code-block:: Lua
256+
257+
local g = t.group('pgroup', {a = {1, 2}, b = {3, 4}})
258+
259+
g.test_params = function(cg)
260+
...
261+
local param_a_val = cg.params.a
262+
local param_b_val = cg.params.b
263+
...
264+
end
265+
266+
Each test will be performed for every params combination. Hooks will work as usual
267+
unless there are specified params.
268+
269+
.. code-block:: Lua
270+
271+
-- called before every test
272+
g.before_each(function(cg) ... end)
273+
274+
-- called before tests when a == 1
275+
g.before_each({a = 1}, function(cg) ... end)
276+
277+
-- called only before the test when a == 1 and b == 3
278+
g.before_each({a = 1, b = 3}, function(cg) ... end)
279+
280+
-- called before test named 'test_something' when a == 1
281+
g.before_test('test_something', {a = 1}, function(cg) ... end)
282+
283+
--etc
284+
285+
`before_` hooks called in order from most general to most specialized.
286+
`after_` hooks on contrary called in order from most specialized to most general.
287+
288+
.. code-block:: Lua
289+
290+
-- The order which hooks are called in.
291+
g.before_all(function() ... end)
292+
g.before_all({a = 1}, function() ... end)
293+
g.before_all({a = 1, b = 3}, function() ... end)
294+
...
295+
g.after_all({a = 1, b = 3}, function() ... end)
296+
g.after_all({a = 1}, function() ... end)
297+
g.after_all(function() ... end)
298+
299+
Test with specific params can be called from command line.
300+
301+
.. code-block:: Bash
302+
303+
luatest path-to-file pgroup.a:1.b:3.test_params
304+
235305
.. _test-helpers:
236306

237307
---------------------------------

luatest/hooks.lua

Lines changed: 102 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,85 @@ local utils = require('luatest.utils')
22

33
local export = {}
44

5+
local function arrange_hooks(group, hooks, hooks_type, test_name)
6+
local active_hooks_name = 'active_' .. hooks_type
7+
if test_name then
8+
-- Don't forget about named hooks
9+
active_hooks_name = active_hooks_name .. '_' .. test_name
10+
hooks = hooks[test_name]
11+
end
12+
active_hooks_name = active_hooks_name .. '_hooks'
13+
14+
-- If hooks have been arranged already then there is nothing to do
15+
if group[active_hooks_name] then
16+
return group[active_hooks_name]
17+
end
18+
19+
local active_hooks = {}
20+
for _, entry in ipairs(hooks) do
21+
-- entry == {fn, params}
22+
local params_is_ok = true
23+
for name, value in pairs(entry[2]) do
24+
if group.params
25+
and group.params[name] ~= value then
26+
params_is_ok = false
27+
break
28+
end
29+
end
30+
31+
if params_is_ok then
32+
table.insert(active_hooks, entry)
33+
end
34+
end
35+
36+
-- after_* hooks are called in order of the most specialized ones (more fixed params)
37+
-- before_* hooks, on contrary, are called starting with the most general one
38+
local comparator
39+
if string.find(hooks_type, 'after') then
40+
comparator = function(a, b)
41+
return (utils.table_len(a[2])) > (utils.table_len(b[2]))
42+
end
43+
else
44+
comparator = function(a, b)
45+
return (utils.table_len(a[2])) < (utils.table_len(b[2]))
46+
end
47+
end
48+
table.sort(active_hooks, comparator)
49+
50+
group[active_hooks_name] = {}
51+
for _, entry in ipairs(active_hooks) do
52+
table.insert(group[active_hooks_name], entry[1])
53+
end
54+
55+
return group[active_hooks_name]
56+
end
57+
558
local function define_hooks(object, hooks_type)
659
local hooks = {}
760
object[hooks_type .. '_hooks'] = hooks
861

9-
object[hooks_type] = function(fn)
10-
table.insert(hooks, fn)
62+
object[hooks_type] = function(...)
63+
local params, fn = ...
64+
if fn == nil then
65+
fn = params
66+
params = {}
67+
end
68+
69+
assert(type(params) == 'table',
70+
string.format('Params should be table, got %s', type(params)))
71+
assert(type(fn) == 'function',
72+
string.format('Hook should be function, got %s', type(fn)))
73+
74+
params = params or {}
75+
table.insert(hooks, {fn, params})
1176
end
1277
object['_original_' .. hooks_type] = object[hooks_type] -- for leagacy hooks support
1378

1479
object['run_' .. hooks_type] = function()
15-
for _, fn in ipairs(hooks) do
16-
fn()
80+
local stored_hooks = object[hooks_type .. '_hooks']
81+
local active_hooks = arrange_hooks(object, stored_hooks, hooks_type)
82+
for _, hook in ipairs(active_hooks) do
83+
hook(object)
1784
end
1885
end
1986
end
@@ -22,22 +89,48 @@ local function define_named_hooks(object, hooks_type)
2289
local hooks = {}
2390
object[hooks_type .. '_hooks'] = hooks
2491

25-
object[hooks_type] = function(test_name, fn)
92+
object[hooks_type] = function(...)
93+
local test_name, params, fn = ...
94+
if fn == nil then
95+
fn = params
96+
params = {}
97+
end
98+
99+
assert(type(test_name) == 'string',
100+
string.format('Test name should be string, got %s', type(test_name)))
101+
assert(type(params) == 'table',
102+
string.format('Params should be table, got %s', type(params)))
103+
assert(type(fn) == 'function',
104+
string.format('Hook should be function, got %s', type(fn)))
105+
26106
test_name = object.name .. '.' .. test_name
107+
params = params or {}
27108
if not hooks[test_name] then
28109
hooks[test_name] = {}
29110
end
30-
table.insert(hooks[test_name], fn)
111+
table.insert(hooks[test_name], {fn, params})
31112
end
32113

33114
object['run_' .. hooks_type] = function(test)
115+
local stored_hooks = object[hooks_type .. '_hooks']
34116
local test_name = test.name
35-
if not hooks[test_name] then
117+
118+
-- When parametrized groups are defined named hooks saved by
119+
-- super group test name. When they are called test name is
120+
-- specific to the parametrized group. So, it should be
121+
-- converted back to the super one.
122+
if object.super_group then
123+
local test_name_parts, parts_amount = utils.split_test_name(test_name)
124+
test_name = object.super_group.name .. '.' .. test_name_parts[parts_amount]
125+
end
126+
127+
if not stored_hooks[test_name] then
36128
return
37129
end
38130

39-
for _, fn in ipairs(hooks[test_name]) do
40-
fn()
131+
local active_hooks = arrange_hooks(object, stored_hooks, hooks_type, test_name)
132+
for _, hook in ipairs(active_hooks) do
133+
hook(object)
41134
end
42135
end
43136
end

luatest/init.lua

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ luatest.Server = require('luatest.server')
1717

1818
local Group = require('luatest.group')
1919
local hooks = require('luatest.hooks')
20+
local parametrizer = require('luatest.parametrizer')
2021

2122
--- Add before suite hook.
2223
--
@@ -36,13 +37,26 @@ luatest.groups = {}
3637
-- @string[opt] name
3738
-- @return Group object
3839
-- @see luatest.group
39-
function luatest.group(name)
40+
function luatest.group(name, params)
4041
local group = Group:new(name)
4142
name = group.name
4243
if luatest.groups[name] then
4344
error('Test group already exists: ' .. name ..'.')
4445
end
45-
luatest.groups[name] = group
46+
47+
if params then
48+
parametrizer.parametrize(group, params)
49+
end
50+
51+
-- Register all parametrized groups
52+
if group.pgroups then
53+
for _, pgroup in ipairs(group.pgroups) do
54+
luatest.groups[pgroup.name] = pgroup
55+
end
56+
else
57+
luatest.groups[name] = group
58+
end
59+
4660
return group
4761
end
4862

luatest/parametrizer.lua

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
local checks = require('checks')
2+
3+
local Group = require('luatest.group')
4+
local pp = require('luatest.pp')
5+
6+
local export = {}
7+
8+
local function params_combinations(params)
9+
local combinations = {}
10+
11+
local params_names = {}
12+
for name, _ in pairs(params) do
13+
table.insert(params_names, name)
14+
end
15+
table.sort(params_names)
16+
17+
local function combinator(_params, ind, ...)
18+
if ind < 1 then
19+
local combination = {}
20+
for _, entry in ipairs({...}) do
21+
combination[entry[1]] = entry[2]
22+
end
23+
table.insert(combinations, combination)
24+
else
25+
local name = params_names[ind]
26+
for i=1,#(_params[name]) do combinator(_params, ind - 1, {name, _params[name][i]}, ...) end
27+
end
28+
end
29+
30+
combinator(params, #params_names)
31+
return combinations
32+
end
33+
34+
local function get_group_name(params)
35+
local params_names = {}
36+
for name, _ in pairs(params) do
37+
table.insert(params_names, name)
38+
end
39+
table.sort(params_names)
40+
41+
for ind, param_name in ipairs(params_names) do
42+
params_names[ind] = param_name .. ':' .. pp.tostring(params[param_name])
43+
end
44+
return table.concat(params_names, ".")
45+
end
46+
47+
local function redefine_hooks(group, hooks_type)
48+
-- Super group shares its hooks with pgroups
49+
for _, pgroup in ipairs(group.pgroups) do
50+
pgroup[hooks_type .. '_hooks'] = group[hooks_type .. '_hooks']
51+
end
52+
end
53+
54+
local function redefine_pgroups_hooks(group)
55+
redefine_hooks(group, 'before_each')
56+
redefine_hooks(group, 'after_each')
57+
redefine_hooks(group, 'before_all')
58+
redefine_hooks(group, 'after_all')
59+
60+
redefine_hooks(group, 'before_test')
61+
redefine_hooks(group, 'after_test')
62+
end
63+
64+
local function redirect_index(group)
65+
local super_group_mt = table.deepcopy(getmetatable(group))
66+
super_group_mt.__newindex = function(_group, key, value)
67+
for _, pgroup in ipairs(_group.pgroups) do
68+
pgroup[key] = value
69+
end
70+
end
71+
setmetatable(group, super_group_mt)
72+
end
73+
74+
function export.parametrize(object, params)
75+
checks('table', 'table')
76+
-- Validate params' name and values
77+
for parameter_name, parameter_values in pairs(params) do
78+
assert(type(parameter_name) == 'string',
79+
string.format('Parameter name should be string, got %s', type(parameter_name)))
80+
assert(type(parameter_values) == 'table',
81+
string.format('Parameter values should be table, got %s', type(parameter_values)))
82+
83+
for ind, _ in pairs(parameter_values) do
84+
assert(type(ind) == 'number',
85+
'Parameter values should be contained in an array')
86+
end
87+
end
88+
object.params = params
89+
90+
-- Create a subgroup on every param combination
91+
object.pgroups = {}
92+
local combinations = params_combinations(object.params)
93+
for _, pgroup_params in ipairs(combinations) do
94+
local pgroup_name = get_group_name(pgroup_params)
95+
local pgroup = Group:new(object.name .. '.' .. pgroup_name)
96+
pgroup.params = pgroup_params
97+
98+
pgroup.super_group = object
99+
table.insert(object.pgroups, pgroup)
100+
end
101+
102+
redirect_index(object)
103+
redefine_pgroups_hooks(object)
104+
end
105+
106+
return export

luatest/utils.lua

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,17 @@ function utils.pattern_filter(patterns, expr)
120120
return result
121121
end
122122

123+
function utils.split_test_name(test_name)
124+
local test_name_parts = string.split(test_name, '.')
125+
return test_name_parts, #test_name_parts
126+
end
127+
128+
function utils.table_len(t)
129+
local counter = 0
130+
for _, _ in pairs(t) do
131+
counter = counter + 1
132+
end
133+
return counter
134+
end
135+
123136
return utils

0 commit comments

Comments
 (0)