Skip to content
Snippets Groups Projects
Commit 5a83a923 authored by Steven D. Lander's avatar Steven D. Lander Committed by Paulus Schoutsen
Browse files

Refactor imports for tensorflow (#27617)

* Refactoring imports for tensorflow

* Removing whitespace spaces on blank line 110

* Moving tensorflow to try/except block

* Fixed black formatting

* Refactoring try/except to if/else
parent 09de6d58
No related branches found
No related tags found
No related merge requests found
......@@ -2,8 +2,22 @@
import logging
import os
import sys
import io
import voluptuous as vol
from PIL import Image, ImageDraw
import numpy as np
try:
import cv2
except ImportError:
cv2 = None
try:
# Verify that the TensorFlow Object Detection API is pre-installed
import tensorflow as tf # noqa
from object_detection.utils import label_map_util # noqa
except ImportError:
label_map_util = None
from homeassistant.components.image_processing import (
CONF_CONFIDENCE,
......@@ -84,14 +98,8 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
# append custom model path to sys.path
sys.path.append(model_dir)
try:
# Verify that the TensorFlow Object Detection API is pre-installed
# pylint: disable=unused-import,unused-variable
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf # noqa
from object_detection.utils import label_map_util # noqa
except ImportError:
# pylint: disable=line-too-long
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
if label_map_util is None:
_LOGGER.error(
"No TensorFlow Object Detection library found! Install or compile "
"for your system following instructions here: "
......@@ -99,11 +107,7 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
) # noqa
return
try:
# Display warning that PIL will be used if no OpenCV is found.
# pylint: disable=unused-import,unused-variable
import cv2 # noqa
except ImportError:
if cv2 is None:
_LOGGER.warning(
"No OpenCV library found. TensorFlow will process image with "
"PIL at reduced resolution"
......@@ -236,9 +240,6 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
}
def _save_image(self, image, matches, paths):
from PIL import Image, ImageDraw
import io
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
img_width, img_height = img.size
draw = ImageDraw.Draw(img)
......@@ -280,18 +281,8 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
def process_image(self, image):
"""Process the image."""
import numpy as np
try:
import cv2 # pylint: disable=import-error
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
inp = img[:, :, [2, 1, 0]] # BGR->RGB
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
except ImportError:
from PIL import Image
import io
if cv2 is None:
img = Image.open(io.BytesIO(bytearray(image))).convert("RGB")
img.thumbnail((460, 460), Image.ANTIALIAS)
img_width, img_height = img.size
......@@ -301,6 +292,10 @@ class TensorFlowImageProcessor(ImageProcessingEntity):
.astype(np.uint8)
)
inp_expanded = np.expand_dims(inp, axis=0)
else:
img = cv2.imdecode(np.asarray(bytearray(image)), cv2.IMREAD_UNCHANGED)
inp = img[:, :, [2, 1, 0]] # BGR->RGB
inp_expanded = inp.reshape(1, inp.shape[0], inp.shape[1], 3)
image_tensor = self._graph.get_tensor_by_name("image_tensor:0")
boxes = self._graph.get_tensor_by_name("detection_boxes:0")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment