diff --git a/homeassistant/components/switch/mqtt.py b/homeassistant/components/switch/mqtt.py
index 4442b186e3015ac3a4280106a1780c2deb67b20c..eb91f8d846ae12e60a34888f0a582f8aedc6bcc7 100644
--- a/homeassistant/components/switch/mqtt.py
+++ b/homeassistant/components/switch/mqtt.py
@@ -31,12 +31,16 @@ DEFAULT_PAYLOAD_ON = 'ON'
 DEFAULT_PAYLOAD_OFF = 'OFF'
 DEFAULT_OPTIMISTIC = False
 CONF_UNIQUE_ID = 'unique_id'
+CONF_STATE_ON = "state_on"
+CONF_STATE_OFF = "state_off"
 
 PLATFORM_SCHEMA = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend({
     vol.Optional(CONF_NAME, default=DEFAULT_NAME): cv.string,
     vol.Optional(CONF_ICON): cv.icon,
     vol.Optional(CONF_PAYLOAD_ON, default=DEFAULT_PAYLOAD_ON): cv.string,
     vol.Optional(CONF_PAYLOAD_OFF, default=DEFAULT_PAYLOAD_OFF): cv.string,
+    vol.Optional(CONF_STATE_ON): cv.string,
+    vol.Optional(CONF_STATE_OFF): cv.string,
     vol.Optional(CONF_UNIQUE_ID): cv.string,
     vol.Optional(CONF_OPTIMISTIC, default=DEFAULT_OPTIMISTIC): cv.boolean,
 }).extend(mqtt.MQTT_AVAILABILITY_SCHEMA.schema)
@@ -62,6 +66,8 @@ async def async_setup_platform(hass, config, async_add_devices,
         config.get(CONF_RETAIN),
         config.get(CONF_PAYLOAD_ON),
         config.get(CONF_PAYLOAD_OFF),
+        config.get(CONF_STATE_ON),
+        config.get(CONF_STATE_OFF),
         config.get(CONF_OPTIMISTIC),
         config.get(CONF_PAYLOAD_AVAILABLE),
         config.get(CONF_PAYLOAD_NOT_AVAILABLE),
@@ -75,9 +81,10 @@ class MqttSwitch(MqttAvailability, SwitchDevice):
 
     def __init__(self, name, icon,
                  state_topic, command_topic, availability_topic,
-                 qos, retain, payload_on, payload_off, optimistic,
-                 payload_available, payload_not_available,
-                 unique_id: Optional[str], value_template):
+                 qos, retain, payload_on, payload_off, state_on,
+                 state_off, optimistic, payload_available,
+                 payload_not_available, unique_id: Optional[str],
+                 value_template):
         """Initialize the MQTT switch."""
         super().__init__(availability_topic, qos, payload_available,
                          payload_not_available)
@@ -90,6 +97,8 @@ class MqttSwitch(MqttAvailability, SwitchDevice):
         self._retain = retain
         self._payload_on = payload_on
         self._payload_off = payload_off
+        self._state_on = state_on if state_on else self._payload_on
+        self._state_off = state_off if state_off else self._payload_off
         self._optimistic = optimistic
         self._template = value_template
         self._unique_id = unique_id
@@ -104,9 +113,9 @@ class MqttSwitch(MqttAvailability, SwitchDevice):
             if self._template is not None:
                 payload = self._template.async_render_with_possible_json_value(
                     payload)
-            if payload == self._payload_on:
+            if payload == self._state_on:
                 self._state = True
-            elif payload == self._payload_off:
+            elif payload == self._state_off:
                 self._state = False
 
             self.async_schedule_update_ha_state()
diff --git a/tests/components/switch/test_mqtt.py b/tests/components/switch/test_mqtt.py
index 31f9a729c53d98647c696ece3f2dc37941e37c48..7cd5a42b4a3e51de298b99a5254c36bbc140ba6c 100644
--- a/tests/components/switch/test_mqtt.py
+++ b/tests/components/switch/test_mqtt.py
@@ -249,6 +249,37 @@ class TestSwitchMQTT(unittest.TestCase):
         state = self.hass.states.get('switch.test')
         self.assertEqual(STATE_ON, state.state)
 
+    def test_custom_state_payload(self):
+        """Test the state payload."""
+        assert setup_component(self.hass, switch.DOMAIN, {
+            switch.DOMAIN: {
+                'platform': 'mqtt',
+                'name': 'test',
+                'state_topic': 'state-topic',
+                'command_topic': 'command-topic',
+                'payload_on': 1,
+                'payload_off': 0,
+                'state_on': "HIGH",
+                'state_off': "LOW",
+            }
+        })
+
+        state = self.hass.states.get('switch.test')
+        self.assertEqual(STATE_OFF, state.state)
+        self.assertFalse(state.attributes.get(ATTR_ASSUMED_STATE))
+
+        fire_mqtt_message(self.hass, 'state-topic', 'HIGH')
+        self.hass.block_till_done()
+
+        state = self.hass.states.get('switch.test')
+        self.assertEqual(STATE_ON, state.state)
+
+        fire_mqtt_message(self.hass, 'state-topic', 'LOW')
+        self.hass.block_till_done()
+
+        state = self.hass.states.get('switch.test')
+        self.assertEqual(STATE_OFF, state.state)
+
     def test_unique_id(self):
         """Test unique id option only creates one switch per unique_id."""
         assert setup_component(self.hass, switch.DOMAIN, {