From a92c52e65b4852893b5faaf2776d7a3c9f55366e Mon Sep 17 00:00:00 2001 From: Sam Wright <s.wright.aus@gmail.com> Date: Sat, 22 Feb 2025 04:14:52 +1100 Subject: [PATCH] Unifi zone based rules (#138974) * Add support for controlling zone based firewall policies * Add test * Address Kane's comments + add real repo * Add firewall icon --- .../components/unifi/hub/entity_loader.py | 1 + homeassistant/components/unifi/icons.json | 3 + homeassistant/components/unifi/switch.py | 35 +++++++ tests/components/unifi/conftest.py | 10 ++ tests/components/unifi/test_switch.py | 97 +++++++++++++++++++ 5 files changed, 146 insertions(+) diff --git a/homeassistant/components/unifi/hub/entity_loader.py b/homeassistant/components/unifi/hub/entity_loader.py index 64403152b0c..84948a92e98 100644 --- a/homeassistant/components/unifi/hub/entity_loader.py +++ b/homeassistant/components/unifi/hub/entity_loader.py @@ -46,6 +46,7 @@ class UnifiEntityLoader: hub.api.port_forwarding.update, hub.api.sites.update, hub.api.system_information.update, + hub.api.firewall_policies.update, hub.api.traffic_rules.update, hub.api.traffic_routes.update, hub.api.wlans.update, diff --git a/homeassistant/components/unifi/icons.json b/homeassistant/components/unifi/icons.json index 6874bb5ae03..616d7cb185f 100644 --- a/homeassistant/components/unifi/icons.json +++ b/homeassistant/components/unifi/icons.json @@ -55,6 +55,9 @@ "off": "mdi:network-off" } }, + "firewall_policy_control": { + "default": "mdi:security-network" + }, "port_forward_control": { "default": "mdi:upload-network" }, diff --git a/homeassistant/components/unifi/switch.py b/homeassistant/components/unifi/switch.py index de0e8d3f412..282d0c9ae93 100644 --- a/homeassistant/components/unifi/switch.py +++ b/homeassistant/components/unifi/switch.py @@ -4,6 +4,7 @@ Support for controlling power supply of clients which are powered over Ethernet Support for controlling network access of clients selected in option flow. Support for controlling deep packet inspection (DPI) restriction groups. Support for controlling WLAN availability. +Support for controlling zone based traffic rules. """ from __future__ import annotations @@ -17,6 +18,7 @@ import aiounifi from aiounifi.interfaces.api_handlers import ItemEvent from aiounifi.interfaces.clients import Clients from aiounifi.interfaces.dpi_restriction_groups import DPIRestrictionGroups +from aiounifi.interfaces.firewall_policies import FirewallPolicies from aiounifi.interfaces.outlets import Outlets from aiounifi.interfaces.port_forwarding import PortForwarding from aiounifi.interfaces.ports import Ports @@ -29,6 +31,7 @@ from aiounifi.models.device import DeviceSetOutletRelayRequest from aiounifi.models.dpi_restriction_app import DPIRestrictionAppEnableRequest from aiounifi.models.dpi_restriction_group import DPIRestrictionGroup from aiounifi.models.event import Event, EventKey +from aiounifi.models.firewall_policy import FirewallPolicy, FirewallPolicyUpdateRequest from aiounifi.models.outlet import Outlet from aiounifi.models.port import Port from aiounifi.models.port_forward import PortForward, PortForwardEnableRequest @@ -129,6 +132,24 @@ async def async_dpi_group_control_fn(hub: UnifiHub, obj_id: str, target: bool) - ) +async def async_firewall_policy_control_fn( + hub: UnifiHub, obj_id: str, target: bool +) -> None: + """Control firewall policy state.""" + policy = hub.api.firewall_policies[obj_id].raw + policy["enabled"] = target + await hub.api.request(FirewallPolicyUpdateRequest.create(policy)) + # Update the policies so the UI is updated appropriately + await hub.api.firewall_policies.update() + + +@callback +def async_firewall_policy_supported_fn(hub: UnifiHub, obj_id: str) -> bool: + """Check if firewall policy is able to be controlled. Predefined policies are unable to be turned off.""" + policy = hub.api.firewall_policies[obj_id] + return not policy.predefined + + @callback def async_outlet_switching_supported_fn(hub: UnifiHub, obj_id: str) -> bool: """Determine if an outlet supports switching.""" @@ -236,6 +257,20 @@ ENTITY_DESCRIPTIONS: tuple[UnifiSwitchEntityDescription, ...] = ( supported_fn=lambda hub, obj_id: bool(hub.api.dpi_groups[obj_id].dpiapp_ids), unique_id_fn=lambda hub, obj_id: obj_id, ), + UnifiSwitchEntityDescription[FirewallPolicies, FirewallPolicy]( + key="Firewall policy control", + translation_key="firewall_policy_control", + device_class=SwitchDeviceClass.SWITCH, + entity_category=EntityCategory.CONFIG, + api_handler_fn=lambda api: api.firewall_policies, + control_fn=async_firewall_policy_control_fn, + device_info_fn=async_unifi_network_device_info_fn, + is_on_fn=lambda hub, firewall_policy: firewall_policy.enabled, + name_fn=lambda firewall_policy: firewall_policy.name, + object_fn=lambda api, obj_id: api.firewall_policies[obj_id], + unique_id_fn=lambda hub, obj_id: f"firewall_policy-{obj_id}", + supported_fn=async_firewall_policy_supported_fn, + ), UnifiSwitchEntityDescription[Outlets, Outlet]( key="Outlet control", device_class=SwitchDeviceClass.OUTLET, diff --git a/tests/components/unifi/conftest.py b/tests/components/unifi/conftest.py index ec7a0595731..4075aa0ad59 100644 --- a/tests/components/unifi/conftest.py +++ b/tests/components/unifi/conftest.py @@ -172,6 +172,7 @@ def fixture_request( device_payload: list[dict[str, Any]], dpi_app_payload: list[dict[str, Any]], dpi_group_payload: list[dict[str, Any]], + firewall_policy_payload: list[dict[str, Any]], port_forward_payload: list[dict[str, Any]], traffic_rule_payload: list[dict[str, Any]], traffic_route_payload: list[dict[str, Any]], @@ -211,6 +212,9 @@ def fixture_request( mock_get_request(f"/api/s/{site_id}/stat/device", device_payload) mock_get_request(f"/api/s/{site_id}/rest/dpiapp", dpi_app_payload) mock_get_request(f"/api/s/{site_id}/rest/dpigroup", dpi_group_payload) + mock_get_request( + f"/v2/api/site/{site_id}/firewall-policies", firewall_policy_payload + ) mock_get_request(f"/api/s/{site_id}/rest/portforward", port_forward_payload) mock_get_request(f"/api/s/{site_id}/stat/sysinfo", system_information_payload) mock_get_request(f"/api/s/{site_id}/rest/wlanconf", wlan_payload) @@ -253,6 +257,12 @@ def fixture_dpi_group_data() -> list[dict[str, Any]]: return [] +@pytest.fixture(name="firewall_policy_payload") +def firewall_policy_payload_data() -> list[dict[str, Any]]: + """Firewall policy data.""" + return [] + + @pytest.fixture(name="port_forward_payload") def fixture_port_forward_data() -> list[dict[str, Any]]: """Port forward data.""" diff --git a/tests/components/unifi/test_switch.py b/tests/components/unifi/test_switch.py index e4765d1181e..c8ee786895c 100644 --- a/tests/components/unifi/test_switch.py +++ b/tests/components/unifi/test_switch.py @@ -827,6 +827,45 @@ TRAFFIC_ROUTE = { ], } +FIREWALL_POLICY = { + "_id": "678ceb9fe3849d293243405c", + "action": "ALLOW", + "connection_state_type": "ALL", + "connection_states": [], + "create_allow_respond": True, + "description": "", + "destination": { + "match_opposite_ports": False, + "matching_target": "ANY", + "port_matching_type": "ANY", + "zone_id": "678ccc26e3849d2932432e26", + }, + "enabled": True, + "icmp_typename": "ANY", + "icmp_v6_typename": "ANY", + "index": 10000, + "ip_version": "BOTH", + "logging": False, + "match_ip_sec": False, + "match_opposite_protocol": False, + "name": "Allow internal to IoT", + "predefined": False, + "protocol": "all", + "schedule": { + "mode": "EVERY_DAY", + "repeat_on_days": [], + "time_all_day": False, + "time_range_end": "12:00", + "time_range_start": "09:00", + }, + "source": { + "match_opposite_ports": False, + "matching_target": "ANY", + "port_matching_type": "ANY", + "zone_id": "678c63bc2d97692f08adcdfa", + }, +} + @pytest.mark.parametrize( "config_entry_options", [{CONF_BLOCK_CLIENT: [BLOCKED["mac"]]}] @@ -1226,6 +1265,62 @@ async def test_traffic_routes( assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call +@pytest.mark.parametrize(("firewall_policy_payload"), [([FIREWALL_POLICY])]) +async def test_firewall_policies( + hass: HomeAssistant, + aioclient_mock: AiohttpClientMocker, + config_entry_setup: MockConfigEntry, + firewall_policy_payload: list[dict[str, Any]], +) -> None: + """Test control of UniFi firewall policies.""" + assert len(hass.states.async_entity_ids(SWITCH_DOMAIN)) == 1 + + # Validate state object + assert ( + hass.states.get("switch.unifi_network_allow_internal_to_iot").state == STATE_ON + ) + + firewall_policy = deepcopy(firewall_policy_payload[0]) + + # Disable firewall policy + aioclient_mock.put( + f"https://{config_entry_setup.data[CONF_HOST]}:1234" + f"/v2/api/site/{config_entry_setup.data[CONF_SITE_ID]}" + f"/firewall-policies/{firewall_policy['_id']}", + ) + + call_count = aioclient_mock.call_count + + await hass.services.async_call( + SWITCH_DOMAIN, + "turn_off", + {"entity_id": "switch.unifi_network_allow_internal_to_iot"}, + blocking=True, + ) + # Updating the value for firewall policies will make another call to retrieve the values + assert aioclient_mock.call_count == call_count + 2 + expected_disable_call = deepcopy(firewall_policy) + expected_disable_call["enabled"] = False + + assert aioclient_mock.mock_calls[call_count][2] == expected_disable_call + + call_count = aioclient_mock.call_count + + # Enable firewall policy + await hass.services.async_call( + SWITCH_DOMAIN, + "turn_on", + {"entity_id": "switch.unifi_network_allow_internal_to_iot"}, + blocking=True, + ) + + expected_enable_call = deepcopy(firewall_policy) + expected_enable_call["enabled"] = True + + assert aioclient_mock.call_count == call_count + 2 + assert aioclient_mock.mock_calls[call_count][2] == expected_enable_call + + @pytest.mark.parametrize( ("device_payload", "entity_id", "outlet_index", "expected_switches"), [ @@ -1677,6 +1772,7 @@ async def test_updating_unique_id( @pytest.mark.parametrize("dpi_group_payload", [DPI_GROUPS]) @pytest.mark.parametrize("port_forward_payload", [[PORT_FORWARD_PLEX]]) @pytest.mark.parametrize(("traffic_rule_payload"), [([TRAFFIC_RULE])]) +@pytest.mark.parametrize("firewall_policy_payload", [[FIREWALL_POLICY]]) @pytest.mark.parametrize("wlan_payload", [[WLAN]]) @pytest.mark.usefixtures("config_entry_setup") @pytest.mark.usefixtures("entity_registry_enabled_by_default") @@ -1691,6 +1787,7 @@ async def test_hub_state_change( "switch.block_media_streaming", "switch.unifi_network_plex", "switch.unifi_network_test_traffic_rule", + "switch.unifi_network_allow_internal_to_iot", "switch.ssid_1", ) for entity_id in entity_ids: -- GitLab