Curve fitting

mail@pastecode.io avatar
unknown
python
7 months ago
1.0 kB
9
Indexable
Never
import numpy as np
import matplotlib.pyplot as plt

X = np.random.random(1000)
Y = np.random.normal(0,0.0,len(X)) + (np.sin(X*np.pi*2))

plt.scatter(X,Y)
plt.show()

# set of weights for deg-3 polynomial
W = np.array([4,3,2,1,3],dtype=np.float64)

lr = 0.1
epoch = 10
for ep in range(epoch):
  sum0 = 0
  sum1 = 0
  sum2 = 0
  sum3 = 0
  for i, x in enumerate(X):
    sum0 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(1)
    sum1 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x)
    sum2 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**2)
    sum3 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**3)
    sum3 += (W[0] + W[1]*x + W[2]*x**2 + W[3]*x**3 + W[4]*x**4 - Y[i])*(x**4)
  W[0] = W[0] - lr*sum0
  W[1] = W[1] - lr*sum1
  W[2] = W[2] - lr*sum2
  W[3] = W[3] - lr*sum3
  W[4] = W[4] - lr*sum3

print(W)

pred_Y = []
# X = np.arange(0,100,0.1)
for i, x in enumerate(X):
  pred_Y.append(W[0] + W[1]*x + W[2]*(x**2) + W[3]*x**3)

plt.scatter(X, pred_Y)
plt.show()
Leave a Comment