#!/usr/bin/env python3
#-----------------------------------------------------------------------------
# Copyright (C) 2017 Correlated Solutions, Inc.
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
# 
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
# OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.
# 
#
# This program illustrates how to compute the reference and deformed volume
# between two surfaces given as two separate AOIs in the data file.
#
# The data set must contain two AOIs corresponding to the two surfaces. The
# correspondence between points is based on the X/Y coordinates.
#-----------------------------------------------------------------------------

from VicPy import * 

from hexahedron_volume import *

import numpy as np
import math
from pathlib import Path

import glob
import os
from os import _exit as exit
import sys

def compute_volume(data):
    # Check that two AOIs are present
    if data.numData() != 2:
        print("Data must have exactly two aois.", file=sys.stderr)
        return false

    # The add variable function returns the actual name
    # of the variable, which might be different from the
    # requested name if the variable already exists.
    dv_var = data.addVariable("dV", "ΔV/V0 [1]")
    v0_var = data.addVariable("V0", "V0 [mm^3]")
    v1_var = data.addVariable("V1", "V1 [mm^3]")

    dt_var = data.addVariable("dt", "Δt [mm]")
    t0_var = data.addVariable("t0", "t0 [mm]")
    t1_var = data.addVariable("t1", "t1 [mm]")
    
    # Get top and bottom AOI
    top = data.data(0)
    bot = data.data(1)
    
    for side in ['Side 1', 'Side 2']:
        top_ids = []
        bot_ids = []
        for v in ['X', 'Y', 'Z', 'U', 'V', 'W', 'sigma']:
            top_ids.append(top.varIndex(v))
            bot_ids.append(bot.varIndex(v))
    
        dv_top = top.varIndex(dv_var)
        v0_top = top.varIndex(v0_var)
        v1_top = top.varIndex(v1_var)
        dt_top = top.varIndex(dt_var)
        t0_top = top.varIndex(t0_var)
        t1_top = top.varIndex(t1_var)
        
        # Show 30 dots for progress
        progress = int(top.matrixSize() / 30)
        print(side + ": ", end='')
        n_col = top.numColumns()
        n_row = top.numRows()
        i = 0
        sum_v0 = 0
        sum_v1 = 0
        for row in range(0, n_row - 1):
            r0 = row * n_col
            for col in range(0, n_col - 1):
                idx = r0 + col
                i += 1
                if i % progress == 0:
                    print(".", end='', flush='True')
                # Get values on top surface at current point
                t_val0 = top.values(idx, top_ids)
                # Check if data is valid at this point
                if t_val0[-1] < 0:
                    continue
                # right neighbor
                t_val1 = top.values(idx + 1, top_ids)
                # bottom neighbor
                t_val2 = top.values(idx + n_col, top_ids)
                # bottom right neighbor
                t_val3 = top.values(idx + n_col + 1, top_ids)
    
                if t_val1[-1] < 0 or t_val2[-1] < 0 or t_val3[-1] < 0:
                    # This point does not have valid neighbors. We set
                    # the sigma value to a high positive value here, so
                    # it will still participate in data lookup when we
                    # process the other side. In a final pass, we set
                    # the value to a negative one to eliminate the point
                    top.setValue(idx, top_ids[-1], 100)
                    continue
                
                # Try to get corresponding point on bottom surface
                # t_val[0/1] are the X/Y coordinates on top
                ok, b_val0 = bot.atGlobalXY(t_val0[0], t_val0[1], bot_ids);
                if ok:
                    ok, b_val1 = bot.atGlobalXY(t_val1[0], t_val1[1], bot_ids);
                if ok:
                    ok, b_val2 = bot.atGlobalXY(t_val2[0], t_val2[1], bot_ids);
                if ok:
                    ok, b_val3 = bot.atGlobalXY(t_val3[0], t_val3[1], bot_ids);
                if ok:
                    pts = []
                    pts.append([b_val0[0], b_val0[1], b_val0[2]])
                    pts.append([b_val1[0], b_val1[1], b_val1[2]])
                    pts.append([b_val3[0], b_val3[1], b_val3[2]])
                    pts.append([b_val2[0], b_val2[1], b_val2[2]])
                    pts.append([t_val0[0], t_val0[1], t_val0[2]])
                    pts.append([t_val1[0], t_val1[1], t_val1[2]])
                    pts.append([t_val3[0], t_val3[1], t_val3[2]])
                    pts.append([t_val2[0], t_val2[1], t_val2[2]])
                    v0 = hexahedron_volume(pts)
    
                    pts = []
                    pts.append([b_val0[0] + b_val0[3], b_val0[1] + b_val0[4], b_val0[2] + b_val0[5]])
                    pts.append([b_val1[0] + b_val1[3], b_val1[1] + b_val1[4], b_val1[2] + b_val1[5]])
                    pts.append([b_val3[0] + b_val3[3], b_val3[1] + b_val3[4], b_val3[2] + b_val3[5]])
                    pts.append([b_val2[0] + b_val2[3], b_val2[1] + b_val2[4], b_val2[2] + b_val2[5]])
                    pts.append([t_val0[0] + t_val0[3], t_val0[1] + t_val0[4], t_val0[2] + t_val0[5]])
                    pts.append([t_val1[0] + t_val1[3], t_val1[1] + t_val1[4], t_val1[2] + t_val1[5]])
                    pts.append([t_val3[0] + t_val3[3], t_val3[1] + t_val3[4], t_val3[2] + t_val3[5]])
                    pts.append([t_val2[0] + t_val2[3], t_val2[1] + t_val2[4], t_val2[2] + t_val2[5]])
                    v1 = hexahedron_volume(pts);

                    # Compute thickness change
                    x = t_val0[0] - b_val1[0]
                    y = t_val0[1] - b_val1[1]
                    z = t_val0[2] - b_val1[2]
                    u = t_val0[3] - b_val1[3]
                    v = t_val0[4] - b_val1[4]
                    w = t_val0[5] - b_val1[5]
                    l0 = math.sqrt(x * x + y * y + z * z)
                    x = x + u
                    y = y + v
                    z = z + w
                    l1 = math.sqrt(x * x + y * y + z * z)
                    
                    # Save change in volume
                    top.setValue(idx, v0_top, v0)
                    top.setValue(idx, v1_top, v1)
                    top.setValue(idx, dv_top, v1/v0 - 1)
                    top.setValue(idx, dt_top, l1 - l0)
                    top.setValue(idx, t1_top, l1)
                    top.setValue(idx, t0_top, l0)
                    sum_v0 += v0
                    sum_v1 += v1
                else:
                    # Here, we set sigma to a high value so we can fix it later
                    top.setValue(idx, top_ids[-1], 100)
    
        # Swap top and bottom and repeat
        top, bot = bot, top
        print(' v0: {0:.1f}mm^3 v1: {1:.1f}mm^3\n'.format(sum_v0, sum_v1), end='')
    
    # In this pass, we simply fix invalid points
    for side in ['Side 1', 'Side 2']:
        top_s = top.varIndex("sigma")
        n_col = top.numColumns()
        n_row = top.numRows()
        for i in range(top.matrixSize()):
            s = top.value(i, top_s)
            if s > 50:
                top.setValue(i, top_s, -1)
        for i in range(1, n_row):
            top.setValue(i * n_col - 1, top_s, -1)
        lr = (n_row - 1) * n_col
        for i in range(n_col):
            top.setValue(lr + i, top_s, -1)
        # Swap and repeat
        top, bot = bot, top
    return True
    
if __name__ == '__main__':
    if len(sys.argv) < 3:
        print("Usage: %s prefix input_0.out [input_1.out ...]\n" % sys.argv[0])
        exit(-1)

    prefix = sys.argv[1]
    if Path(prefix).is_file():
        print("The first argument should be a prefix, not an existing file.", file=sys.stderr);
        exit(-1)

    input_files = sys.argv[2:]
    if len(input_files) == 1:
        if not os.path.isfile(input_files[0]):
            input_files = glob.glob(input_files[0])
    for filename in input_files:
        # Create data set and load the data
        data = VicDataSet()
        if data.load(filename) == False:
            print("Could not load dataset from file {0}".format(filename))
            continue

        if compute_volume(data):
            # Save data
            oname = '{0}{1}'.format(prefix, filename)
            if data.save(oname) == False:
                print("Could not save dataset {0]".format(oname))
            
    exit(0)
    
