diff --git a/functions/pipeline/download/__init__.py b/functions/pipeline/download/__init__.py index 96f3965a..d208b25f 100644 --- a/functions/pipeline/download/__init__.py +++ b/functions/pipeline/download/__init__.py @@ -35,12 +35,19 @@ def main(req: func.HttpRequest) -> func.HttpResponse: # DB configuration data_access = ImageTagDataAccess(get_postgres_provider()) user_id = data_access.create_user(user_name) - image_urls = list(data_access.get_new_images(image_count, user_id)) + image_id_to_urls = data_access.get_images_for_tagging(image_count, user_id) + image_urls = list(image_id_to_urls.values()) - # TODO: Populate starting json with tags, if any exist... (precomputed or retagging?) - vott_json = create_starting_vott_json(image_urls) + image_id_to_image_tags = {} + for image_id in image_id_to_urls.keys(): + image_id_to_image_tags[image_id] = data_access.get_image_tags(image_id) - return_body_json = {"imageUrls": image_urls, "vottJson": vott_json} + existing_classifications_list = data_access.get_existing_classifications() + + vott_json = create_starting_vott_json(image_id_to_urls, image_id_to_image_tags, existing_classifications_list) + + return_body_json = {"imageUrls": image_urls, + "vottJson": vott_json} content = json.dumps(return_body_json) return func.HttpResponse( diff --git a/functions/pipeline/shared/db_access/db_access_v2.py b/functions/pipeline/shared/db_access/db_access_v2.py index c174f57b..8d0efee8 100644 --- a/functions/pipeline/shared/db_access/db_access_v2.py +++ b/functions/pipeline/shared/db_access/db_access_v2.py @@ -1,15 +1,12 @@ -# import sys import string -# import os -# import time import logging import random from enum import IntEnum, unique import getpass import itertools - from ..db_provider import DatabaseInfo, PostGresProvider + @unique class ImageTagState(IntEnum): NOT_READY = 0 @@ -19,6 +16,7 @@ class ImageTagState(IntEnum): INCOMPLETE_TAG = 4 ABANDONED = 5 + # An entity class for a VOTT image class ImageInfo(object): def __init__(self, image_name, image_location, height, width): @@ -27,14 +25,25 @@ def __init__(self, image_name, image_location, height, width): self.height = height self.width = width + +# Entity class for Tags stored in DB class ImageTag(object): def __init__(self, image_id, x_min, x_max, y_min, y_max, classification_names): - self.image_id = image_id - self.x_min = x_min - self.x_max = x_max - self.y_min = y_min - self.y_max = y_max - self.classification_names = classification_names + self.image_id = image_id + self.x_min = x_min + self.x_max = x_max + self.y_min = y_min + self.y_max = y_max + self.classification_names = classification_names + + +# Vott tags have image height & width data as well. +class VottImageTag(ImageTag): + def __init__(self, image_id, x_min, x_max, y_min, y_max, classification_names, image_height, image_width): + super().__init__(image_id, x_min, x_max, y_min, y_max, classification_names) + self.image_height = image_height + self.image_width = image_width + class ImageTagDataAccess(object): def __init__(self, db_provider): @@ -69,7 +78,7 @@ def create_user(self,user_name): finally: conn.close() return user_id - def get_new_images(self, number_of_images, user_id): + def get_images_for_tagging(self, number_of_images, user_id): if number_of_images <= 0: raise ArgumentException("Parameter must be greater than zero") @@ -84,14 +93,16 @@ def get_new_images(self, number_of_images, user_id): cursor.execute(query.format(number_of_images, ImageTagState.READY_TO_TAG, ImageTagState.INCOMPLETE_TAG)) for row in cursor: logging.debug('Image Id: {0} \t\tImage Name: {1} \t\tTag State: {2}'.format(row[0], row[1], row[2])) - selected_images_to_tag[str(row[0])] = str(row[1]) + selected_images_to_tag[row[0]] = str(row[1]) self._update_images(selected_images_to_tag,ImageTagState.TAG_IN_PROGRESS, user_id, conn) - finally: cursor.close() + finally: + cursor.close() except Exception as e: logging.error("An errors occured getting images: {0}".format(e)) raise - finally: conn.close() - return selected_images_to_tag.values() + finally: + conn.close() + return selected_images_to_tag def add_new_images(self,list_of_image_infos, user_id): @@ -119,6 +130,106 @@ def add_new_images(self,list_of_image_infos, user_id): finally: conn.close() return url_to_image_id_map + def get_tag_complete_images(self, number_of_images, user_id): + if number_of_images <= 0: + raise ArgumentException("Parameter must be greater than zero") + + tag_complete_images = {} + try: + conn = self._db_provider.get_connection() + try: + cursor = conn.cursor() + query = ("SELECT b.ImageId, b.ImageLocation, a.TagStateId FROM Image_Tagging_State a " + "JOIN Image_Info b ON a.ImageId = b.ImageId WHERE a.TagStateId = {1} order by " + "a.createddtim DESC limit {0}") + cursor.execute(query.format(number_of_images, ImageTagState.COMPLETED_TAG)) + for row in cursor: + logging.debug('Image Id: {0} \t\tImage Name: {1} \t\tTag State: {2}'.format(row[0], row[1], row[2])) + tag_complete_images[row[0]] = str(row[1]) + finally: + cursor.close() + except Exception as e: + logging.error("An errors occured getting images: {0}".format(e)) + raise + finally: + conn.close() + return tag_complete_images + + def get_image_tags(self, image_id): + if type(image_id) is not int: + raise TypeError('image_id must be an integer') + + try: + conn = self._db_provider.get_connection() + try: + cursor = conn.cursor() + query = ("SELECT image_tags.imagetagid, image_info.imageid, x_min, x_max, y_min, y_max, " + "classification_info.classificationname, image_info.height, image_info.width " + "FROM image_tags " + "inner join tags_classification on image_tags.imagetagid = tags_classification.imagetagid " + "inner join classification_info on tags_classification.classificationid = classification_info.classificationid " + "inner join image_info on image_info.imageid = image_tags.imageid " + "WHERE image_tags.imageid = {0};") + cursor.execute(query.format(image_id,)) + + logging.debug("Got image tags back for image_id={}".format(image_id)) + tag_id_to_VottImageTag = self.__build_id_to_VottImageTag(cursor) + + finally: + cursor.close() + except Exception as e: + logging.error("An error occurred getting image tags for image id = {0}: {1}".format(image_id, e)) + raise + finally: + conn.close() + return list(tag_id_to_VottImageTag.values()) + + def __build_id_to_VottImageTag(self, tag_db_cursor): + tag_id_to_VottImageTag = {} + try : + for row in tag_db_cursor: + logging.debug(row) + tag_id = row[0] + if tag_id in tag_id_to_VottImageTag: + logging.debug("Existing ImageTag found, appending classification {}", row[6]) + tag_id_to_VottImageTag[tag_id].classification_names.append(row[6].strip()) + else: + logging.debug("No existing ImageTag found, creating new ImageTag: " + "id={0} x_min={1} x_max={2} x_min={3} x_max={4} classification={5} " + "image_height={6} image_width={7}" + .format(row[1], float(row[2]), float(row[3]), float(row[4]), float(row[5]), + [row[6].strip()], row[7], row[8])) + tag_id_to_VottImageTag[tag_id] = VottImageTag(row[1], float(row[2]), float(row[3]), + float(row[4]), float(row[5]), [row[6].strip()], + row[7], row[8]) + except Exception as e: + logging.error("An error occurred building VottImageTag dict: {0}".format(e)) + raise + return tag_id_to_VottImageTag + + + def get_existing_classifications(self): + try: + conn = self._db_provider.get_connection() + try: + cursor = conn.cursor() + query = "SELECT classificationname from classification_info order by classificationname asc" + cursor.execute(query) + + classification_set = set() + for row in cursor: + logging.debug(row) + classification_set.add(row[0]) + logging.debug("Got back {0} classifications existing in db.".format(len(classification_set))) + finally: + cursor.close() + except Exception as e: + logging.error("An error occurred getting classifications from DB: {0}".format(e)) + raise + finally: + conn.close() + return list(classification_set) + def update_incomplete_images(self, list_of_image_ids, user_id): #TODO: Make sure the image ids are in a TAG_IN_PROGRESS state self._update_images(list_of_image_ids,ImageTagState.INCOMPLETE_TAG,user_id, self._db_provider.get_connection()) @@ -229,6 +340,12 @@ def main(): # Checking in images been tagged ################################################################# + # import sys + # import os + # sys.path.append("..") + # sys.path.append(os.path.abspath('db_provider')) + # from db_provider import DatabaseInfo, PostGresProvider + #Replace me for testing db_config = DatabaseInfo("","","","") data_access = ImageTagDataAccess(PostGresProvider(db_config)) @@ -241,7 +358,9 @@ def main(): image_tags = generate_test_image_tags(list(url_to_image_id_map.values()),4,4) data_access.update_tagged_images(image_tags,user_id) -TestClassifications = ("maine coon","german shephard","goldfinch","mackerel"," african elephant","rattlesnake") + +TestClassifications = ("maine coon","german shephard","goldfinch","mackerel","african elephant","rattlesnake") + def generate_test_image_infos(count): list_of_image_infos = [] diff --git a/functions/pipeline/shared/db_access/test_db_access_v2.py b/functions/pipeline/shared/db_access/test_db_access_v2.py index 45eea1db..6de5af0c 100644 --- a/functions/pipeline/shared/db_access/test_db_access_v2.py +++ b/functions/pipeline/shared/db_access/test_db_access_v2.py @@ -74,7 +74,7 @@ def test_get_new_images_bad_request(self): with self.assertRaises(ArgumentException): data_access = ImageTagDataAccess(MockDBProvider()) num_of_images = -5 - data_access.get_new_images(num_of_images,5) + data_access.get_images_for_tagging(num_of_images, 5) def test_add_new_images_user_id_type_error(self): with self.assertRaises(TypeError): diff --git a/functions/pipeline/shared/vott_parser/vott_parser.py b/functions/pipeline/shared/vott_parser/vott_parser.py index 1bffff7f..a476c4d7 100644 --- a/functions/pipeline/shared/vott_parser/vott_parser.py +++ b/functions/pipeline/shared/vott_parser/vott_parser.py @@ -1,27 +1,60 @@ import json -def __build_frames_data(images): + +def __build_tag_from_VottImageTag(image_tag): + return { + "x1": image_tag.x_min, + "x2": image_tag.x_max, + "y1": image_tag.y_min, + "y2": image_tag.y_max, + "width": image_tag.image_width, + "height": image_tag.image_height, + "tags": image_tag.classification_names + } + + +def __build_tag_list_from_VottImageTags(image_tag_list): + tag_list = [] + for image_tag in image_tag_list: + tag_list.append(__build_tag_from_VottImageTag(image_tag)) + return tag_list + + +def __build_frames_data(image_id_to_urls, image_id_to_image_tags): frames = {} - for filename in images: - # TODO: Build tag data per frame if they exist already - frames[__get_filename_from_fullpath(filename)] = [] #list of tags + for image_id in image_id_to_image_tags.keys(): + image_file_name = __get_filename_from_fullpath(image_id_to_urls[image_id]) + image_tags = __build_tag_list_from_VottImageTags(image_id_to_image_tags[image_id]) + frames[image_file_name] = image_tags return frames + # For download function -def create_starting_vott_json(images): +def create_starting_vott_json(image_id_to_urls, image_id_to_image_tags, existing_classifications_list): + # "frames" + frame_to_tag_list_map = __build_frames_data(image_id_to_urls, image_id_to_image_tags) + + # "inputTags" + classification_str = "" + for classification in existing_classifications_list: + classification_str += classification + "," + return { - "frames": __build_frames_data(images), - "inputTags": "", # TODO: populate classifications that exist in db already + "frames": frame_to_tag_list_map, + "inputTags": classification_str, "scd": False # Required for VoTT and image processing? unknown if it's also used for video. } + def __get_filename_from_fullpath(filename): path_components = filename.split('/') return path_components[-1] + def __get_id_from_fullpath(fullpath): return int(__get_filename_from_fullpath(fullpath).split('.')[0]) + # Returns a list of processed tags for a single frame def __create_tag_data_list(json_tag_list): processed_tags = [] @@ -29,6 +62,7 @@ def __create_tag_data_list(json_tag_list): processed_tags.append(__process_json_tag(json_tag)) return processed_tags + def __process_json_tag(json_tag): return { "x1": json_tag['x1'], @@ -42,6 +76,7 @@ def __process_json_tag(json_tag): "name": json_tag["name"] } + # For upload function def process_vott_json(json): all_frame_data = json['frames'] @@ -79,6 +114,7 @@ def process_vott_json(json): "imageIdToTags": id_to_tags_dict } + def main(): images = { "1.png" : {}, @@ -121,7 +157,6 @@ def main(): # add_tag_to_db('something', 2, (tag_data)) - # Currently only used for testing... # returns a json representative of a tag given relevant components def __build_json_tag(x1, x2, y1, y2, img_width, img_height, UID, id, type, tags, name): @@ -145,5 +180,6 @@ def __build_json_tag(x1, x2, y1, y2, img_width, img_height, UID, id, type, tags, "name": name } + if __name__ == '__main__': main() diff --git a/functions/pipeline/taggedimages/__init__.py b/functions/pipeline/taggedimages/__init__.py new file mode 100644 index 00000000..b602345b --- /dev/null +++ b/functions/pipeline/taggedimages/__init__.py @@ -0,0 +1,62 @@ +import logging + +import azure.functions as func +import json + +from ..shared.vott_parser import create_starting_vott_json +from ..shared.db_provider import get_postgres_provider +from ..shared.db_access import ImageTagDataAccess + + +def main(req: func.HttpRequest) -> func.HttpResponse: + logging.info('Python HTTP trigger function processed a request.') + + image_count = int(req.params.get('imageCount')) + user_name = req.params.get('userName') + + # setup response object + headers = { + "content-type": "application/json" + } + if not user_name: + return func.HttpResponse( + status_code=401, + headers=headers, + body=json.dumps({"error": "invalid userName given or omitted"}) + ) + elif not image_count: + return func.HttpResponse( + status_code=400, + headers=headers, + body=json.dumps({"error": "image count not specified"}) + ) + else: + try: + # DB configuration + data_access = ImageTagDataAccess(get_postgres_provider()) + user_id = data_access.create_user(user_name) + image_id_to_urls = data_access.get_tag_complete_images(image_count, user_id) + image_urls = list(image_id_to_urls.values()) + + image_id_to_image_tags = {} + for image_id in image_id_to_urls.keys(): + image_id_to_image_tags[image_id] = data_access.get_image_tags(image_id) + + existing_classifications_list = data_access.get_existing_classifications() + + vott_json = create_starting_vott_json(image_id_to_urls, image_id_to_image_tags, existing_classifications_list) + + return_body_json = {"imageUrls": image_urls, + "vottJson": vott_json} + + content = json.dumps(return_body_json) + return func.HttpResponse( + status_code=200, + headers=headers, + body=content + ) + except Exception as e: + return func.HttpResponse( + "exception:" + str(e), + status_code=500 + ) diff --git a/functions/pipeline/taggedimages/function.json b/functions/pipeline/taggedimages/function.json new file mode 100644 index 00000000..f24e4f1d --- /dev/null +++ b/functions/pipeline/taggedimages/function.json @@ -0,0 +1,19 @@ +{ + "scriptFile": "__init__.py", + "bindings": [ + { + "authLevel": "anonymous", + "type": "httpTrigger", + "direction": "in", + "name": "req", + "methods": [ + "get" + ] + }, + { + "type": "http", + "direction": "out", + "name": "$return" + } + ] +} diff --git a/functions/pipeline/taggedimages/host.json b/functions/pipeline/taggedimages/host.json new file mode 100644 index 00000000..81e35b7b --- /dev/null +++ b/functions/pipeline/taggedimages/host.json @@ -0,0 +1,3 @@ +{ + "version": "2.0" +} \ No newline at end of file