Curve fitting
unknown
python
2 years ago
1.0 kB
12
Indexable
import numpy as np import matplotlib.pyplot as plt X = np.random.random(1000) Y = np.random.normal(0,0.0,len(X)) + (np.sin(X*np.pi*2)) plt.scatter(X,Y) plt.show() # set of weights for deg-3 polynomial W = np.array([4,3,2,1,3],dtype=np.float64) lr = 0.1 epoch = 10 for ep in range(epoch): sum0 = 0 sum1 = 0 sum2 = 0 sum3 = 0 for i, x in enumerate(X): sum0 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(1) sum1 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x) sum2 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**2) sum3 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**3) sum3 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**4) W[0] = W[0] - lr*sum0 W[1] = W[1] - lr*sum1 W[2] = W[2] - lr*sum2 W[3] = W[3] - lr*sum3 W[4] = W[4] - lr*sum3 print(W) pred_Y = [] # X = np.arange(0,100,0.1) for i, x in enumerate(X): pred_Y.append(W[0] + W[1]*x + W[2]*(x**2) + W[3]*x**3) plt.scatter(X, pred_Y) plt.show()
Editor is loading...
Leave a Comment