def convdif(u,a,d,dt,dx): # u' = (d*Tdx - a*Sdx)u
N = 1/dx
M = 1/dt
g = lambda x,t: x**3 -(4/3)*x**2 +(1/3)*x + 0*t
uold = [g(i/N, i/M) for i in range(1,N-1)]
xx = np.linspace(0,1,N+2)
tt = np.linspace(0,1,M+1)
T,X = np.meshgrid(tt, xx)
Tdx = (np.diag(1*np.ones(N-1), -1) + np.diag(-2*np.ones(N)) + np.diag(1*np.ones(N-1), 1))/dx**2
Sdx = (np.diag(-1*np.ones(N-1), -1) + np.diag(0*np.ones(N)) + np.diag(np.ones(N-1), 1))/(2*dx)
M = (d*Tdx -a*Sdx)
m = []
m.append(np.array([0, *uold, 0]))
for _ in range(0,np.size(tt)):
uold = step(choice = 1, Tdx = M, uold= uold, dt=dt)
m.append(np.array([0, *uold, 0]))
m = np.array(m)
return m, T, X