199199 ">>=" ,
200200}
201201
202+ # Binary operations on bools that are specialized and don't just promote operands to int
203+ BOOL_BINARY_OPS : Final = {"&" , "&=" , "|" , "|=" , "^" , "^=" , "==" , "!=" , "<" , "<=" , ">" , ">=" }
204+
202205
203206class LowLevelIRBuilder :
204207 def __init__ (self , current_module : str , mapper : Mapper , options : CompilerOptions ) -> None :
@@ -326,13 +329,13 @@ def coerce(
326329 ):
327330 # Equivalent types
328331 return src
329- elif (
330- is_bool_rprimitive ( src_type ) or is_bit_rprimitive ( src_type )
331- ) and is_int_rprimitive ( target_type ) :
332+ elif (is_bool_rprimitive ( src_type ) or is_bit_rprimitive ( src_type )) and is_tagged (
333+ target_type
334+ ):
332335 shifted = self .int_op (
333336 bool_rprimitive , src , Integer (1 , bool_rprimitive ), IntOp .LEFT_SHIFT
334337 )
335- return self .add (Extend (shifted , int_rprimitive , signed = False ))
338+ return self .add (Extend (shifted , target_type , signed = False ))
336339 elif (
337340 is_bool_rprimitive (src_type ) or is_bit_rprimitive (src_type )
338341 ) and is_fixed_width_rtype (target_type ):
@@ -1245,48 +1248,45 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
12451248 return self .compare_bytes (lreg , rreg , op , line )
12461249 if is_tagged (ltype ) and is_tagged (rtype ) and op in int_comparison_op_mapping :
12471250 return self .compare_tagged (lreg , rreg , op , line )
1248- if (
1249- is_bool_rprimitive (ltype )
1250- and is_bool_rprimitive (rtype )
1251- and op in ("&" , "&=" , "|" , "|=" , "^" , "^=" )
1252- ):
1253- return self .bool_bitwise_op (lreg , rreg , op [0 ], line )
1251+ if is_bool_rprimitive (ltype ) and is_bool_rprimitive (rtype ) and op in BOOL_BINARY_OPS :
1252+ if op in ComparisonOp .signed_ops :
1253+ return self .bool_comparison_op (lreg , rreg , op , line )
1254+ else :
1255+ return self .bool_bitwise_op (lreg , rreg , op [0 ], line )
12541256 if isinstance (rtype , RInstance ) and op in ("in" , "not in" ):
12551257 return self .translate_instance_contains (rreg , lreg , op , line )
12561258 if is_fixed_width_rtype (ltype ):
12571259 if op in FIXED_WIDTH_INT_BINARY_OPS :
12581260 if op .endswith ("=" ):
12591261 op = op [:- 1 ]
1262+ if op != "//" :
1263+ op_id = int_op_to_id [op ]
1264+ else :
1265+ op_id = IntOp .DIV
1266+ if is_bool_rprimitive (rtype ) or is_bit_rprimitive (rtype ):
1267+ rreg = self .coerce (rreg , ltype , line )
1268+ rtype = ltype
12601269 if is_fixed_width_rtype (rtype ) or is_tagged (rtype ):
1261- if op != "//" :
1262- op_id = int_op_to_id [op ]
1263- else :
1264- op_id = IntOp .DIV
12651270 return self .fixed_width_int_op (ltype , lreg , rreg , op_id , line )
12661271 if isinstance (rreg , Integer ):
12671272 # TODO: Check what kind of Integer
1268- if op != "//" :
1269- op_id = int_op_to_id [op ]
1270- else :
1271- op_id = IntOp .DIV
12721273 return self .fixed_width_int_op (
12731274 ltype , lreg , Integer (rreg .value >> 1 , ltype ), op_id , line
12741275 )
12751276 elif op in ComparisonOp .signed_ops :
12761277 if is_int_rprimitive (rtype ):
12771278 rreg = self .coerce_int_to_fixed_width (rreg , ltype , line )
1279+ elif is_bool_rprimitive (rtype ) or is_bit_rprimitive (rtype ):
1280+ rreg = self .coerce (rreg , ltype , line )
12781281 op_id = ComparisonOp .signed_ops [op ]
12791282 if is_fixed_width_rtype (rreg .type ):
12801283 return self .comparison_op (lreg , rreg , op_id , line )
12811284 if isinstance (rreg , Integer ):
12821285 return self .comparison_op (lreg , Integer (rreg .value >> 1 , ltype ), op_id , line )
12831286 elif is_fixed_width_rtype (rtype ):
1284- if (
1285- isinstance (lreg , Integer ) or is_tagged (ltype )
1286- ) and op in FIXED_WIDTH_INT_BINARY_OPS :
1287+ if op in FIXED_WIDTH_INT_BINARY_OPS :
12871288 if op .endswith ("=" ):
12881289 op = op [:- 1 ]
1289- # TODO: Support comparison ops (similar to above)
12901290 if op != "//" :
12911291 op_id = int_op_to_id [op ]
12921292 else :
@@ -1296,15 +1296,38 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
12961296 return self .fixed_width_int_op (
12971297 rtype , Integer (lreg .value >> 1 , rtype ), rreg , op_id , line
12981298 )
1299- else :
1299+ if is_tagged (ltype ):
1300+ return self .fixed_width_int_op (rtype , lreg , rreg , op_id , line )
1301+ if is_bool_rprimitive (ltype ) or is_bit_rprimitive (ltype ):
1302+ lreg = self .coerce (lreg , rtype , line )
13001303 return self .fixed_width_int_op (rtype , lreg , rreg , op_id , line )
13011304 elif op in ComparisonOp .signed_ops :
13021305 if is_int_rprimitive (ltype ):
13031306 lreg = self .coerce_int_to_fixed_width (lreg , rtype , line )
1307+ elif is_bool_rprimitive (ltype ) or is_bit_rprimitive (ltype ):
1308+ lreg = self .coerce (lreg , rtype , line )
13041309 op_id = ComparisonOp .signed_ops [op ]
13051310 if isinstance (lreg , Integer ):
13061311 return self .comparison_op (Integer (lreg .value >> 1 , rtype ), rreg , op_id , line )
1312+ if is_fixed_width_rtype (lreg .type ):
1313+ return self .comparison_op (lreg , rreg , op_id , line )
1314+
1315+ # Mixed int comparisons
1316+ if op in ("==" , "!=" ):
1317+ op_id = ComparisonOp .signed_ops [op ]
1318+ if is_tagged (ltype ) and is_subtype (rtype , ltype ):
1319+ rreg = self .coerce (rreg , int_rprimitive , line )
1320+ return self .comparison_op (lreg , rreg , op_id , line )
1321+ if is_tagged (rtype ) and is_subtype (ltype , rtype ):
1322+ lreg = self .coerce (lreg , int_rprimitive , line )
13071323 return self .comparison_op (lreg , rreg , op_id , line )
1324+ elif op in op in int_comparison_op_mapping :
1325+ if is_tagged (ltype ) and is_subtype (rtype , ltype ):
1326+ rreg = self .coerce (rreg , short_int_rprimitive , line )
1327+ return self .compare_tagged (lreg , rreg , op , line )
1328+ if is_tagged (rtype ) and is_subtype (ltype , rtype ):
1329+ lreg = self .coerce (lreg , short_int_rprimitive , line )
1330+ return self .compare_tagged (lreg , rreg , op , line )
13081331
13091332 call_c_ops_candidates = binary_ops .get (op , [])
13101333 target = self .matching_call_c (call_c_ops_candidates , [lreg , rreg ], line )
@@ -1509,14 +1532,21 @@ def bool_bitwise_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value
15091532 assert False , op
15101533 return self .add (IntOp (bool_rprimitive , lreg , rreg , code , line ))
15111534
1535+ def bool_comparison_op (self , lreg : Value , rreg : Value , op : str , line : int ) -> Value :
1536+ op_id = ComparisonOp .signed_ops [op ]
1537+ return self .comparison_op (lreg , rreg , op_id , line )
1538+
15121539 def unary_not (self , value : Value , line : int ) -> Value :
15131540 mask = Integer (1 , value .type , line )
15141541 return self .int_op (value .type , value , mask , IntOp .XOR , line )
15151542
15161543 def unary_op (self , value : Value , expr_op : str , line : int ) -> Value :
15171544 typ = value .type
1518- if (is_bool_rprimitive (typ ) or is_bit_rprimitive (typ )) and expr_op == "not" :
1519- return self .unary_not (value , line )
1545+ if is_bool_rprimitive (typ ) or is_bit_rprimitive (typ ):
1546+ if expr_op == "not" :
1547+ return self .unary_not (value , line )
1548+ if expr_op == "+" :
1549+ return value
15201550 if is_fixed_width_rtype (typ ):
15211551 if expr_op == "-" :
15221552 # Translate to '0 - x'
@@ -1532,6 +1562,8 @@ def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
15321562 if is_short_int_rprimitive (typ ):
15331563 num >>= 1
15341564 return Integer (- num , typ , value .line )
1565+ if is_tagged (typ ) and expr_op == "+" :
1566+ return value
15351567 if isinstance (typ , RInstance ):
15361568 if expr_op == "-" :
15371569 method = "__neg__"
0 commit comments