import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import scipy.signal
import json


# ----- File functions -----
# (internally, images are represented as floats to reduce errors)
def LoadImage(filename):
	img = mpimg.imread(filename).astype(np.float)*255
	if 2 == len(img.shape): # monochrome image
		return img
	if 3 == len(img.shape):
		return img[:,:,0] # return red channel


# returns an image scaled to the nearest power of 2
def ScaleImage(inputImage):
	assert inputImage.dtype == np.float
	size = np.max(inputImage.shape)
	roundedSize = 2**int(np.log2(size)+0.5)
	
	if(roundedSize == inputImage.shape[0] and roundedSize == inputImage.shape[1]):
		return inputImage
	
	scaled = scipy.misc.imresize(inputImage, (roundedSize, roundedSize), "lanczos")
	return np.array(scaled, dtype=np.float)


def SaveImage(filename, img):
	plt.imsave(filename, np.clip(img, 0, 255).astype(np.uint8))


def ShowImage(img, title=""):
	assert img.dtype == np.float
	plt.figure(str(title))
	plt.imshow(img, vmin=0, vmax=255)





# ----- Sampling functions -----

def Down(img, n=1):
	assert img.dtype == np.float
	
	down = img
	k = np.array([[1, 1], [1, 1]])*0.25 # filter kernel
	for i in range(n):
		f = scipy.signal.convolve2d(down[:,:], k, 'full')
		down = f[1::2,1::2] # the 1 crops the border from the convolution
	return np.asarray(down)


def Up(img, n=1):
	assert img.dtype == np.float
	
	u = img
	for i in range(n):
		u = np.repeat(u, 2, 0)
		u = np.repeat(u, 2, 1)
		#u *= 0.25
	return u
	

# Computes the Gaussian Pyramid from an image
def GaussianPyramid(img):
	assert img.dtype == np.float
	assert img.shape[0] == img.shape[1] # check for square
	s = img.shape[0]
	assert ((s & (s - 1)) == 0) and s > 0 # check power of two

	n = int(np.log2(s))+1
	gaussianPyramid = np.asarray([Up(Down(img, i), i) for i in range(n)])
	return gaussianPyramid


# Computes the Laplacian pyramid from a Gaussian pyramid
def LaplacianPyramid(gaussianPyramid):
	laplacianPyramid = np.empty(gaussianPyramid.shape)
	for i in range(gaussianPyramid.shape[0]-1):
		laplacianPyramid[i,:,:] = gaussianPyramid[i,:,:] - gaussianPyramid[i+1,:,:]
	laplacianPyramid[-1,:,:]=gaussianPyramid[-1,:,:]
	return laplacianPyramid


def Decompose(img):
	return LaplacianPyramid(GaussianPyramid(img))






def Compare(originalFilename, solutionFilename):
	original = ScaleImage(LoadImage(originalFilename))
	solution = ScaleImage(LoadImage(solutionFilename))
	
	lo = Decompose(original)
	ls = Decompose(solution)
		
	if(ls.shape[0] > lo.shape[0]): # input image has more details
		ls = ls[:lo.shape[0], :, :] # so we drop them
	
	while(ls.shape[0] < lo.shape[0]): # fewer details, fill with zeros
		ls = np.concatenate((np.zeros((1, ls.shape[1], ls.shape[2])), ls))
	
	# now the image sizes might still missmatch - use Up and Down to adjust them
	while(ls.shape[1] < lo.shape[1]):
		tmp = np.empty((ls.shape[0], ls.shape[1]*2, ls.shape[1]*2))
		for i in range(ls.shape[0]):
			tmp[i, :, :] = Up(ls[i, :, :])
		ls = tmp
		
	while(ls.shape[1] > lo.shape[1]):
		tmp = np.empty((ls.shape[0], ls.shape[1]//2, ls.shape[1]//2))
		for i in range(ls.shape[0]):
			tmp[i, :, :] = Down(ls[i, :, :])
		ls = tmp
	
	assert(lo.shape == ls.shape)
	
	
	# compare the layers
	layers = np.empty((lo.shape[0]))
	total = 0
	for i in range(lo.shape[0]):
		d = lo[i, :, :] - ls[i, :, :]
		n = np.linalg.norm(d, 'fro') / lo.shape[1]
		total += n**2
		layers[i] = n
	total = np.sqrt(total/lo.shape[0])
	
	return (layers, total)



doc = \
'''command line arguments: [ReferenceFile] [SubmissionFile]
Example: python Texture.py DataBase/Texture/Concert.png Users/ID404/Concert.png'''
if __name__ == "__main__":
	
	if len(sys.argv) != 3:
		print("invalid call:\n", doc)
		sys.exit(2) # the 2 sometimes signals command line syntax errors
	
	(layers, total) = Compare(sys.argv[1], sys.argv[2])
	
	while(len(layers) < 8): # fill beginning with 0s
		layers = np.concatenate(([0], layers))
	
	output = {
	"status": "success",
	"L0": layers[0],
	"L1": layers[1],
	"L2": layers[2],
	"L3": layers[3],
	"L4": layers[4],
	"L5": layers[5],
	"L6": layers[6],
	"L7": layers[7],
	"total": total
	}
	
	print(json.dumps(output, indent="\t"))