From b52c239228acfddaae88cf716fad43ccbe63087d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Ale=C5=A1=20Mr=C3=A1zek?= <ales.mrazek@nic.cz>
Date: Fri, 28 Apr 2023 17:18:59 +0200
Subject: [PATCH] manager: datamodel: ListOrItem custom generic type

---
 manager/etc/knot-resolver/config.dev.yml      |  2 +-
 .../etc/knot-resolver/config.policy.dev.yml   |  2 +-
 .../datamodel/forward_schema.py               |  6 +-
 .../datamodel/network_schema.py               | 21 ++++---
 .../templates/macros/network_macros.lua.j2    | 20 ++-----
 .../datamodel/types/__init__.py               |  2 +
 .../datamodel/types/generic_types.py          | 34 +++++++++++
 .../datamodel/types/types.py                  | 11 ++--
 .../templates/test_network_macros.py          | 12 ++--
 .../unit/datamodel/test_config_schema.py      |  2 +-
 .../unit/datamodel/test_network_schema.py     | 18 +++---
 .../datamodel/types/test_generic_types.py     | 56 +++++++++++++++++++
 12 files changed, 133 insertions(+), 53 deletions(-)
 create mode 100644 manager/knot_resolver_manager/datamodel/types/generic_types.py
 create mode 100644 manager/tests/unit/datamodel/types/test_generic_types.py

diff --git a/manager/etc/knot-resolver/config.dev.yml b/manager/etc/knot-resolver/config.dev.yml
index e9d63dc96..7555b98c7 100644
--- a/manager/etc/knot-resolver/config.dev.yml
+++ b/manager/etc/knot-resolver/config.dev.yml
@@ -11,4 +11,4 @@ logging:
     - supervisord
 network:
   listen:
-    - interface: [127.0.0.1@5353]
+    - interface: 127.0.0.1@5353
diff --git a/manager/etc/knot-resolver/config.policy.dev.yml b/manager/etc/knot-resolver/config.policy.dev.yml
index e0d2646d7..5f5a7429d 100644
--- a/manager/etc/knot-resolver/config.policy.dev.yml
+++ b/manager/etc/knot-resolver/config.policy.dev.yml
@@ -11,7 +11,7 @@ logging:
     - supervisord
 network:
   listen:
-    - interface: [127.0.0.1@5353]
+    - interface: 127.0.0.1@5353
 
 views:
   - subnets: [127.0.0.0/24]
diff --git a/manager/knot_resolver_manager/datamodel/forward_schema.py b/manager/knot_resolver_manager/datamodel/forward_schema.py
index 7c1227cf8..2f7e07ead 100644
--- a/manager/knot_resolver_manager/datamodel/forward_schema.py
+++ b/manager/knot_resolver_manager/datamodel/forward_schema.py
@@ -2,7 +2,7 @@ from typing import List, Optional, Union
 
 from typing_extensions import Literal
 
-from knot_resolver_manager.datamodel.types import DomainName, IPAddressOptionalPort
+from knot_resolver_manager.datamodel.types import DomainName, IPAddressOptionalPort, ListOrItem
 from knot_resolver_manager.datamodel.types.files import FilePath
 from knot_resolver_manager.utils.modeling import ConfigSchema
 
@@ -19,9 +19,9 @@ class ForwardServerSchema(ConfigSchema):
     ca_file: Path to CA certificate file.
     """
 
-    address: List[IPAddressOptionalPort]
+    address: ListOrItem[IPAddressOptionalPort]
     transport: Optional[Literal["tls"]] = None
-    pin_sha256: Optional[List[str]] = None
+    pin_sha256: Optional[str] = None
     hostname: Optional[DomainName] = None
     ca_file: Optional[FilePath] = None
 
diff --git a/manager/knot_resolver_manager/datamodel/network_schema.py b/manager/knot_resolver_manager/datamodel/network_schema.py
index a12ebe637..2349bdc56 100644
--- a/manager/knot_resolver_manager/datamodel/network_schema.py
+++ b/manager/knot_resolver_manager/datamodel/network_schema.py
@@ -12,6 +12,7 @@ from knot_resolver_manager.datamodel.types import (
     IPNetwork,
     IPv4Address,
     IPv6Address,
+    ListOrItem,
     PortNumber,
     SizeUnit,
 )
@@ -84,24 +85,24 @@ class ListenSchema(ConfigSchema):
         freebind: Used for binding to non-local address.
         """
 
-        interface: Optional[List[InterfaceOptionalPort]] = None
-        unix_socket: Optional[List[FilePath]] = None
+        interface: Optional[ListOrItem[InterfaceOptionalPort]] = None
+        unix_socket: Optional[ListOrItem[FilePath]] = None
         port: Optional[PortNumber] = None
         kind: KindEnum = "dns"
         freebind: bool = False
 
     _LAYER = Raw
 
-    interface: Optional[List[InterfaceOptionalPort]]
-    unix_socket: Optional[List[FilePath]]
+    interface: Optional[ListOrItem[InterfaceOptionalPort]]
+    unix_socket: Optional[ListOrItem[FilePath]]
     port: Optional[PortNumber]
     kind: KindEnum
     freebind: bool
 
-    def _interface(self, origin: Raw) -> Optional[List[InterfaceOptionalPort]]:
-        if isinstance(origin.interface, list):
+    def _interface(self, origin: Raw) -> Optional[ListOrItem[InterfaceOptionalPort]]:
+        if origin.interface:
             port_set: Optional[bool] = None
-            for intrfc in origin.interface:
+            for intrfc in origin.interface:  # type: ignore[attr-defined]
                 if origin.port and intrfc.port:
                     raise ValueError("The port number is defined in two places ('port' option and '@<port>' syntax).")
                 if port_set is not None and (bool(intrfc.port) != port_set):
@@ -109,8 +110,6 @@ class ListenSchema(ConfigSchema):
                         "The '@<port>' syntax must be used either for all or none of the interface in the list."
                     )
                 port_set = bool(intrfc.port)
-        elif isinstance(origin.interface, InterfaceOptionalPort) and origin.interface.port and origin.port:
-            raise ValueError("The port number is defined in two places ('port' option and '@<port>' syntax).")
         return origin.interface
 
     def _port(self, origin: Raw) -> Optional[PortNumber]:
@@ -175,6 +174,6 @@ class NetworkSchema(ConfigSchema):
     tls: TLSSchema = TLSSchema()
     proxy_protocol: Union[Literal[False], ProxyProtocolSchema] = False
     listen: List[ListenSchema] = [
-        ListenSchema({"interface": ["127.0.0.1"]}),
-        ListenSchema({"interface": ["::1"], "freebind": True}),
+        ListenSchema({"interface": "127.0.0.1"}),
+        ListenSchema({"interface": "::1", "freebind": True}),
     ]
diff --git a/manager/knot_resolver_manager/datamodel/templates/macros/network_macros.lua.j2 b/manager/knot_resolver_manager/datamodel/templates/macros/network_macros.lua.j2
index 933ecdfa6..ff78fbd80 100644
--- a/manager/knot_resolver_manager/datamodel/templates/macros/network_macros.lua.j2
+++ b/manager/knot_resolver_manager/datamodel/templates/macros/network_macros.lua.j2
@@ -44,20 +44,12 @@ net.{{ interface.if_name }},
 
 {% macro network_listen(listen) -%}
 {%- if listen.unix_socket -%}
-    {%- if listen.unix_socket is iterable-%}
-        {% for path in listen.unix_socket -%}
-            {{ net_listen_unix_socket(path, listen.kind, listen.freebind) }}
-        {% endfor -%}
-    {%- else -%}
-        {{ net_listen_unix_socket(listen.unix_socket, listen.kind, listen.freebind) }}
-    {%- endif -%}
+{% for path in listen.unix_socket %}
+{{ net_listen_unix_socket(path, listen.kind, listen.freebind) }}
+{% endfor %}
 {%- elif listen.interface -%}
-    {%- if listen.interface is iterable-%}
-        {% for interface in listen.interface -%}
-            {{ net_listen_interface(interface, listen.kind, listen.freebind, listen.port) }}
-        {% endfor -%}
-    {%- else -%}
-        {{ net_listen_interface(listen.interface, listen.kind, listen.freebind, listen.port) }}
-    {%- endif -%}
+{% for interface in listen.interface %}
+{{ net_listen_interface(interface, listen.kind, listen.freebind, listen.port) }}
+{% endfor %}
 {%- endif -%}
 {%- endmacro %}
\ No newline at end of file
diff --git a/manager/knot_resolver_manager/datamodel/types/__init__.py b/manager/knot_resolver_manager/datamodel/types/__init__.py
index 256092a1d..33d8c90d4 100644
--- a/manager/knot_resolver_manager/datamodel/types/__init__.py
+++ b/manager/knot_resolver_manager/datamodel/types/__init__.py
@@ -1,5 +1,6 @@
 from .enums import DNSRecordTypeEnum, PolicyActionEnum, PolicyFlagEnum
 from .files import AbsoluteDir, Dir, File, FilePath
+from .generic_types import ListOrItem
 from .types import (
     DomainName,
     IDPattern,
@@ -43,6 +44,7 @@ __all__ = [
     "IPv4Address",
     "IPv6Address",
     "IPv6Network96",
+    "ListOrItem",
     "Percent",
     "PortNumber",
     "SizeUnit",
diff --git a/manager/knot_resolver_manager/datamodel/types/generic_types.py b/manager/knot_resolver_manager/datamodel/types/generic_types.py
new file mode 100644
index 000000000..c5711f582
--- /dev/null
+++ b/manager/knot_resolver_manager/datamodel/types/generic_types.py
@@ -0,0 +1,34 @@
+from typing import Any, List, TypeVar, Union
+
+from knot_resolver_manager.utils.modeling import BaseGenericTypeWrapper
+
+T = TypeVar("T")
+
+
+class ListOrItem(BaseGenericTypeWrapper[Union[List[T], T]]):
+
+    _value_orig: Union[List[T], T]
+    _list: List[T]
+
+    def __init__(self, source_value: Any, object_path: str = "/") -> None:
+        super().__init__(source_value)
+        self._value_orig: Union[List[T], T] = source_value
+        self._list: List[T] = source_value if isinstance(source_value, list) else [source_value]
+
+    def __getitem__(self, index: Any) -> T:
+        return self._list[index]
+
+    def __int__(self) -> int:
+        raise ValueError(f"Can't convert '{type(self).__name__}' to an integer.")
+
+    def __str__(self) -> str:
+        return str(self._value_orig)
+
+    def to_std(self) -> List[T]:
+        return self._list
+
+    def __eq__(self, o: object) -> bool:
+        return isinstance(o, ListOrItem) and o._value_orig == self._value_orig
+
+    def serialize(self) -> Union[List[T], T]:
+        return self._value_orig
diff --git a/manager/knot_resolver_manager/datamodel/types/types.py b/manager/knot_resolver_manager/datamodel/types/types.py
index d0bc21a97..f38759c82 100644
--- a/manager/knot_resolver_manager/datamodel/types/types.py
+++ b/manager/knot_resolver_manager/datamodel/types/types.py
@@ -1,13 +1,10 @@
 import ipaddress
 import re
-from typing import Any, Dict, List, Optional, Type, TypeVar, Union
+from typing import Any, Dict, Optional, Type, Union
 
 from knot_resolver_manager.datamodel.types.base_types import IntRangeBase, PatternBase, StrBase, UnitBase
 from knot_resolver_manager.utils.modeling import BaseValueType
 
-_InnerType = TypeVar("_InnerType")
-ListOrSingle = List[_InnerType]
-
 
 class IntNonNegative(IntRangeBase):
     _min: int = 0
@@ -51,14 +48,14 @@ class SizeUnit(UnitBase):
         return self._value
 
     def mbytes(self) -> int:
-        return self._value // 1024 ** 2
+        return self._value // 1024**2
 
 
 class TimeUnit(UnitBase):
-    _units = {"us": 1, "ms": 10 ** 3, "s": 10 ** 6, "m": 60 * 10 ** 6, "h": 3600 * 10 ** 6, "d": 24 * 3600 * 10 ** 6}
+    _units = {"us": 1, "ms": 10**3, "s": 10**6, "m": 60 * 10**6, "h": 3600 * 10**6, "d": 24 * 3600 * 10**6}
 
     def seconds(self) -> int:
-        return self._value // 1000 ** 2
+        return self._value // 1000**2
 
     def millis(self) -> int:
         return self._value // 1000
diff --git a/manager/tests/unit/datamodel/templates/test_network_macros.py b/manager/tests/unit/datamodel/templates/test_network_macros.py
index 6d6637edf..ad193d982 100644
--- a/manager/tests/unit/datamodel/templates/test_network_macros.py
+++ b/manager/tests/unit/datamodel/templates/test_network_macros.py
@@ -8,8 +8,8 @@ def test_network_listen():
     tmpl = template_from_str(tmpl_str)
 
     soc = ListenSchema({"unix-socket": "/tmp/kresd-socket", "kind": "dot"})
-    assert tmpl.render(listen=soc) == "net.listen('/tmp/kresd-socket',nil,{kind='tls',freebind=false})"
-    soc_list = ListenSchema({"unix-socket": [soc.unix_socket, "/tmp/kresd-socket2"], "kind": "dot"})
+    assert tmpl.render(listen=soc) == "net.listen('/tmp/kresd-socket',nil,{kind='tls',freebind=false})\n"
+    soc_list = ListenSchema({"unix-socket": [soc.unix_socket.to_std()[0], "/tmp/kresd-socket2"], "kind": "dot"})
     assert (
         tmpl.render(listen=soc_list)
         == "net.listen('/tmp/kresd-socket',nil,{kind='tls',freebind=false})\n"
@@ -17,8 +17,8 @@ def test_network_listen():
     )
 
     ip = ListenSchema({"interface": "::1@55", "freebind": True})
-    assert tmpl.render(listen=ip) == "net.listen('::1',55,{kind='dns',freebind=true})"
-    ip_list = ListenSchema({"interface": [ip.interface, "127.0.0.1@5353"]})
+    assert tmpl.render(listen=ip) == "net.listen('::1',55,{kind='dns',freebind=true})\n"
+    ip_list = ListenSchema({"interface": [ip.interface.to_std()[0], "127.0.0.1@5353"]})
     assert (
         tmpl.render(listen=ip_list)
         == "net.listen('::1',55,{kind='dns',freebind=false})\n"
@@ -26,8 +26,8 @@ def test_network_listen():
     )
 
     intrfc = ListenSchema({"interface": "eth0", "kind": "doh2"})
-    assert tmpl.render(listen=intrfc) == "net.listen(net.eth0,443,{kind='doh2',freebind=false})"
-    intrfc_list = ListenSchema({"interface": [intrfc.interface, "lo"], "port": 5555, "kind": "doh2"})
+    assert tmpl.render(listen=intrfc) == "net.listen(net.eth0,443,{kind='doh2',freebind=false})\n"
+    intrfc_list = ListenSchema({"interface": [intrfc.interface.to_std()[0], "lo"], "port": 5555, "kind": "doh2"})
     assert (
         tmpl.render(listen=intrfc_list)
         == "net.listen(net.eth0,5555,{kind='doh2',freebind=false})\n"
diff --git a/manager/tests/unit/datamodel/test_config_schema.py b/manager/tests/unit/datamodel/test_config_schema.py
index 502334734..31703b967 100644
--- a/manager/tests/unit/datamodel/test_config_schema.py
+++ b/manager/tests/unit/datamodel/test_config_schema.py
@@ -49,6 +49,6 @@ def test_config_json_schema():
             try:
                 _ = json.dumps(obj)
             except BaseException as e:
-                raise Exception(f"failed to serialize '{path}'") from e
+                raise Exception(f"failed to serialize '{path}': {e}") from e
 
     recser(dct)
diff --git a/manager/tests/unit/datamodel/test_network_schema.py b/manager/tests/unit/datamodel/test_network_schema.py
index 1a398b50a..7b616f347 100644
--- a/manager/tests/unit/datamodel/test_network_schema.py
+++ b/manager/tests/unit/datamodel/test_network_schema.py
@@ -13,12 +13,12 @@ def test_listen_defaults():
 
     assert len(o.listen) == 2
     # {"ip-address": "127.0.0.1"}
-    assert o.listen[0].interface == InterfaceOptionalPort("127.0.0.1")
+    assert o.listen[0].interface.to_std() == [InterfaceOptionalPort("127.0.0.1")]
     assert o.listen[0].port == PortNumber(53)
     assert o.listen[0].kind == "dns"
     assert o.listen[0].freebind == False
     # {"ip-address": "::1", "freebind": True}
-    assert o.listen[1].interface == InterfaceOptionalPort("::1")
+    assert o.listen[1].interface.to_std() == [InterfaceOptionalPort("::1")]
     assert o.listen[1].port == PortNumber(53)
     assert o.listen[1].kind == "dns"
     assert o.listen[1].freebind == True
@@ -27,11 +27,11 @@ def test_listen_defaults():
 @pytest.mark.parametrize(
     "listen,port",
     [
-        ({"unix-socket": "/tmp/kresd-socket"}, None),
-        ({"interface": "::1"}, 53),
-        ({"interface": "::1", "kind": "dot"}, 853),
-        ({"interface": "::1", "kind": "doh-legacy"}, 443),
-        ({"interface": "::1", "kind": "doh2"}, 443),
+        ({"unix-socket": ["/tmp/kresd-socket"]}, None),
+        ({"interface": ["::1"]}, 53),
+        ({"interface": ["::1"], "kind": "dot"}, 853),
+        ({"interface": ["::1"], "kind": "doh-legacy"}, 443),
+        ({"interface": ["::1"], "kind": "doh2"}, 443),
     ],
 )
 def test_listen_port_defaults(listen: Dict[str, Any], port: Optional[int]):
@@ -64,8 +64,8 @@ def test_listen_valid(listen: Dict[str, Any]):
 @pytest.mark.parametrize(
     "listen",
     [
-        {"unit-socket": "/tmp/kresd-socket", "port": "53"},
-        {"interface": "::1", "unit-socket": "/tmp/kresd-socket"},
+        {"unix-socket": "/tmp/kresd-socket", "port": "53"},
+        {"interface": "::1", "unix-socket": "/tmp/kresd-socket"},
         {"interface": "::1@5353", "port": 5353},
         {"interface": ["127.0.0.1", "::1@5353"]},
         {"interface": ["127.0.0.1@5353", "::1@5353"], "port": 5353},
diff --git a/manager/tests/unit/datamodel/types/test_generic_types.py b/manager/tests/unit/datamodel/types/test_generic_types.py
new file mode 100644
index 000000000..7803ed005
--- /dev/null
+++ b/manager/tests/unit/datamodel/types/test_generic_types.py
@@ -0,0 +1,56 @@
+from typing import Any, List, Optional, Union
+
+import pytest
+from pytest import raises
+
+from knot_resolver_manager.datamodel.types import ListOrItem
+from knot_resolver_manager.utils.modeling import BaseSchema
+from knot_resolver_manager.utils.modeling.exceptions import DataValidationError
+from knot_resolver_manager.utils.modeling.types import get_generic_type_wrapper_argument
+
+
+@pytest.mark.parametrize("val", [str, int])
+def test_list_or_item_inner_type(val: Any):
+    assert get_generic_type_wrapper_argument(ListOrItem[val]) == Union[List[val], val]
+
+
+@pytest.mark.parametrize(
+    "typ,val",
+    [
+        (int, [1, 65_535, 5353, 5000]),
+        (int, 65_535),
+        (str, ["string1", "string2"]),
+        (str, "string1"),
+    ],
+)
+def test_list_or_item_valid(typ: Any, val: Any):
+    class ListOrItemSchema(BaseSchema):
+        test: ListOrItem[typ]
+
+    o = ListOrItemSchema({"test": val})
+    assert o.test.serialize() == val
+    assert o.test.to_std() == val if isinstance(val, list) else [val]
+
+    i = 0
+    for item in o.test:
+        assert item == val[i] if isinstance(val, list) else val
+        i += 1
+
+
+@pytest.mark.parametrize(
+    "typ,val",
+    [
+        (str, [True, False, True, False]),
+        (str, False),
+        (bool, [1, 65_535, 5353, 5000]),
+        (bool, 65_535),
+        (int, "string1"),
+        (int, ["string1", "string2"]),
+    ],
+)
+def test_list_or_item_invalid(typ: Any, val: Any):
+    class ListOrItemSchema(BaseSchema):
+        test: ListOrItem[typ]
+
+    with raises(DataValidationError):
+        ListOrItemSchema({"test": val})
-- 
GitLab