Untitled

 avatar
unknown
python
a month ago
2.0 kB
5
Indexable
import torch

class MyLinearRegression:
    def __init__(self, num_features, num_steps, lr=1e-2):
        self.W = torch.rand(num_features)
        self.B = torch.rand(1)
        self.lr = lr
        self.num_steps = num_steps
    
    def forward(self, X): #(N,D). N samples, each having D features
        return [email protected] + self.B
    
    def fit(self, X: torch.tensor, Y:torch.tensor): #  i/p (N,D), (D)
        N = X.shape[0]
        losses = []
        for itr in range(self.num_steps):
            out = self.forward(X)
            err = Y-out
            delW = -2/N * (X.transpose(-1,0) @ err)
            delB = -2/N * err.sum()

            self.W = self.W  - self.lr*delW
            self.B = self.B  - self.lr*delB
            losses.append(err.mean())
        return losses


N = 100
D = 30
X= torch.rand((N, D))
C = torch.arange(D,dtype=torch.float)
Y = X@C 
linreg = MyLinearRegression(
    num_features = D,
    num_steps = 100000,
)
losses = linreg.fit(X,Y)

""""
Awesome Man, DBS
graident part was a struggle. 
Basically - 
See, my thinking process was, 
Fisrt know that we pass a batch of samples in a forward-backward loop, what happens is gradient becomes AVG of gradient terms of indicidual sample. --- (1)

Okay, let me unpack

So, for a simple Linear reg w.o any batching and num_feature = 1, Y = Wx + B.
See, delW here you can easily get as 2*err*x, right? --(2)
now suppose multiple features, then you would have delW of shape (num_features), ith dimension will have gradient for w_i. right?

now, w_i would be, as per eqn 1, avg of w1_grads across sample. and for a particular smaple, as per 2, its 2*err*x, 
so basically for w_i, I just want to take avg of 2*err_sample *x_sample thats it. 
for this, I would want sum of product of err and i-th input value of every sample_index.

"""
# -----------------------------------------------------
import matplotlib.pyplot as plt
print(linreg.B)
print(linreg.W, "\n", losses[-10:])
plt.plot(losses)

Editor is loading...
Leave a Comment