Curve fitting
unknown
python
2 years ago
1.0 kB
14
Indexable
import numpy as np
import matplotlib.pyplot as plt
X = np.random.random(1000)
Y = np.random.normal(0,0.0,len(X)) + (np.sin(X*np.pi*2))
plt.scatter(X,Y)
plt.show()
# set of weights for deg-3 polynomial
W = np.array([4,3,2,1,3],dtype=np.float64)
lr = 0.1
epoch = 10
for ep in range(epoch):
sum0 = 0
sum1 = 0
sum2 = 0
sum3 = 0
for i, x in enumerate(X):
sum0 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(1)
sum1 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x)
sum2 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**2)
sum3 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**3)
sum3 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**4)
W[0] = W[0] - lr*sum0
W[1] = W[1] - lr*sum1
W[2] = W[2] - lr*sum2
W[3] = W[3] - lr*sum3
W[4] = W[4] - lr*sum3
print(W)
pred_Y = []
# X = np.arange(0,100,0.1)
for i, x in enumerate(X):
pred_Y.append(W[0] + W[1]*x + W[2]*(x**2) + W[3]*x**3)
plt.scatter(X, pred_Y)
plt.show()Editor is loading...
Leave a Comment