diff --git a/homeassistant/components/switch/mqtt.py b/homeassistant/components/switch/mqtt.py index 4442b186e3015ac3a4280106a1780c2deb67b20c..eb91f8d846ae12e60a34888f0a582f8aedc6bcc7 100644 --- a/homeassistant/components/switch/mqtt.py +++ b/homeassistant/components/switch/mqtt.py @@ -31,12 +31,16 @@ DEFAULT_PAYLOAD_ON = 'ON' DEFAULT_PAYLOAD_OFF = 'OFF' DEFAULT_OPTIMISTIC = False CONF_UNIQUE_ID = 'unique_id' +CONF_STATE_ON = "state_on" +CONF_STATE_OFF = "state_off" PLATFORM_SCHEMA = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend({ vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string, vol.Optional(CONF_ICON): cv.icon, vol.Optional(CONF_PAYLOAD_ON, default=DEFAULT_PAYLOAD_ON): cv.string, vol.Optional(CONF_PAYLOAD_OFF, default=DEFAULT_PAYLOAD_OFF): cv.string, + vol.Optional(CONF_STATE_ON): cv.string, + vol.Optional(CONF_STATE_OFF): cv.string, vol.Optional(CONF_UNIQUE_ID): cv.string, vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean, }).extend(mqtt.MQTT_AVAILABILITY_SCHEMA.schema) @@ -62,6 +66,8 @@ async def async_setup_platform(hass, config, async_add_devices, config.get(CONF_RETAIN), config.get(CONF_PAYLOAD_ON), config.get(CONF_PAYLOAD_OFF), + config.get(CONF_STATE_ON), + config.get(CONF_STATE_OFF), config.get(CONF_OPTIMISTIC), config.get(CONF_PAYLOAD_AVAILABLE), config.get(CONF_PAYLOAD_NOT_AVAILABLE), @@ -75,9 +81,10 @@ class MqttSwitch(MqttAvailability, SwitchDevice): def __init__(self, name, icon, state_topic, command_topic, availability_topic, - qos, retain, payload_on, payload_off, optimistic, - payload_available, payload_not_available, - unique_id: Optional[str], value_template): + qos, retain, payload_on, payload_off, state_on, + state_off, optimistic, payload_available, + payload_not_available, unique_id: Optional[str], + value_template): """Initialize the MQTT switch.""" super().__init__(availability_topic, qos, payload_available, payload_not_available) @@ -90,6 +97,8 @@ class MqttSwitch(MqttAvailability, SwitchDevice): self._retain = retain self._payload_on = payload_on self._payload_off = payload_off + self._state_on = state_on if state_on else self._payload_on + self._state_off = state_off if state_off else self._payload_off self._optimistic = optimistic self._template = value_template self._unique_id = unique_id @@ -104,9 +113,9 @@ class MqttSwitch(MqttAvailability, SwitchDevice): if self._template is not None: payload = self._template.async_render_with_possible_json_value( payload) - if payload == self._payload_on: + if payload == self._state_on: self._state = True - elif payload == self._payload_off: + elif payload == self._state_off: self._state = False self.async_schedule_update_ha_state() diff --git a/tests/components/switch/test_mqtt.py b/tests/components/switch/test_mqtt.py index 31f9a729c53d98647c696ece3f2dc37941e37c48..7cd5a42b4a3e51de298b99a5254c36bbc140ba6c 100644 --- a/tests/components/switch/test_mqtt.py +++ b/tests/components/switch/test_mqtt.py @@ -249,6 +249,37 @@ class TestSwitchMQTT(unittest.TestCase): state = self.hass.states.get('switch.test') self.assertEqual(STATE_ON, state.state) + def test_custom_state_payload(self): + """Test the state payload.""" + assert setup_component(self.hass, switch.DOMAIN, { + switch.DOMAIN: { + 'platform': 'mqtt', + 'name': 'test', + 'state_topic': 'state-topic', + 'command_topic': 'command-topic', + 'payload_on': 1, + 'payload_off': 0, + 'state_on': "HIGH", + 'state_off': "LOW", + } + }) + + state = self.hass.states.get('switch.test') + self.assertEqual(STATE_OFF, state.state) + self.assertFalse(state.attributes.get(ATTR_ASSUMED_STATE)) + + fire_mqtt_message(self.hass, 'state-topic', 'HIGH') + self.hass.block_till_done() + + state = self.hass.states.get('switch.test') + self.assertEqual(STATE_ON, state.state) + + fire_mqtt_message(self.hass, 'state-topic', 'LOW') + self.hass.block_till_done() + + state = self.hass.states.get('switch.test') + self.assertEqual(STATE_OFF, state.state) + def test_unique_id(self): """Test unique id option only creates one switch per unique_id.""" assert setup_component(self.hass, switch.DOMAIN, {