import numpy as np
import scipy.sparse
from scipy.sparse import csr_matrix
# Plot-Pakete
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
from mpl_toolkits import mplot3d 
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
from matplotlib.animation import FuncAnimation
from matplotlib.animation import PillowWriter
# Verfeinerungs-Routine
from red_refine import red_refine

def get_geom():
    coord = np.asarray([[0,0],[0.5,0],[1,0],[1,0.5],[1,1],[0.5,1],[0,1],
                    [0,0.5],[0.25,0.25],[0.75, 0.25],[0.75,0.75],[0.25,0.75],
                    [0.5,0.5]])
    triangles = np.asarray([[0,1,8],[1,12,8],[1,9,12],[1,2,9],[2,3,9],
                        [3,12,9],[3, 10, 12],[3,4,10],[4, 5, 10],[5, 12, 10],
                        [5, 11, 12],[5, 6, 11],[6, 7, 11],[7, 12, 11],
                        [7, 8, 12], [7, 0, 8]])
    #dirichlet= np.array([[0,1],[1,2], [2,3],[3,4],[4,5],[5,6],[6,7],[7,0]])
    #neumann= np.zeros([0, 2])
    dirichlet= np.array([[0,1],[1,2], [4,5],[5,6]])
    neumann= np.array([[2,3],[3,4],[6,7],[7,0]])
    return coord, triangles, dirichlet, neumann

def restrict2dof(dof,nnodes):
    ndof=np.size(dof)
    R=csr_matrix((np.ones(ndof),(dof,np.arange(0,ndof))),shape = (nnodes,ndof))
    return R

def FEM_matrices(coord,triangles):
    nelems=np.size(triangles,0)
    nnodes=np.size(coord,0)
    Alocal=np.zeros((nelems,3,3))
    Mlocal=np.zeros((nelems,3,3))
    I1=np.zeros((nelems,3,3))
    I2=np.zeros((nelems,3,3))

    for j in range(0,nelems):
        nodes_loc=triangles[j,:]
        coord_loc=coord[nodes_loc,:]
        T=np.array([coord_loc[1,:]-coord_loc[0,:] ,
               coord_loc[2,:]-coord_loc[0,:] ])
        area = 0.5 * ( T[0,0]*T[1,1] - T[0,1]*T[1,0] )
        tmp1= np.concatenate((np.array([[1,1,1]]), coord_loc.T),axis=0)
        tmp2= np.array([[0,0],[1,0],[0,1]])
        grads = np.linalg.solve(tmp1,tmp2)
        Alocal[j,:,:]=area* np.matmul(grads,grads.T)
        Mlocal[j,:,:]=area/12 * np.array([[2,1,1],[1,2,1],[1,1,2]])
        I1[j,:,:] = np.concatenate((np.array([nodes_loc]),
                      np.array([nodes_loc]),np.array([nodes_loc])),axis=0)
        I2[j,:,:] = np.concatenate((np.array([nodes_loc]).T,
                      np.array([nodes_loc]).T,np.array([nodes_loc]).T),axis=1)
 
    Alocal=np.reshape(Alocal,(9*nelems,1)).T
    Mlocal=np.reshape(Mlocal,(9*nelems,1)).T
    I1=np.reshape(I1,(9*nelems,1)).T
    I2=np.reshape(I2,(9*nelems,1)).T
    A=csr_matrix((Alocal[0,:],(I1[0,:],I2[0,:])),shape = (nnodes,nnodes))
    M=csr_matrix((Mlocal[0,:],(I1[0,:],I2[0,:])),shape = (nnodes,nnodes))
    return M,A

def leapfrog(coord,triangles,dirichlet,u0,u1,dt,horizon):
    nT=int(np.ceil(horizon/dt))
    nnodes=np.size(coord,0)
    U=np.zeros((nnodes,nT))
    U[:,0]=u0(coord[:,0],coord[:,1])
    U[:,1]=dt*u1(coord[:,0],coord[:,1])+U[:,0]
    M,A=FEM_matrices(coord,triangles)
    m = np.array(M.sum(axis=1)).ravel()
    MlumpI = scipy.sparse.diags(1.0 / m)
    dbnodes=np.unique(dirichlet)
    dof=np.setdiff1d(range(0,nnodes),dbnodes)
    ndof=np.size(dof)
    R=restrict2dof(dof,nnodes)
    A_inner=(R.transpose()@A)@R
    Mli_inner=(R.transpose()@MlumpI)@R
    L=Mli_inner*A_inner
    for j in range(2,nT):
        tmp=L@U[dof,j-1]
        U[dof,j]=-dt**2*tmp+2*U[dof,j-1]-U[dof,j-2]
    return U


u0 = lambda x, y:  np.sin(np.pi*y)*np.sin(np.pi*x)
u1 = lambda x, y:  0
fac=1
dt=.01/fac
n_ref=4
horizon=5
coord, triangles, dirichlet, neumann = get_geom()
for _ in range(n_ref):
    coord, triangles, dirichlet, neumann, _, _ = red_refine(coord,triangles,dirichlet,neumann)
U = leapfrog(coord,triangles,dirichlet,u0,u1,dt,horizon)

# Visualisierung
fig = plt.figure(figsize =(14, 9))
ax = plt.axes(projection ='3d')
def animate(i):
    ax.clear()
    ax.plot_trisurf(coord[:,0],coord[:,1],U[:,i*fac],
                          triangles = triangles, 
                          cmap =plt.get_cmap('summer'),
                          edgecolor='Gray')
    ax.set_zlim([1.1*np.min(U),1.1*np.max(U)])
ani = FuncAnimation(fig, animate, frames=U.shape[1],
                    interval=10, repeat=False)
plt.show()
plt.close()
#ani.save("test.mp4", fps=20)
#ani.save("test.gif", dpi=300, writer=PillowWriter(fps=1))
