import numpy as np
from scipy.sparse import csr_matrix # importiere für sparse matrix
import scipy as py
import matplotlib.pyplot as plt
import matplotlib.tri as mtri

# red refinement
def red_refine(coord, triangles, dirichlet, neumann):
	# getting colums and rows
	[nE, _] = triangles.shape
	[nC, d] = coord.shape

	# call to function edge_data
	# edges enthält alle direkten Verbindungen zwischen den Knoten
	# el2edges enthält die Indizes, um die Verbindungen zwischen den Kanten rekonstruieren zu können
	# Db2edges/ Nb2edges ist dafür da, um die Verbindungen des Dirichlet-/ Neumannrandes rekonstruieren zu können
	edges, el2edges, Db2edges, Nb2edges = edge_data(triangles, dirichlet, neumann)
	# cast edges to int
	edges = edges.astype(int)
	
	[nS,_] = edges.shape
	new_nodes = nC + np.transpose(np.arange(nS))
	mat1 = new_nodes[el2edges]
	new_indices = mat1.reshape(el2edges.shape, order = 'F').copy() 

	# Aufruf der Funktion refinement_rule, hier ist die Art und Weise, wie die Verfeinerung für die jeweilige Dimension vorgenommen wird gespeichert
	idx_elements = refinement_rule(d) 

	# die Verfeinerung der Dreiecke wird durchgeführt
	triangles = np.concatenate((triangles,new_indices), axis = 1) 
	triangles = np.transpose(triangles[:, idx_elements])  
	triangles = np.transpose(triangles.reshape(d+1,-1, order = 'F').copy())

	# die neuen Knoten werden erzeugt
	new_coord = 0.5 * (coord[edges[:, 0]] + coord[edges[:, 1]])  
	coord = np.concatenate((coord, new_coord), axis=0) 

	# erzeugen des neuen Neumann- und Dirichletrandes
	idx_boundary = refinement_rule(d-1) # Aufruf der Funktion refinement_rule
	if np.size(dirichlet, 0) > 0:
		mat2 = new_nodes[Db2edges] 
		new_dirichlet = mat2.reshape(Db2edges.shape, order = 'F').copy()
		dirichlet = np.concatenate((dirichlet,new_dirichlet), axis = 1) 
		dirichlet = np.transpose(dirichlet[:, idx_boundary])
		dirichlet = np.transpose(dirichlet.reshape(d,-1, order = 'F').copy()) 
	if np.size(neumann, 0) > 0:
		mat3 = new_nodes[Nb2edges]
		new_neumann = mat3.reshape(Nb2edges.shape, order = 'F').copy()
		neumann = np.concatenate((neumann,new_neumann), axis = 1)
		neumann = np.transpose(neumann[:, idx_boundary])
		neumann = np.transpose(neumann.reshape(d,-1, order = 'F').copy()) 
		
	# erzeugen von Sparse- Martizen und das anschließende schreiben in eine komprimierte Form,
	# in der nur die Teile der Matrix angegeben werde, die nicht Null sind und deren Platz in der Matrix
	# diese erlauben uns, wenn wir eine finite Elemente Funktion auf dem groben Gitter haben,
	# diese auf das feine Gitter zu legen über Matrix- Vektor- Multiplikation
	# p0 ist die für stückweise konstante Funktion
	# p1 ist die für stetige stückweise affine Funktionen
	p0 = get_sparse_matrix_p0(nE, d, triangles) # Aufruf der Funktion
	p1 = get_sparse_matrix_p1(nC, new_nodes, edges, nS) # Aufruf der Funktion
	return coord, triangles, dirichlet, neumann, p0, p1

def get_sparse_matrix_p0(nE, d, triangles):
	col = np.tile(np.arange(nE),(2**d,1)).reshape(1,-1, order = 'F').copy() 
	col = col[0]
	[t,_] = triangles.shape # neu berechnete Dreiecke
	row = np.arange(t)
	data = np.full(t,1) 
	# erstelle Matrix in der Göße, des größten Wertes aus row + 1 und dem größten Wert aus col + 1
	# mit den Werten aus data, jeweils an der Stelle [row[i], col[i]] für data[i]
	sparseMatrix = csr_matrix((data, (row, col)), shape = (max(row)+1, max(col+1)))
	# bringe die erstellte Matrix in die Form (Zeile, Spalte) Wert
	p0 = py.sparse.csc_matrix(sparseMatrix)
	return p0

def get_sparse_matrix_p1(nC, new_nodes, edges, nS):
	t = np.arange(nC)
	mat1 = np.concatenate((t, new_nodes, new_nodes)) 
	mat2 = np.concatenate((t, edges[:,0], edges[:,1])) 
	mat3 = np.concatenate((np.full(nC,1), 0.5*np.full(2*nS,1))) 
	sparseMatrix = csr_matrix((mat3, (mat1, mat2)), shape = (nC+nS,nC))
	# bringe die erstellte Matrix in die Form (Zeile, Spalte) Wert
	p1 = py.sparse.csc_matrix(sparseMatrix)
	return p1


def edge_data(triangles, dirichlet, neumann):
    [nE, nV] = triangles.shape
    # d = dimension
    d = nV - 1 
    idx = np.array([0,1,3,6]) 
    boundary = np.concatenate((dirichlet, neumann), axis=0) 
    new_edges = idx[d] * nE
    new_dirichlet = idx[d-1] * dirichlet.shape[0]
    new_neumann = idx[d-1] * neumann.shape[0]

    # erhalte alle Verbindungen zwischen den Knoten
    edges = get_edges(d, triangles, boundary) 

    # sortiere edges, gebe jeweils nur die erste Verbindung für die Knoten zurück 
    # edge_numbers sind die Indizes, um das alte Array rekonstruieren zu können
    edges, edge_numbers = np.unique(np.sort(edges), return_inverse = True, axis = 0) 

    # forme edge_numbers in ein Array der Form von new_edges um
    el2edges = get_el2edges(new_edges, edge_numbers, idx, d)
    # erhalte die Indzies, um den Dirichletrand rekonstruieren zu können
    Db2edges = get_Db2edges(new_edges, edge_numbers, new_dirichlet, idx, d)
    # erhalte die Indzies, um den Neumannrand rekonstruieren zu können
    Nb2edges = get_Nb2edges(new_edges, edge_numbers, new_dirichlet, new_neumann, idx, d)
    return edges, el2edges, Db2edges, Nb2edges

# d = 1 -> case_1
# d = 2 -> case_2
# d = 3 -> case_3
def get_edges(d, triangles, boundary): 
    switcher = {
        1: case_1,
        2: case_2,
        3: case_3,
    }
    func = switcher.get(d)
    return func(triangles, boundary)

def case_1(A, B):
    return A

def case_2(A, B):
    edges = np.concatenate((get_matrix(A,[0,1,0,2,1,2]),B))
    return edges

def case_3(A, B):
    edges = np.transpose(get_matrix(A,[0,1,0,2,0,3,1,2,1,3,2,3]))
    edges = np.concatenate(edges, get_matrix(B,[0,1,0,2,1,2]), axis = 0)
    return(edges)

# von einer gegebenen Matrix werden die Spalten, die in der gegebenen Liste stehen hintereinander gereiht und dann in die gewünschte Form gebracht    
def get_matrix(A, list_of_colums):
    new_A = np.transpose(A[:, list_of_colums])
    reshaped_A = new_A.reshape(2,-1, order = 'F').copy()
    return np.transpose(reshaped_A)

# bringe edge_numbers in die richtige Form
def get_el2edges(new_edges, edge_numbers, idx, d):
    edge_numbers = np.transpose(edge_numbers[0:new_edges])
    el2edges = edge_numbers.reshape(idx[d],-1, order = 'F').copy()
    return np.transpose(el2edges)

# extrahiere die Indizis von edge_numbers die für den Dirichlet-Rand wichtig sind
def get_Db2edges(new_edges, edge_numbers, new_dirichlet, idx, d):
    ind1 = np.arange(new_dirichlet) + new_edges
    edge_numbers1 = edge_numbers[ind1]
    edge_numbers1 = np.transpose(edge_numbers1)
    Db2edges = edge_numbers1.reshape(idx[d-1],-1, order = 'F').copy()
    return np.transpose(Db2edges)

# extrahiere die Indizis von edge_numbers die für den Neumann-Rand wichtig sind
def get_Nb2edges(new_edges, edge_numbers, new_dirichlet, new_neumann, idx, d):
    ind2 = np.arange(new_neumann) + new_edges + new_dirichlet
    edge_numbers2 = edge_numbers[ind2]
    edge_numbers2 = np.transpose(edge_numbers2)
    Nb2edges = edge_numbers2.reshape(idx[d-1],-1, order = 'F').copy()
    return np.transpose(Nb2edges)



# the dimesnion d decides how the refinement will be done
def refinement_rule(d):
    switcher = {
        0: 0,
        1: [0,2,2,1],
        2: [0,3,4,4,5,2,3,1,5,5,4,3],
        3: [0, 4, 5, 6, 4, 1, 7, 8, 5, 7, 2, 9, 6, 8, 9, 3, 4, 5, 6, 8, 8, 5, 7, 4, 5, 6, 8, 9, 9, 7, 8, 5],

    }
    return switcher.get(d)




# einfaches Beispiel, wie alles angegeben werden muss

# coord = np.array([[0,0], 
#                     [1,0],
#                     [1,1],  
#                     [0,1], 
#                     [0.5,0.5]])
# dreiecke = np.array([[0,1,4],
#                     [1,2,4],
#                     [2,3,4],
#                     [3,0,4]])
# diri = np.array([[0,1],
#                    [1,2]])
# neu = np.array([[2,3],
#                    [3,0]])

# [coord,dreiecke,dirichlet,neumann, p0, p1] = red_refine(coord,dreiecke,diri,neu)
# print("coord = " ,coord, "\n" , "dreiecke = ", dreiecke, "\n", "dirichlet = ", dirichlet,"\n", "neumann = ", neumann, "\n", "p0 = ", p0, "\n", "p1 = ", p1 )
