Экологичный AI обучаем нейросеть сортировать мусор

Экологичный ИИ: обучаем нейросеть классифицировать мусор

Объединяем заботу об окружающей среде и машинное обучение. Разрабатываем систему, способную в реальном времени определять тип отходов с помощью компьютерного зрения.

Система состоит из двух основных компонентов:

  1. Обучение модели для классификации отходов
  2. Распознавания отходов в реальном времени с использованием камеры

Установка библиотек

Установим torch, torchvision, pilliow

Bash
pip install torch, torchvision, pilliow

Скачаем датасет

Скачаем изображения мусора (dataset-resized.zip) из репозитория garythung/trashnet и разархивируем.

Структура директорий

Типичная структура датасета для задачи классификации изображений выглядит следующим образом

Plaintext
path_to_your_dataset/
├── train/
│   ├── plastic/
│   ├── paper/
│   ├── metal/
│   ├── glass/
│   ├── organic/
│   └── other/
├── val/
│   ├── plastic/
│   ├── paper/
│   ├── metal/
│   ├── glass/
│   ├── organic/
│   └── other/
└── test/
    ├── plastic/
    ├── paper/
    ├── metal/
    ├── glass/
    ├── organic/
    └── other/
  1. Содержимое директорий:
    • Каждая поддиректория (plastic, paper, metal и т. д.) содержит изображения соответствующего типа отходов.
    • Изображения обычно представлены в форматах JPEG или PNG.
    • Имена файлов могут быть произвольными, но часто используют последовательную нумерацию, например: plastic_001.jpg, plastic_002.jpg и т. д.
  2. Разделение данных:
    • train/: Содержит основную часть данных (обычно 70-80% от общего количества), используемую для обучения модели.
    • val/: Валидационный набор (обычно 10-15%), используется для оценки модели во время обучения и настройки гиперпараметров.
    • test/: Тестовый набор (обычно 10-15%), используется для финальной оценки модели после завершения обучения.

Python-скрипт, который создаст нужную структуру директорий из исходного датасета и распределит изображения по соответствующим папкам. Скрипт будет использовать 80% изображений для обучающего набора (train), 10% для валидационного набора (val) и 10% для тестового набора (test).

Python
import os
import shutil
import random


def create_dataset_structure(base_path, categories):
    # Создаем основные директории
    for split in ['train', 'val', 'test']:
        for category in categories:
            os.makedirs(os.path.join(base_path, split, category), exist_ok=True)


def distribute_images(source_dir, dest_dir, categories):
    for category in categories:
        # Получаем список всех изображений в исходной директории категории
        source_category_dir = os.path.join(source_dir, category)
        images = [f for f in os.listdir(source_category_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

        # Перемешиваем список изображений
        random.shuffle(images)

        # Вычисляем количество изображений для каждого набора
        total_images = len(images)
        train_count = int(total_images * 0.8)
        val_count = int(total_images * 0.1)

        # Распределяем изображения по наборам
        for i, image in enumerate(images):
            source_path = os.path.join(source_category_dir, image)
            if i < train_count:
                dest_path = os.path.join(dest_dir, 'train', category, image)
            elif i < train_count + val_count:
                dest_path = os.path.join(dest_dir, 'val', category, image)
            else:
                dest_path = os.path.join(dest_dir, 'test', category, image)

            shutil.copy2(source_path, dest_path)

        print(f"Категория {category}: распределено {total_images} изображений")


# Основные параметры
source_directory = r'/путь/к/исходным/изображениям'
destination_directory = r'/путь/к/новому/датасету'
categories = ['plastic', 'paper', 'metal', 'glass', 'cardboard', 'trash']

# Создаем структуру директорий
create_dataset_structure(destination_directory, categories)

# Распределяем изображения
distribute_images(source_directory, destination_directory, categories)

print("Датасет успешно создан!")

Обучение модели

Настройка устройства

Этот код проверяет, доступен ли GPU (CUDA), и если да, то использует его. В противном случае используется CPU.

Python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Загрузка предобученной модели

Загружается предобученная модель ResNet50 из torchvision.

Python
model = models.resnet50(pretrained=True)

Модификация модели

ResNet50, предварительно обученная на ImageNet, изначально настроена на классификацию 1000 различных классов объектов. Однако в нашей задаче классификации отходов у нас всего 6 классов. Поэтому нам нужно адаптировать модель под нашу конкретную задачу. Последний полносвязный слой модели заменяется новым с 6 выходами (по числу классов отходов).

Python
num_classes = 6
model.fc = nn.Linear(model.fc.in_features, num_classes)

Перемещение модели на выбранное устройство

Модель перемещается на GPU, если он доступен, или остается на CPU.

Python
model = model.to(device)

Определение преобразований для входных изображений

Задаются преобразования, которые будут применяться к каждому изображению перед подачей в модель.

Python
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
  1. transforms.Compose([...]):
    • Это функция, которая объединяет несколько преобразований в одну последовательность. Каждое преобразование будет применяться к изображению по порядку.
  2. transforms.Resize(256):
    • Изменяет размер входного изображения так, чтобы его меньшая сторона была равна 256 пикселей, сохраняя при этом соотношение сторон.
    • Это нужно для стандартизации размера входных данных, так как изображения могут иметь разные размеры.
  3. transforms.CenterCrop(224):
    • Вырезает центральную часть изображения размером 224×224 пикселя.
    • Это обеспечивает, что все изображения будут иметь одинаковый размер и центрирование.
  4. transforms.ToTensor():
    • Преобразует изображение из формата PIL (Python Imaging Library) или numpy.ndarray в тензор PyTorch.
    • Также нормализует значения пикселей из диапазона [0, 255] в диапазон [0.0, 1.0].
      • Почему нужна нормализация:
        • Нейронные сети обычно работают лучше с числами в меньшем диапазоне, близком к 0.
        • Использование чисел с плавающей точкой вместо целых позволяет более точно представлять градации цвета.
    • transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
      • Нормализует тензор изображения, используя заданные средние значения и стандартные отклонения для каждого канала (R, G, B).
      • Формула нормализации: output = (input - mean) / std
      • Значения mean и std здесь соответствуют статистике набора данных ImageNet, на котором была предварительно обучена модель ResNet50.
      • Нормализация помогает привести данные к диапазону, на котором обучалась модель, что улучшает сходимость и производительность.

    Почему эти преобразования важны:

    1. Стандартизация размера: Сети с фиксированной архитектурой, как ResNet50, ожидают входные данные определенного размера.
    2. Центрирование: Помогает фокусироваться на наиболее важной части изображения.
    3. Преобразование в тензор: Необходимо для работы с PyTorch.
    4. Нормализация: Приводит данные к диапазону, на котором модель была обучена, что улучшает обобщающую способность и скорость обучения.

    Подготовка данных

    Создается набор данных из папки с изображениями и загрузчик данных для эффективной подачи батчей в модель.

    Python
    train_dataset = ImageFolder(data_dir + '/train', transform=preprocess)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    Настройка функции потерь и оптимизатора

    Python
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    Обучение модели

    Основной цикл обучения. Для каждой эпохи модель проходит через все батчи данных, вычисляет потери, выполняет обратное распространение ошибки и обновляет веса.

    Python
    for epoch in range(num_epochs):
        model.train()
        for i, (inputs, labels) in enumerate(train_loader):
            ...

    Сохранение обученной модели

    После завершения обучения модель сохраняется в файл.

    Python
    torch.save(model.state_dict(), 'waste_classification_model.pth')

    Распознавание мусора в реальном времени

    После обучения модели мы создаем систему для распознавания мусора в реальном времени с использованием веб-камеры.

    Основные шаги

    1. Загрузка обученной модели.
    2. Инициализация камеры.
    3. Непрерывное считывание кадров с камеры.
    4. Предобработка каждого кадра.
    5. Получение предсказания от модели.
    6. Отображение результата на кадре.

    Загрузка модели

    Загружается предобученная модель ResNet50, модифицируется последний слой для 6 классов, загружаются веса обученной модели, и модель переводится в режим оценки.

    Python
    model = models.resnet50()
    num_classes = 6
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load('waste_classification_model.pth'))
    model = model.to(device)
    model.eval()

    Настройка предобработки изображений

    Python
    preprocess = transforms.Compose([...])

    Инициализация камеры

    Python
    cap = cv2.VideoCapture(0)

    Предобработка кадра

    Python
    input_tensor = preprocess(frame)
    input_batch = input_tensor.unsqueeze(0).to(device)
    

    Получение предсказания

    Python
    with torch.no_grad():
        output = model(input_batch)

    Определение класса с наибольшей вероятностью

    Python
    _, predicted = torch.max(output, 1)
    label = classes[predicted.item()]

    Отображение результата

    Python
    cv2.putText(frame, f"Class: {label}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    cv2.imshow('Waste Classification', frame)

    Результат работы

    Металл

    Результат распознавания металла

    Пластик

    Результат распознавания пластика

    Другой мусор (см. датасет — trash)

    Результат распознавания другого мусора

    Стекло

    Результат распознавания стекла

    Полный код

    Полный код доступен в репозитории на Гитхабе.