@@ -285,6 +285,11 @@ def flatten(node):
285285 return '' .join (acc )
286286
287287
288+ def make_xml (text ):
289+ xml = ET .XML ('<xml>%s</xml>' % text )
290+ return xml
291+
292+
288293def normalize_xpath (path ):
289294 path = path .replace ("{{channel}}" , channel )
290295 if path .startswith ('//' ):
@@ -401,7 +406,7 @@ def get_tree_count(tree, path):
401406 return len (tree .findall (path ))
402407
403408
404- def check_snapshot (snapshot_name , tree , normalize_to_text ):
409+ def check_snapshot (snapshot_name , actual_tree , normalize_to_text ):
405410 assert rust_test_path .endswith ('.rs' )
406411 snapshot_path = '{}.{}.{}' .format (rust_test_path [:- 3 ], snapshot_name , 'html' )
407412 try :
@@ -414,11 +419,15 @@ def check_snapshot(snapshot_name, tree, normalize_to_text):
414419 raise FailedCheck ('No saved snapshot value' )
415420
416421 if not normalize_to_text :
417- actual_str = ET .tostring (tree ).decode ('utf-8' )
422+ actual_str = ET .tostring (actual_tree ).decode ('utf-8' )
418423 else :
419- actual_str = flatten (tree )
424+ actual_str = flatten (actual_tree )
425+
426+ if not expected_str \
427+ or (not normalize_to_text and
428+ not compare_tree (make_xml (actual_str ), make_xml (expected_str ), stderr )) \
429+ or (normalize_to_text and actual_str != expected_str ):
420430
421- if expected_str != actual_str :
422431 if bless :
423432 with open (snapshot_path , 'w' ) as snapshot_file :
424433 snapshot_file .write (actual_str )
@@ -430,6 +439,59 @@ def check_snapshot(snapshot_name, tree, normalize_to_text):
430439 print ()
431440 raise FailedCheck ('Actual snapshot value is different than expected' )
432441
442+
443+ # Adapted from https://github.com/formencode/formencode/blob/3a1ba9de2fdd494dd945510a4568a3afeddb0b2e/formencode/doctest_xml_compare.py#L72-L120
444+ def compare_tree (x1 , x2 , reporter = None ):
445+ if x1 .tag != x2 .tag :
446+ if reporter :
447+ reporter ('Tags do not match: %s and %s' % (x1 .tag , x2 .tag ))
448+ return False
449+ for name , value in x1 .attrib .items ():
450+ if x2 .attrib .get (name ) != value :
451+ if reporter :
452+ reporter ('Attributes do not match: %s=%r, %s=%r'
453+ % (name , value , name , x2 .attrib .get (name )))
454+ return False
455+ for name in x2 .attrib :
456+ if name not in x1 .attrib :
457+ if reporter :
458+ reporter ('x2 has an attribute x1 is missing: %s'
459+ % name )
460+ return False
461+ if not text_compare (x1 .text , x2 .text ):
462+ if reporter :
463+ reporter ('text: %r != %r' % (x1 .text , x2 .text ))
464+ return False
465+ if not text_compare (x1 .tail , x2 .tail ):
466+ if reporter :
467+ reporter ('tail: %r != %r' % (x1 .tail , x2 .tail ))
468+ return False
469+ cl1 = list (x1 )
470+ cl2 = list (x2 )
471+ if len (cl1 ) != len (cl2 ):
472+ if reporter :
473+ reporter ('children length differs, %i != %i'
474+ % (len (cl1 ), len (cl2 )))
475+ return False
476+ i = 0
477+ for c1 , c2 in zip (cl1 , cl2 ):
478+ i += 1
479+ if not compare_tree (c1 , c2 , reporter = reporter ):
480+ if reporter :
481+ reporter ('children %i do not match: %s'
482+ % (i , c1 .tag ))
483+ return False
484+ return True
485+
486+
487+ def text_compare (t1 , t2 ):
488+ if not t1 and not t2 :
489+ return True
490+ if t1 == '*' or t2 == '*' :
491+ return True
492+ return (t1 or '' ).strip () == (t2 or '' ).strip ()
493+
494+
433495def stderr (* args ):
434496 if sys .version_info .major < 3 :
435497 file = codecs .getwriter ('utf-8' )(sys .stderr )
0 commit comments