Untitled

 avatar
unknown
python
2 months ago
666 B
2
Indexable
import numpy as np
import matplotlib.pyplot as plt

def solve_diff_eq(n, tol):

    dx = 1 / (n - 1)  # Grid spacing
    x = np.linspace(0, 1, n)  # Grid points

    u = np.zeros(n)

    u[0] = 0
    u[-1] = 0

    iter_count = 0

    error = 1

    while error > tol:
        u_old = u.copy()

        for i in range(1, n - 1):
            u[i] = (u[i + 1] + u[i - 1] + 8 * dx**2) / 2

        error = np.linalg.norm(u - u_old) / np.linalg.norm(u)

        iter_count += 1

    plt.plot(x, u)
    plt.xlabel("x")
    plt.ylabel("u(x)")
    plt.title("Solution of the Differential Equation")
    plt.show()

    return u

n = 21
tol = 1e-6

u = solve_diff_eq(n, tol)
Leave a Comment