搭建好训练环境后,就需要训练自己的数据啦。在下载到的数据中将 B、S、V三个字母的图片文件夹保留,其余均删除。每个文件夹下都有3000张图片。使用0.85:0.15的比例,随机分配,分出训练集和测试集。自己的项目是一个分类问题,所以参考官方提供的例子,猫狗分类的例子,来写一个自己的手势训练的项目。
1、在ai8x-training训练需要3个文件,训练模型,这里直接使用的是ai85cdnet;载入训练、测试数据的脚本文件 ai8x-training\datasets\gesture.py 这个文件指明了,训练数据和测试数据的来源、对测试数据的变换以及输出内容的数量。policies/qat_policy_cd.yaml这个文件没有搞清楚具体是做什么用的,直接复用猫狗分类模型的文件。
python train.py --epochs 200 --optimizer Adam --lr 0.001 --wd 0 --deterministic --compress policies/schedule-gesture.yaml --qat-policy policies/qat_policy_cd.yaml --model ai85cdnet --dataset gesture --confusion --param-hist --embedding --device MAX78000 "$@"
###################################################################################################
#
# Copyright (C) 2023 Analog Devices, Inc. All Rights Reserved.
# This software is proprietary to Analog Devices, Inc. and its licensors.
#
###################################################################################################
#
# Copyright (C) 2022 Maxim Integrated Products, Inc. (now owned by Analog Devices Inc.)
# All Rights Reserved.
#
# Maxim Integrated Products, Inc. Default Copyright Notice:
# https://www.maximintegrated.com/en/aboutus/legal/copyrights.html
#
###################################################################################################
"""
剪刀V 石头S 布B Datasets
"""
import os
import sys
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations as album
import cv2
import ai8x
class Gesture(Dataset):
"""
`Cats vs Dogs dataset <https://www.kaggle.com/datasets/salader/dogs-vs-cats>` Dataset.
Args:
root_dir (string): Root directory of dataset where ``KWS/processed/dataset.pt``
exist.
d_type(string): Option for the created dataset. ``train`` or ``test``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version.
resize_size(int, int): Width and height of the images to be resized for the dataset.
augment_data(bool): Flag to augment the data or not. If d_type is `test`, augmentation is
disabled.
"""
labels = ['B', 'S','V']
label_to_id_map = {k: v for v, k in enumerate(labels)}
label_to_folder_map = {'B': 'B', 'S': 'S','V':'V'}
def __init__(self, root_dir, d_type, transform=None,
resize_size=(128, 128), augment_data=False):
self.root_dir = root_dir
self.data_dir = os.path.join(root_dir, 'gesture', d_type)
if not self.__check_gesture_data_exist():
self.__print_download_manual()
sys.exit("Dataset not found!")
self.__get_image_paths()
self.album_transform = None
if d_type == 'train' and augment_data:
self.album_transform = album.Compose([
album.GaussNoise(var_limit=(1.0, 20.0), p=0.25),
album.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
album.ColorJitter(p=0.5),
album.SmallestMaxSize(max_size=int(1.2*min(resize_size))),
album.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
album.RandomCrop(height=resize_size[0], width=resize_size[1]),
album.HorizontalFlip(p=0.5),
album.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0))])
if not augment_data or d_type == 'test':
self.album_transform = album.Compose([
album.SmallestMaxSize(max_size=int(1.2*min(resize_size))),
album.CenterCrop(height=resize_size[0], width=resize_size[1]),
album.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0))])
self.transform = transform
def __check_gesture_data_exist(self):
return os.path.isdir(self.data_dir)
def __print_download_manual(self):
print("******************************************")
print("Please follow the instructions below:")
print("Download the dataset to the \'data\' folder by visiting this link: "
"\'https://www.kaggle.com/datasets/salader/dogs-vs-cats\'")
print("If you do not have a Kaggle account, sign up first.")
print("Unzip the downloaded file and find \'test\' and \'train\' folders "
"and copy them into \'data/cats_vs_dogs\'. ")
print("Make sure that images are in the following directory structure:")
print(" \'data/cats_vs_dogs/train/cats\'")
print(" \'data/cats_vs_dogs/train/dogs\'")
print(" \'data/cats_vs_dogs/test/cats\'")
print(" \'data/cats_vs_dogs/test/dogs\'")
print("Re-run the script. The script will create an \'augmented\' folder ")
print("with all the original and augmented images. Remove this folder if you want "
"to change the augmentation and to recreate the dataset.")
print("******************************************")
def __get_image_paths(self):
self.data_list = []
for label in self.labels:
image_dir = os.path.join(self.data_dir, self.label_to_folder_map[label])
for file_name in sorted(os.listdir(image_dir)):
file_path = os.path.join(image_dir, file_name)
if os.path.isfile(file_path):
self.data_list.append((file_path, self.label_to_id_map[label]))
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
label = torch.tensor(self.data_list[index][1], dtype=torch.int64)
image_path = self.data_list[index][0]
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.album_transform:
image = self.album_transform(image=image)["image"]
if self.transform:
image = self.transform(image)
return image, label
def get_gesture_dataset(data, load_train, load_test):
(data_dir, args) = data
transform = transforms.Compose([
transforms.ToTensor(),
ai8x.normalize(args=args),
])
if load_train:
train_dataset = Gesture(root_dir=data_dir, d_type='train',
transform=transform, augment_data=True)
else:
train_dataset = None
if load_test:
test_dataset = Gesture(root_dir=data_dir, d_type='test', transform=transform)
else:
test_dataset = None
return train_dataset, test_dataset
datasets = [
{
'name': 'gesture',
'input': (3, 128, 128),
'output': ('B', 'S','V'),
'loader': get_gesture_dataset,
},
]

经过超长时间的训练(6小时),终于训练完成啦!该步骤会生成训练结果文件qat_best.pth.tar。
2、模型转换。在ai8x-synthesis 下执行以下命令,由qat_best.pth.tar生成gesture-q.pth.tar。
python quantize.py ../ai8x-training/logs/2025.07.21-152827/qat_best.pth.tar ../ai8x-training/logs/2025.07.21-152827/gesture-q.pth.tar --device MAX78000 -v

3、模型评估。在 ai8x-training下执行以下命令。
python train.py --model ai85cdnet --dataset gesture --confusion --evaluate --exp-load-weights-from ./logs/2025.07.21-152827/gesture-q.pth.tar -8 --device MAX78000
4、生成测试样本。在 ai8x-training下执行以下命令。这次命令会产生一个sample_gesture.npy文件,需要把这个文件拷贝到ai8x-synthesis\tests下,下一步操作会用到这个文件。
python train.py --model ai85cdnet --save-sample 10 --dataset gesture --evaluate --exp-load-weights-from ./logs/2025.07.21-152827/gesture-q.pth.tar -8 --device MAX78000 --data data --use-bias
5、生成MAX78000可用的工程。在ai8x-synthesis 下执行以下命令。留意命令中“--fifo”参数,catsdogs例程中的脚本是没有这个参数的,但是我在实际跑的过程中,发现如果不带这个参数,会报fifo错误,无法生成工程。networks/gesture-hwc.yaml文件是照搬cats-dogs-hwc.yaml文件的,简单修改了一下dataset
python ai8xize.py --verbose --test-dir "sdk/Examples/MAX78000/CNN" --prefix gesture --checkpoint-file ../ai8x-training/logs/2025.07.21-152827/gesture-q.pth.tar --config-file networks/gesture-hwc.yaml --fifo --device MAX78000 --compact-data --mexpress --softmax --overwrite

至此数据训练完成。接下来就是单片机上的编程了。
我要赚赏金
