Untitled
unknown
plain_text
a year ago
2.5 kB
5
Indexable
import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.neighbors import KNeighborsClassifier # Load the Iris dataset iris = load_iris() X = iris.data[:, :2] # Use only the first two features y = iris.target # Create a k-NN classifier knn = KNeighborsClassifier(n_neighbors=3) # Train the classifier on the dataset knn.fit(X, y) # Generate a new data point for prediction new_data_point = np.array([[5.0, 3.5]]) # Example new data point # Predict the class of the new data point predicted_class = knn.predict(new_data_point) # Scatter plot of the dataset with two classes and the predicted class of the new data point plt.figure(figsize=(8, 6)) plt.scatter(X[y == 0, 0], X[y == 0, 1], color='red', label='Setosa') plt.scatter(X[y == 1, 0], X[y == 1, 1], color='blue', label='Versicolor') plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='green', label=f'Predicted: {predicted_class[0]}') plt.xlabel('Sepal Length (cm)') plt.ylabel('Sepal Width (cm)') plt.title('Scatter Plot of Iris Dataset (Setosa vs Versicolor)') plt.legend() plt.grid(True) plt.show() import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import load_iris from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import train_test_split # Load the Iris dataset iris = load_iris() X = iris.data[:, :2] # Use only the first two features y = iris.target # Split the dataset into training and testing sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Define a range of k values to test k_values = range(1, 21) accuracies = [] # Iterate over different values of k for k in k_values: # Create a k-NN classifier knn = KNeighborsClassifier(n_neighbors=k) # Train the classifier on the training data knn.fit(X_train, y_train) # Make predictions on the test data y_pred = knn.predict(X_test) # Calculate the accuracy of the model accuracy = np.mean(y_pred == y_test) accuracies.append(accuracy) # Plot accuracy versus k plt.figure(figsize=(8, 6)) plt.plot(k_values, accuracies, marker='o') plt.xlabel('Number of Neighbors (k)') plt.ylabel('Accuracy') plt.title('Accuracy vs Number of Neighbors (k) for Iris Dataset') plt.xticks(k_values) plt.grid(True) plt.show()
Editor is loading...
Leave a Comment