Создание кода
This commit is contained in:
55
.gitignore
vendored
Normal file
55
.gitignore
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
# ===============================
|
||||
# Python
|
||||
# ===============================
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.pyd
|
||||
|
||||
# ===============================
|
||||
# Virtual environments
|
||||
# ===============================
|
||||
.venv/
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# ===============================
|
||||
# IDE (PyCharm / IntelliJ)
|
||||
# ===============================
|
||||
.idea/
|
||||
*.iml
|
||||
.gigaide/
|
||||
|
||||
# ===============================
|
||||
# OS
|
||||
# ===============================
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# ===============================
|
||||
# Project-specific (inside src/)
|
||||
# ===============================
|
||||
src/artifacts/
|
||||
src/data/
|
||||
|
||||
# ===============================
|
||||
# Logs
|
||||
# ===============================
|
||||
*.log
|
||||
|
||||
# ===============================
|
||||
# Temporary / backup files
|
||||
# ===============================
|
||||
*.tmp
|
||||
*.bak
|
||||
*.swp
|
||||
|
||||
# ===============================
|
||||
# Jupyter
|
||||
# ===============================
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# ===============================
|
||||
# CUDA cache (optional)
|
||||
# ===============================
|
||||
.cuda/
|
||||
249
src/app.py
Normal file
249
src/app.py
Normal file
@@ -0,0 +1,249 @@
|
||||
from __future__ import annotations # включаем отложенную обработку аннотаций типов
|
||||
|
||||
import os # импортируем модуль/пакет
|
||||
import threading # импортируем модуль/пакет
|
||||
import queue # импортируем модуль/пакет
|
||||
|
||||
import tkinter as tk # импортируем модуль/пакет
|
||||
from tkinter import ttk, messagebox # импортируем объекты из модуля
|
||||
|
||||
import numpy as np # импортируем модуль/пакет
|
||||
from PIL import Image, ImageDraw # импортируем объекты из модуля
|
||||
|
||||
import torch # импортируем модуль/пакет
|
||||
|
||||
from model import MNISTCNN # импортируем объекты из модуля
|
||||
from preprocess import preprocess_canvas_to_mnist_tensor # импортируем объекты из модуля
|
||||
from train import TrainConfig, train # импортируем объекты из модуля
|
||||
from utils import get_device # импортируем объекты из модуля
|
||||
|
||||
|
||||
class App(tk.Tk): # объявляем класс приложения/модели
|
||||
def __init__(self) -> None: # объявляем функцию/метод
|
||||
super().__init__() # инициализируем базовый класс
|
||||
self.title("MNIST Pro: Распознавание чисел (Tkinter + PyTorch)") # задаём заголовок окна
|
||||
self.geometry("960x600") # задаём стартовый размер окна
|
||||
self.minsize(960, 600) # задаём минимальный размер окна
|
||||
|
||||
self.device = get_device() # определяем устройство (CPU/GPU) для PyTorch
|
||||
self.artifacts_dir = "artifacts" # задаём папку для артефактов (модель/логи/картинки)
|
||||
self.model_path = os.path.join(self.artifacts_dir, "mnist_cnn.pt") # формируем путь к файлу весов модели
|
||||
|
||||
self.canvas_size = 360 # задаём размер холста для рисования
|
||||
self.brush_size = tk.IntVar(value=12) # храним толщину кисти как переменную Tkinter
|
||||
|
||||
# PIL-буфер: белый фон
|
||||
self.buffer_img = Image.new("L", (self.canvas_size, self.canvas_size), 255) # создаём PIL-буфер холста (градации серого)
|
||||
self.buffer_draw = ImageDraw.Draw(self.buffer_img) # создаём объект для рисования по PIL-буфере
|
||||
|
||||
self._last_x = None # инициализируем координаты последней точки рисования
|
||||
self._last_y = None # инициализируем координаты последней точки рисования
|
||||
|
||||
self.msg_q: queue.Queue[str] = queue.Queue() # создаём очередь сообщений между потоками (обучение → GUI)
|
||||
self.model: MNISTCNN | None = None # готовим поле для хранения модели
|
||||
|
||||
self._build_ui() # собираем интерфейс
|
||||
self._load_or_train() # загружаем модель или обучаем, если её нет
|
||||
self.after(150, self._poll_messages) # планируем периодический опрос очереди сообщений
|
||||
|
||||
def _build_ui(self) -> None: # объявляем функцию/метод
|
||||
root = ttk.Frame(self, padding=12) # создаём контейнерный фрейм для компоновки UI
|
||||
root.pack(fill="both", expand=True) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
left = ttk.Frame(root) # создаём контейнерный фрейм для компоновки UI
|
||||
left.pack(side="left", fill="y") # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
right = ttk.Frame(root) # создаём контейнерный фрейм для компоновки UI
|
||||
right.pack(side="right", fill="both", expand=True, padx=(20, 0)) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
# Холст
|
||||
self.canvas = tk.Canvas( # создаём холст Tkinter для рисования мышью
|
||||
left, # выполняем инструкцию
|
||||
width=self.canvas_size, # выполняем инструкцию
|
||||
height=self.canvas_size, # выполняем инструкцию
|
||||
bg="white", # выполняем инструкцию
|
||||
highlightthickness=1, # выполняем инструкцию
|
||||
highlightbackground="#999", # выполняем инструкцию
|
||||
cursor="crosshair", # выполняем инструкцию
|
||||
) # выполняем инструкцию
|
||||
self.canvas.pack() # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
self.canvas.bind("<ButtonPress-1>", self._on_down) # подключаем обработчик событий мыши
|
||||
self.canvas.bind("<B1-Motion>", self._on_move) # подключаем обработчик событий мыши
|
||||
self.canvas.bind("<ButtonRelease-1>", self._on_up) # подключаем обработчик событий мыши
|
||||
|
||||
# Контролы слева
|
||||
controls = ttk.Frame(left, padding=(0, 12, 0, 0)) # создаём контейнерный фрейм для компоновки UI
|
||||
controls.pack(fill="x") # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
ttk.Label(controls, text="Толщина кисти:").pack(anchor="w") # размещаем виджет в окне (менеджер pack)
|
||||
scale = ttk.Scale( # выполняем инструкцию
|
||||
controls, # выполняем инструкцию
|
||||
from_=6, to=30, # выполняем инструкцию
|
||||
orient="horizontal", # выполняем инструкцию
|
||||
command=self._on_brush_change, # выполняем инструкцию
|
||||
value=self.brush_size.get(), # выполняем инструкцию
|
||||
) # выполняем инструкцию
|
||||
scale.pack(fill="x", pady=(2, 10)) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
btns = ttk.Frame(controls) # создаём контейнерный фрейм для компоновки UI
|
||||
btns.pack(fill="x") # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
# Стилизация кнопок
|
||||
s = ttk.Style() # инициализируем объект стилей ttk
|
||||
s.configure("Big.TButton", font=("Segoe UI", 11, "bold")) # настраиваем стиль для крупных кнопок
|
||||
|
||||
ttk.Button(btns, text="Распознать число", style="Big.TButton", command=self.on_recognize).pack(side="left", fill="x", expand=True) # размещаем виджет в окне (менеджер pack)
|
||||
ttk.Button(btns, text="Очистить", command=self.on_clear).pack(side="left", fill="x", expand=True, padx=(8, 0)) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
ttk.Button(controls, text="Переобучить (с аугментацией)", command=self.on_train_click).pack(fill="x", pady=(15, 0)) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
# Правая панель
|
||||
ttk.Label(right, text="Распознано:", font=("Segoe UI", 12)).pack(anchor="w") # размещаем виджет в окне (менеджер pack)
|
||||
self.pred_label = ttk.Label(right, text="—", font=("Segoe UI", 80, "bold"), foreground="#007ACC") # создаём/настраиваем текстовый элемент интерфейса
|
||||
self.pred_label.pack(anchor="w", pady=(0, 10)) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
self.details_label = ttk.Label(right, text="", font=("Consolas", 10), justify="left") # создаём/настраиваем текстовый элемент интерфейса
|
||||
self.details_label.pack(anchor="w", fill="both", expand=True) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
ttk.Separator(right).pack(fill="x", pady=10) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
info_text = ( # выполняем инструкцию
|
||||
"Инструкция:\n" # выполняем инструкцию
|
||||
"1. Напишите одну или несколько цифр (например, 12, 45).\n" # выполняем инструкцию
|
||||
"2. Пишите цифры раздельно (не соединяйте их).\n" # выполняем инструкцию
|
||||
"3. Нажмите 'Распознать'.\n\n" # выполняем инструкцию
|
||||
"Алгоритм:\n" # выполняем инструкцию
|
||||
"• Поиск связных компонентов (разделение цифр).\n" # выполняем инструкцию
|
||||
"• Индивидуальная обработка каждой цифры (28x28).\n" # выполняем инструкцию
|
||||
"• Прогон через CNN батчем." # выполняем инструкцию
|
||||
) # выполняем инструкцию
|
||||
ttk.Label(right, text=info_text, foreground="#555").pack(anchor="w", side="bottom") # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
self.status = ttk.Label(self, text="Готово.", anchor="w", relief="sunken") # создаём/настраиваем текстовый элемент интерфейса
|
||||
self.status.pack(side="bottom", fill="x", padx=0, pady=0) # размещаем виджет в окне (менеджер pack)
|
||||
|
||||
def _on_brush_change(self, value: str) -> None: # объявляем функцию/метод
|
||||
try: # выполняем инструкцию
|
||||
self.brush_size.set(int(float(value))) # храним толщину кисти как переменную Tkinter
|
||||
except ValueError: # выполняем инструкцию
|
||||
pass # выполняем инструкцию
|
||||
|
||||
def _on_down(self, event) -> None: # объявляем функцию/метод
|
||||
self._last_x, self._last_y = event.x, event.y # инициализируем координаты последней точки рисования
|
||||
|
||||
def _on_move(self, event) -> None: # объявляем функцию/метод
|
||||
if self._last_x is None or self._last_y is None: # выполняем инструкцию
|
||||
return # выполняем инструкцию
|
||||
x1, y1 = self._last_x, self._last_y # выполняем инструкцию
|
||||
x2, y2 = event.x, event.y # выполняем инструкцию
|
||||
w = self.brush_size.get() # выполняем инструкцию
|
||||
|
||||
self.canvas.create_line(x1, y1, x2, y2, width=w, fill="black", capstyle=tk.ROUND, smooth=True, splinesteps=24) # рисуем линию на видимом холсте Tkinter
|
||||
self.buffer_draw.line((x1, y1, x2, y2), fill=0, width=w) # рисуем ту же линию в PIL-буфере для распознавания
|
||||
self._last_x, self._last_y = x2, y2 # инициализируем координаты последней точки рисования
|
||||
|
||||
def _on_up(self, event) -> None: # объявляем функцию/метод
|
||||
self._last_x, self._last_y = None, None # инициализируем координаты последней точки рисования
|
||||
|
||||
def on_clear(self) -> None: # объявляем функцию/метод
|
||||
self.canvas.delete("all") # очищаем нарисованные элементы на холсте
|
||||
self.buffer_img = Image.new("L", (self.canvas_size, self.canvas_size), 255) # создаём PIL-буфер холста (градации серого)
|
||||
self.buffer_draw = ImageDraw.Draw(self.buffer_img) # создаём объект для рисования по PIL-буфере
|
||||
self.pred_label.config(text="—") # создаём/настраиваем текстовый элемент интерфейса
|
||||
self.details_label.config(text="") # создаём/настраиваем текстовый элемент интерфейса
|
||||
self.status.config(text="Поле очищено.") # создаём/настраиваем текстовый элемент интерфейса
|
||||
|
||||
def _load_model(self) -> None: # объявляем функцию/метод
|
||||
model = MNISTCNN().to(self.device) # выполняем инструкцию
|
||||
state = torch.load(self.model_path, map_location=self.device, weights_only=True) # загружаем веса модели из файла
|
||||
model.load_state_dict(state) # применяем загруженные веса к модели
|
||||
model.eval() # переводим модель в режим инференса
|
||||
self.model = model # готовим поле для хранения модели
|
||||
|
||||
def _load_or_train(self) -> None: # объявляем функцию/метод
|
||||
os.makedirs(self.artifacts_dir, exist_ok=True) # создаём папку, если она отсутствует
|
||||
if os.path.exists(self.model_path): # проверяем наличие файла модели
|
||||
try: # выполняем инструкцию
|
||||
self._load_model() # выполняем инструкцию
|
||||
self.status.config(text="Модель загружена.") # создаём/настраиваем текстовый элемент интерфейса
|
||||
return # выполняем инструкцию
|
||||
except Exception: # выполняем инструкцию
|
||||
pass # выполняем инструкцию
|
||||
self.on_train_click(initial=True) # выполняем инструкцию
|
||||
|
||||
def _poll_messages(self) -> None: # объявляем функцию/метод
|
||||
try: # выполняем инструкцию
|
||||
while True: # выполняем инструкцию
|
||||
msg = self.msg_q.get_nowait() # выполняем инструкцию
|
||||
self.status.config(text=msg) # создаём/настраиваем текстовый элемент интерфейса
|
||||
except queue.Empty: # выполняем инструкцию
|
||||
pass # выполняем инструкцию
|
||||
self.after(150, self._poll_messages) # планируем периодический опрос очереди сообщений
|
||||
|
||||
def on_train_click(self, initial: bool = False) -> None: # объявляем функцию/метод
|
||||
self.status.config(text="Обучение..." if initial else "Переобучение...") # создаём/настраиваем текстовый элемент интерфейса
|
||||
|
||||
def worker(): # объявляем функцию/метод
|
||||
try: # выполняем инструкцию
|
||||
cfg = TrainConfig(epochs=5, batch_size=64, lr=1e-3, device=self.device, artifacts_dir=self.artifacts_dir) # формируем конфигурацию обучения
|
||||
def progress(msg: str): self.msg_q.put(msg) # объявляем функцию/метод
|
||||
train(cfg, progress_cb=progress) # запускаем обучение модели
|
||||
self._load_model() # выполняем инструкцию
|
||||
self.msg_q.put("Обучение завершено.") # выполняем инструкцию
|
||||
except Exception as e: # выполняем инструкцию
|
||||
self.msg_q.put(f"Ошибка: {e}") # выполняем инструкцию
|
||||
|
||||
threading.Thread(target=worker, daemon=True).start() # запускаем отдельный поток, чтобы GUI не зависал
|
||||
|
||||
def on_recognize(self) -> None: # объявляем функцию/метод
|
||||
if self.model is None: # выполняем инструкцию
|
||||
messagebox.showinfo("Wait", "Модель загружается...") # показываем пользователю диалоговое окно
|
||||
return # выполняем инструкцию
|
||||
|
||||
try: # выполняем инструкцию
|
||||
os.makedirs(self.artifacts_dir, exist_ok=True) # создаём папку, если она отсутствует
|
||||
debug_path = os.path.join(self.artifacts_dir, "last_debug.png") # выполняем инструкцию
|
||||
|
||||
# 1. Получаем батч тензоров (N, 1, 28, 28)
|
||||
batch_tensors = preprocess_canvas_to_mnist_tensor( # выполняем инструкцию
|
||||
self.buffer_img, # выполняем инструкцию
|
||||
debug_save_path=debug_path, # выполняем инструкцию
|
||||
).to(self.device) # выполняем инструкцию
|
||||
|
||||
# 2. Прогоняем сразу все цифры через модель
|
||||
with torch.no_grad(): # отключаем градиенты для ускорения инференса
|
||||
logits = self.model(batch_tensors) # (N, 10)
|
||||
probs = torch.softmax(logits, dim=1) # (N, 10)
|
||||
|
||||
# Получаем классы и уверенность
|
||||
max_probs, preds = torch.max(probs, dim=1) # берём наиболее вероятный класс и его вероятность
|
||||
|
||||
preds = preds.cpu().numpy() # переносим результаты на CPU и конвертируем в NumPy
|
||||
max_probs = max_probs.cpu().numpy() # переносим результаты на CPU и конвертируем в NumPy
|
||||
|
||||
# 3. Формируем итоговое число и отчет
|
||||
result_str = "".join(map(str, preds)) # склеиваем предсказанные цифры в итоговую строку
|
||||
self.pred_label.config(text=result_str) # создаём/настраиваем текстовый элемент интерфейса
|
||||
|
||||
# Детальный отчет по каждой цифре
|
||||
details = [] # выполняем инструкцию
|
||||
for i, (p, conf) in enumerate(zip(preds, max_probs)): # выполняем инструкцию
|
||||
details.append(f"Цифра #{i+1}: '{p}' (увер. {conf*100:.1f}%)") # выполняем инструкцию
|
||||
|
||||
self.details_label.config(text="\n".join(details)) # создаём/настраиваем текстовый элемент интерфейса
|
||||
self.status.config(text=f"Распознано чисел: {len(preds)}. Debug: {debug_path}") # создаём/настраиваем текстовый элемент интерфейса
|
||||
|
||||
except ValueError: # выполняем инструкцию
|
||||
self.status.config(text="Пустой холст.") # создаём/настраиваем текстовый элемент интерфейса
|
||||
except Exception as e: # выполняем инструкцию
|
||||
messagebox.showerror("Ошибка", str(e)) # показываем пользователю диалоговое окно
|
||||
self.status.config(text="Ошибка.") # создаём/настраиваем текстовый элемент интерфейса
|
||||
|
||||
|
||||
def main() -> None: # объявляем функцию/метод
|
||||
app = App() # создаём экземпляр приложения
|
||||
app.mainloop() # запускаем главный цикл Tkinter
|
||||
|
||||
if __name__ == "__main__": # точка входа при запуске файла как скрипта
|
||||
main() # выполняем инструкцию
|
||||
5
src/artifacts/A.java
Normal file
5
src/artifacts/A.java
Normal file
@@ -0,0 +1,5 @@
|
||||
numpy
|
||||
pillow
|
||||
matplotlib
|
||||
torch
|
||||
torchvision
|
||||
31
src/model.py
Normal file
31
src/model.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations # включаем отложенную обработку аннотаций типов
|
||||
|
||||
import torch # импортируем модуль/пакет
|
||||
import torch.nn as nn # импортируем модуль/пакет
|
||||
import torch.nn.functional as F # импортируем модуль/пакет
|
||||
|
||||
|
||||
class MNISTCNN(nn.Module): # объявляем класс приложения/модели
|
||||
"""
|
||||
Классическая CNN: 2 сверточных слоя, pooling, dropout и 2 полносвязных.
|
||||
"""
|
||||
def __init__(self) -> None: # объявляем функцию/метод
|
||||
super().__init__() # инициализируем базовый класс
|
||||
self.conv1 = nn.Conv2d(1, 32, 3, 1) # выполняем инструкцию
|
||||
self.conv2 = nn.Conv2d(32, 64, 3, 1) # выполняем инструкцию
|
||||
self.dropout1 = nn.Dropout(0.25) # выполняем инструкцию
|
||||
self.dropout2 = nn.Dropout(0.50) # выполняем инструкцию
|
||||
# После Conv1 (26x26), Conv2 (24x24) и MaxPool (12x12):
|
||||
self.fc1 = nn.Linear(64 * 12 * 12, 128) # выполняем инструкцию
|
||||
self.fc2 = nn.Linear(128, 10) # выполняем инструкцию
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # объявляем функцию/метод
|
||||
x = F.relu(self.conv1(x)) # выполняем инструкцию
|
||||
x = F.relu(self.conv2(x)) # выполняем инструкцию
|
||||
x = F.max_pool2d(x, 2) # выполняем инструкцию
|
||||
x = self.dropout1(x) # выполняем инструкцию
|
||||
x = torch.flatten(x, 1) # выполняем инструкцию
|
||||
x = F.relu(self.fc1(x)) # выполняем инструкцию
|
||||
x = self.dropout2(x) # выполняем инструкцию
|
||||
x = self.fc2(x) # выполняем инструкцию
|
||||
return x # выполняем инструкцию
|
||||
181
src/preprocess.py
Normal file
181
src/preprocess.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from __future__ import annotations # включаем отложенную обработку аннотаций типов
|
||||
|
||||
import numpy as np # импортируем модуль/пакет
|
||||
import torch # импортируем модуль/пакет
|
||||
from PIL import Image, ImageFilter # импортируем объекты из модуля
|
||||
import collections # импортируем модуль/пакет
|
||||
|
||||
MNIST_MEAN = 0.1307 # выполняем инструкцию
|
||||
MNIST_STD = 0.3081 # выполняем инструкцию
|
||||
|
||||
|
||||
def _shift_image(img: Image.Image, dx: int, dy: int) -> Image.Image: # объявляем функцию/метод
|
||||
"""Сдвигает содержимое картинки на dx, dy, заполняя пустоты черным."""
|
||||
w, h = img.size # выполняем инструкцию
|
||||
shifted = Image.new("L", (w, h), 0) # выполняем инструкцию
|
||||
x_to, y_to = max(0, dx), max(0, dy) # выполняем инструкцию
|
||||
x_from, y_from = max(0, -dx), max(0, -dy) # выполняем инструкцию
|
||||
width, height = w - abs(dx), h - abs(dy) # выполняем инструкцию
|
||||
|
||||
if width <= 0 or height <= 0: # выполняем инструкцию
|
||||
return shifted # выполняем инструкцию
|
||||
|
||||
crop = img.crop((x_from, y_from, x_from + width, y_from + height)) # выполняем инструкцию
|
||||
shifted.paste(crop, (x_to, y_to)) # выполняем инструкцию
|
||||
return shifted # выполняем инструкцию
|
||||
|
||||
|
||||
def _center_of_mass_shift(img28: Image.Image) -> Image.Image: # объявляем функцию/метод
|
||||
"""Центрирует изображение 28x28 по центру масс (как в датасете MNIST)."""
|
||||
arr = np.array(img28, dtype=np.float32) # выполняем инструкцию
|
||||
total = float(arr.sum()) # выполняем инструкцию
|
||||
if total <= 1e-6: # выполняем инструкцию
|
||||
return img28 # выполняем инструкцию
|
||||
|
||||
ys, xs = np.indices(arr.shape) # выполняем инструкцию
|
||||
cy = float((ys * arr).sum() / total) # выполняем инструкцию
|
||||
cx = float((xs * arr).sum() / total) # выполняем инструкцию
|
||||
|
||||
# Целевой центр для 28x28 — (14, 14)
|
||||
dx = int(round(14.0 - cx)) # выполняем инструкцию
|
||||
dy = int(round(14.0 - cy)) # выполняем инструкцию
|
||||
return _shift_image(img28, dx, dy) # выполняем инструкцию
|
||||
|
||||
|
||||
def _find_connected_components(mask: np.ndarray) -> list[tuple[int, int, int, int]]: # объявляем функцию/метод
|
||||
"""
|
||||
Находит bbox (x1, y1, x2, y2) для каждого связного объекта на маске.
|
||||
Использует BFS (поиск в ширину).
|
||||
"""
|
||||
h, w = mask.shape # выполняем инструкцию
|
||||
visited = np.zeros_like(mask, dtype=bool) # выполняем инструкцию
|
||||
bboxes = [] # выполняем инструкцию
|
||||
|
||||
for y in range(h): # выполняем инструкцию
|
||||
for x in range(w): # выполняем инструкцию
|
||||
if mask[y, x] and not visited[y, x]: # выполняем инструкцию
|
||||
q = collections.deque([(x, y)]) # выполняем инструкцию
|
||||
visited[y, x] = True # выполняем инструкцию
|
||||
min_x, max_x = x, x # выполняем инструкцию
|
||||
min_y, max_y = y, y # выполняем инструкцию
|
||||
count = 0 # выполняем инструкцию
|
||||
|
||||
while q: # выполняем инструкцию
|
||||
cx, cy = q.popleft() # выполняем инструкцию
|
||||
count += 1 # выполняем инструкцию
|
||||
|
||||
min_x = min(min_x, cx) # выполняем инструкцию
|
||||
max_x = max(max_x, cx) # выполняем инструкцию
|
||||
min_y = min(min_y, cy) # выполняем инструкцию
|
||||
max_y = max(max_y, cy) # выполняем инструкцию
|
||||
|
||||
# Соседи (4 стороны)
|
||||
for nx, ny in [(cx+1, cy), (cx-1, cy), (cx, cy+1), (cx, cy-1)]: # выполняем инструкцию
|
||||
if 0 <= nx < w and 0 <= ny < h: # выполняем инструкцию
|
||||
if mask[ny, nx] and not visited[ny, nx]: # выполняем инструкцию
|
||||
visited[ny, nx] = True # выполняем инструкцию
|
||||
q.append((nx, ny)) # выполняем инструкцию
|
||||
|
||||
# Фильтр совсем мелкого шума (точки)
|
||||
if count > 10: # выполняем инструкцию
|
||||
bboxes.append((min_x, min_y, max_x, max_y)) # выполняем инструкцию
|
||||
|
||||
return bboxes # выполняем инструкцию
|
||||
|
||||
|
||||
def _process_single_crop( # объявляем функцию/метод
|
||||
crop_img: Image.Image, # выполняем инструкцию
|
||||
debug_save_path: str | None = None # выполняем инструкцию
|
||||
) -> torch.Tensor: # выполняем инструкцию
|
||||
"""Подготовка одного кусочка с цифрой для нейросети."""
|
||||
|
||||
# 1. Легкий блюр для сглаживания пикселизации
|
||||
crop_img = crop_img.filter(ImageFilter.GaussianBlur(radius=0.6)) # выполняем инструкцию
|
||||
|
||||
w, h = crop_img.size # выполняем инструкцию
|
||||
|
||||
# 2. Ресайз: вписываем в квадрат 20x20, сохраняя пропорции
|
||||
max_side = max(w, h) # выполняем инструкцию
|
||||
scale = 20.0 / float(max_side) # выполняем инструкцию
|
||||
new_w = max(1, int(round(w * scale))) # выполняем инструкцию
|
||||
new_h = max(1, int(round(h * scale))) # выполняем инструкцию
|
||||
resized = crop_img.resize((new_w, new_h), resample=Image.BILINEAR) # выполняем инструкцию
|
||||
|
||||
# 3. Вставляем в центр черного квадрата 28x28
|
||||
img28 = Image.new("L", (28, 28), 0) # выполняем инструкцию
|
||||
left = (28 - new_w) // 2 # выполняем инструкцию
|
||||
top = (28 - new_h) // 2 # выполняем инструкцию
|
||||
img28.paste(resized, (left, top)) # выполняем инструкцию
|
||||
|
||||
# 4. Центрирование по массе (Критически важно!)
|
||||
img28 = _center_of_mass_shift(img28) # выполняем инструкцию
|
||||
|
||||
if debug_save_path: # выполняем инструкцию
|
||||
img28.save(debug_save_path) # выполняем инструкцию
|
||||
|
||||
# 5. Нормализация
|
||||
x = np.array(img28, dtype=np.float32) / 255.0 # выполняем инструкцию
|
||||
x = (x - MNIST_MEAN) / MNIST_STD # выполняем инструкцию
|
||||
return torch.from_numpy(x).unsqueeze(0) # (1, 28, 28)
|
||||
|
||||
|
||||
def preprocess_canvas_to_mnist_tensor( # объявляем функцию/метод
|
||||
img_l: Image.Image, # выполняем инструкцию
|
||||
*, # выполняем инструкцию
|
||||
debug_save_path: str | None = None, # выполняем инструкцию
|
||||
) -> torch.Tensor: # выполняем инструкцию
|
||||
"""
|
||||
Основная функция: Холст -> Батч тензоров.
|
||||
"""
|
||||
if img_l.mode != "L": # выполняем инструкцию
|
||||
img_l = img_l.convert("L") # выполняем инструкцию
|
||||
|
||||
arr = np.array(img_l, dtype=np.uint8) # выполняем инструкцию
|
||||
# Инверсия: (белый фон -> черный фон)
|
||||
inv = 255 - arr # выполняем инструкцию
|
||||
|
||||
# Бинаризация для поиска объектов
|
||||
mask = inv > 20 # выполняем инструкцию
|
||||
|
||||
bboxes = _find_connected_components(mask) # выполняем инструкцию
|
||||
|
||||
if not bboxes: # выполняем инструкцию
|
||||
raise ValueError("Холст пуст.") # выполняем инструкцию
|
||||
|
||||
# --- УМНАЯ СОРТИРОВКА ---
|
||||
# Определяем габариты всего написанного текста
|
||||
min_x_total = min(b[0] for b in bboxes) # выполняем инструкцию
|
||||
max_x_total = max(b[2] for b in bboxes) # выполняем инструкцию
|
||||
min_y_total = min(b[1] for b in bboxes) # выполняем инструкцию
|
||||
max_y_total = max(b[3] for b in bboxes) # выполняем инструкцию
|
||||
|
||||
total_w = max_x_total - min_x_total # выполняем инструкцию
|
||||
total_h = max_y_total - min_y_total # выполняем инструкцию
|
||||
|
||||
# Если высота текста значительно больше ширины (коэффициент 1.2),
|
||||
# считаем, что это "столбик" -> сортируем по Y.
|
||||
# Иначе (строка или лесенка) -> сортируем по X.
|
||||
if total_h > total_w * 1.2: # выполняем инструкцию
|
||||
bboxes.sort(key=lambda b: b[1]) # Сортировка сверху вниз
|
||||
else: # выполняем инструкцию
|
||||
bboxes.sort(key=lambda b: b[0]) # Сортировка слева направо
|
||||
# ------------------------
|
||||
|
||||
tensors = [] # выполняем инструкцию
|
||||
pad = 12 # Отступ при вырезании цифры
|
||||
|
||||
for i, (x1, y1, x2, y2) in enumerate(bboxes): # выполняем инструкцию
|
||||
x1 = max(0, x1 - pad) # выполняем инструкцию
|
||||
y1 = max(0, y1 - pad) # выполняем инструкцию
|
||||
x2 = min(inv.shape[1] - 1, x2 + pad) # выполняем инструкцию
|
||||
y2 = min(inv.shape[0] - 1, y2 + pad) # выполняем инструкцию
|
||||
|
||||
digit_crop = Image.fromarray(inv[y1 : y2+1, x1 : x2+1]) # выполняем инструкцию
|
||||
|
||||
# Сохраняем для дебага только первую цифру (чтобы не спамить файлами)
|
||||
path = debug_save_path if i == 0 else None # выполняем инструкцию
|
||||
|
||||
t = _process_single_crop(digit_crop, debug_save_path=path) # выполняем инструкцию
|
||||
tensors.append(t) # выполняем инструкцию
|
||||
|
||||
return torch.stack(tensors) # выполняем инструкцию
|
||||
11
src/requirements.txt
Normal file
11
src/requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
# Основные зависимости проекта
|
||||
numpy
|
||||
pillow
|
||||
matplotlib
|
||||
|
||||
# Машинное обучение
|
||||
torch
|
||||
torchvision
|
||||
|
||||
# GUI (входит в стандартную библиотеку Python)
|
||||
# tkinter
|
||||
8
src/test_cuda.py
Normal file
8
src/test_cuda.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import torch
|
||||
|
||||
print("torch:", torch.__version__)
|
||||
print("cuda available:", torch.cuda.is_available())
|
||||
print("cuda version:", torch.version.cuda)
|
||||
print("device count:", torch.cuda.device_count())
|
||||
if torch.cuda.is_available():
|
||||
print("device name:", torch.cuda.get_device_name(0))
|
||||
141
src/train.py
Normal file
141
src/train.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from __future__ import annotations # включаем отложенную обработку аннотаций типов
|
||||
|
||||
import os # импортируем модуль/пакет
|
||||
from dataclasses import dataclass # импортируем объекты из модуля
|
||||
from typing import Callable, Optional # импортируем объекты из модуля
|
||||
|
||||
import torch # импортируем модуль/пакет
|
||||
import torch.nn as nn # импортируем модуль/пакет
|
||||
from torch.utils.data import DataLoader # импортируем объекты из модуля
|
||||
from torchvision import datasets, transforms # импортируем объекты из модуля
|
||||
|
||||
from model import MNISTCNN # импортируем объекты из модуля
|
||||
from utils import setup_logger, save_metrics, save_training_plot, get_device # импортируем объекты из модуля
|
||||
|
||||
|
||||
ProgressCB = Callable[[str], None] # выполняем инструкцию
|
||||
|
||||
|
||||
@dataclass # выполняем инструкцию
|
||||
class TrainConfig: # объявляем класс приложения/модели
|
||||
epochs: int = 5 # выполняем инструкцию
|
||||
batch_size: int = 64 # выполняем инструкцию
|
||||
lr: float = 1e-3 # выполняем инструкцию
|
||||
device: str = "cpu" # выполняем инструкцию
|
||||
artifacts_dir: str = "artifacts" # выполняем инструкцию
|
||||
|
||||
|
||||
def _get_loaders(batch_size: int) -> tuple[DataLoader, DataLoader]: # объявляем функцию/метод
|
||||
# --- АГРЕССИВНАЯ АУГМЕНТАЦИЯ ---
|
||||
# Это учит модель понимать кривой почерк, наклон и повороты.
|
||||
train_tfm = transforms.Compose([ # выполняем инструкцию
|
||||
transforms.RandomAffine( # выполняем инструкцию
|
||||
degrees=20, # Поворот +/- 20 градусов
|
||||
translate=(0.15, 0.15), # Сдвиг картинки на 15%
|
||||
scale=(0.85, 1.15), # Масштаб (толстые/тонкие цифры)
|
||||
shear=15 # Наклон (курсив) до 15 градусов
|
||||
), # выполняем инструкцию
|
||||
transforms.ToTensor(), # выполняем инструкцию
|
||||
transforms.Normalize((0.1307,), (0.3081,)), # выполняем инструкцию
|
||||
]) # выполняем инструкцию
|
||||
|
||||
test_tfm = transforms.Compose([ # выполняем инструкцию
|
||||
transforms.ToTensor(), # выполняем инструкцию
|
||||
transforms.Normalize((0.1307,), (0.3081,)), # выполняем инструкцию
|
||||
]) # выполняем инструкцию
|
||||
|
||||
# Загрузка датасета (автоматически скачает в папку ./data)
|
||||
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=train_tfm) # выполняем инструкцию
|
||||
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=test_tfm) # выполняем инструкцию
|
||||
|
||||
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0) # выполняем инструкцию
|
||||
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0) # выполняем инструкцию
|
||||
return train_loader, test_loader # выполняем инструкцию
|
||||
|
||||
|
||||
@torch.no_grad() # выполняем инструкцию
|
||||
def evaluate(model: nn.Module, loader: DataLoader, device: str) -> float: # объявляем функцию/метод
|
||||
model.eval() # переводим модель в режим инференса
|
||||
correct, total = 0, 0 # выполняем инструкцию
|
||||
for x, y in loader: # выполняем инструкцию
|
||||
x, y = x.to(device), y.to(device) # выполняем инструкцию
|
||||
logits = model(x) # выполняем инструкцию
|
||||
pred = logits.argmax(dim=1) # выполняем инструкцию
|
||||
correct += int((pred == y).sum().item()) # выполняем инструкцию
|
||||
total += int(y.numel()) # выполняем инструкцию
|
||||
return correct / max(total, 1) # выполняем инструкцию
|
||||
|
||||
|
||||
def train(cfg: TrainConfig, progress_cb: Optional[ProgressCB] = None) -> dict: # объявляем функцию/метод
|
||||
os.makedirs(cfg.artifacts_dir, exist_ok=True) # создаём папку, если она отсутствует
|
||||
|
||||
log_path = os.path.join(cfg.artifacts_dir, "train.log") # выполняем инструкцию
|
||||
logger = setup_logger(log_path) # выполняем инструкцию
|
||||
|
||||
train_loader, test_loader = _get_loaders(cfg.batch_size) # выполняем инструкцию
|
||||
|
||||
model = MNISTCNN().to(cfg.device) # выполняем инструкцию
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) # выполняем инструкцию
|
||||
criterion = nn.CrossEntropyLoss() # выполняем инструкцию
|
||||
|
||||
epochs_list: list[int] = [] # выполняем инструкцию
|
||||
losses: list[float] = [] # выполняем инструкцию
|
||||
accs: list[float] = [] # выполняем инструкцию
|
||||
|
||||
def notify(msg: str) -> None: # объявляем функцию/метод
|
||||
logger.info(msg) # выполняем инструкцию
|
||||
if progress_cb: # выполняем инструкцию
|
||||
progress_cb(msg) # выполняем инструкцию
|
||||
|
||||
notify(f"Старт обучения. Device={cfg.device}. Augmentation=ON (Rotate, Shear, Scale).") # выполняем инструкцию
|
||||
|
||||
for epoch in range(1, cfg.epochs + 1): # выполняем инструкцию
|
||||
model.train() # выполняем инструкцию
|
||||
running_loss = 0.0 # выполняем инструкцию
|
||||
seen = 0 # выполняем инструкцию
|
||||
|
||||
for i, (x, y) in enumerate(train_loader, start=1): # выполняем инструкцию
|
||||
x, y = x.to(cfg.device), y.to(cfg.device) # выполняем инструкцию
|
||||
|
||||
optimizer.zero_grad() # выполняем инструкцию
|
||||
logits = model(x) # выполняем инструкцию
|
||||
loss = criterion(logits, y) # выполняем инструкцию
|
||||
loss.backward() # выполняем инструкцию
|
||||
optimizer.step() # выполняем инструкцию
|
||||
|
||||
bs = int(y.size(0)) # выполняем инструкцию
|
||||
running_loss += float(loss.item()) * bs # выполняем инструкцию
|
||||
seen += bs # выполняем инструкцию
|
||||
|
||||
if i % 300 == 0: # выполняем инструкцию
|
||||
notify(f"Epoch {epoch}/{cfg.epochs} | step={i} | loss={running_loss/max(seen,1):.4f}") # выполняем инструкцию
|
||||
|
||||
epoch_loss = running_loss / max(seen, 1) # выполняем инструкцию
|
||||
acc = evaluate(model, test_loader, cfg.device) # выполняем инструкцию
|
||||
|
||||
epochs_list.append(epoch) # выполняем инструкцию
|
||||
losses.append(float(epoch_loss)) # выполняем инструкцию
|
||||
accs.append(float(acc)) # выполняем инструкцию
|
||||
|
||||
notify(f"Epoch {epoch} finished. Test Acc: {acc*100:.2f}%") # выполняем инструкцию
|
||||
|
||||
model_path = os.path.join(cfg.artifacts_dir, "mnist_cnn.pt") # выполняем инструкцию
|
||||
torch.save(model.state_dict(), model_path) # выполняем инструкцию
|
||||
|
||||
metrics_path = os.path.join(cfg.artifacts_dir, "metrics.json") # выполняем инструкцию
|
||||
plot_path = os.path.join(cfg.artifacts_dir, "training_plot.png") # выполняем инструкцию
|
||||
|
||||
save_metrics({"epochs": epochs_list, "acc": accs, "loss": losses}, metrics_path) # выполняем инструкцию
|
||||
save_training_plot(epochs_list, losses, accs, plot_path) # выполняем инструкцию
|
||||
|
||||
notify(f"Обучение завершено. Точность: {accs[-1]*100:.2f}%") # выполняем инструкцию
|
||||
return {"accuracy": accs[-1]} # выполняем инструкцию
|
||||
|
||||
|
||||
if __name__ == "__main__": # точка входа при запуске файла как скрипта
|
||||
# Для ручного запуска
|
||||
cfg = TrainConfig( # формируем конфигурацию обучения
|
||||
epochs=5, # выполняем инструкцию
|
||||
device="cuda" if get_device() == "cuda" else "cpu", # выполняем инструкцию
|
||||
) # выполняем инструкцию
|
||||
train(cfg) # запускаем обучение модели
|
||||
51
src/utils.py
Normal file
51
src/utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations # включаем отложенную обработку аннотаций типов
|
||||
|
||||
import json # импортируем модуль/пакет
|
||||
import logging # импортируем модуль/пакет
|
||||
import os # импортируем модуль/пакет
|
||||
from typing import Any # импортируем объекты из модуля
|
||||
import matplotlib.pyplot as plt # импортируем модуль/пакет
|
||||
|
||||
def get_device() -> str: # объявляем функцию/метод
|
||||
try: # выполняем инструкцию
|
||||
import torch # импортируем модуль/пакет
|
||||
return "cuda" if torch.cuda.is_available() else "cpu" # выполняем инструкцию
|
||||
except Exception: # выполняем инструкцию
|
||||
return "cpu" # выполняем инструкцию
|
||||
|
||||
|
||||
def setup_logger(log_path: str) -> logging.Logger: # объявляем функцию/метод
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True) # создаём папку, если она отсутствует
|
||||
logger = logging.getLogger("train_logger") # выполняем инструкцию
|
||||
logger.setLevel(logging.INFO) # выполняем инструкцию
|
||||
logger.handlers.clear() # выполняем инструкцию
|
||||
|
||||
fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") # выполняем инструкцию
|
||||
|
||||
fh = logging.FileHandler(log_path, encoding="utf-8") # выполняем инструкцию
|
||||
fh.setFormatter(fmt) # выполняем инструкцию
|
||||
logger.addHandler(fh) # выполняем инструкцию
|
||||
|
||||
sh = logging.StreamHandler() # выполняем инструкцию
|
||||
sh.setFormatter(fmt) # выполняем инструкцию
|
||||
logger.addHandler(sh) # выполняем инструкцию
|
||||
return logger # выполняем инструкцию
|
||||
|
||||
|
||||
def save_metrics(metrics: dict[str, Any], path: str) -> None: # объявляем функцию/метод
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True) # создаём папку, если она отсутствует
|
||||
with open(path, "w", encoding="utf-8") as f: # выполняем инструкцию
|
||||
json.dump(metrics, f, ensure_ascii=False, indent=2) # выполняем инструкцию
|
||||
|
||||
|
||||
def save_training_plot(epochs: list[int], losses: list[float], accs: list[float], path: str) -> None: # объявляем функцию/метод
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True) # создаём папку, если она отсутствует
|
||||
plt.figure() # выполняем инструкцию
|
||||
plt.plot(epochs, losses, label="Loss") # выполняем инструкцию
|
||||
plt.plot(epochs, accs, label="Accuracy") # выполняем инструкцию
|
||||
plt.xlabel("Epoch") # выполняем инструкцию
|
||||
plt.legend() # выполняем инструкцию
|
||||
plt.grid(True, alpha=0.3) # выполняем инструкцию
|
||||
plt.tight_layout() # выполняем инструкцию
|
||||
plt.savefig(path, dpi=120) # выполняем инструкцию
|
||||
plt.close() # выполняем инструкцию
|
||||
Reference in New Issue
Block a user