# This code implements the implicit Euler Y^{n} = Y^{n-1} + hf(t^{n},Y^{n}), for n = 1,2,...,M,
# using the Newton method to approximate the solution of the non-linear system


import numpy as np

# Define the right hand side of the initial value problem
def f(t, y) :                    
	z = np.exp(-y)	
	return z

# Define the exact solution of the initial value problem
def exact(t) :
	z = np.log(t + np.exp(1))	
	return z
	
# Define the function g of the Newton iteration
def g(t, yold, x, h) :
	z = x - yold - h*f(t,x)
	return z

# Define the derivative of the function of the Newton iteration
def dg(t, yold, x, h) :
	z = 1 - h*(-np.exp(-x))
	return z
	

# Main program
def myeuler(a, b, y0, M, L) :
	h = (b-a)/M                     # Time step
	t = a                           # Start from the left endpoint 
	yold = y0                       # Define the Y^n as yold and set at the begining as y0
	err = 0	                        # Define the error variable at zero
	for n in range(M) :             # Loop over all nodes
		xold = yold
		for m in range(M) :           # At each time step, approximate Y^{n} by the Neuton method 
			xnew = xold - g(t+h, yold, xold, h)/dg(t+h, yold, xold, h)
			xold = xnew
		ynew = xnew	                  # Define the Y^{n} as the last approximation of Newton method
		t = t + h                     # Update time
		yold = ynew                   # Update the old approximation
		error = abs(ynew - exact(t))  # Calculate the error
		if error > err :	            # Check if the current error is greater than the previous error
			err = error
	return err

# Define the parameters of the current problem and print the convergence rate for a given vector M

a = 0.0
b = 5.0
t0 = a
y0 = exact(t0)
M = np.array([20,40,80,160])
L = 3                             # Stopping criterion for Newton method. Maximun number of iterations
m = len(M)
error = np.zeros([m])
p = np.zeros([m])
for i in range(m) :
	err = myeuler(a, b, y0, M[i], M)
	error[i] = err
	if i > 0 :
		p[i] = np.log(error[i-1]/error[i])/np.log(M[i]/M[i-1]);
print(error)
print(p)
