#!/usr/bin/env python3
#-----------------------------------------------------------------------------
# Copyright (C) 2017 Correlated Solutions, Inc.
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
# 
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
# OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.
# 
#
# This program illustrates how to compute the reference and deformed volume
# between two surfaces given as two separate AOIs in the data file.
#
# The data set must contain two AOIs corresponding to the two surfaces. The
# correspondence between points is based on the X/Y coordinates.
#-----------------------------------------------------------------------------

from VicPy import * 

from hexahedron_volume import *

import numpy as np
import math
from pathlib import Path

import glob
import os
from os import _exit as exit
import sys

def compute_volume(d1, d2):

    # The add variable function returns the actual name
    # of the variable, which might be different from the
    # requested name if the variable already exists.
    dv_var = d1.addVariable("dV", "ΔV/V0 [1]")
    v0_var = d1.addVariable("V0", "V0 [mm^3]")
    v1_var = d1.addVariable("V1", "V1 [mm^3]")

    dt_var = d1.addVariable("dt", "Δt [mm]")
    t0_var = d1.addVariable("t0", "t0 [mm]")
    t1_var = d1.addVariable("t1", "t1 [mm]")

    for aoi in range(d1.numData()):
        # Get top and bottom AOI
        top = d1.data(aoi)
        bot = d2.data(aoi)
    
        top_ids = []
        bot_ids = []
        for v in ['X', 'Y', 'Z', 'U', 'V', 'W', 'sigma']:
            top_ids.append(top.varIndex(v))
            bot_ids.append(bot.varIndex(v))
    
        dv_top = top.varIndex(dv_var)
        v0_top = top.varIndex(v0_var)
        v1_top = top.varIndex(v1_var)
        dt_top = top.varIndex(dt_var)
        t0_top = top.varIndex(t0_var)
        t1_top = top.varIndex(t1_var)
        
        # Show 30 dots for progress
        progress = int(top.matrixSize() / 30)
        n_col = top.numColumns()
        n_row = top.numRows()
        i = 0
        sum_v0 = 0
        sum_v1 = 0
        for row in range(0, n_row - 1):
            r0 = row * n_col
            for col in range(0, n_col - 1):
                idx = r0 + col
                i += 1
                if i % progress == 0:
                    print(".", end='', flush='True')
                # Get values on top surface at current point
                t_val0 = top.values(idx, top_ids)
                # Check if data is valid at this point
                if t_val0[-1] < 0:
                    continue
                # right neighbor
                t_val1 = top.values(idx + 1, top_ids)
                # bottom neighbor
                t_val2 = top.values(idx + n_col, top_ids)
                # bottom right neighbor
                t_val3 = top.values(idx + n_col + 1, top_ids)
    
                if t_val1[-1] < 0 or t_val2[-1] < 0 or t_val3[-1] < 0:
                    # This point does not have valid neighbors. We set
                    # the sigma value to a high positive value here, so
                    # it will still participate in data lookup when we
                    # process the other side. In a final pass, we set
                    # the value to a negative one to eliminate the point
                    top.setValue(idx, top_ids[-1], 100)
                    continue
                
                # Try to get corresponding point on bottom surface
                # t_val[0/1] are the X/Y coordinates on top
                ok, b_val0 = bot.atGlobalXY(t_val0[0], t_val0[1], bot_ids);
                if ok:
                    ok, b_val1 = bot.atGlobalXY(t_val1[0], t_val1[1], bot_ids);
                if ok:
                    ok, b_val2 = bot.atGlobalXY(t_val2[0], t_val2[1], bot_ids);
                if ok:
                    ok, b_val3 = bot.atGlobalXY(t_val3[0], t_val3[1], bot_ids);
                if ok:
                    pts = []
                    pts.append([b_val0[0], b_val0[1], b_val0[2]])
                    pts.append([b_val1[0], b_val1[1], b_val1[2]])
                    pts.append([b_val3[0], b_val3[1], b_val3[2]])
                    pts.append([b_val2[0], b_val2[1], b_val2[2]])
                    pts.append([t_val0[0], t_val0[1], t_val0[2]])
                    pts.append([t_val1[0], t_val1[1], t_val1[2]])
                    pts.append([t_val3[0], t_val3[1], t_val3[2]])
                    pts.append([t_val2[0], t_val2[1], t_val2[2]])
                    v0 = hexahedron_volume(pts)
    
                    pts = []
                    pts.append([b_val0[0] + b_val0[3], b_val0[1] + b_val0[4], b_val0[2] + b_val0[5]])
                    pts.append([b_val1[0] + b_val1[3], b_val1[1] + b_val1[4], b_val1[2] + b_val1[5]])
                    pts.append([b_val3[0] + b_val3[3], b_val3[1] + b_val3[4], b_val3[2] + b_val3[5]])
                    pts.append([b_val2[0] + b_val2[3], b_val2[1] + b_val2[4], b_val2[2] + b_val2[5]])
                    pts.append([t_val0[0] + t_val0[3], t_val0[1] + t_val0[4], t_val0[2] + t_val0[5]])
                    pts.append([t_val1[0] + t_val1[3], t_val1[1] + t_val1[4], t_val1[2] + t_val1[5]])
                    pts.append([t_val3[0] + t_val3[3], t_val3[1] + t_val3[4], t_val3[2] + t_val3[5]])
                    pts.append([t_val2[0] + t_val2[3], t_val2[1] + t_val2[4], t_val2[2] + t_val2[5]])
                    v1 = hexahedron_volume(pts);

                    # Compute thickness change
                    x = t_val0[0] - b_val1[0]
                    y = t_val0[1] - b_val1[1]
                    z = t_val0[2] - b_val1[2]
                    u = t_val0[3] - b_val1[3]
                    v = t_val0[4] - b_val1[4]
                    w = t_val0[5] - b_val1[5]
                    l0 = math.sqrt(x * x + y * y + z * z)
                    x = x + u
                    y = y + v
                    z = z + w
                    l1 = math.sqrt(x * x + y * y + z * z)
                    
                    # Save change in volume
                    top.setValue(idx, v0_top, v0)
                    top.setValue(idx, v1_top, v1)
                    top.setValue(idx, dv_top, v1/v0 - 1)
                    top.setValue(idx, dt_top, l1 - l0)
                    top.setValue(idx, t1_top, l1)
                    top.setValue(idx, t0_top, l0)
                    sum_v0 += v0
                    sum_v1 += v1
                else:
                    # Here, we set sigma to a high value so we can fix it later
                    top.setValue(idx, top_ids[-1], 100)
    
    # In this pass, we simply fix invalid points
    for aoi in range(d1.numData()):
        top = d1.data(aoi)
        top_s = top.varIndex("sigma")
        n_col = top.numColumns()
        n_row = top.numRows()
        for i in range(top.matrixSize()):
            s = top.value(i, top_s)
            if s > 50:
                top.setValue(i, top_s, -1)
        for i in range(1, n_row):
            top.setValue(i * n_col - 1, top_s, -1)
        lr = (n_row - 1) * n_col
        for i in range(n_col):
            top.setValue(lr + i, top_s, -1)
    return True
    
if __name__ == '__main__':
    if len(sys.argv) < 3:
        print("Usage: %s prefix input_0.out [input_1.out ...]\n" % sys.argv[0])
        exit(-1)

    prefix = sys.argv[1]
    if Path(prefix).is_file():
        print("The first argument should be a prefix, not an existing file.", file=sys.stderr);
        exit(-1)

    input_files = sys.argv[2:]
    if len(input_files) == 1:
        if not os.path.isfile(input_files[0]):
            input_files = glob.glob(input_files[0])
    sys_1_files = []
    for s in input_files:
        s1 = None
        s2 = None
        added = False
        if '-sys1-' in s:
            s1 = s
            s2 = s.replace('-sys1-', '-sys2-')
        elif '-sys2-' in s:
            s2 = s
            s1 = s.replace('-sys2-', '-sys1-')
        if s1 and s2 and s1 not in sys_1_files:
            if Path(s1).is_file() and Path(s2).is_file():
                sys_1_files.append(s1)
                added = True
        if not added:
            print('Skipping file {}'.format(s))
    for s1 in sys_1_files:
        s2 = s1.replace('-sys1-', '-sys2-')
        # Create data set and load the data
        d1 = VicDataSet()
        d2 = VicDataSet()
        if not d1.load(s1):
            print("Could not load dataset from file {0}".format(s1))
            continue
        if not d2.load(s2):
            print("Could not load dataset from file {0}".format(s2))
            continue
        if d1.numData() != d2.numData():
            print('Mismatch in number of aois for {}, skipping...'.format(s1))
            continue
        print('Processing file {}'.format(s1), end='')
        if compute_volume(d1, d2):
            oname = '{0}{1}'.format(prefix, s1)
            if d1.save(oname) == False:
                print("Could not save dataset {0]".format(oname))
        if compute_volume(d2, d1):
            oname = '{0}{1}'.format(prefix, s2)
            if d2.save(oname) == False:
                print("Could not save dataset {0]".format(oname))
        print()
            
    exit(0)
    
