diff --git a/src/zeroconf/_core.py b/src/zeroconf/_core.py index 4b29717a..cb488b4e 100644 --- a/src/zeroconf/_core.py +++ b/src/zeroconf/_core.py @@ -39,7 +39,11 @@ from ._protocol.outgoing import DNSOutgoing from ._services import ServiceListener from ._services.browser import ServiceBrowser -from ._services.info import ServiceInfo, instance_name_from_service_info +from ._services.info import ( + AsyncServiceInfo, + ServiceInfo, + instance_name_from_service_info, +) from ._services.registry import ServiceRegistry from ._transport import _WrappedTransport from ._updates import RecordUpdateListener @@ -261,7 +265,13 @@ def get_service_info( ) -> Optional[ServiceInfo]: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, - which defaults to 3 seconds.""" + which defaults to 3 seconds. + + :param type_: fully qualified service type name + :param name: the name of the service + :param timeout: milliseconds to wait for a response + :param question_type: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) + """ info = ServiceInfo(type_, name) if info.request(self, timeout, question_type): return info @@ -360,6 +370,23 @@ async def async_update_service(self, info: ServiceInfo) -> Awaitable: self.registry.async_update(info) return asyncio.ensure_future(self._async_broadcast_service(info, _REGISTER_TIME, None)) + async def async_get_service_info( + self, type_: str, name: str, timeout: int = 3000, question_type: Optional[DNSQuestionType] = None + ) -> Optional[AsyncServiceInfo]: + """Returns network's service information for a particular + name and type, or None if no service matches by the timeout, + which defaults to 3 seconds. + + :param type_: fully qualified service type name + :param name: the name of the service + :param timeout: milliseconds to wait for a response + :param question_type: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) + """ + info = AsyncServiceInfo(type_, name) + if await info.async_request(self, timeout, question_type): + return info + return None + async def _async_broadcast_service( self, info: ServiceInfo, diff --git a/src/zeroconf/_services/info.py b/src/zeroconf/_services/info.py index 48ad1140..6d68de83 100644 --- a/src/zeroconf/_services/info.py +++ b/src/zeroconf/_services/info.py @@ -770,6 +770,12 @@ def request( While it is not expected during normal operation, this function may raise EventLoopBlocked if the underlying call to `async_request` cannot be completed. + + :param zc: Zeroconf instance + :param timeout: time in milliseconds to wait for a response + :param question_type: question type to ask + :param addr: address to send the request to + :param port: port to send the request to """ assert zc.loop is not None and zc.loop.is_running() if zc.loop == get_running_loop(): @@ -803,6 +809,12 @@ async def async_request( mDNS multicast address and port. This is useful for directing requests to a specific host that may be able to respond across subnets. + + :param zc: Zeroconf instance + :param timeout: time in milliseconds to wait for a response + :param question_type: question type to ask + :param addr: address to send the request to + :param port: port to send the request to """ if not zc.started: await zc.async_wait_for_start() @@ -924,3 +936,7 @@ def __repr__(self) -> str: ) ), ) + + +class AsyncServiceInfo(ServiceInfo): + """An async version of ServiceInfo.""" diff --git a/src/zeroconf/_utils/ipaddress.py b/src/zeroconf/_utils/ipaddress.py index b0b551ff..ba137955 100644 --- a/src/zeroconf/_utils/ipaddress.py +++ b/src/zeroconf/_utils/ipaddress.py @@ -104,7 +104,7 @@ def _cached_ip_addresses(address: Union[str, bytes, int]) -> Optional[Union[IPv4 def get_ip_address_object_from_record(record: DNSAddress) -> Optional[Union[IPv4Address, IPv6Address]]: """Get the IP address object from the record.""" - if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id is not None: + if IPADDRESS_SUPPORTS_SCOPE_ID and record.type == _TYPE_AAAA and record.scope_id: return ip_bytes_and_scope_to_address(record.address, record.scope_id) return cached_ip_addresses_wrapper(record.address) diff --git a/src/zeroconf/asyncio.py b/src/zeroconf/asyncio.py index cfe3693e..b2daeb10 100644 --- a/src/zeroconf/asyncio.py +++ b/src/zeroconf/asyncio.py @@ -28,7 +28,7 @@ from ._dns import DNSQuestionType from ._services import ServiceListener from ._services.browser import _ServiceBrowserBase -from ._services.info import ServiceInfo +from ._services.info import AsyncServiceInfo, ServiceInfo from ._services.types import ZeroconfServiceTypes from ._utils.net import InterfaceChoice, InterfacesType, IPVersion from .const import _BROWSER_TIME, _MDNS_PORT, _SERVICE_TYPE_ENUMERATION_NAME @@ -41,10 +41,6 @@ ] -class AsyncServiceInfo(ServiceInfo): - """An async version of ServiceInfo.""" - - class AsyncServiceBrowser(_ServiceBrowserBase): """Used to browse for a service for specific type(s). @@ -239,11 +235,14 @@ async def async_get_service_info( ) -> Optional[AsyncServiceInfo]: """Returns network's service information for a particular name and type, or None if no service matches by the timeout, - which defaults to 3 seconds.""" - info = AsyncServiceInfo(type_, name) - if await info.async_request(self.zeroconf, timeout, question_type): - return info - return None + which defaults to 3 seconds. + + :param type_: fully qualified service type name + :param name: the name of the service + :param timeout: milliseconds to wait for a response + :param question_type: The type of questions to ask (DNSQuestionType.QM or DNSQuestionType.QU) + """ + return await self.zeroconf.async_get_service_info(type_, name, timeout, question_type) async def async_add_service_listener(self, type_: str, listener: ServiceListener) -> None: """Adds a listener for a particular service type. This object diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index 63255158..382b1a3d 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -680,6 +680,10 @@ async def test_service_info_async_request() -> None: assert aiosinfo is not None assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] + aiosinfo = await aiozc.zeroconf.async_get_service_info(type_, registration_name) + assert aiosinfo is not None + assert aiosinfo.addresses == [socket.inet_aton("10.0.1.3")] + aiosinfos = await asyncio.gather( aiozc.async_get_service_info(type_, registration_name), aiozc.async_get_service_info(type_, registration_name2), diff --git a/tests/utils/test_ipaddress.py b/tests/utils/test_ipaddress.py index 3ec1a9a7..ff491e4f 100644 --- a/tests/utils/test_ipaddress.py +++ b/tests/utils/test_ipaddress.py @@ -2,6 +2,10 @@ """Unit tests for zeroconf._utils.ipaddress.""" +import pytest + +from zeroconf import const +from zeroconf._dns import DNSAddress from zeroconf._utils import ipaddress @@ -34,3 +38,34 @@ def test_cached_ip_addresses_wrapper(): assert ipv6 is not None assert ipv6.is_link_local is False assert ipv6.is_unspecified is True + + +@pytest.mark.skipif(not ipaddress.IPADDRESS_SUPPORTS_SCOPE_ID, reason='scope_id is not supported') +def test_get_ip_address_object_from_record(): + """Test the get_ip_address_object_from_record.""" + # not link local + packed = b'&\x06(\x00\x02 \x00\x01\x02H\x18\x93%\xc8\x19F' + record = DNSAddress( + 'domain.local', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed, scope_id=3 + ) + assert record.scope_id == 3 + assert ipaddress.get_ip_address_object_from_record(record) == ipaddress.IPv6Address( + '2606:2800:220:1:248:1893:25c8:1946' + ) + + # link local + packed = b'\xfe\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01' + record = DNSAddress( + 'domain.local', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed, scope_id=3 + ) + assert record.scope_id == 3 + assert ipaddress.get_ip_address_object_from_record(record) == ipaddress.IPv6Address('fe80::1%3') + record = DNSAddress('domain.local', const._TYPE_AAAA, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed) + assert record.scope_id is None + assert ipaddress.get_ip_address_object_from_record(record) == ipaddress.IPv6Address('fe80::1') + record = DNSAddress( + 'domain.local', const._TYPE_A, const._CLASS_IN | const._CLASS_UNIQUE, 1, packed, scope_id=0 + ) + assert record.scope_id == 0 + # Ensure scope_id of 0 is not appended to the address + assert ipaddress.get_ip_address_object_from_record(record) == ipaddress.IPv6Address('fe80::1')