Untitled
import numpy as np import matplotlib.pyplot as plt def solve_diff_eq(n, tol): dx = 1 / (n - 1) # Grid spacing x = np.linspace(0, 1, n) # Grid points u = np.zeros(n) u[0] = 0 u[-1] = 0 iter_count = 0 error = 1 while error > tol: u_old = u.copy() for i in range(1, n - 1): u[i] = (u[i + 1] + u[i - 1] + 8 * dx**2) / 2 error = np.linalg.norm(u - u_old) / np.linalg.norm(u) iter_count += 1 plt.plot(x, u) plt.xlabel("x") plt.ylabel("u(x)") plt.title("Solution of the Differential Equation") plt.show() return u n = 21 tol = 1e-6 u = solve_diff_eq(n, tol)
Leave a Comment