Untitled
unknown
plain_text
2 years ago
2.9 kB
34
Indexable
class DigitClassificationModel(object): """ A model for handwritten digit classification using the MNIST dataset. Each handwritten digit is a 28x28 pixel grayscale image, which is flattened into a 784-dimensional vector for the purposes of this model. Each entry in the vector is a floating point number between 0 and 1. The goal is to sort each digit into one of 10 classes (number 0 through 9). (See RegressionModel for more information about the APIs of different methods here. We recommend that you implement the RegressionModel before working on this part of the project.) """ def __init__(self): # Initialize your model parameters here "*** YOUR CODE HERE ***" self.batch_size = 2 self.wf = nn.Parameter(784, 60) self.bf = nn.Parameter(1, 60) self.wr = nn.Parameter(60, 10) self.br = nn.Parameter(1, 10) def run(self, x): """ Runs the model for a batch of examples. Your model should predict a node with shape (batch_size x 10), containing scores. Higher scores correspond to greater probability of the image belonging to a particular class. Inputs: x: a node with shape (batch_size x 784) Output: A node with shape (batch_size x 10) containing predicted scores (also called logits) """ "*** YOUR CODE HERE ***" relued = nn.ReLU(nn.AddBias(nn.Linear(x, self.wf), self.bf)) return nn.AddBias(nn.Linear(relued, self.wr), self.br) def get_loss(self, x, y): """ Computes the loss for a batch of examples. The correct labels `y` are represented as a node with shape (batch_size x 10). Each row is a one-hot vector encoding the correct digit class (0-9). Inputs: x: a node with shape (batch_size x 784) y: a node with shape (batch_size x 10) Returns: a loss node """ "*** YOUR CODE HERE ***" return nn.SoftmaxLoss(self.run(x), y) def train(self, dataset): """ Trains the model. """ "*** YOUR CODE HERE ***" mistakes = 1 while mistakes > 0: mistakes = 0 for x, y in dataset.iterate_once(self.batch_size): losses = self.get_loss(x, y) gradient = nn.gradients(losses, [self.wf, self.wr, self.bf, self.br]) self.wf.update(gradient[0], -0.009) self.wr.update(gradient[1], -0.009) self.bf.update(gradient[2], -0.009) self.br.update(gradient[3], -0.009) val = dataset.get_validation_accuracy() if val < 0.97: mistakes += 1
Editor is loading...
Leave a Comment