Hi~ 我又回来了。在之前的帖子中,我已经实现了PC端和STM32端的蓝牙数据交互,所以后面可以做的东西就很多了。上个帖子说过,这次会做一个QT + STM32 的具体应用,我大概看了一下,好像没有人用stm32做识别的,于是乎我就做一下吧,弥补一下这个空白。
具体来说,我就做一个基于MNIST数据集的手写数字识别吧 。
视频演示:【STM32 手写数字识别 点阵显示 蓝牙传输数据】
一、系统设计
用QT设计一个GUI界面,可以写字,写完数字,通过蓝牙把写的28x28的灰度数字发给STM32进行识别,识别之后通过屏幕之类的显示。系统框图如下:
二、QT 上位机设计
上位机蓝牙通信部分还是延续前两次的代码,但是需要做一个画图的界面,所以就需要加一个GUI专门画图像。
上位机界面如下:
具体代码上,在之前代码的基础上新添加一个drawingwidget.cpp ,默认画笔大小是20,这部分代码也是参考了别人的代码加以修改过来的。其中发送图片像素的函数给图片做了一个缩放,最后会变成28x28的图像,以适应mnist数据集。
DrawingWidget::DrawingWidget(QWidget *parent) : QWidget(parent), penColor(Qt::darkGray), drawingArea(50, 50, 280, 280) { PenWidth=20; } DrawingWidget::~DrawingWidget() { } void DrawingWidget::setPenColor(void) { QColor color = QColorDialog::getColor(Qt::white, this, "选择颜色"); if (color.isValid()) { penColor = color; } } void DrawingWidget::clearDrawing() { points.clear(); update(); } QImage DrawingWidget::getDrawingBits(int scaleFactor){ QPixmap pixmap(drawingArea.size()); pixmap.fill(Qt::black); // 在 QPixmap 上绘制内容 QPainter painter(&pixmap); QPen pen(penColor); pen.setWidth(PenWidth); painter.setPen(pen); // 平移 painter 以匹配绘图区域的相对坐标 painter.translate(-drawingArea.topLeft()); // 绘制线条 painter.drawPolyline(QPolygon(point)); for (auto p : points) { painter.drawPolyline(QPolygon(p)); } // 缩放 QPixmap scaledPixmap = pixmap.scaled(pixmap.size() / scaleFactor, Qt::KeepAspectRatio, Qt::SmoothTransformation); grayImage = scaledPixmap.toImage().convertToFormat(QImage::Format_Grayscale8); return grayImage; } void DrawingWidget::saveDrawing(const QString &filePath, int scaleFactor) { // 创建一个与绘图区域相同大小的 QPixmap QPixmap pixmap(drawingArea.size()); pixmap.fill(Qt::black); // 在 QPixmap 上绘制内容 QPainter painter(&pixmap); QPen pen(penColor); pen.setWidth(PenWidth); painter.setPen(pen); // 平移 painter 以匹配绘图区域的相对坐标 painter.translate(-drawingArea.topLeft()); // 绘制线条 painter.drawPolyline(QPolygon(point)); for (auto p : points) { painter.drawPolyline(QPolygon(p)); } // 缩放 QPixmap scaledPixmap = pixmap.scaled(pixmap.size() / scaleFactor, Qt::KeepAspectRatio, Qt::SmoothTransformation); QImage grayImage = scaledPixmap.toImage().convertToFormat(QImage::Format_Grayscale8); //const uchar *data = grayImage.bits(); //only pix data //int size_T=grayImage.height()*grayImage.width(); // 保存图片到文件 if (!filePath.isEmpty()) { grayImage.save(filePath, "BMP"); } } bool DrawingWidget::event(QEvent *event) { QMouseEvent *e = static_cast<QMouseEvent *>(event); if (event->type() == QEvent::MouseMove) { if (drawingArea.contains(e->pos())) { point.append(e->pos()); this->update(); } } else if (event->type() == QEvent::MouseButtonRelease) { points.append(point); point.clear(); } else if (event->type() == QEvent::Paint) { QPainter painter(this); // 绘制绘图区域边框 QPen borderPen(Qt::DashLine); borderPen.setColor(Qt::gray); painter.setPen(borderPen); painter.drawRect(drawingArea); // 绘制用户的线条 QPen pen(penColor); pen.setWidth(PenWidth); painter.setPen(pen); painter.drawPolyline(QPolygon(point)); for (auto p : points) { painter.drawPolyline(QPolygon(p)); } } return QWidget::event(event); }
使用这个类的话直接 new 就行
drawingWidget_t = new DrawingWidget(this); ui->verticalLayout_2->addWidget(drawingWidget_t);
图片数据发送部分,我选择了分成两个characteristic 发送,因为图片太大了(784 bytes),一次最多只能发500 bytes
void Widget::on_pushButton_Send_clicked() { QImage imageBuf=drawingWidget_t->getDrawingBits(10); QByteArray byteArray1(reinterpret_cast<const char*>(imageBuf.bits()), 500); QByteArray byteArray2(reinterpret_cast<const char*>(imageBuf.bits()+500), 284); gatt_comm_t->Characteristic_Write_Trigger(0,byteArray1); gatt_comm_t->Characteristic_Write_Trigger(1,byteArray2); }
三、STM32 程序设计
stm32 程序分为:接收触发识别、识别、显示,三个部分组成。
3.1 首先是接收,接收之后存到数组里。接收完成触发识别flag
/* Functions Definition ------------------------------------------------------*/ void P2P_SERVER_Notification(P2P_SERVER_NotificationEvt_t *p_Notification) { /* USER CODE BEGIN Service1_Notification_1 */ uint8_t *dataBuf=p_Notification->DataTransfered.p_Payload; uint8_t dataLen=p_Notification->DataTransfered.Length; LOG_INFO_APP("\r\nData len:%d\r\n",dataLen); /* USER CODE END Service1_Notification_1 */ switch(p_Notification->EvtOpcode) { /* USER CODE BEGIN Service1_Notification_Service1_EvtOpcode */ /* USER CODE END Service1_Notification_Service1_EvtOpcode */ case P2P_SERVER_TX_READ_EVT: /* USER CODE BEGIN Service1Char1_READ_EVT */ LOG_INFO_APP("P2P_SERVER_TX_READ_EVT\r\n"); /* USER CODE END Service1Char1_READ_EVT */ break; case P2P_SERVER_TX_WRITE_NO_RESP_EVT: /* USER CODE BEGIN Service1Char1_WRITE_NO_RESP_EVT */ //LOG_INFO_APP("P2P_SERVER_TX_WRITE_NO_RESP_EVT:%s\r\n",dataBuf); memcpy(recv_pic_1+recv_index1,dataBuf,dataLen); recv_index1+=dataLen; if(recv_index1==FIRST_PACKET_LEN){ recv_index1=0; } /* USER CODE END Service1Char1_WRITE_NO_RESP_EVT */ break; case P2P_SERVER_TX_WRITE_EVT: /* USER CODE BEGIN Service1Char1_WRITE_EVT */ LOG_INFO_APP("P2P_SERVER_TX_WRITE_EVT:%s\r\n",dataBuf); /* USER CODE END Service1Char1_WRITE_EVT */ break; case P2P_SERVER_TX_NOTIFY_ENABLED_EVT: /* USER CODE BEGIN Service1Char1_NOTIFY_ENABLED_EVT */ LOG_INFO_APP("P2P_SERVER_TX_NOTIFY_ENABLED_EVT:%s\r\n",dataBuf); /* USER CODE END Service1Char1_NOTIFY_ENABLED_EVT */ break; case P2P_SERVER_TX_NOTIFY_DISABLED_EVT: /* USER CODE BEGIN Service1Char1_NOTIFY_DISABLED_EVT */ LOG_INFO_APP("P2P_SERVER_TX_NOTIFY_DISABLED_EVT:%s\r\n",dataBuf); /* USER CODE END Service1Char1_NOTIFY_DISABLED_EVT */ break; case P2P_SERVER_RX_READ_EVT: /* USER CODE BEGIN Service1Char2_READ_EVT */ /* USER CODE END Service1Char2_READ_EVT */ break; case P2P_SERVER_RX_WRITE_NO_RESP_EVT: /* USER CODE BEGIN Service1Char2_WRITE_NO_RESP_EVT */ //LOG_INFO_APP("P2P_SERVER_RX_WRITE_NO_RESP_EVT:%s\r\n",dataBuf); memcpy(recv_pic_2+recv_index1,dataBuf,dataLen); recv_index1+=dataLen; if(recv_index1==SECOND_PACKET_LEN){ recv_index1=0; set_reg_flag(1); } /* USER CODE END Service1Char2_WRITE_NO_RESP_EVT */ break; case P2P_SERVER_RX_WRITE_EVT: /* USER CODE BEGIN Service1Char2_WRITE_EVT */ LOG_INFO_APP("P2P_SERVER_RX_WRITE_EVT:%s\r\n",dataBuf); /* USER CODE END Service1Char2_WRITE_EVT */ break; case P2P_SERVER_RX_NOTIFY_ENABLED_EVT: /* USER CODE BEGIN Service1Char2_NOTIFY_ENABLED_EVT */ LOG_INFO_APP("P2P_SERVER_RX_NOTIFY_ENABLED_EVT\r\n"); /* USER CODE END Service1Char2_NOTIFY_ENABLED_EVT */ break; case P2P_SERVER_RX_NOTIFY_DISABLED_EVT: /* USER CODE BEGIN Service1Char2_NOTIFY_DISABLED_EVT */ LOG_INFO_APP("P2P_SERVER_RX_NOTIFY_DISABLED_EVT\r\n"); /* USER CODE END Service1Char2_NOTIFY_DISABLED_EVT */ break; case P2P_SERVER_SW_READ_EVT: /* USER CODE BEGIN Service1Char3_READ_EVT */ LOG_INFO_APP("P2P_SERVER_SW_READ_EVT\r\n"); /* USER CODE END Service1Char3_READ_EVT */ break; case P2P_SERVER_SW_WRITE_NO_RESP_EVT: /* USER CODE BEGIN Service1Char3_WRITE_NO_RESP_EVT */ LOG_INFO_APP("P2P_SERVER_SW_WRITE_NO_RESP_EVT:%s\r\n",dataBuf); /* USER CODE END Service1Char3_WRITE_NO_RESP_EVT */ break; case P2P_SERVER_SW_WRITE_EVT: /* USER CODE BEGIN Service1Char3_WRITE_EVT */ LOG_INFO_APP("P2P_SERVER_SW_WRITE_EVT:%s\r\n",dataBuf); /* USER CODE END Service1Char3_WRITE_EVT */ break; case P2P_SERVER_SW_NOTIFY_ENABLED_EVT: /* USER CODE BEGIN Service1Char3_NOTIFY_ENABLED_EVT */ LOG_INFO_APP("P2P_SERVER_SW_NOTIFY_ENABLED_EVT\r\n"); /* USER CODE END Service1Char3_NOTIFY_ENABLED_EVT */ break; case P2P_SERVER_SW_NOTIFY_DISABLED_EVT: /* USER CODE BEGIN Service1Char3_NOTIFY_DISABLED_EVT */ LOG_INFO_APP("P2P_SERVER_SW_NOTIFY_DISABLED_EVT\r\n"); /* USER CODE END Service1Char3_NOTIFY_DISABLED_EVT */ break; default: /* USER CODE BEGIN Service1_Notification_default */ LOG_INFO_APP("P2P_SERVER_Notification FUC\r\n"); /* USER CODE END Service1_Notification_default */ break; } /* USER CODE BEGIN Service1_Notification_2 */ /* USER CODE END Service1_Notification_2 */ return; }
3.2 然后是识别,识别是关键,这里用的是nnom框架,它可以生成C语言的神经网络,不过我这里层数并不多,不然单片机的内存不够用,会Hard Fault 。但是牺牲的是识别的精度,实际用起来其实并不是太好,勉勉强强能识别吧。本来想用LCD显示,奈何内存严重不足,退而求其次,最后只能用点阵显示数字了。
while(1)部分:
if(reg_flag){ temp_pic1=get_recv_pic_first(); temp_pic2=get_recv_pic_second(); //display_pic(temp_pic); pre_num=pic_recognition(temp_pic1,temp_pic2); display_num(pre_num); printf("NN recognition num:%d\r\n",pre_num); reg_flag=0; }
识别的函数:
uint32_t pic_recognition(uint8_t *custom_pic,uint8_t *custom_pic2){ uint32_t pre_label=0; float prob; // model:mnist nn mod , nnom_output_data:10 classes, get top-1:record num //pre = prediction_create(model, nnom_output_data, sizeof(nnom_output_data), 1); memcpy(nnom_input_data, custom_pic, 500); memcpy(nnom_input_data+500, custom_pic2, 284); //this provide more infor but requires prediction API //prediction_run(pre, 0, &pre_label, &prob); nnom_predict(model, &pre_label, &prob); //printf("pre result is %d prob is %f\r\n",pre_label,prob); // print prediction result //prediction_end(pre); //prediction_summary(pre); //prediction_delete(pre); // model Print running stat //model_stat(model); return pre_label; }
模型具体结构:
static nnom_model_t* nnom_model_create(void) { static nnom_model_t model; nnom_layer_t* layer[12]; new_model(&model); layer[0] = Input(shape(28, 28, 1), nnom_input_data); layer[1] = model.hook(Conv2D(8, kernel(5, 5), stride(1, 1), dilation(1, 1), PADDING_VALID, &conv2d_w, &conv2d_b), layer[0]); layer[2] = model.active(act_relu(), layer[1]); layer[3] = model.hook(AvgPool(kernel(6, 6), stride(1, 1), PADDING_VALID), layer[2]); layer[4] = model.hook(Conv2D(8, kernel(6, 6), stride(1, 1), dilation(1, 1), PADDING_VALID, &conv2d_1_w, &conv2d_1_b), layer[3]); layer[5] = model.active(act_relu(), layer[4]); layer[6] = model.hook(AvgPool(kernel(7, 7), stride(1, 1), PADDING_VALID), layer[5]); layer[7] = model.hook(Dense(32, &dense_w, &dense_b), layer[6]); layer[8] = model.active(act_relu(), layer[7]); layer[9] = model.hook(Dense(10, &dense_1_w, &dense_1_b), layer[8]); layer[10] = model.hook(Softmax(), layer[9]); layer[11] = model.hook(Output(shape(10,1,1), nnom_output_data), layer[10]); model_compile(&model, layer[0], layer[11]); return &model; }
最后来一张完工的截图: