diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py
index fc3fff475148f7af2da003b14fe3ba3039847f19..29aea64c9c53dfe7b3aa8b1ddf4412dc23c4c3e3 100644
--- a/homeassistant/components/automation/state.py
+++ b/homeassistant/components/automation/state.py
@@ -6,10 +6,14 @@ from typing import Dict
 import voluptuous as vol
 
 from homeassistant import exceptions
-from homeassistant.const import CONF_FOR, CONF_PLATFORM, MATCH_ALL
+from homeassistant.const import CONF_FOR, CONF_PLATFORM, EVENT_STATE_CHANGED, MATCH_ALL
 from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
 from homeassistant.helpers import config_validation as cv, template
-from homeassistant.helpers.event import async_track_same_state, async_track_state_change
+from homeassistant.helpers.event import (
+    Event,
+    async_track_same_state,
+    process_state_match,
+)
 
 # mypy: allow-incomplete-defs, allow-untyped-calls, allow-untyped-defs
 # mypy: no-check-untyped-defs
@@ -56,10 +60,30 @@ async def async_attach_trigger(
     match_all = from_state == MATCH_ALL and to_state == MATCH_ALL
     unsub_track_same = {}
     period: Dict[str, timedelta] = {}
+    match_from_state = process_state_match(from_state)
+    match_to_state = process_state_match(to_state)
 
     @callback
-    def state_automation_listener(entity, from_s, to_s):
+    def state_automation_listener(event: Event):
         """Listen for state changes and calls action."""
+        entity: str = event.data["entity_id"]
+        if entity not in entity_id:
+            return
+
+        from_s = event.data.get("old_state")
+        to_s = event.data.get("new_state")
+
+        if (
+            (from_s is not None and not match_from_state(from_s.state))
+            or (to_s is not None and not match_to_state(to_s.state))
+            or (
+                not match_all
+                and from_s is not None
+                and to_s is not None
+                and from_s.state == to_s.state
+            )
+        ):
+            return
 
         @callback
         def call_action():
@@ -75,7 +99,7 @@ async def async_attach_trigger(
                             "for": time_delta if not time_delta else period[entity],
                         }
                     },
-                    context=to_s.context,
+                    context=event.context,
                 )
             )
 
@@ -120,17 +144,16 @@ async def async_attach_trigger(
             )
             return
 
+        def _check_same_state(_, _2, new_st):
+            if new_st is None:
+                return False
+            return new_st.state == to_s.state
+
         unsub_track_same[entity] = async_track_same_state(
-            hass,
-            period[entity],
-            call_action,
-            lambda _, _2, to_state: to_state.state == to_s.state,
-            entity_ids=entity,
+            hass, period[entity], call_action, _check_same_state, entity_ids=entity,
         )
 
-    unsub = async_track_state_change(
-        hass, entity_id, state_automation_listener, from_state, to_state
-    )
+    unsub = hass.bus.async_listen(EVENT_STATE_CHANGED, state_automation_listener)
 
     @callback
     def async_remove():
diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py
index 74faca6a1d2f357396c451fc05c6530cc8fdd1d4..8a4b4bc2b7612a7524e3c58b056d681d3d6c36e3 100644
--- a/homeassistant/helpers/event.py
+++ b/homeassistant/helpers/event.py
@@ -67,8 +67,8 @@ def async_track_state_change(
 
     Must be run within the event loop.
     """
-    match_from_state = _process_state_match(from_state)
-    match_to_state = _process_state_match(to_state)
+    match_from_state = process_state_match(from_state)
+    match_to_state = process_state_match(to_state)
 
     # Ensure it is a lowercase list with entity ids we want to match on
     if entity_ids == MATCH_ALL:
@@ -473,7 +473,7 @@ def async_track_time_change(
 track_time_change = threaded_listener_factory(async_track_time_change)
 
 
-def _process_state_match(
+def process_state_match(
     parameter: Union[None, str, Iterable[str]]
 ) -> Callable[[str], bool]:
     """Convert parameter to function that matches input against parameter."""
diff --git a/tests/components/automation/test_state.py b/tests/components/automation/test_state.py
index 9d4fa9a1100172f986cbea1acca2182a305b01a6..173af8158a476405df5370f64789e7622937e53e 100644
--- a/tests/components/automation/test_state.py
+++ b/tests/components/automation/test_state.py
@@ -519,6 +519,28 @@ async def test_if_fires_on_entity_change_with_for(hass, calls):
     assert 1 == len(calls)
 
 
+async def test_if_fires_on_entity_removal(hass, calls):
+    """Test for firing on entity removal, when new_state is None."""
+    hass.states.async_set("test.entity", "hello")
+    await hass.async_block_till_done()
+
+    assert await async_setup_component(
+        hass,
+        automation.DOMAIN,
+        {
+            automation.DOMAIN: {
+                "trigger": {"platform": "state", "entity_id": "test.entity"},
+                "action": {"service": "test.automation"},
+            }
+        },
+    )
+    await hass.async_block_till_done()
+
+    assert hass.states.async_remove("test.entity")
+    await hass.async_block_till_done()
+    assert 1 == len(calls)
+
+
 async def test_if_fires_on_for_condition(hass, calls):
     """Test for firing if condition is on."""
     point1 = dt_util.utcnow()