import numpy as np
from mpl_toolkits import mplot3d 
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
from mpl_toolkits import mplot3d
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import math
import pylab
import scipy.sparse
import scipy.sparse.linalg
from scipy.sparse import csr_matrix
from red_refine import red_refine

def get_geom():
    coord = np.asarray([[0,0],[0.5,0],[1,0],[1,0.5],[1,1],[0.5,1],[0,1],
                    [0,0.5],[0.25,0.25],[0.75, 0.25],[0.75,0.75],[0.25,0.75],
                    [0.5,0.5]])
    triangles = np.asarray([[0,1,8],[1,12,8],[1,9,12],[1,2,9],[2,3,9],
                        [3,12,9],[3, 10, 12],[3,4,10],[4, 5, 10],[5, 12, 10],
                        [5, 11, 12],[5, 6, 11],[6, 7, 11],[7, 12, 11],
                        [7, 8, 12], [7, 0, 8]])
    dirichlet= np.array([[0,1],[1,2], [2,3],[3,4],[4,5],[5,6],[6,7],[7,0]])
    neumann= np.zeros([0, 2])
    return coord, triangles, dirichlet, neumann

def FEM(coord,triangles,dirichlet,f):
    nnodes=np.size(coord,0)
    A,b=FEM_data(coord,triangles,f)
    dbnodes=np.unique(dirichlet)
    dof=np.setdiff1d(range(0,nnodes),dbnodes)
    ndof=np.size(dof)
    R=restrict2dof(dof,nnodes)
    A_inner=(R.transpose()@A)@R
    b_inner=R.transpose()@b
    x=np.zeros(nnodes)
    x[dof]=cg(A_inner,b_inner)
    return x, ndof

def restrict2dof(dof,nnodes):
    ndof=np.size(dof)
    R=csr_matrix((np.ones(ndof),(dof,np.arange(0,ndof))),shape = (nnodes,ndof))
    return R

def cg(A,b):
    x=0*b
    n_steps=0
    g_neu = A@x-b
    d=-g_neu
    while np.linalg.norm(A@x-b)>1e-8:
         n_steps=n_steps+1
         g=g_neu 
         tmp=g.T@g 
         alpha = tmp / ( d.T@A@d)
         x = x+alpha * d
         g_neu = A@x - b
         beta = g_neu.T@g_neu / tmp
         d = -g_neu + beta*d
    print('#cg iteration:'+str(n_steps))
    return x

def FEM_data(coord,triangles,f):
    nelems=np.size(triangles,0)
    nnodes=np.size(coord,0)
    Alocal=np.zeros((nelems,3,3))
    I1=np.zeros((nelems,3,3))
    I2=np.zeros((nelems,3,3))

    b=np.zeros(nnodes)
    for j in range(0,nelems):
        nodes_loc=triangles[j,:]
        coord_loc=coord[nodes_loc,:]
        T=np.array([coord_loc[1,:]-coord_loc[0,:] ,
               coord_loc[2,:]-coord_loc[0,:] ])
        area = 0.5 * ( T[0,0]*T[1,1] - T[0,1]*T[1,0] )
        tmp1= np.concatenate((np.array([[1,1,1]]), coord_loc.T),axis=0)
        tmp2= np.array([[0,0],[1,0],[0,1]])
        grads = np.linalg.solve(tmp1,tmp2)
        Alocal[j,:,:]=area* np.matmul(grads,grads.T)
        I1[j,:,:] = np.concatenate((np.array([nodes_loc]),np.array([nodes_loc]),np.array([nodes_loc])),axis=0)
        I2[j,:,:] = np.concatenate((np.array([nodes_loc]).T,np.array([nodes_loc]).T,np.array([nodes_loc]).T),axis=1)
        mid=1/3*(coord_loc[0,:]+coord_loc[1,:]+coord_loc[2,:])
        b[nodes_loc]=b[nodes_loc]+area/3*f(mid[0],mid[1])
 
    Alocal=np.reshape(Alocal,(9*nelems,1)).T
    I1=np.reshape(I1,(9*nelems,1)).T
    I2=np.reshape(I2,(9*nelems,1)).T
    A=csr_matrix((Alocal[0,:],(I1[0,:],I2[0,:])),shape = (nnodes,nnodes))
    return A,b

f = lambda x, y:  -np.exp(x)*x*(x*(y**2-y+2)+3*y**2-3*y-2)
coord, triangles, dirichlet, neumann = get_geom()

u_ex = lambda x,y: np.exp(x)*(x-x**2)*(y-y**2)


n_ref=6

for j in range(0,n_ref):

    if j>0:
        coord, triangles, dirichlet, neumann, _, _ =  \
          red_refine(coord,triangles,dirichlet,neumann)

    
    x,ndof=FEM(coord, triangles, dirichlet,f)
    x_reference=u_ex(coord[:,0],coord[:,1]) 


    e_max = np.max(abs(x-x_reference))
    print("Gitter ndof="+str(ndof)+"   "+str(j)+" : max err = ",str(e_max))
    
    ## plot the FEM solution graph
    fig = plt.figure(figsize =(14, 9))
    ax = plt.axes(projection ='3d')
    if j>3:
        ec='none'
    else:
        ec='Gray'
    trisurf = ax.plot_trisurf(coord[:,0],coord[:,1],x,
                          triangles = triangles, 
                          cmap =plt.get_cmap('summer'),
                          edgecolor=ec);
    ax.set_title('finite element solution '+str(j))
    plt.show(block=False)
    

## plot the reference solution
fig = plt.figure(figsize =(14, 9))
ax = plt.axes(projection ='3d')
ec='none'
trisurf = ax.plot_trisurf(coord[:,0],coord[:,1],u_ex(coord[:,0],coord[:,1]),
                triangles = triangles, 
                          cmap =plt.get_cmap('summer'),
                          edgecolor=ec);
ax.set_title('reference solution')
plt.show()
