Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions src/zeroconf/_utils/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
import socket
import struct
import sys
from collections.abc import Sequence
import warnings
from collections.abc import Iterable, Sequence
from typing import Any, Union, cast

import ifaddr
Expand Down Expand Up @@ -73,19 +74,39 @@ def _encode_address(address: str) -> bytes:
return socket.inet_pton(address_family, address)


def get_all_addresses() -> list[str]:
return list({addr.ip for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv4}) # type: ignore[misc]
def get_all_addresses_ipv4(adapters: Iterable[ifaddr.Adapter]) -> list[str]:
return list({addr.ip for iface in adapters for addr in iface.ips if addr.is_IPv4}) # type: ignore[misc]


def get_all_addresses_v6() -> list[tuple[tuple[str, int, int], int]]:
def get_all_addresses_ipv6(adapters: Iterable[ifaddr.Adapter]) -> list[tuple[tuple[str, int, int], int]]:
# IPv6 multicast uses positive indexes for interfaces
# TODO: What about multi-address interfaces?
return list(
{(addr.ip, iface.index) for iface in ifaddr.get_adapters() for addr in iface.ips if addr.is_IPv6} # type: ignore[misc]
{(addr.ip, iface.index) for iface in adapters for addr in iface.ips if addr.is_IPv6} # type: ignore[misc]
)


def get_all_addresses() -> list[str]:
warnings.warn(
"get_all_addresses is deprecated, and will be removed in a future version. Use ifaddr"
"directly instead to get a list of adapters.",
DeprecationWarning,
stacklevel=2,
)
return get_all_addresses_ipv4(ifaddr.get_adapters())


def get_all_addresses_v6() -> list[tuple[tuple[str, int, int], int]]:
warnings.warn(
"get_all_addresses_v6 is deprecated, and will be removed in a future version. Use ifaddr"
"directly instead to get a list of adapters.",
DeprecationWarning,
stacklevel=2,
)
return get_all_addresses_ipv6(ifaddr.get_adapters())


def ip6_to_address_and_index(adapters: list[ifaddr.Adapter], ip: str) -> tuple[tuple[str, int, int], int]:
def ip6_to_address_and_index(adapters: Iterable[ifaddr.Adapter], ip: str) -> tuple[tuple[str, int, int], int]:
if "%" in ip:
ip = ip[: ip.index("%")] # Strip scope_id.
ipaddr = ipaddress.ip_address(ip)
Expand All @@ -102,7 +123,7 @@ def ip6_to_address_and_index(adapters: list[ifaddr.Adapter], ip: str) -> tuple[t
raise RuntimeError(f"No adapter found for IP address {ip}")


def interface_index_to_ip6_address(adapters: list[ifaddr.Adapter], index: int) -> tuple[str, int, int]:
def interface_index_to_ip6_address(adapters: Iterable[ifaddr.Adapter], index: int) -> tuple[str, int, int]:
for adapter in adapters:
if adapter.index == index:
for adapter_ip in adapter.ips:
Expand Down Expand Up @@ -152,10 +173,11 @@ def normalize_interface_choice(
if ip_version != IPVersion.V6Only:
result.append("0.0.0.0")
elif choice is InterfaceChoice.All:
adapters = ifaddr.get_adapters()
if ip_version != IPVersion.V4Only:
result.extend(get_all_addresses_v6())
result.extend(get_all_addresses_ipv6(adapters))
if ip_version != IPVersion.V6Only:
result.extend(get_all_addresses())
result.extend(get_all_addresses_ipv4(adapters))
if not result:
raise RuntimeError(
f"No interfaces to listen on, check that any interfaces have IP version {ip_version}"
Expand Down
40 changes: 38 additions & 2 deletions tests/utils/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import socket
import sys
import unittest
import warnings
from unittest.mock import MagicMock, Mock, patch

import ifaddr
import pytest

import zeroconf as r
from zeroconf import get_all_addresses, get_all_addresses_v6
from zeroconf._utils import net as netutils


Expand All @@ -35,6 +37,40 @@ def _generate_mock_adapters():
return [mock_eth0, mock_lo0, mock_eth1, mock_vtun0]


def test_get_all_addresses() -> None:
"""Test public get_all_addresses API."""
with (
patch(
"zeroconf._utils.net.ifaddr.get_adapters",
return_value=_generate_mock_adapters(),
),
warnings.catch_warnings(record=True) as warned,
):
addresses = get_all_addresses()
assert isinstance(addresses, list)
assert len(addresses) == 3
assert len(warned) == 1
first_warning = warned[0]
assert "get_all_addresses is deprecated" in str(first_warning.message)


def test_get_all_addresses_v6() -> None:
"""Test public get_all_addresses_v6 API."""
with (
patch(
"zeroconf._utils.net.ifaddr.get_adapters",
return_value=_generate_mock_adapters(),
),
warnings.catch_warnings(record=True) as warned,
):
addresses = get_all_addresses_v6()
assert isinstance(addresses, list)
assert len(addresses) == 1
assert len(warned) == 1
first_warning = warned[0]
assert "get_all_addresses_v6 is deprecated" in str(first_warning.message)


def test_ip6_to_address_and_index():
"""Test we can extract from mocked adapters."""
adapters = _generate_mock_adapters()
Expand Down Expand Up @@ -84,8 +120,8 @@ def test_ip6_addresses_to_indexes():
def test_normalize_interface_choice_errors():
"""Test we generate exception on invalid input."""
with (
patch("zeroconf._utils.net.get_all_addresses", return_value=[]),
patch("zeroconf._utils.net.get_all_addresses_v6", return_value=[]),
patch("zeroconf._utils.net.get_all_addresses_ipv4", return_value=[]),
patch("zeroconf._utils.net.get_all_addresses_ipv6", return_value=[]),
pytest.raises(RuntimeError),
):
netutils.normalize_interface_choice(r.InterfaceChoice.All)
Expand Down
Loading