"""Callable classes for Containerized Processors.
Currently we don't offer functionalities to clean up the containers after the
program finishes. Use the following commands to clean up the containers started
by this module.
$ docker stop -t 0 $(docker ps -a -q --filter="name=GABRIELTOOL")
"""
import ast
import copy
import io
import os
import time
import cv2
import docker
import requests
from logzero import logger
from gabrieltool.statemachine.callable_zoo import record_kwargs
from gabrieltool.statemachine.callable_zoo import CallableBase
from gabrieltool.statemachine.callable_zoo.processor_zoo import tfutils
docker_client = docker.from_env()
[docs]class SingletonContainerManager():
"""Helper class to start, get, and remove a container identified by a name."""
def __init__(self, container_name):
self._container_name = container_name
self._container = None
@property
def container_name(self):
return self._container_name
@property
def container(self):
if self._container is None:
self._container = self._get_container_obj()
return self._container
def _add_image_tag_if_not_available(self, image_url):
"""default to latest tag if tag is not specified."""
if ':' not in image_url:
image_url = '{}:{}'.format(image_url, 'latest')
return image_url
def _make_image_available(self, image_url):
logger.info(
'Checking if the container image is available locally... (image: {})'.format(image_url))
try:
docker_client.images.get(image_url)
logger.info(
'Found local image (image: {})!'.format(image_url))
except docker.errors.ImageNotFound:
logger.info('{} not found locally. Trying to downloading it...'.format(image_url))
logger.info('Downloading image: {}...'.format(image_url))
repo = image_url.split(':')[0]
tag = image_url.split(':')[-1]
docker_client.images.pull(repo, tag=tag)
logger.info('Download finished!')
[docs] def start_container(self, image_url, command, **kwargs):
"""Start a container
Args:
image_url (string): Container Image URL.
command (string): Container command.
kwargs (dictionary): Extra arguments to pass to Docker client.
Returns:
Container: A container
"""
image_url = self._add_image_tag_if_not_available(image_url)
if self.container is None or self.container.status != 'running':
self._make_image_available(image_url)
# start container
logger.info(
'launching Docker container (name: {}, image: {}, command: {}, extra args: {})'.format(
self.container_name,
image_url,
command,
str(kwargs)
))
try:
container = docker_client.containers.run(
image_url,
command=command,
name=self.container_name,
auto_remove=True,
detach=True,
**kwargs
)
container.reload()
# sleep here to give some time for container to start
time.sleep(10)
self._container = container
except Exception as e:
logger.error('Error starting container: {}'.format(e))
return self._container
def _get_container_obj(self):
containers = docker_client.containers.list(filters={'name': self.container_name})
if len(containers) > 0:
return containers[0]
return None
[docs] def clean(self):
"""Remove the container if it exists."""
container = self._get_container_obj()
if container:
container.remove(force=True)
[docs]class FasterRCNNContainerCallable(CallableBase):
"""A callable class to execute containerized FasterRCNN model in Caffe.
Use this class if your object detector is generated by TPOD v1 and the
container image is hosted by cmusatyalab's gitlab container registry.
"""
# UNIQUE to each python process
CONTAINER_NAME = 'GABRIELTOOL-FasterRCNNContainerCallable-{}'.format(os.getpid())
@record_kwargs
def __init__(self, container_image_url, conf_threshold=0.5):
"""Constructor.
Args:
container_image_url (string): URL to the container image.
conf_threshold (float, optional): Cutoff threshold for detection. Defaults to 0.5.
"""
# For default parameter settings,
# see:
# https://github.com/rbgirshick/fast-rcnn/blob/b612190f279da3c11dd8b1396dd5e72779f8e463/lib/fast_rcnn/config.py
super(FasterRCNNContainerCallable, self).__init__()
self.container_image_url = container_image_url
self.conf_threshold = conf_threshold
self.container_manager = SingletonContainerManager(self.CONTAINER_NAME)
# start container
# port number inside the container that is open
self.container_port = '8000/tcp'
ports = {self.container_port: None} # map container 8000 to a random host port
command = '/bin/bash run_server.sh',
self.container_manager.start_container(self.container_image_url, command, ports=ports)
@property
def container_server_url(self):
if self.container_manager.container is not None and self.container_manager.container.status == 'running':
return 'http://localhost:{}/detect'.format(
self.container_manager.container.ports[self.container_port][0]['HostPort'])
return None
[docs] @classmethod
def from_json(cls, json_obj):
"""Deserialize."""
try:
kwargs = copy.copy(json_obj)
kwargs['container_image_url'] = json_obj['container_image_url']
kwargs['conf_threshold'] = float(json_obj['conf_threshold'])
except ValueError as e:
raise ValueError(
'Failed to convert json object to {} instance. '
'The input json object is {}. ({})'.format(cls.__name__,
json_obj, e))
return cls(**kwargs)
def __call__(self, image):
fp = io.BytesIO()
fp.write(cv2.imencode('.jpg', image)[1].tostring())
fp.seek(0, 0)
response = requests.post(self.container_server_url, data={
'confidence': self.conf_threshold,
'format': 'box'
}, files={
'picture': fp
})
detections = ast.literal_eval(response.text)
result = {}
for detection in detections:
logger.info(detection)
label = detection[0]
bbox = detection[1]
confidence = detection[2]
if label not in result:
result[label] = []
result[label].append(
[*bbox, confidence, label]
)
return result
[docs] def clean(self):
self.container_manager.clean()
[docs]class TFServingContainerCallable(CallableBase):
"""A callable class to execute frozen tensorflow models using TF serving container images.
Use this class if your object detector is generated by OpenTPOD and you have
downloaded the model. The TF serving container is started lazily when an
FSM runner starts.
"""
# UNIQUE to each python process
CONTAINER_NAME = 'GABRIELTOOL-TFServingContainerCallable-{}'.format(os.getpid())
# TF Serving image by default listens on 8500 for GRPC
TFSERVING_GRPC_PORT = 8500
SERVED_DIRS = {}
@record_kwargs
def __init__(self, model_name, serving_dir, conf_threshold=0.5):
"""Constructor.
Args:
model_name (string): Arbitrary Name of the Model. Each DNN should
have a unique model name.
serving_dir (string): Path to the TF saved_model. This should refers
to the 'saved_model' directory of the downloaded OpenTPOD model.
conf_threshold (float, optional): Cutoff threshold for detection. Defaults to 0.5.
"""
super(TFServingContainerCallable, self).__init__()
self.serving_dir = serving_dir
self.model_name = model_name
self.conf_threshold = conf_threshold
TFServingContainerCallable.SERVED_DIRS[model_name] = os.path.abspath(serving_dir)
self.container_manager = SingletonContainerManager(TFServingContainerCallable.CONTAINER_NAME)
self.container_internal_port = '{}/tcp'.format(TFServingContainerCallable.TFSERVING_GRPC_PORT)
self.predictor = None
[docs] def prepare(self):
"""Launch the TF serving container. Do not call this method directly unless debugging.
This function is called when an FSM runner starts. This enables
gabrieltool to start only one TF serving container to serve many models.
"""
# launch container
self._start_container()
def _start_container(self):
"""Launch TF serving container image to serve all entries in SERVED_DIRS."""
ports = {self.container_internal_port: ('127.0.0.1', None)}
container_image_url = 'tensorflow/serving:2.1.0'
# geerate model config
from tensorflow_serving.config import model_server_config_pb2
from google.protobuf import text_format
model_server_config = model_server_config_pb2.ModelServerConfig()
for model_name, model_dir in TFServingContainerCallable.SERVED_DIRS.items():
model_config = model_server_config.model_config_list.config.add()
model_config.name = model_name
model_config.base_path = '/models/{}'.format(model_name)
model_config.model_platform = 'tensorflow'
with open('models.config', 'w') as f:
f.write(text_format.MessageToString(model_server_config))
# mount volumes
volumes = {
os.path.abspath('models.config'): {'bind': '/models/models.config', 'mode': 'ro'}
}
for model_name, model_dir in TFServingContainerCallable.SERVED_DIRS.items():
volumes[model_dir] = {'bind': '/models/{}'.format(model_name), 'mode': 'ro'}
logger.debug('volumes: {}'.format(volumes))
cmd = '--model_config_file=/models/models.config'
self.container_manager.start_container(container_image_url, cmd, ports=ports, volumes=volumes)
@property
def container_external_port(self):
"""Port of the TF Serving container."""
if self.container_manager.container is not None and self.container_manager.container.status == 'running':
return self.container_manager.container.ports[self.container_internal_port][0]['HostPort']
return None
[docs] @classmethod
def from_json(cls, json_obj):
"""Deserialize."""
try:
kwargs = copy.copy(json_obj)
kwargs['model_name'] = json_obj['model_name']
kwargs['serving_dir'] = json_obj['serving_dir']
kwargs['conf_threshold'] = float(json_obj['conf_threshold'])
except ValueError as e:
raise ValueError(
'Failed to convert json object to {} instance. '
'The input json object is {}. ({})'.format(cls.__name__,
json_obj, e))
return cls(**kwargs)
def __call__(self, image):
if not self.predictor:
self.predictor = tfutils.TFServingPredictor('localhost', self.container_external_port)
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = self.predictor.infer_one(self.model_name, rgb_image, conf_threshold=self.conf_threshold)
# # debug
# debug_image = visualize_detections(image, results)
# cv2.imshow('debug', debug_image)
# cv2.waitKey(1)
return results
[docs] def clean(self):
self.container_manager.clean()
if __name__ == "__main__":
container_image_url = 'registry.cmusatyalab.org/junjuew/container-registry:tpod-image-sandwich-sandwich'
processor = FasterRCNNContainerCallable(container_image_url=container_image_url,
conf_threshold=0.1)
image = cv2.imread('test.png')
print(processor(image))
processor.clean()