import cv2
import os
import base64
from PIL import  Image
import PIL
import io
import json
import numpy as np
from  multiprocessing import  Pool
from generateLabel import *
import shutil

class imgToSplit():
    def __init__(self,imgFile):
        self.objName = imgFile.strip('.jpg')
        self.img = cv2.imread(imgFile)
        self.h,self.w = self.img.shape[:-1]
        self.half_h,self.half_w = self.h//2,self.w//2
        self.sub_json_list = []
        self.disp = np.array([ (0,0) , (self.half_w,0) , (0,self.half_h) , (self.half_w,self.half_h)])
        self.category = None  #在3rd顺便获得category


    def splitImage(self):
        '''1st step : split origin image into 4 pieces!'''

        cv2.imwrite(r'.\target\%s_1.jpg' % self.objName, self.img[:self.half_h, :self.half_w])
        cv2.imwrite(r'.\target\%s_2.jpg' % self.objName, self.img[:self.half_h, self.half_w:])
        cv2.imwrite(r'.\target\%s_3.jpg' % self.objName, self.img[self.half_h:, :self.half_w])
        cv2.imwrite(r'.\target\%s_4.jpg' % self.objName, self.img[self.half_h:, self.half_w:])

    def gen_4sub_json(self):
        '''2nd step : generate 4 json file to sub image!'''

        def apply_exif_orientation(image):
                exif = image._getexif()
            except AttributeError:
                exif = None

            if exif is None:
                return image

            exif = {
                PIL.ExifTags.TAGS[k]: v
                for k, v in exif.items()
                if k in PIL.ExifTags.TAGS

            orientation = exif.get('Orientation', None)

            if orientation == 1:
                # do nothing
                return image
            elif orientation == 2:
                # left-to-right mirror
                return PIL.ImageOps.mirror(image)
            elif orientation == 3:
                # rotate 180
                return image.transpose(PIL.Image.ROTATE_180)
            elif orientation == 4:
                # top-to-bottom mirror
                return PIL.ImageOps.flip(image)
            elif orientation == 5:
                # top-to-left mirror
                return PIL.ImageOps.mirror(image.transpose(PIL.Image.ROTATE_270))
            elif orientation == 6:
                # rotate 270
                return image.transpose(PIL.Image.ROTATE_270)
            elif orientation == 7:
                # top-to-right mirror
                return PIL.ImageOps.mirror(image.transpose(PIL.Image.ROTATE_90))
            elif orientation == 8:
                # rotate 90
                return image.transpose(PIL.Image.ROTATE_90)
                return image

        def generateImgData(path):
                image_pil = PIL.Image.open(path)
            except IOError:

            # apply orientation to image according to exif
            image_pil = apply_exif_orientation(image_pil)

            with io.BytesIO() as f:
                image_pil.save(f, format='PNG')
                raw = f.read()
                return base64.b64encode(raw).decode('utf8')

        for i in range(1, 5):
            content = dict()
            imgName = '%s_%i.jpg'%(self.objName,i)
            img = cv2.imread(r'.\target\%s' % imgName)

            content['version'] = '3.10.0'
            content['flags'] = dict()
            content['shapes'] = []
            content['lineColor'] = [0, 255, 0, 128]
            content['fillColor'] = [255, 0, 0, 128]
            content['imagePath'] = imgName
            content['imageData'] = generateImgData(r'.\target\%s' % imgName)

            h, w = img.shape[:-1]
            content['imageHeight'] = h
            content['imageWidth'] = w

            # content = json.dumps(content,indent=4)
            # print(content)

            # with open(r'.\target\%s' % imgName.replace('jpg', 'json'), 'w') as f:
            #     json.dump(content, f)

    def justDrawLabel(self):
        '''3rd 将label图片分门别类 画出label'''

        with open('%s.json'%self.objName, 'r') as f:
            data = json.load(f)

        classes = [shape['label'] for shape in  data['shapes']]
        self.category = tuple(set(classes))

        imgData = data['imageData']

        img = img_b64_to_arr(imgData)

        categories = pointsDivByLabel(data['shapes'])
        for cateName, shapes in categories.items():
            label_name_to_value = {'_background_': 0}

            for shape in sorted(shapes, key=lambda x: x['label']):
                label_name = shape['label']
                if label_name in label_name_to_value:
                    label_value = label_name_to_value[label_name]
                    label_value = len(label_name_to_value)
                    label_name_to_value[label_name] = label_value

            lbl = shapes_to_label(img.shape, shapes, label_name_to_value)
            # lblsave(r'.\labelme_temp\%s_label_%s.png' % (self.objName,cateName), lbl)

    def _mapPiecesToJson(self,cate,img):
        """ 4th 读取label_png图片 分割 并将每块 映射到对应的sub json"""

        def generateNewShape(labelName,pointSet):
            """生成新的字典 方便json添加到shapes列表"""
            res = dict()
            res['label'] = labelName
            res['line_color'] = None
            res['fill_color'] = None
            res['points'] = pointSet
            res['shape_type'] = 'polygon'

            return res

        # img = cv2.imread(r'.\labelme_temp\%s_label_%s.png' % (self.objName,cate),0)
        img = np.uint8(img)

        img[:, self.half_w - 2: self.half_w + 3] = 0
        img[self.half_h - 2:self.half_h + 3, :] = 0

        contours, _ = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        aprox = [cv2.approxPolyDP(c, 1, True) for c in contours]
        aprox = [a.reshape(-1,2) for a in aprox if len(a)>=3] #过滤掉长度不符合规范的点 并reshape gzj
        aprox = [a.reshape(-1, 2) for a in aprox ]  #
        #aprox = [a for a in aprox if (np.std(a,axis=0) > 2).all() ] #过滤掉 噪声点 (纠结在一团的小块)gzj

        for pointSet in aprox:

            #任意取出一点 作为判断 在哪个象限标注
            judgeX = pointSet[0,0]
            judgeY = pointSet[0,1]

            if judgeX < self.half_w and judgeY < self.half_h:
                order = 0
            elif judgeX > self.half_w and judgeY < self.half_h:
                order = 1
            elif judgeX < self.half_w and judgeY > self.half_h:
                order = 2
            elif judgeX > self.half_w and judgeY > self.half_h:
                order = 3

            pointSet -= self.disp[order]


            # os.remove(r'.\labelme_temp\%s_label_%s.png'% (self.objName,cate))

    def saveSubJson(self):
        for i in range(1,5):
            with open(r'.\target\%s_%i.json' %(self.objName,i), 'w') as f:
                json.dump(self.sub_json_list[i-1], f)

def mission(file):
    file = file.replace('JPG', 'jpg')

if __name__ == '__main__':
    # mission = imgToSplit('DJI_0301.jpg')
    # mission.splitImage()
    # mission.gen_4sub_json()
    # mission.confirmNewPoints()
    # mission.saveSubJson()

    if not os.path.exists(r'.\target'):

    imgFiles = [ f for f in os.listdir('.') if 'jpg' in f or 'JPG' in f]
    # for file in imgFiles:
    #     file = file.replace('JPG','jpg')
    #     imgToSplit(file)
    pool = Pool(processes=4)

import json
import io
import base64
import numpy as np
from PIL import  Image
from PIL import ImageDraw
import PIL
import math
import  os.path as osp

def label_colormap(N=256):

    def bitget(byteval, idx):
        return ((byteval & (1 << idx)) != 0)

    cmap = np.zeros((N, 3))
    for i in range(0, N):
        id = i
        r, g, b = 0, 0, 0
        for j in range(0, 8):
            r = np.bitwise_or(r, (bitget(id, 0) << 7 - j))
            g = np.bitwise_or(g, (bitget(id, 1) << 7 - j))
            b = np.bitwise_or(b, (bitget(id, 2) << 7 - j))
            id = (id >> 3)
        cmap[i, 0] = r
        cmap[i, 1] = g
        cmap[i, 2] = b
    cmap = cmap.astype(np.float32) / 255
    return cmap

def lblsave(filename, lbl):
    if osp.splitext(filename)[1] != '.png':
        filename += '.png'
    # Assume label ranses [-1, 254] for int32,
    # and [0, 255] for uint8 as VOC.
    if lbl.min() >= -1 and lbl.max() < 255:
        lbl_pil = PIL.Image.fromarray(lbl.astype(np.uint8), mode='P')
        colormap = label_colormap(255)
        lbl_pil.putpalette((colormap * 255).astype(np.uint8).flatten())
            '[%s] Cannot save the pixel-wise class label as PNG, '
            'so please use the npy file.' % filename

def img_b64_to_arr(img_b64):
    f = io.BytesIO()
    img_arr = np.array(Image.open(f))
    return img_arr

def shape_to_mask(img_shape, points, shape_type=None,
                  line_width=10, point_size=5):
    mask = np.zeros(img_shape[:2], dtype=np.uint8)
    mask = Image.fromarray(mask)
    draw = ImageDraw.Draw(mask)
    xy = [tuple(point) for point in points]
    if shape_type == 'circle':
        assert len(xy) == 2, 'Shape of shape_type=circle must have 2 points'
        (cx, cy), (px, py) = xy
        d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
        draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1)
    elif shape_type == 'rectangle':
        assert len(xy) == 2, 'Shape of shape_type=rectangle must have 2 points'
        draw.rectangle(xy, outline=1, fill=1)
    elif shape_type == 'line':
        assert len(xy) == 2, 'Shape of shape_type=line must have 2 points'
        draw.line(xy=xy, fill=1, width=line_width)
    elif shape_type == 'linestrip':
        draw.line(xy=xy, fill=1, width=line_width)
    elif shape_type == 'point':
        assert len(xy) == 1, 'Shape of shape_type=point must have 1 points'
        cx, cy = xy[0]
        r = point_size
        draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
        assert len(xy) > 2, 'Polygon must have points more than 2'
        draw.polygon(xy=xy, outline=1, fill=1)
    mask = np.array(mask, dtype=bool)
    return mask

def shapes_to_label(img_shape, shapes, label_name_to_value, type='class'):
    assert type in ['class', 'instance']

    cls = np.zeros(img_shape[:2], dtype=np.int32)
    if type == 'instance':
        ins = np.zeros(img_shape[:2], dtype=np.int32)
        instance_names = ['_background_']
    for shape in shapes:
        points = shape['points']
        label = shape['label']
        shape_type = shape.get('shape_type', None)
        if type == 'class':
            cls_name = label
        elif type == 'instance':
            cls_name = label.split('-')[0]
            if label not in instance_names:
            ins_id = len(instance_names) - 1
        cls_id = label_name_to_value[cls_name]
        mask = shape_to_mask(img_shape[:2], points, shape_type)
        cls[mask] = cls_id
        if type == 'instance':
            ins[mask] = ins_id

    if type == 'instance':
        return cls, ins
    return cls

def pointsDivByLabel(listOfDict):
    """传入data['shapes'] 用label 将各个类分开"""
    res = dict()
    for shape in listOfDict:
        if shape['label'] in res.keys():
            res[shape['label']] = [shape]
    return res

# def justDrawLabel(file):
#     obj = file.strip('.json')
#     with open(file,'r') as f:
#         data = json.load(f)
#     imgData = data['imageData']
#     img = img_b64_to_arr(imgData)
#     categories = pointsDivByLabel(data['shapes'])
#     for cateName,shapes in categories.items():
#         label_name_to_value = {'_background_': 0}
#         for shape in sorted(shapes, key=lambda x: x['label']):
#             label_name = shape['label']
#             if label_name in label_name_to_value:
#                 label_value = label_name_to_value[label_name]
#             else:
#                 label_value = len(label_name_to_value)
#                 label_name_to_value[label_name] = label_value
#         lbl = shapes_to_label(img.shape, shapes, label_name_to_value)
#         lblsave('label_%s.png'%cateName, lbl)

if __name__ == '__main__':
    # justDrawLabel('test4.json')

