这些小活动你都参加了吗?快来围观一下吧!>>
电子产品世界 » 论坛首页 » 活动中心 » 板卡试用 » 【MAX78000FEATHER】板卡试用——数据训练

共4条 1/1 1 跳转至

【MAX78000FEATHER】板卡试用——数据训练

助工
2025-07-22 14:20:23     打赏

搭建好训练环境后,就需要训练自己的数据啦。在下载到的数据中将 B、S、V三个字母的图片文件夹保留,其余均删除。每个文件夹下都有3000张图片。使用0.85:0.15的比例,随机分配,分出训练集和测试集。自己的项目是一个分类问题,所以参考官方提供的例子,猫狗分类的例子,来写一个自己的手势训练的项目。

image.png1、在ai8x-training训练需要3个文件,训练模型,这里直接使用的是ai85cdnet;载入训练、测试数据的脚本文件 ai8x-training\datasets\gesture.py 这个文件指明了,训练数据和测试数据的来源、对测试数据的变换以及输出内容的数量。policies/qat_policy_cd.yaml这个文件没有搞清楚具体是做什么用的,直接复用猫狗分类模型的文件。

python train.py --epochs 200 --optimizer Adam --lr 0.001 --wd 0 --deterministic --compress policies/schedule-gesture.yaml --qat-policy policies/qat_policy_cd.yaml --model ai85cdnet --dataset gesture --confusion --param-hist --embedding --device MAX78000 "$@"


###################################################################################################
#
# Copyright (C) 2023 Analog Devices, Inc. All Rights Reserved.
# This software is proprietary to Analog Devices, Inc. and its licensors.
#
###################################################################################################
#
# Copyright (C) 2022 Maxim Integrated Products, Inc. (now owned by Analog Devices Inc.)
# All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
"""
剪刀V 石头S 布B Datasets
"""
import os
import sys
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations as album
import cv2
import ai8x
class Gesture(Dataset):
    """
    `Cats vs Dogs dataset <https://www.kaggle.com/datasets/salader/dogs-vs-cats>` Dataset.
    Args:
    root_dir (string): Root directory of dataset where ``KWS/processed/dataset.pt``
        exist.
    d_type(string): Option for the created dataset. ``train`` or ``test``.
    transform (callable, optional): A function/transform that takes in an PIL image
        and returns a transformed version.
    resize_size(int, int): Width and height of the images to be resized for the dataset.
    augment_data(bool): Flag to augment the data or not. If d_type is `test`, augmentation is
        disabled.
    """
    labels = ['B', 'S','V']
    label_to_id_map = {k: v for v, k in enumerate(labels)}
    label_to_folder_map = {'B': 'B', 'S': 'S','V':'V'}
    def __init__(self, root_dir, d_type, transform=None,
                 resize_size=(128, 128), augment_data=False):
        self.root_dir = root_dir
        self.data_dir = os.path.join(root_dir, 'gesture', d_type)
        if not self.__check_gesture_data_exist():
            self.__print_download_manual()
            sys.exit("Dataset not found!")
        self.__get_image_paths()
        self.album_transform = None
        if d_type == 'train' and augment_data:
            self.album_transform = album.Compose([
                album.GaussNoise(var_limit=(1.0, 20.0), p=0.25),
                album.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                album.ColorJitter(p=0.5),
                album.SmallestMaxSize(max_size=int(1.2*min(resize_size))),
                album.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                album.RandomCrop(height=resize_size[0], width=resize_size[1]),
                album.HorizontalFlip(p=0.5),
                album.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0))])
        if not augment_data or d_type == 'test':
            self.album_transform = album.Compose([
                album.SmallestMaxSize(max_size=int(1.2*min(resize_size))),
                album.CenterCrop(height=resize_size[0], width=resize_size[1]),
                album.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0))])
        self.transform = transform
    def __check_gesture_data_exist(self):
        return os.path.isdir(self.data_dir)
    def __print_download_manual(self):
        print("******************************************")
        print("Please follow the instructions below:")
        print("Download the dataset to the \'data\' folder by visiting this link: "
              "\'https://www.kaggle.com/datasets/salader/dogs-vs-cats\'")
        print("If you do not have a Kaggle account, sign up first.")
        print("Unzip the downloaded file and find \'test\' and \'train\' folders "
              "and copy them into \'data/cats_vs_dogs\'. ")
        print("Make sure that images are in the following directory structure:")
        print("  \'data/cats_vs_dogs/train/cats\'")
        print("  \'data/cats_vs_dogs/train/dogs\'")
        print("  \'data/cats_vs_dogs/test/cats\'")
        print("  \'data/cats_vs_dogs/test/dogs\'")
        print("Re-run the script. The script will create an \'augmented\' folder ")
        print("with all the original and augmented images. Remove this folder if you want "
              "to change the augmentation and to recreate the dataset.")
        print("******************************************")
    def __get_image_paths(self):
        self.data_list = []
        for label in self.labels:
            image_dir = os.path.join(self.data_dir, self.label_to_folder_map[label])
            for file_name in sorted(os.listdir(image_dir)):
                file_path = os.path.join(image_dir, file_name)
                if os.path.isfile(file_path):
                    self.data_list.append((file_path, self.label_to_id_map[label]))
    def __len__(self):
        return len(self.data_list)
    def __getitem__(self, index):
        label = torch.tensor(self.data_list[index][1], dtype=torch.int64)
        image_path = self.data_list[index][0]
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.album_transform:
            image = self.album_transform(image=image)["image"]
        if self.transform:
            image = self.transform(image)
        return image, label
def get_gesture_dataset(data, load_train, load_test):
    (data_dir, args) = data
    transform = transforms.Compose([
        transforms.ToTensor(),
        ai8x.normalize(args=args),
    ])
    if load_train:
        train_dataset = Gesture(root_dir=data_dir, d_type='train',
                                   transform=transform, augment_data=True)
    else:
        train_dataset = None
    if load_test:
        test_dataset = Gesture(root_dir=data_dir, d_type='test', transform=transform)
    else:
        test_dataset = None
    return train_dataset, test_dataset
datasets = [
    {
        'name': 'gesture',
        'input': (3, 128, 128),
        'output': ('B', 'S','V'),
        'loader': get_gesture_dataset,
    },
]

image.pngimage.png

经过超长时间的训练(6小时),终于训练完成啦!该步骤会生成训练结果文件qat_best.pth.tar。

2、模型转换。在ai8x-synthesis 下执行以下命令,由qat_best.pth.tar生成gesture-q.pth.tar。

python quantize.py ../ai8x-training/logs/2025.07.21-152827/qat_best.pth.tar ../ai8x-training/logs/2025.07.21-152827/gesture-q.pth.tar  --device MAX78000 -v

image.png

3、模型评估。在 ai8x-training下执行以下命令。

python train.py --model ai85cdnet --dataset gesture --confusion --evaluate --exp-load-weights-from ./logs/2025.07.21-152827/gesture-q.pth.tar -8 --device MAX78000

image.png4、生成测试样本。在 ai8x-training下执行以下命令。这次命令会产生一个sample_gesture.npy文件,需要把这个文件拷贝到ai8x-synthesis\tests下,下一步操作会用到这个文件。


python train.py --model ai85cdnet --save-sample 10 --dataset gesture --evaluate --exp-load-weights-from ./logs/2025.07.21-152827/gesture-q.pth.tar -8 --device MAX78000 --data data --use-bias

image.png5、生成MAX78000可用的工程。在ai8x-synthesis 下执行以下命令。留意命令中“--fifo”参数,catsdogs例程中的脚本是没有这个参数的,但是我在实际跑的过程中,发现如果不带这个参数,会报fifo错误,无法生成工程。networks/gesture-hwc.yaml文件是照搬cats-dogs-hwc.yaml文件的,简单修改了一下dataset

python ai8xize.py --verbose --test-dir "sdk/Examples/MAX78000/CNN" --prefix gesture --checkpoint-file ../ai8x-training/logs/2025.07.21-152827/gesture-q.pth.tar --config-file networks/gesture-hwc.yaml --fifo  --device MAX78000 --compact-data --mexpress --softmax --overwrite

image.png

至此数据训练完成。接下来就是单片机上的编程了。


专家
2025-07-22 21:22:12     打赏
2楼

感谢分享


专家
2025-07-22 21:23:19     打赏
3楼

感谢分享


专家
2025-07-22 21:25:04     打赏
4楼

感谢分享


共4条 1/1 1 跳转至

回复

匿名不能发帖!请先 [ 登陆 注册 ]