基础的
import numpy as np
import tkinter as tk
from tkinter import ttk
from torchvision import datasets, transforms
from scipy.special import expit # for sigmoid function
class DigitRecognizer:
def __init__(self):
# Load MNIST data from torchvision
mnist_data = datasets.MNIST('data', train=True, download=True,
transform=transforms.ToTensor())
self.x_train = mnist_data.data.numpy().reshape(-1, 784) / 255.0
self.y_train = mnist_data.targets.numpy()
# Initialize weights and biases for 10 digits
self.weights = np.random.randn(784, 10) * 0.01
self.biases = np.zeros(10)
# Training parameters
self.learning_rate = 0.01
self.epochs = 5
# GUI setup
self.setup_gui()
def sigmoid(self, x):
return expit(x)
def train(self):
print("Training started...")
for epoch in range(self.epochs):
for i in range(len(self.x_train)):
# Forward pass
x = self.x_train[i]
y = np.zeros(10)
y[self.y_train[i]] = 1
# Calculate predictions
z = np.dot(x, self.weights) + self.biases
predictions = self.sigmoid(z)
# Backward pass
error = predictions - y
self.weights -= self.learning_rate * np.outer(x, error)
self.biases -= self.learning_rate * error
if epoch % 1 == 0:
print(f"Epoch {epoch + 1}/{self.epochs} completed")
print("Training completed!")
def setup_gui(self):
self.root = tk.Tk()
self.root.title("Digit Recognizer")
# Drawing area
self.canvas = tk.Canvas(self.root, width=280, height=280, bg='black')
self.canvas.pack(pady=20)
# Create 28x28 grid of rectangles
self.cells = []
for i in range(28):
row = []
for j in range(28):
x1, y1 = j * 10, i * 10
x2, y2 = x1 + 10, y1 + 10
rect = self.canvas.create_rectangle(x1, y1, x2, y2, fill='black', outline='gray')
row.append(rect)
self.cells.append(row)
# Bind mouse events
self.canvas.bind('<B1-Motion>', self.paint)
self.canvas.bind('<ButtonRelease-1>', self.update_probabilities)
# Probability bars
self.prob_frame = ttk.Frame(self.root)
self.prob_frame.pack(pady=20, padx=20, fill='x')
self.prob_bars = []
self.prob_labels = []
for i in range(10):
label = ttk.Label(self.prob_frame, text=str(i))
label.grid(row=i, column=0, padx=5)
progressbar = ttk.Progressbar(self.prob_frame, length=200, mode='determinate')
progressbar.grid(row=i, column=1, padx=5, pady=2)
value_label = ttk.Label(self.prob_frame, text="0.00%")
value_label.grid(row=i, column=2, padx=5)
self.prob_bars.append(progressbar)
self.prob_labels.append(value_label)
# Clear button
clear_btn = ttk.Button(self.root, text="Clear", command=self.clear_canvas)
clear_btn.pack(pady=10)
def paint(self, event):
# Get cell coordinates
x, y = event.x // 10, event.y // 10
if 0 <= x < 28 and 0 <= y < 28:
# Paint current cell and neighbors
for dx in [-1, 0, 1]:
for dy in [-1, 0, 1]:
nx, ny = x + dx, y + dy
if 0 <= nx < 28 and 0 <= ny < 28:
intensity = 255 - (abs(dx) + abs(dy)) * 50
color = f'#{intensity:02x}{intensity:02x}{intensity:02x}'
self.canvas.itemconfig(self.cells[ny][nx], fill=color)
def update_probabilities(self, event):
# Get image data
image = np.zeros((28, 28))
for i in range(28):
for j in range(28):
color = self.canvas.itemcget(self.cells[i][j], 'fill')
if color != 'black':
# Convert hex color to grayscale value
intensity = int(color[1:3], 16) / 255.0
image[i][j] = intensity
# Flatten and predict
flat_image = image.reshape(784)
z = np.dot(flat_image, self.weights) + self.biases
probabilities = self.sigmoid(z)
# Update probability bars
for i in range(10):
prob = probabilities[i] * 100
self.prob_bars[i]['value'] = prob
self.prob_labels[i]['text'] = f"{prob:.2f}%"
def clear_canvas(self):
for row in self.cells:
for cell in row:
self.canvas.itemconfig(cell, fill='black')
for i in range(10):
self.prob_bars[i]['value'] = 0
self.prob_labels[i]['text'] = "0.00%"
def run(self):
self.train()
self.root.mainloop()
if __name__ == "__main__":
app = DigitRecognizer()
app.run()
Last updated