From 7c8f9db15d7a73d0372f7be926b5ffffc91e6321 Mon Sep 17 00:00:00 2001 From: Nikolaev Misha <63680387+nikolaevdev@users.noreply.github.com> Date: Mon, 22 Dec 2025 22:35:23 +0300 Subject: [PATCH] =?UTF-8?q?=D0=A1=D0=BE=D0=B7=D0=B4=D0=B0=D0=BD=D0=B8?= =?UTF-8?q?=D0=B5=20=D0=BA=D0=BE=D0=B4=D0=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 55 ++++++++++ src/app.py | 249 +++++++++++++++++++++++++++++++++++++++++++ src/artifacts/A.java | 5 + src/model.py | 31 ++++++ src/preprocess.py | 181 +++++++++++++++++++++++++++++++ src/requirements.txt | 11 ++ src/test_cuda.py | 8 ++ src/train.py | 141 ++++++++++++++++++++++++ src/utils.py | 51 +++++++++ 9 files changed, 732 insertions(+) create mode 100644 .gitignore create mode 100644 src/app.py create mode 100644 src/artifacts/A.java create mode 100644 src/model.py create mode 100644 src/preprocess.py create mode 100644 src/requirements.txt create mode 100644 src/test_cuda.py create mode 100644 src/train.py create mode 100644 src/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..429c03a --- /dev/null +++ b/.gitignore @@ -0,0 +1,55 @@ +# =============================== +# Python +# =============================== +__pycache__/ +*.py[cod] +*.pyo +*.pyd + +# =============================== +# Virtual environments +# =============================== +.venv/ +venv/ +ENV/ + +# =============================== +# IDE (PyCharm / IntelliJ) +# =============================== +.idea/ +*.iml +.gigaide/ + +# =============================== +# OS +# =============================== +.DS_Store +Thumbs.db + +# =============================== +# Project-specific (inside src/) +# =============================== +src/artifacts/ +src/data/ + +# =============================== +# Logs +# =============================== +*.log + +# =============================== +# Temporary / backup files +# =============================== +*.tmp +*.bak +*.swp + +# =============================== +# Jupyter +# =============================== +.ipynb_checkpoints/ + +# =============================== +# CUDA cache (optional) +# =============================== +.cuda/ diff --git a/src/app.py b/src/app.py new file mode 100644 index 0000000..38ed0f8 --- /dev/null +++ b/src/app.py @@ -0,0 +1,249 @@ +from __future__ import annotations # включаем отложенную обработку аннотаций типов + +import os # импортируем модуль/пакет +import threading # импортируем модуль/пакет +import queue # импортируем модуль/пакет + +import tkinter as tk # импортируем модуль/пакет +from tkinter import ttk, messagebox # импортируем объекты из модуля + +import numpy as np # импортируем модуль/пакет +from PIL import Image, ImageDraw # импортируем объекты из модуля + +import torch # импортируем модуль/пакет + +from model import MNISTCNN # импортируем объекты из модуля +from preprocess import preprocess_canvas_to_mnist_tensor # импортируем объекты из модуля +from train import TrainConfig, train # импортируем объекты из модуля +from utils import get_device # импортируем объекты из модуля + + +class App(tk.Tk): # объявляем класс приложения/модели + def __init__(self) -> None: # объявляем функцию/метод + super().__init__() # инициализируем базовый класс + self.title("MNIST Pro: Распознавание чисел (Tkinter + PyTorch)") # задаём заголовок окна + self.geometry("960x600") # задаём стартовый размер окна + self.minsize(960, 600) # задаём минимальный размер окна + + self.device = get_device() # определяем устройство (CPU/GPU) для PyTorch + self.artifacts_dir = "artifacts" # задаём папку для артефактов (модель/логи/картинки) + self.model_path = os.path.join(self.artifacts_dir, "mnist_cnn.pt") # формируем путь к файлу весов модели + + self.canvas_size = 360 # задаём размер холста для рисования + self.brush_size = tk.IntVar(value=12) # храним толщину кисти как переменную Tkinter + + # PIL-буфер: белый фон + self.buffer_img = Image.new("L", (self.canvas_size, self.canvas_size), 255) # создаём PIL-буфер холста (градации серого) + self.buffer_draw = ImageDraw.Draw(self.buffer_img) # создаём объект для рисования по PIL-буфере + + self._last_x = None # инициализируем координаты последней точки рисования + self._last_y = None # инициализируем координаты последней точки рисования + + self.msg_q: queue.Queue[str] = queue.Queue() # создаём очередь сообщений между потоками (обучение → GUI) + self.model: MNISTCNN | None = None # готовим поле для хранения модели + + self._build_ui() # собираем интерфейс + self._load_or_train() # загружаем модель или обучаем, если её нет + self.after(150, self._poll_messages) # планируем периодический опрос очереди сообщений + + def _build_ui(self) -> None: # объявляем функцию/метод + root = ttk.Frame(self, padding=12) # создаём контейнерный фрейм для компоновки UI + root.pack(fill="both", expand=True) # размещаем виджет в окне (менеджер pack) + + left = ttk.Frame(root) # создаём контейнерный фрейм для компоновки UI + left.pack(side="left", fill="y") # размещаем виджет в окне (менеджер pack) + + right = ttk.Frame(root) # создаём контейнерный фрейм для компоновки UI + right.pack(side="right", fill="both", expand=True, padx=(20, 0)) # размещаем виджет в окне (менеджер pack) + + # Холст + self.canvas = tk.Canvas( # создаём холст Tkinter для рисования мышью + left, # выполняем инструкцию + width=self.canvas_size, # выполняем инструкцию + height=self.canvas_size, # выполняем инструкцию + bg="white", # выполняем инструкцию + highlightthickness=1, # выполняем инструкцию + highlightbackground="#999", # выполняем инструкцию + cursor="crosshair", # выполняем инструкцию + ) # выполняем инструкцию + self.canvas.pack() # размещаем виджет в окне (менеджер pack) + + self.canvas.bind("", self._on_down) # подключаем обработчик событий мыши + self.canvas.bind("", self._on_move) # подключаем обработчик событий мыши + self.canvas.bind("", self._on_up) # подключаем обработчик событий мыши + + # Контролы слева + controls = ttk.Frame(left, padding=(0, 12, 0, 0)) # создаём контейнерный фрейм для компоновки UI + controls.pack(fill="x") # размещаем виджет в окне (менеджер pack) + + ttk.Label(controls, text="Толщина кисти:").pack(anchor="w") # размещаем виджет в окне (менеджер pack) + scale = ttk.Scale( # выполняем инструкцию + controls, # выполняем инструкцию + from_=6, to=30, # выполняем инструкцию + orient="horizontal", # выполняем инструкцию + command=self._on_brush_change, # выполняем инструкцию + value=self.brush_size.get(), # выполняем инструкцию + ) # выполняем инструкцию + scale.pack(fill="x", pady=(2, 10)) # размещаем виджет в окне (менеджер pack) + + btns = ttk.Frame(controls) # создаём контейнерный фрейм для компоновки UI + btns.pack(fill="x") # размещаем виджет в окне (менеджер pack) + + # Стилизация кнопок + s = ttk.Style() # инициализируем объект стилей ttk + s.configure("Big.TButton", font=("Segoe UI", 11, "bold")) # настраиваем стиль для крупных кнопок + + ttk.Button(btns, text="Распознать число", style="Big.TButton", command=self.on_recognize).pack(side="left", fill="x", expand=True) # размещаем виджет в окне (менеджер pack) + ttk.Button(btns, text="Очистить", command=self.on_clear).pack(side="left", fill="x", expand=True, padx=(8, 0)) # размещаем виджет в окне (менеджер pack) + + ttk.Button(controls, text="Переобучить (с аугментацией)", command=self.on_train_click).pack(fill="x", pady=(15, 0)) # размещаем виджет в окне (менеджер pack) + + # Правая панель + ttk.Label(right, text="Распознано:", font=("Segoe UI", 12)).pack(anchor="w") # размещаем виджет в окне (менеджер pack) + self.pred_label = ttk.Label(right, text="—", font=("Segoe UI", 80, "bold"), foreground="#007ACC") # создаём/настраиваем текстовый элемент интерфейса + self.pred_label.pack(anchor="w", pady=(0, 10)) # размещаем виджет в окне (менеджер pack) + + self.details_label = ttk.Label(right, text="", font=("Consolas", 10), justify="left") # создаём/настраиваем текстовый элемент интерфейса + self.details_label.pack(anchor="w", fill="both", expand=True) # размещаем виджет в окне (менеджер pack) + + ttk.Separator(right).pack(fill="x", pady=10) # размещаем виджет в окне (менеджер pack) + + info_text = ( # выполняем инструкцию + "Инструкция:\n" # выполняем инструкцию + "1. Напишите одну или несколько цифр (например, 12, 45).\n" # выполняем инструкцию + "2. Пишите цифры раздельно (не соединяйте их).\n" # выполняем инструкцию + "3. Нажмите 'Распознать'.\n\n" # выполняем инструкцию + "Алгоритм:\n" # выполняем инструкцию + "• Поиск связных компонентов (разделение цифр).\n" # выполняем инструкцию + "• Индивидуальная обработка каждой цифры (28x28).\n" # выполняем инструкцию + "• Прогон через CNN батчем." # выполняем инструкцию + ) # выполняем инструкцию + ttk.Label(right, text=info_text, foreground="#555").pack(anchor="w", side="bottom") # размещаем виджет в окне (менеджер pack) + + self.status = ttk.Label(self, text="Готово.", anchor="w", relief="sunken") # создаём/настраиваем текстовый элемент интерфейса + self.status.pack(side="bottom", fill="x", padx=0, pady=0) # размещаем виджет в окне (менеджер pack) + + def _on_brush_change(self, value: str) -> None: # объявляем функцию/метод + try: # выполняем инструкцию + self.brush_size.set(int(float(value))) # храним толщину кисти как переменную Tkinter + except ValueError: # выполняем инструкцию + pass # выполняем инструкцию + + def _on_down(self, event) -> None: # объявляем функцию/метод + self._last_x, self._last_y = event.x, event.y # инициализируем координаты последней точки рисования + + def _on_move(self, event) -> None: # объявляем функцию/метод + if self._last_x is None or self._last_y is None: # выполняем инструкцию + return # выполняем инструкцию + x1, y1 = self._last_x, self._last_y # выполняем инструкцию + x2, y2 = event.x, event.y # выполняем инструкцию + w = self.brush_size.get() # выполняем инструкцию + + self.canvas.create_line(x1, y1, x2, y2, width=w, fill="black", capstyle=tk.ROUND, smooth=True, splinesteps=24) # рисуем линию на видимом холсте Tkinter + self.buffer_draw.line((x1, y1, x2, y2), fill=0, width=w) # рисуем ту же линию в PIL-буфере для распознавания + self._last_x, self._last_y = x2, y2 # инициализируем координаты последней точки рисования + + def _on_up(self, event) -> None: # объявляем функцию/метод + self._last_x, self._last_y = None, None # инициализируем координаты последней точки рисования + + def on_clear(self) -> None: # объявляем функцию/метод + self.canvas.delete("all") # очищаем нарисованные элементы на холсте + self.buffer_img = Image.new("L", (self.canvas_size, self.canvas_size), 255) # создаём PIL-буфер холста (градации серого) + self.buffer_draw = ImageDraw.Draw(self.buffer_img) # создаём объект для рисования по PIL-буфере + self.pred_label.config(text="—") # создаём/настраиваем текстовый элемент интерфейса + self.details_label.config(text="") # создаём/настраиваем текстовый элемент интерфейса + self.status.config(text="Поле очищено.") # создаём/настраиваем текстовый элемент интерфейса + + def _load_model(self) -> None: # объявляем функцию/метод + model = MNISTCNN().to(self.device) # выполняем инструкцию + state = torch.load(self.model_path, map_location=self.device, weights_only=True) # загружаем веса модели из файла + model.load_state_dict(state) # применяем загруженные веса к модели + model.eval() # переводим модель в режим инференса + self.model = model # готовим поле для хранения модели + + def _load_or_train(self) -> None: # объявляем функцию/метод + os.makedirs(self.artifacts_dir, exist_ok=True) # создаём папку, если она отсутствует + if os.path.exists(self.model_path): # проверяем наличие файла модели + try: # выполняем инструкцию + self._load_model() # выполняем инструкцию + self.status.config(text="Модель загружена.") # создаём/настраиваем текстовый элемент интерфейса + return # выполняем инструкцию + except Exception: # выполняем инструкцию + pass # выполняем инструкцию + self.on_train_click(initial=True) # выполняем инструкцию + + def _poll_messages(self) -> None: # объявляем функцию/метод + try: # выполняем инструкцию + while True: # выполняем инструкцию + msg = self.msg_q.get_nowait() # выполняем инструкцию + self.status.config(text=msg) # создаём/настраиваем текстовый элемент интерфейса + except queue.Empty: # выполняем инструкцию + pass # выполняем инструкцию + self.after(150, self._poll_messages) # планируем периодический опрос очереди сообщений + + def on_train_click(self, initial: bool = False) -> None: # объявляем функцию/метод + self.status.config(text="Обучение..." if initial else "Переобучение...") # создаём/настраиваем текстовый элемент интерфейса + + def worker(): # объявляем функцию/метод + try: # выполняем инструкцию + cfg = TrainConfig(epochs=5, batch_size=64, lr=1e-3, device=self.device, artifacts_dir=self.artifacts_dir) # формируем конфигурацию обучения + def progress(msg: str): self.msg_q.put(msg) # объявляем функцию/метод + train(cfg, progress_cb=progress) # запускаем обучение модели + self._load_model() # выполняем инструкцию + self.msg_q.put("Обучение завершено.") # выполняем инструкцию + except Exception as e: # выполняем инструкцию + self.msg_q.put(f"Ошибка: {e}") # выполняем инструкцию + + threading.Thread(target=worker, daemon=True).start() # запускаем отдельный поток, чтобы GUI не зависал + + def on_recognize(self) -> None: # объявляем функцию/метод + if self.model is None: # выполняем инструкцию + messagebox.showinfo("Wait", "Модель загружается...") # показываем пользователю диалоговое окно + return # выполняем инструкцию + + try: # выполняем инструкцию + os.makedirs(self.artifacts_dir, exist_ok=True) # создаём папку, если она отсутствует + debug_path = os.path.join(self.artifacts_dir, "last_debug.png") # выполняем инструкцию + + # 1. Получаем батч тензоров (N, 1, 28, 28) + batch_tensors = preprocess_canvas_to_mnist_tensor( # выполняем инструкцию + self.buffer_img, # выполняем инструкцию + debug_save_path=debug_path, # выполняем инструкцию + ).to(self.device) # выполняем инструкцию + + # 2. Прогоняем сразу все цифры через модель + with torch.no_grad(): # отключаем градиенты для ускорения инференса + logits = self.model(batch_tensors) # (N, 10) + probs = torch.softmax(logits, dim=1) # (N, 10) + + # Получаем классы и уверенность + max_probs, preds = torch.max(probs, dim=1) # берём наиболее вероятный класс и его вероятность + + preds = preds.cpu().numpy() # переносим результаты на CPU и конвертируем в NumPy + max_probs = max_probs.cpu().numpy() # переносим результаты на CPU и конвертируем в NumPy + + # 3. Формируем итоговое число и отчет + result_str = "".join(map(str, preds)) # склеиваем предсказанные цифры в итоговую строку + self.pred_label.config(text=result_str) # создаём/настраиваем текстовый элемент интерфейса + + # Детальный отчет по каждой цифре + details = [] # выполняем инструкцию + for i, (p, conf) in enumerate(zip(preds, max_probs)): # выполняем инструкцию + details.append(f"Цифра #{i+1}: '{p}' (увер. {conf*100:.1f}%)") # выполняем инструкцию + + self.details_label.config(text="\n".join(details)) # создаём/настраиваем текстовый элемент интерфейса + self.status.config(text=f"Распознано чисел: {len(preds)}. Debug: {debug_path}") # создаём/настраиваем текстовый элемент интерфейса + + except ValueError: # выполняем инструкцию + self.status.config(text="Пустой холст.") # создаём/настраиваем текстовый элемент интерфейса + except Exception as e: # выполняем инструкцию + messagebox.showerror("Ошибка", str(e)) # показываем пользователю диалоговое окно + self.status.config(text="Ошибка.") # создаём/настраиваем текстовый элемент интерфейса + + +def main() -> None: # объявляем функцию/метод + app = App() # создаём экземпляр приложения + app.mainloop() # запускаем главный цикл Tkinter + +if __name__ == "__main__": # точка входа при запуске файла как скрипта + main() # выполняем инструкцию diff --git a/src/artifacts/A.java b/src/artifacts/A.java new file mode 100644 index 0000000..7eb47f3 --- /dev/null +++ b/src/artifacts/A.java @@ -0,0 +1,5 @@ +numpy +pillow +matplotlib +torch +torchvision \ No newline at end of file diff --git a/src/model.py b/src/model.py new file mode 100644 index 0000000..6e3c85d --- /dev/null +++ b/src/model.py @@ -0,0 +1,31 @@ +from __future__ import annotations # включаем отложенную обработку аннотаций типов + +import torch # импортируем модуль/пакет +import torch.nn as nn # импортируем модуль/пакет +import torch.nn.functional as F # импортируем модуль/пакет + + +class MNISTCNN(nn.Module): # объявляем класс приложения/модели + """ + Классическая CNN: 2 сверточных слоя, pooling, dropout и 2 полносвязных. + """ + def __init__(self) -> None: # объявляем функцию/метод + super().__init__() # инициализируем базовый класс + self.conv1 = nn.Conv2d(1, 32, 3, 1) # выполняем инструкцию + self.conv2 = nn.Conv2d(32, 64, 3, 1) # выполняем инструкцию + self.dropout1 = nn.Dropout(0.25) # выполняем инструкцию + self.dropout2 = nn.Dropout(0.50) # выполняем инструкцию + # После Conv1 (26x26), Conv2 (24x24) и MaxPool (12x12): + self.fc1 = nn.Linear(64 * 12 * 12, 128) # выполняем инструкцию + self.fc2 = nn.Linear(128, 10) # выполняем инструкцию + + def forward(self, x: torch.Tensor) -> torch.Tensor: # объявляем функцию/метод + x = F.relu(self.conv1(x)) # выполняем инструкцию + x = F.relu(self.conv2(x)) # выполняем инструкцию + x = F.max_pool2d(x, 2) # выполняем инструкцию + x = self.dropout1(x) # выполняем инструкцию + x = torch.flatten(x, 1) # выполняем инструкцию + x = F.relu(self.fc1(x)) # выполняем инструкцию + x = self.dropout2(x) # выполняем инструкцию + x = self.fc2(x) # выполняем инструкцию + return x # выполняем инструкцию diff --git a/src/preprocess.py b/src/preprocess.py new file mode 100644 index 0000000..e411066 --- /dev/null +++ b/src/preprocess.py @@ -0,0 +1,181 @@ +from __future__ import annotations # включаем отложенную обработку аннотаций типов + +import numpy as np # импортируем модуль/пакет +import torch # импортируем модуль/пакет +from PIL import Image, ImageFilter # импортируем объекты из модуля +import collections # импортируем модуль/пакет + +MNIST_MEAN = 0.1307 # выполняем инструкцию +MNIST_STD = 0.3081 # выполняем инструкцию + + +def _shift_image(img: Image.Image, dx: int, dy: int) -> Image.Image: # объявляем функцию/метод + """Сдвигает содержимое картинки на dx, dy, заполняя пустоты черным.""" + w, h = img.size # выполняем инструкцию + shifted = Image.new("L", (w, h), 0) # выполняем инструкцию + x_to, y_to = max(0, dx), max(0, dy) # выполняем инструкцию + x_from, y_from = max(0, -dx), max(0, -dy) # выполняем инструкцию + width, height = w - abs(dx), h - abs(dy) # выполняем инструкцию + + if width <= 0 or height <= 0: # выполняем инструкцию + return shifted # выполняем инструкцию + + crop = img.crop((x_from, y_from, x_from + width, y_from + height)) # выполняем инструкцию + shifted.paste(crop, (x_to, y_to)) # выполняем инструкцию + return shifted # выполняем инструкцию + + +def _center_of_mass_shift(img28: Image.Image) -> Image.Image: # объявляем функцию/метод + """Центрирует изображение 28x28 по центру масс (как в датасете MNIST).""" + arr = np.array(img28, dtype=np.float32) # выполняем инструкцию + total = float(arr.sum()) # выполняем инструкцию + if total <= 1e-6: # выполняем инструкцию + return img28 # выполняем инструкцию + + ys, xs = np.indices(arr.shape) # выполняем инструкцию + cy = float((ys * arr).sum() / total) # выполняем инструкцию + cx = float((xs * arr).sum() / total) # выполняем инструкцию + + # Целевой центр для 28x28 — (14, 14) + dx = int(round(14.0 - cx)) # выполняем инструкцию + dy = int(round(14.0 - cy)) # выполняем инструкцию + return _shift_image(img28, dx, dy) # выполняем инструкцию + + +def _find_connected_components(mask: np.ndarray) -> list[tuple[int, int, int, int]]: # объявляем функцию/метод + """ + Находит bbox (x1, y1, x2, y2) для каждого связного объекта на маске. + Использует BFS (поиск в ширину). + """ + h, w = mask.shape # выполняем инструкцию + visited = np.zeros_like(mask, dtype=bool) # выполняем инструкцию + bboxes = [] # выполняем инструкцию + + for y in range(h): # выполняем инструкцию + for x in range(w): # выполняем инструкцию + if mask[y, x] and not visited[y, x]: # выполняем инструкцию + q = collections.deque([(x, y)]) # выполняем инструкцию + visited[y, x] = True # выполняем инструкцию + min_x, max_x = x, x # выполняем инструкцию + min_y, max_y = y, y # выполняем инструкцию + count = 0 # выполняем инструкцию + + while q: # выполняем инструкцию + cx, cy = q.popleft() # выполняем инструкцию + count += 1 # выполняем инструкцию + + min_x = min(min_x, cx) # выполняем инструкцию + max_x = max(max_x, cx) # выполняем инструкцию + min_y = min(min_y, cy) # выполняем инструкцию + max_y = max(max_y, cy) # выполняем инструкцию + + # Соседи (4 стороны) + for nx, ny in [(cx+1, cy), (cx-1, cy), (cx, cy+1), (cx, cy-1)]: # выполняем инструкцию + if 0 <= nx < w and 0 <= ny < h: # выполняем инструкцию + if mask[ny, nx] and not visited[ny, nx]: # выполняем инструкцию + visited[ny, nx] = True # выполняем инструкцию + q.append((nx, ny)) # выполняем инструкцию + + # Фильтр совсем мелкого шума (точки) + if count > 10: # выполняем инструкцию + bboxes.append((min_x, min_y, max_x, max_y)) # выполняем инструкцию + + return bboxes # выполняем инструкцию + + +def _process_single_crop( # объявляем функцию/метод + crop_img: Image.Image, # выполняем инструкцию + debug_save_path: str | None = None # выполняем инструкцию +) -> torch.Tensor: # выполняем инструкцию + """Подготовка одного кусочка с цифрой для нейросети.""" + + # 1. Легкий блюр для сглаживания пикселизации + crop_img = crop_img.filter(ImageFilter.GaussianBlur(radius=0.6)) # выполняем инструкцию + + w, h = crop_img.size # выполняем инструкцию + + # 2. Ресайз: вписываем в квадрат 20x20, сохраняя пропорции + max_side = max(w, h) # выполняем инструкцию + scale = 20.0 / float(max_side) # выполняем инструкцию + new_w = max(1, int(round(w * scale))) # выполняем инструкцию + new_h = max(1, int(round(h * scale))) # выполняем инструкцию + resized = crop_img.resize((new_w, new_h), resample=Image.BILINEAR) # выполняем инструкцию + + # 3. Вставляем в центр черного квадрата 28x28 + img28 = Image.new("L", (28, 28), 0) # выполняем инструкцию + left = (28 - new_w) // 2 # выполняем инструкцию + top = (28 - new_h) // 2 # выполняем инструкцию + img28.paste(resized, (left, top)) # выполняем инструкцию + + # 4. Центрирование по массе (Критически важно!) + img28 = _center_of_mass_shift(img28) # выполняем инструкцию + + if debug_save_path: # выполняем инструкцию + img28.save(debug_save_path) # выполняем инструкцию + + # 5. Нормализация + x = np.array(img28, dtype=np.float32) / 255.0 # выполняем инструкцию + x = (x - MNIST_MEAN) / MNIST_STD # выполняем инструкцию + return torch.from_numpy(x).unsqueeze(0) # (1, 28, 28) + + +def preprocess_canvas_to_mnist_tensor( # объявляем функцию/метод + img_l: Image.Image, # выполняем инструкцию + *, # выполняем инструкцию + debug_save_path: str | None = None, # выполняем инструкцию +) -> torch.Tensor: # выполняем инструкцию + """ + Основная функция: Холст -> Батч тензоров. + """ + if img_l.mode != "L": # выполняем инструкцию + img_l = img_l.convert("L") # выполняем инструкцию + + arr = np.array(img_l, dtype=np.uint8) # выполняем инструкцию + # Инверсия: (белый фон -> черный фон) + inv = 255 - arr # выполняем инструкцию + + # Бинаризация для поиска объектов + mask = inv > 20 # выполняем инструкцию + + bboxes = _find_connected_components(mask) # выполняем инструкцию + + if not bboxes: # выполняем инструкцию + raise ValueError("Холст пуст.") # выполняем инструкцию + + # --- УМНАЯ СОРТИРОВКА --- + # Определяем габариты всего написанного текста + min_x_total = min(b[0] for b in bboxes) # выполняем инструкцию + max_x_total = max(b[2] for b in bboxes) # выполняем инструкцию + min_y_total = min(b[1] for b in bboxes) # выполняем инструкцию + max_y_total = max(b[3] for b in bboxes) # выполняем инструкцию + + total_w = max_x_total - min_x_total # выполняем инструкцию + total_h = max_y_total - min_y_total # выполняем инструкцию + + # Если высота текста значительно больше ширины (коэффициент 1.2), + # считаем, что это "столбик" -> сортируем по Y. + # Иначе (строка или лесенка) -> сортируем по X. + if total_h > total_w * 1.2: # выполняем инструкцию + bboxes.sort(key=lambda b: b[1]) # Сортировка сверху вниз + else: # выполняем инструкцию + bboxes.sort(key=lambda b: b[0]) # Сортировка слева направо + # ------------------------ + + tensors = [] # выполняем инструкцию + pad = 12 # Отступ при вырезании цифры + + for i, (x1, y1, x2, y2) in enumerate(bboxes): # выполняем инструкцию + x1 = max(0, x1 - pad) # выполняем инструкцию + y1 = max(0, y1 - pad) # выполняем инструкцию + x2 = min(inv.shape[1] - 1, x2 + pad) # выполняем инструкцию + y2 = min(inv.shape[0] - 1, y2 + pad) # выполняем инструкцию + + digit_crop = Image.fromarray(inv[y1 : y2+1, x1 : x2+1]) # выполняем инструкцию + + # Сохраняем для дебага только первую цифру (чтобы не спамить файлами) + path = debug_save_path if i == 0 else None # выполняем инструкцию + + t = _process_single_crop(digit_crop, debug_save_path=path) # выполняем инструкцию + tensors.append(t) # выполняем инструкцию + + return torch.stack(tensors) # выполняем инструкцию diff --git a/src/requirements.txt b/src/requirements.txt new file mode 100644 index 0000000..07ac397 --- /dev/null +++ b/src/requirements.txt @@ -0,0 +1,11 @@ +# Основные зависимости проекта +numpy +pillow +matplotlib + +# Машинное обучение +torch +torchvision + +# GUI (входит в стандартную библиотеку Python) +# tkinter \ No newline at end of file diff --git a/src/test_cuda.py b/src/test_cuda.py new file mode 100644 index 0000000..0d85c43 --- /dev/null +++ b/src/test_cuda.py @@ -0,0 +1,8 @@ +import torch + +print("torch:", torch.__version__) +print("cuda available:", torch.cuda.is_available()) +print("cuda version:", torch.version.cuda) +print("device count:", torch.cuda.device_count()) +if torch.cuda.is_available(): + print("device name:", torch.cuda.get_device_name(0)) \ No newline at end of file diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..1558c32 --- /dev/null +++ b/src/train.py @@ -0,0 +1,141 @@ +from __future__ import annotations # включаем отложенную обработку аннотаций типов + +import os # импортируем модуль/пакет +from dataclasses import dataclass # импортируем объекты из модуля +from typing import Callable, Optional # импортируем объекты из модуля + +import torch # импортируем модуль/пакет +import torch.nn as nn # импортируем модуль/пакет +from torch.utils.data import DataLoader # импортируем объекты из модуля +from torchvision import datasets, transforms # импортируем объекты из модуля + +from model import MNISTCNN # импортируем объекты из модуля +from utils import setup_logger, save_metrics, save_training_plot, get_device # импортируем объекты из модуля + + +ProgressCB = Callable[[str], None] # выполняем инструкцию + + +@dataclass # выполняем инструкцию +class TrainConfig: # объявляем класс приложения/модели + epochs: int = 5 # выполняем инструкцию + batch_size: int = 64 # выполняем инструкцию + lr: float = 1e-3 # выполняем инструкцию + device: str = "cpu" # выполняем инструкцию + artifacts_dir: str = "artifacts" # выполняем инструкцию + + +def _get_loaders(batch_size: int) -> tuple[DataLoader, DataLoader]: # объявляем функцию/метод + # --- АГРЕССИВНАЯ АУГМЕНТАЦИЯ --- + # Это учит модель понимать кривой почерк, наклон и повороты. + train_tfm = transforms.Compose([ # выполняем инструкцию + transforms.RandomAffine( # выполняем инструкцию + degrees=20, # Поворот +/- 20 градусов + translate=(0.15, 0.15), # Сдвиг картинки на 15% + scale=(0.85, 1.15), # Масштаб (толстые/тонкие цифры) + shear=15 # Наклон (курсив) до 15 градусов + ), # выполняем инструкцию + transforms.ToTensor(), # выполняем инструкцию + transforms.Normalize((0.1307,), (0.3081,)), # выполняем инструкцию + ]) # выполняем инструкцию + + test_tfm = transforms.Compose([ # выполняем инструкцию + transforms.ToTensor(), # выполняем инструкцию + transforms.Normalize((0.1307,), (0.3081,)), # выполняем инструкцию + ]) # выполняем инструкцию + + # Загрузка датасета (автоматически скачает в папку ./data) + train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=train_tfm) # выполняем инструкцию + test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=test_tfm) # выполняем инструкцию + + train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0) # выполняем инструкцию + test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0) # выполняем инструкцию + return train_loader, test_loader # выполняем инструкцию + + +@torch.no_grad() # выполняем инструкцию +def evaluate(model: nn.Module, loader: DataLoader, device: str) -> float: # объявляем функцию/метод + model.eval() # переводим модель в режим инференса + correct, total = 0, 0 # выполняем инструкцию + for x, y in loader: # выполняем инструкцию + x, y = x.to(device), y.to(device) # выполняем инструкцию + logits = model(x) # выполняем инструкцию + pred = logits.argmax(dim=1) # выполняем инструкцию + correct += int((pred == y).sum().item()) # выполняем инструкцию + total += int(y.numel()) # выполняем инструкцию + return correct / max(total, 1) # выполняем инструкцию + + +def train(cfg: TrainConfig, progress_cb: Optional[ProgressCB] = None) -> dict: # объявляем функцию/метод + os.makedirs(cfg.artifacts_dir, exist_ok=True) # создаём папку, если она отсутствует + + log_path = os.path.join(cfg.artifacts_dir, "train.log") # выполняем инструкцию + logger = setup_logger(log_path) # выполняем инструкцию + + train_loader, test_loader = _get_loaders(cfg.batch_size) # выполняем инструкцию + + model = MNISTCNN().to(cfg.device) # выполняем инструкцию + optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr) # выполняем инструкцию + criterion = nn.CrossEntropyLoss() # выполняем инструкцию + + epochs_list: list[int] = [] # выполняем инструкцию + losses: list[float] = [] # выполняем инструкцию + accs: list[float] = [] # выполняем инструкцию + + def notify(msg: str) -> None: # объявляем функцию/метод + logger.info(msg) # выполняем инструкцию + if progress_cb: # выполняем инструкцию + progress_cb(msg) # выполняем инструкцию + + notify(f"Старт обучения. Device={cfg.device}. Augmentation=ON (Rotate, Shear, Scale).") # выполняем инструкцию + + for epoch in range(1, cfg.epochs + 1): # выполняем инструкцию + model.train() # выполняем инструкцию + running_loss = 0.0 # выполняем инструкцию + seen = 0 # выполняем инструкцию + + for i, (x, y) in enumerate(train_loader, start=1): # выполняем инструкцию + x, y = x.to(cfg.device), y.to(cfg.device) # выполняем инструкцию + + optimizer.zero_grad() # выполняем инструкцию + logits = model(x) # выполняем инструкцию + loss = criterion(logits, y) # выполняем инструкцию + loss.backward() # выполняем инструкцию + optimizer.step() # выполняем инструкцию + + bs = int(y.size(0)) # выполняем инструкцию + running_loss += float(loss.item()) * bs # выполняем инструкцию + seen += bs # выполняем инструкцию + + if i % 300 == 0: # выполняем инструкцию + notify(f"Epoch {epoch}/{cfg.epochs} | step={i} | loss={running_loss/max(seen,1):.4f}") # выполняем инструкцию + + epoch_loss = running_loss / max(seen, 1) # выполняем инструкцию + acc = evaluate(model, test_loader, cfg.device) # выполняем инструкцию + + epochs_list.append(epoch) # выполняем инструкцию + losses.append(float(epoch_loss)) # выполняем инструкцию + accs.append(float(acc)) # выполняем инструкцию + + notify(f"Epoch {epoch} finished. Test Acc: {acc*100:.2f}%") # выполняем инструкцию + + model_path = os.path.join(cfg.artifacts_dir, "mnist_cnn.pt") # выполняем инструкцию + torch.save(model.state_dict(), model_path) # выполняем инструкцию + + metrics_path = os.path.join(cfg.artifacts_dir, "metrics.json") # выполняем инструкцию + plot_path = os.path.join(cfg.artifacts_dir, "training_plot.png") # выполняем инструкцию + + save_metrics({"epochs": epochs_list, "acc": accs, "loss": losses}, metrics_path) # выполняем инструкцию + save_training_plot(epochs_list, losses, accs, plot_path) # выполняем инструкцию + + notify(f"Обучение завершено. Точность: {accs[-1]*100:.2f}%") # выполняем инструкцию + return {"accuracy": accs[-1]} # выполняем инструкцию + + +if __name__ == "__main__": # точка входа при запуске файла как скрипта + # Для ручного запуска + cfg = TrainConfig( # формируем конфигурацию обучения + epochs=5, # выполняем инструкцию + device="cuda" if get_device() == "cuda" else "cpu", # выполняем инструкцию + ) # выполняем инструкцию + train(cfg) # запускаем обучение модели diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..3e3e0d0 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,51 @@ +from __future__ import annotations # включаем отложенную обработку аннотаций типов + +import json # импортируем модуль/пакет +import logging # импортируем модуль/пакет +import os # импортируем модуль/пакет +from typing import Any # импортируем объекты из модуля +import matplotlib.pyplot as plt # импортируем модуль/пакет + +def get_device() -> str: # объявляем функцию/метод + try: # выполняем инструкцию + import torch # импортируем модуль/пакет + return "cuda" if torch.cuda.is_available() else "cpu" # выполняем инструкцию + except Exception: # выполняем инструкцию + return "cpu" # выполняем инструкцию + + +def setup_logger(log_path: str) -> logging.Logger: # объявляем функцию/метод + os.makedirs(os.path.dirname(log_path), exist_ok=True) # создаём папку, если она отсутствует + logger = logging.getLogger("train_logger") # выполняем инструкцию + logger.setLevel(logging.INFO) # выполняем инструкцию + logger.handlers.clear() # выполняем инструкцию + + fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") # выполняем инструкцию + + fh = logging.FileHandler(log_path, encoding="utf-8") # выполняем инструкцию + fh.setFormatter(fmt) # выполняем инструкцию + logger.addHandler(fh) # выполняем инструкцию + + sh = logging.StreamHandler() # выполняем инструкцию + sh.setFormatter(fmt) # выполняем инструкцию + logger.addHandler(sh) # выполняем инструкцию + return logger # выполняем инструкцию + + +def save_metrics(metrics: dict[str, Any], path: str) -> None: # объявляем функцию/метод + os.makedirs(os.path.dirname(path), exist_ok=True) # создаём папку, если она отсутствует + with open(path, "w", encoding="utf-8") as f: # выполняем инструкцию + json.dump(metrics, f, ensure_ascii=False, indent=2) # выполняем инструкцию + + +def save_training_plot(epochs: list[int], losses: list[float], accs: list[float], path: str) -> None: # объявляем функцию/метод + os.makedirs(os.path.dirname(path), exist_ok=True) # создаём папку, если она отсутствует + plt.figure() # выполняем инструкцию + plt.plot(epochs, losses, label="Loss") # выполняем инструкцию + plt.plot(epochs, accs, label="Accuracy") # выполняем инструкцию + plt.xlabel("Epoch") # выполняем инструкцию + plt.legend() # выполняем инструкцию + plt.grid(True, alpha=0.3) # выполняем инструкцию + plt.tight_layout() # выполняем инструкцию + plt.savefig(path, dpi=120) # выполняем инструкцию + plt.close() # выполняем инструкцию