import os
import sys
import bpy, bmesh, mathutils
import time, struct, os
from bpy.props import *


import glob

import nrfile
import nrtools
import nrdump
import nrblendtools


# Short wrapper
def writeLog(s):
    nrblendtools.writeLog(s)


def retVal(status, msg):
    res = {}
    res['status'] = status
    res['msg']    = msg
    return res


def isMeshSkippedFromImport(options, rsttProps):
    if not rsttProps:
        return False

    if options['ignoreDepthEnableFalse'] and rsttProps.isFalse('DEPTH_ENABLE'):
        writeLog("Ignoring mesh DEPTH_ENABLE==FALSE")
        return True

    if options['ignoreDepthWriteEnableFalse'] and rsttProps.isFalse('DEPTH_WRITE_ENABLE'):
        writeLog("Ignoring mesh DEPTH_WRITE_ENABLE==FALSE")
        return True

    if options['ignoreRGBWriteDisabled'] and rsttProps.isTrue('RGBWRITE_DISABLED'):
        writeLog("Ignoring mesh RGBWRITE_DISABLED==TRUE")
        return True

    if options['ignoreIfRenderedToRenderTarget'] and rsttProps.isTrue('RENDERED_TO_RENDERTARGET'):
        writeLog("Ignoring mesh RENDERED_TO_RENDERTARGET==TRUE")
        return True

    if options['ignoreIfRenderedToBackBuffer'] and rsttProps.isFalse('RENDERED_TO_RENDERTARGET'):
        writeLog("Ignoring mesh RENDERED_TO_RENDERTARGET==FALSE")
        return True

    if options['ignoreIfRenderTargetWidthDoesNotMatchToBackBuffer'] and rsttProps.isTrue('RT_WIDTH_NOT_MATCH_BACKBUF'):
        writeLog("Ignoring mesh RT_WIDTH_NOT_MATCH_BACKBUF==TRUE")
        return True

    return False


def isPostShaderStage(shaderStage):
    if shaderStage == nrfile.ShaderStage.Vs:
        return True
    if shaderStage == nrfile.ShaderStage.DsGs:
        return True
    return False


def assignNormals(options, mesh, vatrs, vert, vertexData, faces):
    normalsTab = options['normalsTab']
    if normalsTab == 'DISABLED':
        return
    elif normalsTab == 'AUTO':
        # Assign normals
        # TODO: Manual settings
        mesh.polygons.foreach_set("use_smooth", [True] * len(faces))
        normalsVaList  = vatrs.findSemantic("NORMAL")
        if len(normalsVaList):
            # Check for normal is 3 components
            if normalsVaList[0].compCount == 3:
                normals = nrtools.unpackVertexComponentAsList(vert, vertexData, normalsVaList[0])
                if normals is not None:
                    mesh.use_auto_smooth = True
                    if hasattr(mesh, "show_normal_vertex"):
                        mesh.show_normal_vertex = True
                    if hasattr(mesh, "show_normal_loop"):
                        mesh.show_normal_loop = True
                    mesh.normals_split_custom_set_from_vertices(normals)
    elif normalsTab == 'MANUAL':
        normalAttrCompX  = options['normalAttrCompX']  # [attrIdx, compIdx]
        normalAttrCompY  = options['normalAttrCompY']  # [attrIdx, compIdx]
        normalAttrCompZ  = options['normalAttrCompZ']  # [attrIdx, compIdx]

        normals = nrtools.unpackVertexComponentVaAsList(vert, vertexData, vatrs, [normalAttrCompX, normalAttrCompY, normalAttrCompZ])
        if normals is not None:
            mesh.use_auto_smooth = True
            if hasattr(mesh, "show_normal_vertex"):
                mesh.show_normal_vertex = True
            if hasattr(mesh, "show_normal_loop"):
                mesh.show_normal_loop = True
            mesh.normals_split_custom_set_from_vertices(normals)


def importNrFiles(paths, options):

    if 'logger' in options:
        nrblendtools.setLogger(options['logger'])

    importer = BlenderImporter(options)

    for file in paths:
        if os.path.isfile(file):
            importer.importPostVs(file)
        elif os.path.isdir(file):
            fileList = glob.glob(file + "*.nr")
            for file in fileList:
                importer.importPostVs(file)

    importer.printInfo("PostVS")
    nrblendtools.setFarClipDistance()
    importer.selectLargestObjectViewSelected()


###################################################################################################
## PREVS
###################################################################################################
def importNrFilesPreVs(paths, options):

    if 'logger' in options:
        nrblendtools.setLogger(options['logger'])

    importer = BlenderImporter(options)

    for file in paths:
        if os.path.isfile(file):
            importer.importPreVs(file)
        elif os.path.isdir(file):
            fileList = glob.glob(file + "*.nr")
            for file in fileList:
                importer.importPreVs(file)
    
    importer.printInfo("PreVS")
    nrblendtools.setFarClipDistance()
    importer.selectLargestObjectViewSelected()


class BlenderImporter(object):
    def __init__(self, options):
        self.options = options
        self.matCache = nrblendtools.MaterialCache()
        self.groupMgr = nrblendtools.GroupManager()
        self.totalFilesCount = 0
        self.totalCreated = 0
        self.totalSkipped = 0
        self.loadExtraVertexData = False
        if (options['texMode'] == 'EXTRA_TEXCOORDBYNAME') or (options['texMode'] == 'EXTRA_SCATTERTEXCOORD'):
            self.loadExtraVertexData = True
        self._maxNrSize = 0
        self._maxMeshName = ''

    def printInfo(self, shaderStage):
        writeLog("{:s}. Total parsed files count={:d}".format(shaderStage, self.totalFilesCount))
        writeLog("Total created/skipped meshs {:d}/{:d}".format(self.totalCreated, self.totalSkipped))
        writeLog("Images loaded={:d}. Images failed={:d}".format(len(self.matCache.loadedImgs), len(self.matCache.failedImgs)))
        for filename in self.matCache.failedImgs:
          writeLog("Failed: {:s}".format(filename))
        writeLog("Largest NR-file: {:s}. FileSize={:d}".format(self._maxMeshName, self._maxNrSize))


    def selectLargestObjectViewSelected(self):
        #bpy.context.scene.objects.active = None
        if 0 == nrblendtools.ver_blender():
            # blender < 2.8
            cube = bpy.data.objects.get('Cube')  # Default Cube
            if cube:
                cube.select = False
            largeObj = bpy.data.objects.get(self._maxMeshName)
            if largeObj:
                largeObj.select = True
                bpy.context.scene.objects.active = largeObj
        else:
            # blender >= 2.8
            cube = bpy.data.objects.get('Cube')# Default Cube
            if cube:
                cube.select_set(False)
            largeObj = bpy.data.objects.get(self._maxMeshName)
            if largeObj:
                largeObj.select_set(True)
                bpy.context.view_layer.objects.active = largeObj

        # Set camera view to active object
        for area in bpy.context.screen.areas:
            if area.type == 'VIEW_3D':
                ctx = bpy.context.copy()
                ctx['area'] = area
                ctx['region'] = area.regions[-1]
                bpy.ops.view3d.view_selected(ctx)            # points view
                #bpy.ops.view3d.camera_to_view_selected(ctx)   # points camera

        #bpy.ops.view3d.view_selected(use_all_regions=False)


    def importPreVs(self, filename):
        nr = nrfile.NRFile()
        if not nr.parse(filename):
            errMsg = "Ninja Ripper file parsing failed: " + nr.getErrorString()
            writeLog(errMsg)
            retVal(False, errMsg)

        writeLog("Loading: {:s}".format(filename))

        fileDirectory = os.path.dirname(os.path.abspath(filename))
        
        createdMeshsCount = 0
        for meshIdx in range(0, nr.getMeshCount()):
            nrmesh = nr.getMesh(meshIdx)

            if nrfile.ShaderStage.PreVs != nrmesh.getShaderStage():
                continue

            vert  = nrmesh.getVertexes(0)
            if None == vert:
                writeLog("vertexes == Null")
                continue
            
            indx  = nrmesh.getIndexes(0)
            if None == indx:
                writeLog("indexes == Null")
                continue
            
            if indx.getIndexTopology() != nrfile.IndexTopology.TriangleList:
                writeLog("IndexTopology != TriangleList  (Topology={:s})".format(nrdump.topologyToStr(indx.getIndexTopology())))
                continue
        
            vatrs = nrmesh.getVertexAttributes(0)
            if None == vatrs:
                writeLog("vertexAttribs == Null")
                continue
            
            extraVertexDataOk = False
            vert1 = None
            vatrs1 = None
            if self.loadExtraVertexData:
                # Extra vatrs/vertexes
                vert1  = nrmesh.getVertexes(1)
                vatrs1 = nrmesh.getVertexAttributes(1)
                if vert1:
                    vertexData1 = vert1.read()

                extraVertexDataOk = True
                if (not vert1) or (not vertexData1) or (not vatrs1):
                    extraVertexDataOk = False

                if extraVertexDataOk:
                    if vert.getVertexCount() != vert1.getVertexCount():
                        extraVertexDataOk = False
                        writeLog("Vertex count != Extra vertex count {:d} != {:d}".format(vert.getVertexCount(), vert1.getVertexCount()))

            # Get Textures object (TXTR tag in rip file)
            #  textures = texture files list
            textures = nrmesh.getTextures()

            vertexData = vert.read()

            positions3 = []
            if (self.options['vertexLayoutTab'] == 'AUTO'):
                positions3 = nrtools.unpackVertexComponentVaAsList(vert, vertexData, vatrs, [[0,0], [0,1], [0,2]])
            elif (self.options['vertexLayoutTab'] == 'MANUAL'):
                positions3 = nrtools.unpackVertexComponentVaAsList(vert, vertexData, vatrs, [self.options['posX'], self.options['posY'], self.options['posZ']])
            
            if None == positions3:
                writeLog("positions3 == Null")
                continue

            # Create list of faces
            triangles = indx.read()
            faces = []
            for idx in range(0, int(indx.getIndexCount()/3) ):
                p = struct.unpack_from("iii", triangles, 12*idx)
                f = (p[0], p[1], p[2])
                faces.append(f)

            #Define mesh and object
            meshName = os.path.basename(filename)
            mesh = bpy.data.meshes.new(meshName)
            obj  = bpy.data.objects.new(meshName, mesh)

            #Set location and scene of object
            if hasattr(bpy.context.scene, "cursor_location"):
                obj.location = bpy.context.scene.cursor_location
                bpy.context.scene.objects.link(obj)
            elif hasattr(bpy.context.scene.cursor, "location"):
                obj.location = bpy.context.scene.cursor.location
                bpy.context.collection.objects.link(obj)

            #Create mesh
            mesh.from_pydata(positions3, [], faces)

            # Normals assign
            assignNormals(self.options, mesh, vatrs, vert, vertexData, faces)

            # Create material if textures presented
            if None != textures:
                texList = []
                if textures.getTexturesCount() > 0:
                    texCount = textures.getTexturesCount()
                    # Blender texture slots count maximum 18?
                    if texCount > 8:
                        texCount = 8

                    for i in range(0, texCount):
                        texName = fileDirectory + "\\" + textures.getTexture(i).fileName
                        texList.append(texName)

                mat = self.matCache.createMaterial(texList)
                if mat:
                    obj.data.materials.append(mat)

            mesh.update()

            # Switch to bmesh
            bm = bmesh.new()

            try:
                bm.from_mesh(mesh)
                bm.verts.ensure_lookup_table()

                if (self.options['texMode'] == 'AUTO') or (self.options['texMode'] == 'TEXCOORDBYNAME'):
                    nrblendtools.createTexCoords(bm, self.options, vert, vertexData, vatrs)
                elif (self.options['texMode'] == 'SCATTERTEXCOORD'):
                    nrblendtools.createTexCoordsScatter(bm, self.options, vert, vertexData, vatrs)
                elif (self.options['texMode'] == 'EXTRA_TEXCOORDBYNAME') and extraVertexDataOk:
                    nrblendtools.createTexCoords(bm, self.options, vert1, vertexData1, vatrs1)
                elif (self.options['texMode'] == 'EXTRA_SCATTERTEXCOORD') and extraVertexDataOk:
                    nrblendtools.createTexCoordsScatter(bm, self.options, vert1, vertexData1, vatrs1)

                bm.to_mesh(mesh)
            finally:
                bm.free()

            # Finalize
            mesh.update()

            if nr.getFileSize() > self._maxNrSize:
                self._maxNrSize = nr.getFileSize()
                self._maxMeshName = meshName

            createdMeshsCount = createdMeshsCount + 1

        self.totalCreated = self.totalCreated + createdMeshsCount
        self.totalSkipped = self.totalSkipped + (nr.getMeshCount() - createdMeshsCount)
        self.totalFilesCount = self.totalFilesCount + 1
        return retVal(True, '')


    def importPostVs(self, filename):
        nr = nrfile.NRFile()
        if not nr.parse(filename):
            errMsg = "Ninja Ripper file parsing failed: " + nr.getErrorString()
            writeLog(errMsg)
            return retVal(False, errMsg)

        writeLog("Loading: {:s}".format(filename))

        fileDirectory = os.path.dirname(os.path.abspath(filename))
        
        # TODO: grouping
        grp = self.groupMgr.getGroup(filename)

        createdMeshsCount = 0

        for meshIdx in range(0, nr.getMeshCount()):
            nrmesh = nr.getMesh(meshIdx)

            if nrfile.ShaderStage.PreVs == nrmesh.getShaderStage():
                #writeLog("Skipping PreVs")
                continue

            rsttProps = nrmesh.getProperties(nrfile.createTag('R', 'S', 'T', 'T'))

            if isMeshSkippedFromImport(self.options, rsttProps):
                continue

            vert  = nrmesh.getVertexes(0)
            if None == vert:
                writeLog("vertexes == Null")
                continue
            
            indx  = nrmesh.getIndexes(0)
            if None == indx:
                writeLog("indexes == Null")
                continue
            
            if indx.getIndexTopology() != nrfile.IndexTopology.TriangleList:
                writeLog("IndexTopology != TriangleList  (Topology={:s})".format(nrdump.topologyToStr(indx.getIndexTopology())))
                continue
        
            vatrs = nrmesh.getVertexAttributes(0)
            if None == vatrs:
                writeLog("vertexAttribs == Null")
                continue
            
            # Extra vatrs/vertexes
            extraVertexDataOk = False
            vert1 = None
            vatrs1 = None
            if self.loadExtraVertexData:
                vert1  = nrmesh.getVertexes(1)
                vatrs1 = nrmesh.getVertexAttributes(1)
                if vert1:
                    vertexData1 = vert1.read()

                extraVertexDataOk = True
                if (not vert1) or (not vertexData1) or (not vatrs1):
                    extraVertexDataOk = False

                if extraVertexDataOk:
                    if vert.getVertexCount() != vert1.getVertexCount():
                        extraVertexDataOk = False
                        writeLog("Vertex count != Extra vertex count {:d} != {:d}".format(vert.getVertexCount(), vert1.getVertexCount()))

            # Get Textures object (TXTR tag in rip file)
            #  textures = texture files list
            textures = nrmesh.getTextures()

            # Attr 0 is POSITION
            if vatrs.getAttr(0).compCount != 4:
                writeLog("Post vs position not 4 component")
                continue

            vertexData = vert.read()
            positions3 = nrtools.restorePositionAsList(vert, vertexData, vatrs.getAttr(0), self.options['projmat'])
            if None == positions3:
                writeLog("positions3 == Null")
                continue

            # Create list of faces
            triangles = indx.read()
            faces = []
            for idx in range(0, int(indx.getIndexCount()/3) ):
                p = struct.unpack_from("iii", triangles, 12*idx)
                f = (p[0], p[1], p[2])
                faces.append(f)


            #Define mesh and object
            meshName = os.path.basename(filename)
            mesh = bpy.data.meshes.new(meshName)
            obj  = bpy.data.objects.new(meshName, mesh)


            if grp:
                # If collection created then add object to its collection
                grp.objects.link(obj)

            #Set location and scene of object
            if hasattr(bpy.context.scene, "cursor_location"):
                obj.location = bpy.context.scene.cursor_location
                bpy.context.scene.objects.link(obj)
            elif hasattr(bpy.context.scene.cursor, "location"):
                obj.location = bpy.context.scene.cursor.location
                bpy.context.collection.objects.link(obj)


            #Create mesh
            mesh.from_pydata(positions3, [], faces)

            assignNormals(self.options, mesh, vatrs, vert, vertexData, faces)


            # Create material if textures presented
            if None != textures:
                texList = []
                if textures.getTexturesCount() > 0:
                    texCount = textures.getTexturesCount()
                    # Blender texture slots count maximum 8?
                    if texCount > 8:
                        texCount = 8

                    for i in range(0, texCount):
                        texName = fileDirectory + "\\" + textures.getTexture(i).fileName
                        texList.append(texName)

                mat = self.matCache.createMaterial(texList)
                if mat:
                    obj.data.materials.append(mat)

            mesh.update()

            # Switch to bmesh
            bm = bmesh.new()

            try:
                bm.from_mesh(mesh)
                bm.verts.ensure_lookup_table()

                if (self.options['texMode'] == 'AUTO') or (self.options['texMode'] == 'TEXCOORDBYNAME'):
                    nrblendtools.createTexCoords(bm, self.options, vert, vertexData, vatrs)
                elif (self.options['texMode'] == 'SCATTERTEXCOORD'):
                    nrblendtools.createTexCoordsScatter(bm, self.options, vert, vertexData, vatrs)
                elif (self.options['texMode'] == 'EXTRA_TEXCOORDBYNAME') and extraVertexDataOk:
                    nrblendtools.createTexCoords(bm, self.options, vert1, vertexData1, vatrs1)
                elif (self.options['texMode'] == 'EXTRA_SCATTERTEXCOORD') and extraVertexDataOk:
                    nrblendtools.createTexCoordsScatter(bm, self.options, vert1, vertexData1, vatrs1)

                bm.to_mesh(mesh)
            finally:
                bm.free()

            # Finalize
            mesh.update()

            if nr.getFileSize() > self._maxNrSize:
                self._maxNrSize = nr.getFileSize()
                self._maxMeshName = meshName

            createdMeshsCount = createdMeshsCount + 1

        self.totalCreated = self.totalCreated + createdMeshsCount
        self.totalSkipped = self.totalSkipped + (nr.getMeshCount() - createdMeshsCount)
        self.totalFilesCount = self.totalFilesCount + 1
        return retVal(True, '')
