52 lines
3.3 KiB
Python
52 lines
3.3 KiB
Python
from __future__ import annotations # включаем отложенную обработку аннотаций типов
|
|
|
|
import json # импортируем модуль/пакет
|
|
import logging # импортируем модуль/пакет
|
|
import os # импортируем модуль/пакет
|
|
from typing import Any # импортируем объекты из модуля
|
|
import matplotlib.pyplot as plt # импортируем модуль/пакет
|
|
|
|
def get_device() -> str: # объявляем функцию/метод
|
|
try: # выполняем инструкцию
|
|
import torch # импортируем модуль/пакет
|
|
return "cuda" if torch.cuda.is_available() else "cpu" # выполняем инструкцию
|
|
except Exception: # выполняем инструкцию
|
|
return "cpu" # выполняем инструкцию
|
|
|
|
|
|
def setup_logger(log_path: str) -> logging.Logger: # объявляем функцию/метод
|
|
os.makedirs(os.path.dirname(log_path), exist_ok=True) # создаём папку, если она отсутствует
|
|
logger = logging.getLogger("train_logger") # выполняем инструкцию
|
|
logger.setLevel(logging.INFO) # выполняем инструкцию
|
|
logger.handlers.clear() # выполняем инструкцию
|
|
|
|
fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") # выполняем инструкцию
|
|
|
|
fh = logging.FileHandler(log_path, encoding="utf-8") # выполняем инструкцию
|
|
fh.setFormatter(fmt) # выполняем инструкцию
|
|
logger.addHandler(fh) # выполняем инструкцию
|
|
|
|
sh = logging.StreamHandler() # выполняем инструкцию
|
|
sh.setFormatter(fmt) # выполняем инструкцию
|
|
logger.addHandler(sh) # выполняем инструкцию
|
|
return logger # выполняем инструкцию
|
|
|
|
|
|
def save_metrics(metrics: dict[str, Any], path: str) -> None: # объявляем функцию/метод
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) # создаём папку, если она отсутствует
|
|
with open(path, "w", encoding="utf-8") as f: # выполняем инструкцию
|
|
json.dump(metrics, f, ensure_ascii=False, indent=2) # выполняем инструкцию
|
|
|
|
|
|
def save_training_plot(epochs: list[int], losses: list[float], accs: list[float], path: str) -> None: # объявляем функцию/метод
|
|
os.makedirs(os.path.dirname(path), exist_ok=True) # создаём папку, если она отсутствует
|
|
plt.figure() # выполняем инструкцию
|
|
plt.plot(epochs, losses, label="Loss") # выполняем инструкцию
|
|
plt.plot(epochs, accs, label="Accuracy") # выполняем инструкцию
|
|
plt.xlabel("Epoch") # выполняем инструкцию
|
|
plt.legend() # выполняем инструкцию
|
|
plt.grid(True, alpha=0.3) # выполняем инструкцию
|
|
plt.tight_layout() # выполняем инструкцию
|
|
plt.savefig(path, dpi=120) # выполняем инструкцию
|
|
plt.close() # выполняем инструкцию
|