import numpy as np
from mpl_toolkits import mplot3d 
import matplotlib.pyplot as plt
import matplotlib.tri as mtri
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
import math
import pylab
from red_refine import red_refine
import scipy.sparse
import scipy.sparse.linalg
from scipy.sparse import csr_matrix, bmat

def get_geom():
    coord = np.asarray([[0,0],[0.5,0],[1,0],[1,0.5],[1,1],[0.5,1],[0,1],
                    [0,0.5],[0.25,0.25],[0.75, 0.25],[0.75,0.75],[0.25,0.75],
                    [0.5,0.5]])
    triangles = np.asarray([[0,1,8],[1,12,8],[1,9,12],[1,2,9],[2,3,9],
                        [3,12,9],[3, 10, 12],[3,4,10],[4, 5, 10],[5, 12, 10],
                        [5, 11, 12],[5, 6, 11],[6, 7, 11],[7, 12, 11],
                        [7, 8, 12], [7, 0, 8]])
    dirichlet= np.array([[0,1],[1,2], [2,3],[3,4],[4,5],[5,6],[6,7],[7,0]])
    neumann= np.zeros([0, 2])
    return coord, triangles, dirichlet, neumann

def FEM(coord,triangles,dirichlet,f,coeffA,dut4Db):
    nnodes=np.size(coord,0)
    A,b=FEM_data_w(coord,triangles,dirichlet,f,coeffA,dut4Db)
    w=scipy.sparse.linalg.spsolve(A,b)
    # reconstruct u 
    A,b=FEM_data_u(coord,triangles,w)
    dbnodes=np.unique(dirichlet)
    dof=np.setdiff1d(range(0,nnodes),dbnodes)
    ndof=np.size(dof)
    R=restrict2dof(dof,nnodes)
    A_inner=(R.transpose()@A)@R
    b_inner=R.transpose()@b
    u=np.zeros(nnodes)
    u[dof]=scipy.sparse.linalg.spsolve(A_inner,b_inner)
    return u, w, ndof

def restrict2dof(dof,nnodes):
    ndof=np.size(dof)
    R=csr_matrix((np.ones(ndof),(dof,np.arange(0,ndof))),shape = (nnodes,ndof))
    return R

def FEM_data_u(coord,triangles,w):
    nelems=np.size(triangles,0)
    nnodes=np.size(coord,0)
    Alocal=np.zeros((nelems,3,3))
    I1=np.zeros((nelems,3,3))
    I2=np.zeros((nelems,3,3))

    b=np.zeros(nnodes)
    for j in range(0,nelems):
        nodes_loc=triangles[j,:]
        coord_loc=coord[nodes_loc,:]
        T=np.array([coord_loc[1,:]-coord_loc[0,:] ,
               coord_loc[2,:]-coord_loc[0,:] ])
        area = 0.5 * ( T[0,0]*T[1,1] - T[0,1]*T[1,0] )
        mid=1/3*(coord_loc[0,:]+coord_loc[1,:]+coord_loc[2,:])
        tmp1= np.concatenate((np.array([[1,1,1]]), coord_loc.T),axis=0)
        tmp2= np.array([[0,0],[1,0],[0,1]])
        grads = np.linalg.solve(tmp1,tmp2)
        Alocal[j,:,:]=area* np.matmul(grads,grads.T)
        t= np.array([nodes_loc])
        I1[j,:,:] = np.concatenate((t,t,t),axis=0)
        I2[j,:,:] = np.concatenate((t.T,t.T,t.T),axis=1)
        wloc=np.array( [1/3 * np.sum(w[nodes_loc]),1/3 * np.sum(w[nnodes+nodes_loc])])
        b[t]+=area* grads@wloc
 
    Alocal=np.reshape(Alocal,(9*nelems,1)).T
    I1=np.reshape(I1,(9*nelems,1)).T
    I2=np.reshape(I2,(9*nelems,1)).T
    A=csr_matrix((Alocal[0,:],(I1[0,:],I2[0,:])),shape = (nnodes,nnodes))
    return A,b
 
def FEM_data_w(coord,triangles,dirichlet,f,coeffA,dut4Db):
    nelems=np.size(triangles,0)
    nnodes=np.size(coord,0)
    Alocal=np.zeros((nelems,6,6))
    I1=np.zeros((nelems,6,6))
    I2=np.zeros((nelems,6,6))

    b=np.zeros(2*nnodes)
    for j in range(0,nelems):
        nodes_loc=triangles[j,:]
        coord_loc=coord[nodes_loc,:]
        T=np.array([coord_loc[1,:]-coord_loc[0,:] ,
               coord_loc[2,:]-coord_loc[0,:] ])
        area = 0.5 * ( T[0,0]*T[1,1] - T[0,1]*T[1,0] )
        mid=1/3*(coord_loc[0,:]+coord_loc[1,:]+coord_loc[2,:])
        tmp1= np.concatenate((np.array([[1,1,1]]), coord_loc.T),axis=0)
        tmp2= np.array([[0,0],[1,0],[0,1]])
        grads = np.linalg.solve(tmp1,tmp2)
        Z     =np.zeros((3,2)) 
        Dbasis = np.block([[grads,Z],[Z,grads]]) #derivative of all 6 basis functions
        a = coeffA(mid[0],mid[1])
        gamma = (a[0]+a[3])/np.linalg.norm(a)**2;
        Alocal[j,:,:]=area*gamma *np.outer(Dbasis@a,(Dbasis@np.array([1,0,0,1])))\
                   + .5*area *np.outer(Dbasis@np.array([0,-1,1,0]), \
                                       Dbasis@np.array([0,-1,1,0]))
        
        t= np.concatenate((nodes_loc,nnodes+nodes_loc))
        I1[j,:,:] = np.tile(t,(6,1))
        I2[j,:,:] = np.tile(t.reshape(-1,1),(1,6))
        b[t]+=area*f(mid[0],mid[1])*gamma * (Dbasis@np.array([1,0,0,1]))
 
    Alocal=np.reshape(Alocal,(36*nelems,1)).T
    I1=np.reshape(I1,(36*nelems,1)).T
    I2=np.reshape(I2,(36*nelems,1)).T
    A=csr_matrix((Alocal[0,:],(I1[0,:],I2[0,:])),shape = (2*nnodes,2*nnodes))
    
    #  Tangential zero BC
    ## include bc with lagrange multipliers
    idx4bc=np.zeros((2*nnodes,2),dtype=int)
    val4constr=np.zeros((2*nnodes,2))
    rhs4bc=np.zeros(2*nnodes)
    k=-1
    eps = np.finfo(float).eps
    for j in range (np.size(dirichlet,0)):
        k=k+1
        if j==0:
            n1=dirichlet[-1,0]
        else:
            n1=dirichlet[j-1,0]
        n2=dirichlet[j,0]
        n3=dirichlet[j,1]
    
        tminus=coord[n2,:]-coord[n1,:]
        tminus= tminus/np.linalg.norm(tminus)
        tplus=coord[n3,:]-coord[n2,:]
        tplus= tplus/np.linalg.norm(tplus);
    
        idx4bc[k,:]=[n2,nnodes+n2];
        val4constr[k,:]=[tminus[0],tminus[1]]
        rhs4bc[k]=dut4Db(coord[n2,0],coord[n2,1]);
        if np.isnan(rhs4bc[k]) :
            rhs4bc[k]=0
    
        if np.linalg.norm(tplus-tminus)>100*eps :
            tmp=coord[n2,:]-10*eps*tminus
            rhs4bc[k]=dut4Db(tmp[0],tmp[1])
            if np.isnan(rhs4bc[k]):
                rhs4bc[k]=0
            k=k+1
            idx4bc[k,:]=[n2,nnodes+n2]
            val4constr[k,:]=[tplus[0],tplus[1]]
            tmp= coord[n2,:]+10*eps*tplus
            rhs4bc[k]=dut4Db(tmp[0],tmp[1])
        if np.isnan(rhs4bc[k]):
            rhs4bc[k]=0

    idx4bc=idx4bc[:k+1,:]
    val4constr=val4constr[:k+1,:]
    rhs4bc=rhs4bc[:k+1]
    n_constraints = k+1
    rows = idx4bc.ravel()
    data = val4constr.ravel()
    cols = np.repeat(np.arange(n_constraints), 2)
    C = csr_matrix((data, (rows, cols)),shape=(2*nnodes, n_constraints))
    N1 = csr_matrix((C.shape[1], C.shape[1]))
    
    M = bmat([[A, C], [C.T, N1]], format='csr')
    b=np.concatenate((b,np.zeros(k+1)),axis=0)
    return M,b

### PDE data
f = lambda x, y: 4* ((x-x**2)+(y- y**2) ) \
         - 2*np.sign(x-.5)*np.sign(y-.5)* (1-2*x)*(1-2*y)
uex = lambda x, y: - (x-x**2)*(y- y**2)
a = lambda x, y: [2,np.sign(x-.5)*np.sign(y-.5),np.sign(x-.5)*np.sign(y-.5),2]
dut4Db = lambda x, y: 0
coord, triangles, dirichlet, neumann = get_geom()

### the numerical test
nmeshes=3
for j in range(nmeshes):
    coord, triangles, dirichlet, neumann, _, _ = \
            red_refine(coord,triangles,dirichlet,neumann)
    u,w,ndof=FEM(coord, triangles, dirichlet,f,a,dut4Db)
    print('mesh '+str(j) + ' max err='+str(np.max(np.abs(uex(coord[:,0],coord[:,1])-u))))

# graphical post processing
## plot the FEM solution graph
fig = plt.figure()
ax3d = fig.add_subplot(2,2,1,projection='3d')
#ax = plt.axes(projection ='3d')
trisurf = ax3d.plot_trisurf(coord[:,0],coord[:,1],u,
                          triangles = triangles, 
                          cmap =plt.get_cmap('summer'),
                          edgecolor='Gray');
ax3d.set_title('finite element solution')

ax3d = fig.add_subplot(2,2,2,projection='3d')
trisurf = ax3d.plot_trisurf(coord[:,0],coord[:,1],uex(coord[:,0],coord[:,1]),
                          triangles = triangles, 
                          cmap =plt.get_cmap('summer'),
                          edgecolor='Gray');
ax3d.set_title('reference solution')

ax3d = fig.add_subplot(2,2,3,projection='3d')
trisurf = ax3d.plot_trisurf(coord[:,0],coord[:,1],u-uex(coord[:,0],coord[:,1]),
                          triangles = triangles, 
                          cmap =plt.get_cmap('summer'),
                          edgecolor='Gray');
ax3d.set_title('error')

ax = fig.add_subplot(2,2,4)
nn=np.size(coord,0)
ax.quiver(coord[:,0],coord[:,1], w[range(nn)], w[range(nn,2*nn)])
ax.set_title('exact solution')
ax.set_title('discrete gradient field')
plt.show()
