#!/usr/bin/python3
import time
from http.server import HTTPServer, BaseHTTPRequestHandler
#from picamera import PiCamera
import operator
import os
import datetime

import sqlite3 as lite

trainthreshold=15

#import asyncio
#import datetime
#import random
#import websockets


con = lite.connect('trains.db')

def db_purge():
    global con
    with con:        
        cur = con.cursor()    
        cur.execute("DELETE FROM trains")


def db_insert(ts,t):
    global con
    with con:
        #print "inserting "+mac+" into db"    
        cur = con.cursor()    
        
        length=22
        sql="INSERT INTO trains(ts,length) VALUES('"+str(ts)+"','"+str(t)+"')"
        cur.execute(sql)



def db_dump():
    global con
    with con:        
        cur = con.cursor()    
        cur.execute("SELECT * FROM trains")

        rows = cur.fetchall()

        for row in rows:
            print(row)

def db_create():
    global con
    with con:
    
        cur = con.cursor()    
        cur.execute("DROP TABLE IF EXISTS trains")
        cur.execute("CREATE TABLE trains(id INTEGER PRIMARY KEY,ts varchar(20) NOT NULL, length INT DEFAULT 0)")


def millis():
    return int(round(time.time() * 1000))

def getCPUtemperature():
    res = os.popen('vcgencmd measure_temp').readline()
    return(res.replace("temp=","").replace("'C\n",""))

class MjpegMixin:
    """
    Add MJPEG features to a subclass of BaseHTTPRequestHandler.
    """

    mjpegBound = 'eb4154aac1c9ee636b8a6f5622176d1fbc08d382ee161bbd42e8483808c684b6'
    frameBegin = 'Content-Type: image/jpeg\n\n'.encode('ascii')
    frameBound = ('\n--' + mjpegBound + '\n').encode('ascii') + frameBegin

    def mjpegBegin(self):
        self.send_response(200)
        self.send_header('Content-Type',
                         'multipart/x-mixed-replace;boundary=' + MjpegMixin.mjpegBound)
        self.end_headers()
        self.wfile.write(MjpegMixin.frameBegin)

    def mjpegEndFrame(self):
        self.wfile.write(MjpegMixin.frameBound)


class SmoothedFpsCalculator:
    """
    Provide smoothed frame per second calculation.
    """

    def __init__(self, alpha=0.1):
        self.t = time.time()
        self.alpha = alpha
        self.sfps = None

    def __call__(self):
        t = time.time()
        d = t - self.t
        self.t = t
        fps = 1.0 / d
        if self.sfps is None:
            self.sfps = fps
        else:
            self.sfps = fps * self.alpha + self.sfps * (1.0 - self.alpha)
        return self.sfps


def cropND(img, startx,endx,starty,endy):
#        start = tuple(map(lambda a, da: a//2-da//2, img.shape, bounding))
#        end = tuple(map(operator.add, start, bounding))
#        starty=0
#        endy=480
#        startx=0
#        endx=640

        start=(starty, startx)
        end= (endy, endx)

        slices = tuple(map(slice, start, end))
#        print(start,end)
#        print(slices)
        return img[slices]


class Handler(BaseHTTPRequestHandler, MjpegMixin):
    def do_GET(self):
        if self.path.startswith('/image.mjpeg'):
            paramstr=self.path.split('?')[1]
            print(paramstr)
            p=paramstr.split(',')
            self.handleContourMjpeg(p[0],p[1],p[2],p[3])
        else:
            self.send_response(404)
            self.end_headers()



    def handleContourMjpeg(self,startx,endx,starty,endy):
        import cv2
        import numpy as np
        avc=[0,0,0]
        avbgc=[0,0,0]
        havebg=0
        bgtimeout=0
        trainevent=0
        lasttrainevent=0
        width, height, blur, sigma = 640, 480, 2, 0.33
        fpsFont, fpsXY = cv2.FONT_HERSHEY_SIMPLEX, (0, height-1)
        self.mjpegBegin()



#        with PiCamera() as camera:
#            camera.resolution = (width, height)
#            camera.video_denoise = False
#            camera.image_effect = 'blur'
#            camera.image_effect_params = (blur,)
        yuv = np.empty((int(width * height * 1.5),), dtype=np.uint8)
        rgb = np.empty((480,640,3), dtype=np.uint8)
#            sfps = SmoothedFpsCalculator()


        starttime=millis()

#            for x in camera.capture_continuous(rgb, format='bgr', use_video_port=True):

        cap = cv2.VideoCapture('track1_cut.mp4')

        while(cap.isOpened()):
            ret, rgb = cap.read()

#            print(ret,rgb.shape)

            outimg=np.zeros((480,640,3), dtype=np.uint8)

            image = rgb.reshape((height, width,3))
#                image= cropND(image, (50,50))
            image= cropND(image, int(startx),int(endx),int(starty),int(endy))

            lastr=round(avc[2])
            lastg=round(avc[1])
            lastb=round(avc[0])

            avc = [image[:, :, i].mean() for i in range(image.shape[-1])]
            r=round(avc[2])
            g=round(avc[1])
            b=round(avc[0])
            
            
            
            if (abs(r-lastr)>5) or (abs(g-lastg)>5) or (abs(b-lastb)>5):
                #print("changed")
                lastr=r
                lastg=g
                lastb=b
                bgtimeout=0

            lasttrainevent=trainevent

            if (abs(r-avbgc[2])>trainthreshold) or (abs(g-avbgc[1])>trainthreshold) or (abs(b-avbgc[0])>trainthreshold):
                if havebg!=0:
            #        print("have train event")
                    trainevent=1
            else:
                trainevent=0
            if trainevent!=lasttrainevent:

                if trainevent==1:
                    ts = time.time()+7200
                    st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
                    print("train begin",st)
                    
                if trainevent==0:
                    et = time.time()+7200
                    ets = datetime.datetime.fromtimestamp(et).strftime('%Y-%m-%d %H:%M:%S')
                    print("train end",ets)
                    
#                    db_insert(ts,et-ts)


            if trainevent==1:
                print("%03i  %03i  %03i" % (r,g,b))

#seconds tick:
            if (millis()-starttime>1000):
               #if trainevent==0:
               # print(getCPUtemperature())
               bgtimeout+=1
               if bgtimeout==10:
                print("have bgval")
                avbgc[0]=b
                avbgc[1]=g
                avbgc[2]=r
                havebg=1
                bgtimeout=0
               starttime=millis()

#                print(sfps())
#            cv2.putText(image, '%0.2f fps' %
#                        25, fpsXY, fpsFont, 1.0, 255)
            barheight=int(endy)-int(starty)
            barheight=barheight//2 

            cv2.rectangle(image,(0,0),(80,barheight),(avc[0],avc[1],avc[2]),-1)
            cv2.rectangle(image,(0,barheight),(80,barheight*2),(avbgc[0],avbgc[1],avbgc[2]),-1)

            cv2.imshow("frame",image)
            self.wfile.write(cv2.imencode('.jpg', image)[1])
            self.mjpegEndFrame()



def handleNostream(startx,endx,starty,endy):
        import cv2
        import numpy as np
        avc=[0,0,0]
        avbgc=[0,0,0]
        havebg=0
        bgtimeout=0
        trainevent=0
        lasttrainevent=0
        width, height, blur, sigma = 640, 480, 2, 0.33
        fpsFont, fpsXY = cv2.FONT_HERSHEY_SIMPLEX, (0, height-1)



        with PiCamera() as camera:
            camera.resolution = (width, height)
            camera.video_denoise = False
            camera.image_effect = 'blur'
            camera.image_effect_params = (blur,)
            yuv = np.empty((int(width * height * 1.5),), dtype=np.uint8)
            rgb = np.empty((480,640,3), dtype=np.uint8)
            sfps = SmoothedFpsCalculator()


            starttime=millis()

            for x in camera.capture_continuous(rgb, format='bgr', use_video_port=True):
                outimg=np.zeros((480,640,3), dtype=np.uint8)

                image = rgb.reshape((height, width,3))
#                image= cropND(image, (50,50))
                image= cropND(image, int(startx),int(endx),int(starty),int(endy))

                lastr=round(avc[2])
                lastg=round(avc[1])
                lastb=round(avc[0])

                avc = [image[:, :, i].mean() for i in range(image.shape[-1])]
                r=round(avc[2])
                g=round(avc[1])
                b=round(avc[0])
                
                
                
                if (abs(r-lastr)>5) or (abs(g-lastg)>5) or (abs(b-lastb)>5):
                    #print("changed")
                    lastr=r
                    lastg=g
                    lastb=b
                    bgtimeout=0

                lasttrainevent=trainevent

                if (abs(r-avbgc[2])>trainthreshold) or (abs(g-avbgc[1])>trainthreshold) or (abs(b-avbgc[0])>trainthreshold):
                    if havebg!=0:
                #        print("have train event")
                        trainevent=1
                else:
                    trainevent=0
                if trainevent!=lasttrainevent:

                    if trainevent==1:
                        ts = time.time()+7200
                        st = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
                        print("train begin",st)
                        
                    if trainevent==0:
                        et = time.time()+7200
                        ets = datetime.datetime.fromtimestamp(et).strftime('%Y-%m-%d %H:%M:%S')
                        print("train end",ets)
                        
                        db_insert(ts,et-ts)


                if trainevent==1:
                    print("%03i  %03i  %03i" % (r,g,b))

#seconds tick:
                if (millis()-starttime>1000):
                   if trainevent==0:
                    print(getCPUtemperature())
                   bgtimeout+=1
                   if bgtimeout==10:
                    print("have bgval")
                    avbgc[0]=b
                    avbgc[1]=g
                    avbgc[2]=r
                    havebg=1
                    bgtimeout=0
                   starttime=millis()

#                print(sfps())
                cv2.putText(image, '%0.2f fps' %
                            sfps(), fpsXY, fpsFont, 1.0, 255)
                barheight=int(endy)-int(starty)
                barheight=barheight//2 

                cv2.rectangle(image,(0,0),(100,barheight),(avc[0],avc[1],avc[2]),-1)
                cv2.rectangle(image,(0,barheight),(100,barheight*2),(avbgc[0],avbgc[1],avbgc[2]),-1)







def run(port=8001):

    httpd = HTTPServer(('', port), Handler)

    httpd.serve_forever()


if __name__ == '__main__':
#    db_purge()
#    db_dump()
    import argparse
    parser = argparse.ArgumentParser(description='HTTP streaming camera.')
    parser.add_argument('--port', type=int, default=8001,
                        help='listening port number')
    parser.add_argument('--nostream', type=int, default=0,
                        help='1 for no streaming')
    parser.add_argument('--purge', type=int, default=0, help='1 for purge')


    args = parser.parse_args()

    if args.purge==1:
      db_purge()
#    db_dump()
    if args.nostream==1:
      handleNostream(24,183,143,445)
    else:
      run(port=args.port)