11import framework
2+ import itertools
3+ import numpy as np
4+ import torch
25
3-
4- class ElementMulBench (framework .Benchmark ):
6+ # A template class for elementwise operations.
7+ # A derived class will override the class instance to customize its behavior.
8+ class ElementBench (framework .Benchmark ):
9+ # List of customization class variables.
10+ op_str = None
11+ binary_op_pt_func = None
12+ binary_op_np_func = None
13+ unary_op_pt_func = None
14+ unary_op_np_func = None
15+ split_input = True
516 def __init__ (self , mode , device , N ):
617 super ().__init__ (mode , device )
718 self .N = N
@@ -11,27 +22,60 @@ def __init__(self, mode, device, N):
1122 self .d4 = self .rand ([N ], device = device , requires_grad = self .requires_grad )
1223 self .inputs = [self .d1 , self .d2 , self .d3 , self .d4 ]
1324
25+ def _eval (self , d1 , d2 , d3 , d4 , binary_op , unary_op ):
26+ if not binary_op :
27+ binary_op = lambda x , y : x + y
28+ if not unary_op :
29+ unary_op = lambda x : x
30+ if self .split_input :
31+ d1 = unary_op (d1 )
32+ d2 = unary_op (d2 )
33+ d3 = unary_op (d3 )
34+ d4 = unary_op (d4 )
35+ else :
36+ d2 = unary_op (d1 + 0.001 )
37+ d3 = unary_op (d1 + 0.002 )
38+ d4 = unary_op (d1 + 0.003 )
39+ d1 = unary_op (d1 )
40+ a = binary_op (d1 , d2 )
41+ b = binary_op (d3 , d4 )
42+ c = a + b
43+ return c
44+
1445 def forward (self , d1 , d2 , d3 , d4 ):
15- y = d1 * d2 + d3 * d4
16- return y
46+ binary_op = self .__class__ .binary_op_pt_func
47+ unary_op = self .__class__ .unary_op_pt_func
48+ return self ._eval (d1 , d2 , d3 , d4 , binary_op , unary_op )
1749
1850 def reference (self ):
19- return self .numpy (self .d1 ) * self .numpy (self .d2 ) + self .numpy (self .d3 ) * self .numpy (self .d4 )
51+ binary_op = self .__class__ .binary_op_np_func
52+ unary_op = self .__class__ .unary_op_np_func
53+ [d1 , d2 , d3 , d4 ] = [self .numpy (d ) for d in [self .d1 , self .d2 , self .d3 , self .d4 ]]
54+ return self ._eval (d1 , d2 , d3 , d4 , binary_op , unary_op )
2055
2156 def config (self ):
2257 return [self .N ]
2358
24- @staticmethod
25- def module ():
26- return 'element_mul'
59+ @classmethod
60+ def module (cls ):
61+ return 'element_' + cls . op_str
2762
2863 def memory_workload (self ):
64+ input_count = len (self .inputs )
2965 if self .mode == 'fwd' :
30- sol_count = 4 + 1
31- algorithmic_count = 3 + 1
66+ if self .split_input :
67+ sol_count = input_count + 1
68+ algorithmic_count = input_count + 1
69+ else :
70+ sol_count = 1 + 1
71+ algorithmic_count = 1 + 1
3272 else :
33- sol_count = (4 + 1 ) + (1 + 4 )
34- algorithmic_count = (4 + 1 ) + ((2 + 1 ) * 4 )
73+ if self .split_input :
74+ sol_count = (input_count + 1 ) + (1 + input_count )
75+ algorithmic_count = (input_count + 1 ) + ((2 + 1 ) * input_count )
76+ else :
77+ sol_count = 1 + 1
78+ algorithmic_count = 1 + 1
3579
3680 buffer_size = self .N * 4
3781 return {'sol' : buffer_size * sol_count , 'algorithmic' : buffer_size * algorithmic_count }
@@ -41,4 +85,56 @@ def default_configs():
4185 return [[1 << 27 ]]
4286
4387
44- framework .register_benchmark_class (ElementMulBench )
88+ def register_element_ops ():
89+ binary_op_list = [
90+ ["mul" , lambda a , b : a * b ],
91+ ["add" , lambda a , b : a + b ],
92+ ["sub" , lambda a , b : a - b ],
93+ ["div" , lambda a , b : a / (b + 1e-4 )],
94+ ["pow" , lambda a , b : torch .pow (a , b ), lambda a , b : np .power (a , b )], # no fuson triggered
95+ ["max" , lambda a , b : torch .max (a , b ), lambda a , b : np .maximum (a , b )],
96+ ["min" , lambda a , b : torch .min (a , b ), lambda a , b : np .minimum (a , b )],
97+ ]
98+
99+ unary_op_list = [
100+ ["exp" , lambda x : torch .exp (x ), lambda x : np .exp (x )],
101+ ["sin" , lambda x : torch .sin (x ), lambda x : np .sin (x )],
102+ ["cos" , lambda x : torch .cos (x ), lambda x : np .cos (x )],
103+ ]
104+
105+ for split_input , binary_op in itertools .product ([True , False ], binary_op_list ):
106+ # Make a copy of ElementBench
107+ if len (binary_op ) == 2 :
108+ [op_str , op_pt_func ] = binary_op
109+ op_np_func = op_pt_func
110+ elif len (binary_op ) == 3 :
111+ [op_str , op_pt_func , op_np_func ] = binary_op
112+ split_str = 'split' if split_input else 'shared'
113+ op_str = split_str + '_' + op_str
114+ bm_cls = type ('ElementBench_' + op_str , (ElementBench ,), {})
115+ bm_cls .op_str = op_str
116+ bm_cls .binary_op_pt_func = op_pt_func
117+ bm_cls .binary_op_np_func = op_np_func
118+ bm_cls .split_input = split_input
119+ framework .register_benchmark_class (bm_cls )
120+
121+ for split_input , unary_op in itertools .product ([True , False ], unary_op_list ):
122+ # Make a copy of ElementBench
123+ if len (unary_op ) == 2 :
124+ [op_str , op_pt_func ] = unary_op
125+ op_np_func = op_pt_func
126+ elif len (unary_op ) == 3 :
127+ [op_str , op_pt_func , op_np_func ] = unary_op
128+ split_str = 'split' if split_input else 'shared'
129+ op_str = split_str + '_' + op_str
130+ bm_cls = type ('ElementBench_' + op_str , (ElementBench ,), {})
131+ bm_cls .op_str = op_str
132+ bm_cls .unary_op_pt_func = op_pt_func
133+ bm_cls .unary_op_np_func = op_np_func
134+ bm_cls .split_input = split_input
135+ framework .register_benchmark_class (bm_cls )
136+
137+
138+ #framework.register_benchmark_class(ElementMulBench)
139+ register_element_ops ()
140+
0 commit comments