Создание кода
This commit is contained in:
51
src/utils.py
Normal file
51
src/utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from __future__ import annotations # включаем отложенную обработку аннотаций типов
|
||||
|
||||
import json # импортируем модуль/пакет
|
||||
import logging # импортируем модуль/пакет
|
||||
import os # импортируем модуль/пакет
|
||||
from typing import Any # импортируем объекты из модуля
|
||||
import matplotlib.pyplot as plt # импортируем модуль/пакет
|
||||
|
||||
def get_device() -> str: # объявляем функцию/метод
|
||||
try: # выполняем инструкцию
|
||||
import torch # импортируем модуль/пакет
|
||||
return "cuda" if torch.cuda.is_available() else "cpu" # выполняем инструкцию
|
||||
except Exception: # выполняем инструкцию
|
||||
return "cpu" # выполняем инструкцию
|
||||
|
||||
|
||||
def setup_logger(log_path: str) -> logging.Logger: # объявляем функцию/метод
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True) # создаём папку, если она отсутствует
|
||||
logger = logging.getLogger("train_logger") # выполняем инструкцию
|
||||
logger.setLevel(logging.INFO) # выполняем инструкцию
|
||||
logger.handlers.clear() # выполняем инструкцию
|
||||
|
||||
fmt = logging.Formatter("%(asctime)s | %(levelname)s | %(message)s") # выполняем инструкцию
|
||||
|
||||
fh = logging.FileHandler(log_path, encoding="utf-8") # выполняем инструкцию
|
||||
fh.setFormatter(fmt) # выполняем инструкцию
|
||||
logger.addHandler(fh) # выполняем инструкцию
|
||||
|
||||
sh = logging.StreamHandler() # выполняем инструкцию
|
||||
sh.setFormatter(fmt) # выполняем инструкцию
|
||||
logger.addHandler(sh) # выполняем инструкцию
|
||||
return logger # выполняем инструкцию
|
||||
|
||||
|
||||
def save_metrics(metrics: dict[str, Any], path: str) -> None: # объявляем функцию/метод
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True) # создаём папку, если она отсутствует
|
||||
with open(path, "w", encoding="utf-8") as f: # выполняем инструкцию
|
||||
json.dump(metrics, f, ensure_ascii=False, indent=2) # выполняем инструкцию
|
||||
|
||||
|
||||
def save_training_plot(epochs: list[int], losses: list[float], accs: list[float], path: str) -> None: # объявляем функцию/метод
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True) # создаём папку, если она отсутствует
|
||||
plt.figure() # выполняем инструкцию
|
||||
plt.plot(epochs, losses, label="Loss") # выполняем инструкцию
|
||||
plt.plot(epochs, accs, label="Accuracy") # выполняем инструкцию
|
||||
plt.xlabel("Epoch") # выполняем инструкцию
|
||||
plt.legend() # выполняем инструкцию
|
||||
plt.grid(True, alpha=0.3) # выполняем инструкцию
|
||||
plt.tight_layout() # выполняем инструкцию
|
||||
plt.savefig(path, dpi=120) # выполняем инструкцию
|
||||
plt.close() # выполняем инструкцию
|
||||
Reference in New Issue
Block a user