From b52c239228acfddaae88cf716fad43ccbe63087d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ale=C5=A1=20Mr=C3=A1zek?= <ales.mrazek@nic.cz> Date: Fri, 28 Apr 2023 17:18:59 +0200 Subject: [PATCH] manager: datamodel: ListOrItem custom generic type --- manager/etc/knot-resolver/config.dev.yml | 2 +- .../etc/knot-resolver/config.policy.dev.yml | 2 +- .../datamodel/forward_schema.py | 6 +- .../datamodel/network_schema.py | 21 ++++--- .../templates/macros/network_macros.lua.j2 | 20 ++----- .../datamodel/types/__init__.py | 2 + .../datamodel/types/generic_types.py | 34 +++++++++++ .../datamodel/types/types.py | 11 ++-- .../templates/test_network_macros.py | 12 ++-- .../unit/datamodel/test_config_schema.py | 2 +- .../unit/datamodel/test_network_schema.py | 18 +++--- .../datamodel/types/test_generic_types.py | 56 +++++++++++++++++++ 12 files changed, 133 insertions(+), 53 deletions(-) create mode 100644 manager/knot_resolver_manager/datamodel/types/generic_types.py create mode 100644 manager/tests/unit/datamodel/types/test_generic_types.py diff --git a/manager/etc/knot-resolver/config.dev.yml b/manager/etc/knot-resolver/config.dev.yml index e9d63dc96..7555b98c7 100644 --- a/manager/etc/knot-resolver/config.dev.yml +++ b/manager/etc/knot-resolver/config.dev.yml @@ -11,4 +11,4 @@ logging: - supervisord network: listen: - - interface: [127.0.0.1@5353] + - interface: 127.0.0.1@5353 diff --git a/manager/etc/knot-resolver/config.policy.dev.yml b/manager/etc/knot-resolver/config.policy.dev.yml index e0d2646d7..5f5a7429d 100644 --- a/manager/etc/knot-resolver/config.policy.dev.yml +++ b/manager/etc/knot-resolver/config.policy.dev.yml @@ -11,7 +11,7 @@ logging: - supervisord network: listen: - - interface: [127.0.0.1@5353] + - interface: 127.0.0.1@5353 views: - subnets: [127.0.0.0/24] diff --git a/manager/knot_resolver_manager/datamodel/forward_schema.py b/manager/knot_resolver_manager/datamodel/forward_schema.py index 7c1227cf8..2f7e07ead 100644 --- a/manager/knot_resolver_manager/datamodel/forward_schema.py +++ b/manager/knot_resolver_manager/datamodel/forward_schema.py @@ -2,7 +2,7 @@ from typing import List, Optional, Union from typing_extensions import Literal -from knot_resolver_manager.datamodel.types import DomainName, IPAddressOptionalPort +from knot_resolver_manager.datamodel.types import DomainName, IPAddressOptionalPort, ListOrItem from knot_resolver_manager.datamodel.types.files import FilePath from knot_resolver_manager.utils.modeling import ConfigSchema @@ -19,9 +19,9 @@ class ForwardServerSchema(ConfigSchema): ca_file: Path to CA certificate file. """ - address: List[IPAddressOptionalPort] + address: ListOrItem[IPAddressOptionalPort] transport: Optional[Literal["tls"]] = None - pin_sha256: Optional[List[str]] = None + pin_sha256: Optional[str] = None hostname: Optional[DomainName] = None ca_file: Optional[FilePath] = None diff --git a/manager/knot_resolver_manager/datamodel/network_schema.py b/manager/knot_resolver_manager/datamodel/network_schema.py index a12ebe637..2349bdc56 100644 --- a/manager/knot_resolver_manager/datamodel/network_schema.py +++ b/manager/knot_resolver_manager/datamodel/network_schema.py @@ -12,6 +12,7 @@ from knot_resolver_manager.datamodel.types import ( IPNetwork, IPv4Address, IPv6Address, + ListOrItem, PortNumber, SizeUnit, ) @@ -84,24 +85,24 @@ class ListenSchema(ConfigSchema): freebind: Used for binding to non-local address. """ - interface: Optional[List[InterfaceOptionalPort]] = None - unix_socket: Optional[List[FilePath]] = None + interface: Optional[ListOrItem[InterfaceOptionalPort]] = None + unix_socket: Optional[ListOrItem[FilePath]] = None port: Optional[PortNumber] = None kind: KindEnum = "dns" freebind: bool = False _LAYER = Raw - interface: Optional[List[InterfaceOptionalPort]] - unix_socket: Optional[List[FilePath]] + interface: Optional[ListOrItem[InterfaceOptionalPort]] + unix_socket: Optional[ListOrItem[FilePath]] port: Optional[PortNumber] kind: KindEnum freebind: bool - def _interface(self, origin: Raw) -> Optional[List[InterfaceOptionalPort]]: - if isinstance(origin.interface, list): + def _interface(self, origin: Raw) -> Optional[ListOrItem[InterfaceOptionalPort]]: + if origin.interface: port_set: Optional[bool] = None - for intrfc in origin.interface: + for intrfc in origin.interface: # type: ignore[attr-defined] if origin.port and intrfc.port: raise ValueError("The port number is defined in two places ('port' option and '@<port>' syntax).") if port_set is not None and (bool(intrfc.port) != port_set): @@ -109,8 +110,6 @@ class ListenSchema(ConfigSchema): "The '@<port>' syntax must be used either for all or none of the interface in the list." ) port_set = bool(intrfc.port) - elif isinstance(origin.interface, InterfaceOptionalPort) and origin.interface.port and origin.port: - raise ValueError("The port number is defined in two places ('port' option and '@<port>' syntax).") return origin.interface def _port(self, origin: Raw) -> Optional[PortNumber]: @@ -175,6 +174,6 @@ class NetworkSchema(ConfigSchema): tls: TLSSchema = TLSSchema() proxy_protocol: Union[Literal[False], ProxyProtocolSchema] = False listen: List[ListenSchema] = [ - ListenSchema({"interface": ["127.0.0.1"]}), - ListenSchema({"interface": ["::1"], "freebind": True}), + ListenSchema({"interface": "127.0.0.1"}), + ListenSchema({"interface": "::1", "freebind": True}), ] diff --git a/manager/knot_resolver_manager/datamodel/templates/macros/network_macros.lua.j2 b/manager/knot_resolver_manager/datamodel/templates/macros/network_macros.lua.j2 index 933ecdfa6..ff78fbd80 100644 --- a/manager/knot_resolver_manager/datamodel/templates/macros/network_macros.lua.j2 +++ b/manager/knot_resolver_manager/datamodel/templates/macros/network_macros.lua.j2 @@ -44,20 +44,12 @@ net.{{ interface.if_name }}, {% macro network_listen(listen) -%} {%- if listen.unix_socket -%} - {%- if listen.unix_socket is iterable-%} - {% for path in listen.unix_socket -%} - {{ net_listen_unix_socket(path, listen.kind, listen.freebind) }} - {% endfor -%} - {%- else -%} - {{ net_listen_unix_socket(listen.unix_socket, listen.kind, listen.freebind) }} - {%- endif -%} +{% for path in listen.unix_socket %} +{{ net_listen_unix_socket(path, listen.kind, listen.freebind) }} +{% endfor %} {%- elif listen.interface -%} - {%- if listen.interface is iterable-%} - {% for interface in listen.interface -%} - {{ net_listen_interface(interface, listen.kind, listen.freebind, listen.port) }} - {% endfor -%} - {%- else -%} - {{ net_listen_interface(listen.interface, listen.kind, listen.freebind, listen.port) }} - {%- endif -%} +{% for interface in listen.interface %} +{{ net_listen_interface(interface, listen.kind, listen.freebind, listen.port) }} +{% endfor %} {%- endif -%} {%- endmacro %} \ No newline at end of file diff --git a/manager/knot_resolver_manager/datamodel/types/__init__.py b/manager/knot_resolver_manager/datamodel/types/__init__.py index 256092a1d..33d8c90d4 100644 --- a/manager/knot_resolver_manager/datamodel/types/__init__.py +++ b/manager/knot_resolver_manager/datamodel/types/__init__.py @@ -1,5 +1,6 @@ from .enums import DNSRecordTypeEnum, PolicyActionEnum, PolicyFlagEnum from .files import AbsoluteDir, Dir, File, FilePath +from .generic_types import ListOrItem from .types import ( DomainName, IDPattern, @@ -43,6 +44,7 @@ __all__ = [ "IPv4Address", "IPv6Address", "IPv6Network96", + "ListOrItem", "Percent", "PortNumber", "SizeUnit", diff --git a/manager/knot_resolver_manager/datamodel/types/generic_types.py b/manager/knot_resolver_manager/datamodel/types/generic_types.py new file mode 100644 index 000000000..c5711f582 --- /dev/null +++ b/manager/knot_resolver_manager/datamodel/types/generic_types.py @@ -0,0 +1,34 @@ +from typing import Any, List, TypeVar, Union + +from knot_resolver_manager.utils.modeling import BaseGenericTypeWrapper + +T = TypeVar("T") + + +class ListOrItem(BaseGenericTypeWrapper[Union[List[T], T]]): + + _value_orig: Union[List[T], T] + _list: List[T] + + def __init__(self, source_value: Any, object_path: str = "/") -> None: + super().__init__(source_value) + self._value_orig: Union[List[T], T] = source_value + self._list: List[T] = source_value if isinstance(source_value, list) else [source_value] + + def __getitem__(self, index: Any) -> T: + return self._list[index] + + def __int__(self) -> int: + raise ValueError(f"Can't convert '{type(self).__name__}' to an integer.") + + def __str__(self) -> str: + return str(self._value_orig) + + def to_std(self) -> List[T]: + return self._list + + def __eq__(self, o: object) -> bool: + return isinstance(o, ListOrItem) and o._value_orig == self._value_orig + + def serialize(self) -> Union[List[T], T]: + return self._value_orig diff --git a/manager/knot_resolver_manager/datamodel/types/types.py b/manager/knot_resolver_manager/datamodel/types/types.py index d0bc21a97..f38759c82 100644 --- a/manager/knot_resolver_manager/datamodel/types/types.py +++ b/manager/knot_resolver_manager/datamodel/types/types.py @@ -1,13 +1,10 @@ import ipaddress import re -from typing import Any, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Dict, Optional, Type, Union from knot_resolver_manager.datamodel.types.base_types import IntRangeBase, PatternBase, StrBase, UnitBase from knot_resolver_manager.utils.modeling import BaseValueType -_InnerType = TypeVar("_InnerType") -ListOrSingle = List[_InnerType] - class IntNonNegative(IntRangeBase): _min: int = 0 @@ -51,14 +48,14 @@ class SizeUnit(UnitBase): return self._value def mbytes(self) -> int: - return self._value // 1024 ** 2 + return self._value // 1024**2 class TimeUnit(UnitBase): - _units = {"us": 1, "ms": 10 ** 3, "s": 10 ** 6, "m": 60 * 10 ** 6, "h": 3600 * 10 ** 6, "d": 24 * 3600 * 10 ** 6} + _units = {"us": 1, "ms": 10**3, "s": 10**6, "m": 60 * 10**6, "h": 3600 * 10**6, "d": 24 * 3600 * 10**6} def seconds(self) -> int: - return self._value // 1000 ** 2 + return self._value // 1000**2 def millis(self) -> int: return self._value // 1000 diff --git a/manager/tests/unit/datamodel/templates/test_network_macros.py b/manager/tests/unit/datamodel/templates/test_network_macros.py index 6d6637edf..ad193d982 100644 --- a/manager/tests/unit/datamodel/templates/test_network_macros.py +++ b/manager/tests/unit/datamodel/templates/test_network_macros.py @@ -8,8 +8,8 @@ def test_network_listen(): tmpl = template_from_str(tmpl_str) soc = ListenSchema({"unix-socket": "/tmp/kresd-socket", "kind": "dot"}) - assert tmpl.render(listen=soc) == "net.listen('/tmp/kresd-socket',nil,{kind='tls',freebind=false})" - soc_list = ListenSchema({"unix-socket": [soc.unix_socket, "/tmp/kresd-socket2"], "kind": "dot"}) + assert tmpl.render(listen=soc) == "net.listen('/tmp/kresd-socket',nil,{kind='tls',freebind=false})\n" + soc_list = ListenSchema({"unix-socket": [soc.unix_socket.to_std()[0], "/tmp/kresd-socket2"], "kind": "dot"}) assert ( tmpl.render(listen=soc_list) == "net.listen('/tmp/kresd-socket',nil,{kind='tls',freebind=false})\n" @@ -17,8 +17,8 @@ def test_network_listen(): ) ip = ListenSchema({"interface": "::1@55", "freebind": True}) - assert tmpl.render(listen=ip) == "net.listen('::1',55,{kind='dns',freebind=true})" - ip_list = ListenSchema({"interface": [ip.interface, "127.0.0.1@5353"]}) + assert tmpl.render(listen=ip) == "net.listen('::1',55,{kind='dns',freebind=true})\n" + ip_list = ListenSchema({"interface": [ip.interface.to_std()[0], "127.0.0.1@5353"]}) assert ( tmpl.render(listen=ip_list) == "net.listen('::1',55,{kind='dns',freebind=false})\n" @@ -26,8 +26,8 @@ def test_network_listen(): ) intrfc = ListenSchema({"interface": "eth0", "kind": "doh2"}) - assert tmpl.render(listen=intrfc) == "net.listen(net.eth0,443,{kind='doh2',freebind=false})" - intrfc_list = ListenSchema({"interface": [intrfc.interface, "lo"], "port": 5555, "kind": "doh2"}) + assert tmpl.render(listen=intrfc) == "net.listen(net.eth0,443,{kind='doh2',freebind=false})\n" + intrfc_list = ListenSchema({"interface": [intrfc.interface.to_std()[0], "lo"], "port": 5555, "kind": "doh2"}) assert ( tmpl.render(listen=intrfc_list) == "net.listen(net.eth0,5555,{kind='doh2',freebind=false})\n" diff --git a/manager/tests/unit/datamodel/test_config_schema.py b/manager/tests/unit/datamodel/test_config_schema.py index 502334734..31703b967 100644 --- a/manager/tests/unit/datamodel/test_config_schema.py +++ b/manager/tests/unit/datamodel/test_config_schema.py @@ -49,6 +49,6 @@ def test_config_json_schema(): try: _ = json.dumps(obj) except BaseException as e: - raise Exception(f"failed to serialize '{path}'") from e + raise Exception(f"failed to serialize '{path}': {e}") from e recser(dct) diff --git a/manager/tests/unit/datamodel/test_network_schema.py b/manager/tests/unit/datamodel/test_network_schema.py index 1a398b50a..7b616f347 100644 --- a/manager/tests/unit/datamodel/test_network_schema.py +++ b/manager/tests/unit/datamodel/test_network_schema.py @@ -13,12 +13,12 @@ def test_listen_defaults(): assert len(o.listen) == 2 # {"ip-address": "127.0.0.1"} - assert o.listen[0].interface == InterfaceOptionalPort("127.0.0.1") + assert o.listen[0].interface.to_std() == [InterfaceOptionalPort("127.0.0.1")] assert o.listen[0].port == PortNumber(53) assert o.listen[0].kind == "dns" assert o.listen[0].freebind == False # {"ip-address": "::1", "freebind": True} - assert o.listen[1].interface == InterfaceOptionalPort("::1") + assert o.listen[1].interface.to_std() == [InterfaceOptionalPort("::1")] assert o.listen[1].port == PortNumber(53) assert o.listen[1].kind == "dns" assert o.listen[1].freebind == True @@ -27,11 +27,11 @@ def test_listen_defaults(): @pytest.mark.parametrize( "listen,port", [ - ({"unix-socket": "/tmp/kresd-socket"}, None), - ({"interface": "::1"}, 53), - ({"interface": "::1", "kind": "dot"}, 853), - ({"interface": "::1", "kind": "doh-legacy"}, 443), - ({"interface": "::1", "kind": "doh2"}, 443), + ({"unix-socket": ["/tmp/kresd-socket"]}, None), + ({"interface": ["::1"]}, 53), + ({"interface": ["::1"], "kind": "dot"}, 853), + ({"interface": ["::1"], "kind": "doh-legacy"}, 443), + ({"interface": ["::1"], "kind": "doh2"}, 443), ], ) def test_listen_port_defaults(listen: Dict[str, Any], port: Optional[int]): @@ -64,8 +64,8 @@ def test_listen_valid(listen: Dict[str, Any]): @pytest.mark.parametrize( "listen", [ - {"unit-socket": "/tmp/kresd-socket", "port": "53"}, - {"interface": "::1", "unit-socket": "/tmp/kresd-socket"}, + {"unix-socket": "/tmp/kresd-socket", "port": "53"}, + {"interface": "::1", "unix-socket": "/tmp/kresd-socket"}, {"interface": "::1@5353", "port": 5353}, {"interface": ["127.0.0.1", "::1@5353"]}, {"interface": ["127.0.0.1@5353", "::1@5353"], "port": 5353}, diff --git a/manager/tests/unit/datamodel/types/test_generic_types.py b/manager/tests/unit/datamodel/types/test_generic_types.py new file mode 100644 index 000000000..7803ed005 --- /dev/null +++ b/manager/tests/unit/datamodel/types/test_generic_types.py @@ -0,0 +1,56 @@ +from typing import Any, List, Optional, Union + +import pytest +from pytest import raises + +from knot_resolver_manager.datamodel.types import ListOrItem +from knot_resolver_manager.utils.modeling import BaseSchema +from knot_resolver_manager.utils.modeling.exceptions import DataValidationError +from knot_resolver_manager.utils.modeling.types import get_generic_type_wrapper_argument + + +@pytest.mark.parametrize("val", [str, int]) +def test_list_or_item_inner_type(val: Any): + assert get_generic_type_wrapper_argument(ListOrItem[val]) == Union[List[val], val] + + +@pytest.mark.parametrize( + "typ,val", + [ + (int, [1, 65_535, 5353, 5000]), + (int, 65_535), + (str, ["string1", "string2"]), + (str, "string1"), + ], +) +def test_list_or_item_valid(typ: Any, val: Any): + class ListOrItemSchema(BaseSchema): + test: ListOrItem[typ] + + o = ListOrItemSchema({"test": val}) + assert o.test.serialize() == val + assert o.test.to_std() == val if isinstance(val, list) else [val] + + i = 0 + for item in o.test: + assert item == val[i] if isinstance(val, list) else val + i += 1 + + +@pytest.mark.parametrize( + "typ,val", + [ + (str, [True, False, True, False]), + (str, False), + (bool, [1, 65_535, 5353, 5000]), + (bool, 65_535), + (int, "string1"), + (int, ["string1", "string2"]), + ], +) +def test_list_or_item_invalid(typ: Any, val: Any): + class ListOrItemSchema(BaseSchema): + test: ListOrItem[typ] + + with raises(DataValidationError): + ListOrItemSchema({"test": val}) -- GitLab