diff --git a/src/asn1.py b/src/asn1.py index 2a81a1e..5214ef3 100644 --- a/src/asn1.py +++ b/src/asn1.py @@ -22,6 +22,7 @@ from builtins import int from builtins import range from builtins import str +from contextlib import contextmanager from enum import IntEnum from numbers import Number @@ -118,6 +119,46 @@ def leave(self): # type: () -> None self._emit_length(len(value)) self._emit(value) + @contextmanager + def construct(self, nr, cls=None): # type: (int, int) -> None + """This method - context manager calls enter and leave methods, + for better code mapping. + + Usage: + ``` + with encoder.construct(asn1.Numbers.Sequence): + encoder.write(1) + with encoder.construct(asn1.Numbers.Sequence): + encoder.write('foo') + encoder.write('bar') + encoder.write(2) + ``` + encoder.output() will result following structure: + SEQUENCE: + INTEGER: 1 + SEQUENCE: + STRING: foo + STRING: bar + INTEGER: 2 + + Args: + nr (int): The desired ASN.1 type. Use ``Numbers`` enumeration. + + cls (int): This optional parameter specifies the class + of the constructed type. The default class to use is the + universal class. Use ``Classes`` enumeration. + + Returns: + None + + Raises: + `Error` + + """ + self.enter(nr, cls) + yield + self.leave() + def write(self, value, nr=None, typ=None, cls=None): # type: (object, int, int, int) -> None """This method encodes one ASN.1 tag and writes it to the output buffer. diff --git a/tests/test_asn1.py b/tests/test_asn1.py index 144822f..e5ce2e8 100644 --- a/tests/test_asn1.py +++ b/tests/test_asn1.py @@ -279,6 +279,41 @@ def test_long_tag_id(self): res = enc.output() assert res == b'\x3f\x83\xff\x7f\x03\x02\x01\x01' + def test_contextmanager_construct(self): + enc = asn1.Encoder() + enc.start() + + with enc.construct(asn1.Numbers.Sequence): + enc.write(1) + enc.write(b'foo') + + res = enc.output() + assert res == b'\x30\x08\x02\x01\x01\x04\x03foo' + + def test_contextmanager_calls_enter(self): + class TestEncoder(asn1.Encoder): + def enter(self, nr, cls=None): + raise RuntimeError() + + enc = TestEncoder() + enc.start() + + with pytest.raises(RuntimeError): + with enc.construct(asn1.Numbers.Sequence): + enc.write(1) + + def test_contextmanager_calls_leave(self): + class TestEncoder(asn1.Encoder): + def leave(self): + raise RuntimeError() + + enc = TestEncoder() + enc.start() + + with pytest.raises(RuntimeError): + with enc.construct(asn1.Numbers.Sequence): + enc.write(1) + def test_long_tag_length(self): enc = asn1.Encoder() enc.start()