HW4-2a/b

 avatar
unknown
python
7 days ago
2.6 kB
9
Indexable
import copy
from tqdm.notebook import tqdm

def stochasticGradientDescent(d, Xin, yin, eta=0.01, alpha=1, beta=1, max_iters=10**5):
    L = len(d)
    N, din = Xin.shape
    '''
    TODO:
    Implement early stopping using a validation set
    '''

    # split to train and val set
    # val_size = int(0.2 * N)
    # indices = np.random.permutation(N)
    # train_indices = indices[val_size:]
    # val_indices = indices[:val_size]
    # Xin_train, yin_train = Xin[train_indices], yin[train_indices]
    # Xin_val, yin_val = Xin[val_indices], yin[val_indices]

    # train_N = len(X_train)

    Eins = []
    W = init_weights(d)
    iteration = 0

    # best_val_error = float('inf')
    # best_W = copy.deepcopy(W)
    # patience_counter = 0

    rng = np.random.default_rng()

    for iteration in tqdm(range(max_iters), desc="SGD Progress"):
        if iteration >= max_iters:
            break
        # print(f'iteration {i}')
        n = rng.integers(0, N)
        Xn = Xin[n, :].reshape((1, din))
        yn = yin[n, :].reshape((1, 1))
        # print(Xn, yn)
        '''
        TODO:
        Compute Ein before update at current weights W on all N data points
        Ein_W = computeError(?)
        '''
        Ein_W = computeError(d, W, Xin, yin)
        Eins.append(Ein_W)
        '''
        TODO:
        Compute gradient using backpropagation
        Ein, G = computeGradientBackpropagation(?)
        '''
        Ein, G = computeGradientBackpropagation(d, W, Xn, yn)
        '''
        TODO:
        Compute updated weights
        '''
        Wtmp = [None] + [np.zeros((1+d[l-1], d[l])) for l in range(1, L)]
        for l in range(1, L):
                Wtmp[l] = W[l] - eta * G[l]
        '''
        Compute Ein after update at Wtmp on all N data points
        '''    
        Ein_Wtmp = computeError(d, Wtmp, Xin, yin)
        if Ein_Wtmp < Ein_W:
            '''
            Lowered in sample error
            TODO: 
            W = ?
            eta = ?
            '''
            W = Wtmp
            eta = eta * alpha
            print(f'good iteration {iteration} Ein_Wtmp {Ein_Wtmp} Ein_W {Ein_W} eta {eta}')
        else:
            '''
            Bad update
            W = ?
            eta = ?
            '''
            W = W
            eta = max(eta * beta, 1e-5)
            
            print(f'bad iteration {iteration} Ein_Wtmp {Ein_Wtmp} Ein_W {Ein_W} eta {eta}')
        iteration += 1
    '''
    Plot the progress made
    '''
    plt.plot(range(max_iters), Eins)
    plt.xlabel('# iterations')
    plt.ylabel('Ein')
    
    return W
Editor is loading...
Leave a Comment