diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py
index d3252525b7647da630d5b5ba3452bc532e4164ac..9cc13ac94bd69d1c0bda336e9534e4133488af3b 100644
--- a/homeassistant/components/mqtt/switch.py
+++ b/homeassistant/components/mqtt/switch.py
@@ -1,11 +1,14 @@
 """Support for MQTT switches."""
+from __future__ import annotations
+
 import functools
 
 import voluptuous as vol
 
 from homeassistant.components import switch
-from homeassistant.components.switch import SwitchEntity
+from homeassistant.components.switch import DEVICE_CLASSES_SCHEMA, SwitchEntity
 from homeassistant.const import (
+    CONF_DEVICE_CLASS,
     CONF_NAME,
     CONF_OPTIMISTIC,
     CONF_PAYLOAD_OFF,
@@ -48,6 +51,7 @@ PLATFORM_SCHEMA = mqtt.MQTT_RW_PLATFORM_SCHEMA.extend(
         vol.Optional(CONF_STATE_OFF): cv.string,
         vol.Optional(CONF_STATE_ON): cv.string,
         vol.Optional(CONF_VALUE_TEMPLATE): cv.template,
+        vol.Optional(CONF_DEVICE_CLASS): DEVICE_CLASSES_SCHEMA,
     }
 ).extend(MQTT_ENTITY_COMMON_SCHEMA.schema)
 
@@ -158,6 +162,11 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
         """Return true if we do optimistic updates."""
         return self._optimistic
 
+    @property
+    def device_class(self) -> str | None:
+        """Return the device class of the sensor."""
+        return self._config.get(CONF_DEVICE_CLASS)
+
     async def async_turn_on(self, **kwargs):
         """Turn the device on.
 
diff --git a/tests/components/mqtt/test_switch.py b/tests/components/mqtt/test_switch.py
index 263ec0a2825ad01b1b77a8634baff3397a11259c..a3ef29d0d080b316372b4d7eb1d410db67bb6bb3 100644
--- a/tests/components/mqtt/test_switch.py
+++ b/tests/components/mqtt/test_switch.py
@@ -6,7 +6,12 @@ import pytest
 
 from homeassistant.components import switch
 from homeassistant.components.mqtt.switch import MQTT_SWITCH_ATTRIBUTES_BLOCKED
-from homeassistant.const import ATTR_ASSUMED_STATE, STATE_OFF, STATE_ON
+from homeassistant.const import (
+    ATTR_ASSUMED_STATE,
+    ATTR_DEVICE_CLASS,
+    STATE_OFF,
+    STATE_ON,
+)
 import homeassistant.core as ha
 from homeassistant.setup import async_setup_component
 
@@ -56,6 +61,7 @@ async def test_controlling_state_via_topic(hass, mqtt_mock):
                 "command_topic": "command-topic",
                 "payload_on": 1,
                 "payload_off": 0,
+                "device_class": "switch",
             }
         },
     )
@@ -63,6 +69,7 @@ async def test_controlling_state_via_topic(hass, mqtt_mock):
 
     state = hass.states.get("switch.test")
     assert state.state == STATE_OFF
+    assert state.attributes.get(ATTR_DEVICE_CLASS) == "switch"
     assert not state.attributes.get(ATTR_ASSUMED_STATE)
 
     async_fire_mqtt_message(hass, "state-topic", "1")
@@ -387,6 +394,7 @@ async def test_discovery_update_unchanged_switch(hass, mqtt_mock, caplog):
     """Test update of discovered switch."""
     data1 = (
         '{ "name": "Beer",'
+        '  "device_class": "switch",'
         '  "state_topic": "test_topic",'
         '  "command_topic": "test_topic" }'
     )