#!/usr/bin/env python3
#-----------------------------------------------------------------------------
# Copyright (C) 2020 Correlated Solutions, Inc.
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
# 
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH
# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
# AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
# OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
# PERFORMANCE OF THIS SOFTWARE.
#-----------------------------------------------------------------------------
#
# Convert abaqus mesh based on shell elements into a vtk polygon file 
#
# version 1.0: initial release

from odbAccess import *

import argparse
import math
import os
import vtk

def write_clean_vtp(object, filename):
    """Save vtkPolyData object to filename. A vtkCleanPolyData() filter
    is first applied. This ensures that node labels are consecutive and
    start at 0.
    """
    clean = vtk.vtkCleanPolyData()
    clean.SetInputData(object)
    clean.Update()
    writer = vtk.vtkXMLPolyDataWriter()
    writer.SetFileName(filename)
    writer.SetInputData(clean.GetOutput())
    writer.Write()

def available_fields(odb):
    """Return whether displacement and strains are available"""
    keys = odb.steps.keys()
    step = odb.steps[keys[0]]
    f_keys = step.frames[-1].fieldOutputs.keys()
    has_disp = 'U' in f_keys
    has_strain = 'E' in f_keys

    return has_disp, has_strain
    

__desc__ = '''This program converts a mesh and field outputs contained
in an abaqus output database file (odb) to one or more VTK polygon files
(vtp), depending on the number of steps contained in the odb. The script
automatically detects whether displacement and strains are present. This
script only works for shell elements.

'''

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__desc__)
    parser.add_argument('input', help='Abaqus odb file.')
    parser.add_argument('--prefix', '-p', help='Prefix for output file.', required=True)
    parser.add_argument('--zero', '-z', default=False, action='store_true',
                            help='Generate an output file with no deformation before first step in odb')
    args, unknown = parser.parse_known_args()

    odb = openOdb(path = args.input)
    assembly = odb.rootAssembly

    # Get nodes and elements from all instances
    nodes = {}
    elements = []
    for key in assembly.instances.keys():
        print('Adding elements and nodes for instance: {}'.format(key))
        instance = assembly.instances[key]
        for node in instance.nodes:
            nodes[node.label] = [node.coordinates[0], node.coordinates[1], node.coordinates[2]]
        for element in instance.elements:
            elements.append([n for n in element.connectivity])

    # Create vtk poly data
    surface = vtk.vtkPolyData()
    # Add the nodes to the vtk data, preserving the node labeling from abaqus
    points = vtk.vtkPoints()
    surface.SetPoints(points)
    for key, p in nodes.iteritems():
        points.InsertPoint(key, p[0], p[1], p[2])

    # Add the elements, converted to triangles
    polys = vtk.vtkCellArray()
    surface.SetPolys(polys)
    for cell in elements:
        n_nodes = len(cell)
        if n_nodes == 3:
            polys.InsertNextCell(3, cell)
        elif n_nodes == 4:
            polys.InsertNextCell(3, [cell[0], cell[1], cell[3]])
            polys.InsertNextCell(3, [cell[1], cell[2], cell[3]])
        elif n_nodes == 8:
            # This is a quadratic shell element
            #   nodes 0, 1, 2, 3 are the corners
            #   nodes 4, 5, 6, 7 are mid-points between corners
            # We convert into six triangles:
            #  four triangles for each corner + neighboring midpoints
            #  two triangles for the center quad made up of mid-points
            polys.InsertNextCell(3, [cell[0], cell[7], cell[4]])
            polys.InsertNextCell(3, [cell[7], cell[3], cell[6]])
            polys.InsertNextCell(3, [cell[7], cell[6], cell[4]])
            polys.InsertNextCell(3, [cell[4], cell[6], cell[5]])
            polys.InsertNextCell(3, [cell[4], cell[5], cell[1]])
            polys.InsertNextCell(3, [cell[6], cell[2], cell[5]])
    # Check whether we have displacements and strains
    has_disp, has_strain = available_fields(odb)
    
    # Add the displacements
    U = []
    if has_disp:
        for name in ['U', 'V', 'W']:
            S = vtk.vtkFloatArray()
            S.SetName(name)
            surface.GetPointData().AddArray(S)
            U.append(S)
            # Create zero deformation if requested
            if args.zero:
                for key in nodes.keys():
                    S.InsertValue(key, 0.0)
    E = []
    if has_strain:
        for name in ['exx', 'eyy', 'exy']:
            S = vtk.vtkFloatArray()
            S.SetName(name)
            surface.GetPointData().AddArray(S)
            E.append(S)
            # Create zero deformation if requested
            if args.zero:
                for key in nodes.keys():
                    S.InsertValue(key, 0.0)

    # The offset is to account for a zero deformation file, if requested
    file_offset = 1 if args.zero else 0
    odb_keys = odb.steps.keys()
    num_files = len(odb_keys) + file_offset

    # Create a format string with sufficient leading zeroes
    try:
        n_dig = int(math.log10(num_files)+1)
    except:
        n_dig = 1
    fmt_string = args.prefix + '_{:0' + '{}'.format(n_dig) + 'd}.vtp'

    # Write zero deformation file if requested
    if args.zero:
        output_file = fmt_string.format(0)
        write_clean_vtp(surface, output_file)
    
    # Create one vtp file for each step
    for step_no, step in enumerate(odb_keys):
        # Reset the node displacements
        for S in U:
            S.Reset()
        for S in E:
            S.Reset()
        # Write the node displacement for current step using
        # the last frame of the step
        field = odb.steps[step].frames[-1].fieldOutputs
        if has_disp:
            for d in field['U'].values:
                key = int(d.nodeLabel)
                U[0].InsertValue(key, d.data[0])
                U[1].InsertValue(key, d.data[1])
            if len(U) == 3:
                for d in field['U'].values:
                    key = int(d.nodeLabel)
                    U[2].InsertValue(key, d.data[0])
            else:
                for d in field['U'].values:
                    key = int(d.nodeLabel)
                    U[2].InsertValue(key, 0)
        # Get the strains and average the element strains at the nodes
        if has_strain:
            acc = {}
            for d in field['E'].getSubset(position=ELEMENT_NODAL).values:
                key = int(d.nodeLabel)
                v = acc.get(key, [0.0, 0.0, 0.0, 0.0])
                v[0] += 1.0
                v[1] += d.data[0]
                v[2] += d.data[1]
                v[3] += d.data[3]
                acc[key] = v
            for key, v in acc.iteritems():
                scale = 1.0 / v[0]
                E[0].InsertValue(key, scale * v[1])
                E[1].InsertValue(key, scale * v[2])
                E[2].InsertValue(key, scale * v[3])
                
        # Output file name
        output_file = fmt_string.format(step_no + file_offset)
        write_clean_vtp(surface, output_file)

        # if there is no field output data to write, only write first
        # step as all remaining steps are identical
        if not has_disp and not has_strain:
            break
