import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial import KDTree
import sys
import json

def LoadObj(filename):
	P = list() # position vectors
	N = list() # normal vectors
	Pi = list() # position indices
	Ni = list() # normal indices

	numberOfObjects = 0

	with open(filename, 'r') as f:
		read_data = f.read()
	lines = read_data.splitlines()
	assert len(lines) != 0

	for line in lines:
		line = line.strip()
		if len(line) == 0:
			continue
		if line[0] == '#': #skip comments
			continue
		l = line.split(" ")

		if 'o' == l[0]:
			numberOfObjects += 1
			if 1<numberOfObjects:
				raise Exception("only a single object per file is allowed")
			continue # we don't handle object names

		elif 'v' == l[0]:
			assert len(l) == 4
			P.append((float(l[1]), float(l[2]), float(l[3])))

		elif 'vn' == l[0]:
			assert len(l) == 4
			N.append((float(l[1].replace("-nan(ind)", "0")), float(l[2].replace("-nan(ind)", "0")), float(l[3].replace("-nan(ind)", "0"))))

		elif 's' == l[0]:
			continue # we ignore 'smooth shading' instructions

		elif 'f' == l[0]:
			if len(l) != 4:
				raise Exception("only triangulated objects are supported")
			p = list()
			n = list()
			for i in range(1,4):
				x = l[i].split('/')

				if len(x) != 3:
					raise Exception("face format must be p/(t)/n, only models with normals are supported!")
				p.append(int(x[0]))
				# texture coordinates are ignored
				n.append(int(x[2]))
			Pi.append(p)
			Ni.append(n)

	# now resolve the indices and generate position and normal buffers
	vertexCount = len(Pi)*3
	positions = np.zeros((vertexCount, 3), dtype=np.float)
	normals = np.zeros((vertexCount, 3), dtype=np.float)

	for i in range(0, vertexCount, 3):
		positions[i+0,:] = P[Pi[i//3][0]-1]
		positions[i+1,:] = P[Pi[i//3][1]-1]
		positions[i+2,:] = P[Pi[i//3][2]-1]

		normals[i+0,:] = N[Ni[i//3][0]-1]
		normals[i+1,:] = N[Ni[i//3][1]-1]
		normals[i+2,:] = N[Ni[i//3][2]-1]

	return positions, normals


###########################################################


# removes all triangles that are not visible from the origin.
def FilterObject(positions, normals):
	if(0 == len(positions) or 0 == len(normals)):
		raise Exception("Empty mesh!")
	newPositions = list()
	for i in range(0, len(positions), 3):
		p0 = positions[i+0]
		p1 = positions[i+1]
		p2 = positions[i+2]
		n0 = normals[i+0]
		n1 = normals[i+1]
		n2 = normals[i+2]

		# compute a plane
		N = np.cross(p0-p1, p0-p2)
		offset = np.dot(N, p0)

		if np.dot(p0+(n0+n1+n2), N) < offset: # triangle is facing in the other direction
			N = -N
			offset = -offset

		# now check if laser spot (origin) is above this plane
		l = np.array([0, 0, 0])
		if np.dot(l, N) > offset:
			newPositions.append(p0)
			newPositions.append(p1)
			newPositions.append(p2)
	if(0 == len(newPositions)):
		raise Exception("All triangles were removed!")
	return np.array(newPositions, dtype=float)


###########################################################


def TriangleCenters(positions):
	centers = np.empty((positions.shape[0]//3, 3))
	for i in range(0, len(positions), 3):
		centers[i//3, :] = (positions[i+0, :] + positions[i+1, :] + positions[i+2, :]) / 3
	return centers

def TriangleSizes(positions):
	sizes = np.empty((positions.shape[0]//3))
	for i in range(0, len(positions), 3):
		sizes[i//3] = np.linalg.norm(np.cross(positions[i+1]-positions[i], positions[i+2]-positions[i]))/2
	return sizes

def MeshDistance(meshFrom, meshTo):
	centersFrom = TriangleCenters(meshFrom)
	centersTo = TriangleCenters(meshTo)
	areasFrom = TriangleSizes(meshFrom)

	tree = KDTree(centersTo)
	distances,i = tree.query(centersFrom)

	return np.sum(distances*areasFrom) / np.sum(areasFrom)

def SymmetricMeshDistance(mesh1, mesh2):
	return max(MeshDistance(mesh1, mesh2), MeshDistance(mesh2, mesh1))


###########################################################


def PlotModel(positions):
	# vertices
	if flat:
		ax.scatter(positions[:,0], positions[:,1], color='C0')
	else:
		ax.scatter(positions[:,0], positions[:,1], positions[:,2], color='C0', depthshade=0)

	# edges
	for i in range(len(positions)//3):
		v = np.array([positions[i*3+0,:], positions[i*3+1,:], positions[i*3+2,:], positions[i*3+0]])
		if flat:
			ax.plot(v[:,0], v[:,1], color='C2', alpha=0.5)
		else:
			ax.plot(v[:,0], v[:,1], v[:,2], color='C2', alpha=0.5)

	# centers
	centers = TriangleCenters(positions)
	if flat:
		ax.scatter(centers[:,0], centers[:,1], s=(TriangleSizes(positions)*10)**2, color='C1')
	else:
		ax.scatter(centers[:,0], centers[:,1], centers[:,2], s=(TriangleSizes(positions)*10)**2, color='C1', depthshade=0)

	ax.scatter(0, 0, 0, s=100, color='C4')


	"""ax.set_xlabel('x')
	ax.set_ylabel('y')
	if not flat:
		ax.set_zlabel('z')
		scale = [-0.5, 0.5]
		#ax.auto_scale_xyz(scale, scale, scale)"""
	ax.axis('off')

def PlotConnection(meshFrom, meshTo):
	centersFrom = TriangleCenters(meshFrom)
	centersTo = TriangleCenters(meshTo)
	tree = KDTree(centersTo)
	distances,i = tree.query(centersFrom)

	toPoints = centersTo[i]
	for idx in range(len(centersFrom)):
		if flat:
			ax.plot([centersFrom[idx,0], toPoints[idx,0]],
			  [centersFrom[idx,1], toPoints[idx,1]], color='0.2')
		else:
			ax.plot([centersFrom[idx,0], toPoints[idx,0]],
			  [centersFrom[idx,1], toPoints[idx,1]],
			  [centersFrom[idx,2], toPoints[idx,2]], color='0.2')


###########################################################




def ComputePrecisionCompleteness(originalFile, reconstructionFile):
	o = FilterObject(*LoadObj(originalFile))
	r = FilterObject(*LoadObj(reconstructionFile))
	print("{AB:.4}\t{BA:.4}".format(AB=MeshDistance(r, o), BA=MeshDistance(o, r)))


doc = \
'''command line arguments: [ReferenceFile] [SubmissionFile]
Example: python Geometry.py DataBase/Geometry/Dinosaur.obj Users/ID404/Geometry/Dinosaur.obj'''
if __name__ == "__main__":
	if len(sys.argv) != 3:
		print("invalid call:\n", doc)
		sys.exit(2) # the 2 sometimes signals command line syntax errors
	
	ref = FilterObject(*LoadObj(sys.argv[1]))
	rec = FilterObject(*LoadObj(sys.argv[2]))
	
	precision = MeshDistance(ref, rec)
	completeness = MeshDistance(rec, ref)
	
	output = {
		"status": "success",
		"precision": precision,
		"completeness": completeness,
		"total": max(precision, completeness)
		}
	
	print(json.dumps(output, indent="\t"))
