k230/k230/state_detection.py
2026-01-01 23:52:26 +08:00

376 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# state_detection.py
# 状态检测模块 - 通过HTTP请求发送图像进行心情和皮肤状态检测
# 适用于庐山派 K230-CanMV 开发板
import network
import socket
import time
import os
try:
import ubinascii as binascii
except ImportError:
import binascii
try:
import ujson as json
except ImportError:
import json
try:
import ussl as ssl
except ImportError:
try:
import ssl
except ImportError:
ssl = None
class StateDetector:
"""
状态检测类 - 通过HTTP POST请求发送图像数据到服务器进行心情和皮肤状态检测
功能:
- 连接WiFi网络
- 发送图像数据到检测服务器 (POST /api/detection/analyze)
- 接收并解析心情和皮肤状态检测结果
参数:
server_host: 检测服务器主机地址
server_port: 检测服务器端口 (默认80)
api_path: API路径 (默认 "/api/detection/analyze")
使用示例:
detector = StateDetector(server_host="192.168.0.21", server_port=8081)
detector.connect_wifi("SSID", "password")
result = detector.detect(image_data)
print(result["emotion"])
detector.disconnect()
"""
# 情绪类型定义
EMOTIONS = ["angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"]
def __init__(self, server_host, server_port=80, api_path="/api/detection/analyze"):
"""
初始化状态检测器
参数:
server_host: 检测服务器主机地址
server_port: 检测服务器端口
api_path: API路径
"""
self._server_host = server_host
self._server_port = server_port
self._api_path = api_path
self._sta = None
self._is_connected = False
def connect_wifi(self, ssid, password=None, timeout=15):
"""
连接WiFi网络
参数:
ssid: WiFi SSID
password: WiFi密码 (无密码时可为None)
timeout: 连接超时时间(秒)
返回:
str: 获取的IP地址
"""
self._sta = network.WLAN(network.STA_IF)
if password:
self._sta.connect(ssid, password)
else:
self._sta.connect(ssid)
print("正在连接WiFi: " + ssid + "...")
start_time = time.time()
while not self._sta.isconnected():
if time.time() - start_time > timeout:
raise RuntimeError("WiFi连接超时: " + ssid)
time.sleep(1)
os.exitpoint()
ip = self._sta.ifconfig()[0]
print("WiFi连接成功! IP: " + ip)
self._is_connected = True
return ip
def detect(self, image_data, width=640, height=480, timestamp=None):
"""
发送图像数据进行状态检测
参数:
image_data: 图像数据 (bytes类型RGB888格式)
width: 图像宽度 (默认640)
height: 图像高度 (默认480)
timestamp: 图像采集时间 (可选,格式: "YYYY-MM-DD HH:MM:SS")
返回:
dict: 检测结果
"""
if not self._is_connected:
raise RuntimeError("网络未连接,请先调用 connect_wifi() 或 connect_lan()")
# 添加BMP头将原始RGB数据包装为BMP图片
bmp_image = self._add_bmp_header(image_data, width, height)
# 将BMP图像数据编码为base64
image_base64 = binascii.b2a_base64(bmp_image).decode('utf-8').strip()
# 构建请求体
request_body = {
"type": "status_detection",
"image": image_base64
}
if timestamp:
request_body["timestamp"] = timestamp
else:
# 使用当前时间
t = time.localtime()
ts = "%04d-%02d-%02d %02d:%02d:%02d" % (t[0], t[1], t[2], t[3], t[4], t[5])
request_body["timestamp"] = ts
body_json = json.dumps(request_body)
# 创建socket并发送HTTP请求
response = self._send_http_request(body_json)
# 解析响应
return self._parse_response(response)
def _add_bmp_header(self, rgb_data, width, height):
"""
为RGB888数据添加BMP文件头
"""
# BMP文件头 (14 bytes)
# 0x42, 0x4D ("BM")
# File size (4 bytes)
# Reserved (4 bytes)
# Data offset (4 bytes) -> 54
# DIB Header (40 bytes)
# Header size (4 bytes) -> 40
# Width (4 bytes)
# Height (4 bytes) -> -height for top-down
# Planes (2 bytes) -> 1
# BPP (2 bytes) -> 24
# Compression (4 bytes) -> 0
# Image size (4 bytes)
# X ppm (4 bytes)
# Y ppm (4 bytes)
# Colors used (4 bytes)
# Colors important (4 bytes)
file_size = 54 + len(rgb_data)
# 构建头部
header = bytearray(54)
# BM
header[0], header[1] = 0x42, 0x4D
# File Size
header[2] = file_size & 0xFF
header[3] = (file_size >> 8) & 0xFF
header[4] = (file_size >> 16) & 0xFF
header[5] = (file_size >> 24) & 0xFF
# Data Offset (54)
header[10] = 54
# DIB Header Size (40)
header[14] = 40
# Width
header[18] = width & 0xFF
header[19] = (width >> 8) & 0xFF
header[20] = (width >> 16) & 0xFF
header[21] = (width >> 24) & 0xFF
# Height (Use negative for top-down RGB)
# 2's complement for negative number
h = -height
header[22] = h & 0xFF
header[23] = (h >> 8) & 0xFF
header[24] = (h >> 16) & 0xFF
header[25] = (h >> 24) & 0xFF
# Planes (1)
header[26] = 1
# BPP (24)
header[28] = 24
# Image Size
data_len = len(rgb_data)
header[34] = data_len & 0xFF
header[35] = (data_len >> 8) & 0xFF
header[36] = (data_len >> 16) & 0xFF
header[37] = (data_len >> 24) & 0xFF
return header + rgb_data
def _send_http_request(self, body, max_retries=3):
"""
发送HTTP POST请求
参数:
body: 请求体(JSON字符串)
max_retries: 最大重试次数
返回:
str: HTTP响应内容
"""
# 获取服务器地址
addr_info = None
for attempt in range(max_retries):
try:
addr_info = socket.getaddrinfo(self._server_host, self._server_port)
break
except BaseException:
print("DNS解析重试 (" + str(attempt + 1) + "/" + str(max_retries) + ")")
time.sleep(1)
if not addr_info:
raise RuntimeError("无法解析服务器地址: " + self._server_host)
addr = addr_info[0][-1]
print("连接服务器: " + str(addr))
# 确保body是bytes类型
if isinstance(body, str):
body_bytes = body.encode('utf-8')
else:
body_bytes = body
# 创建socket并连接
s = socket.socket()
try:
s.connect(addr)
# 构建HTTP请求头
host_header = self._server_host
if self._server_port != 80:
host_header = self._server_host + ":" + str(self._server_port)
# 构建头部
header = "POST " + self._api_path + " HTTP/1.1\r\n"
header += "Host: " + host_header + "\r\n"
header += "Content-Type: application/json\r\n"
header += "Content-Length: " + str(len(body_bytes)) + "\r\n"
header += "Connection: close\r\n"
header += "\r\n"
# 发送头部 (确保全部发送)
self._send_all(s, header.encode('utf-8'))
# 发送body (确保全部发送)
self._send_all(s, body_bytes)
# 接收响应
response = b""
while True:
chunk = s.recv(4096)
if not chunk:
break
response += chunk
return response.decode('utf-8')
finally:
s.close()
def _send_all(self, s, data):
"""
辅助函数:确保数据全部发送
使用memoryview避免内存复制分块发送
"""
# 使用memoryview避免切片时的内存复制
mv = memoryview(data)
total_len = len(data)
total_sent = 0
chunk_size = 4096 # 每次发送4KB
while total_sent < total_len:
try:
# 计算本次发送的切片
remaining = total_len - total_sent
to_send = min(chunk_size, remaining)
# 发送数据
# mv[start:end] 创建新的memoryview切片不复制数据
sent = s.send(mv[total_sent : total_sent + to_send])
if sent == 0:
raise RuntimeError("Socket连接断开 (sent=0)")
total_sent += sent
except OSError as e:
# EAGAIN (11) or EWOULDBLOCK
if e.args[0] == 11:
time.sleep(0.01)
continue
# 重新抛出其他错误 (如 ECONNRESET)
raise e
def _parse_response(self, response):
"""
解析HTTP响应
参数:
response: HTTP响应字符串
返回:
dict: 解析后的JSON响应体
"""
# 分离HTTP头部和body
if "\r\n\r\n" in response:
header, body = response.split("\r\n\r\n", 1)
elif "\n\n" in response:
header, body = response.split("\n\n", 1)
else:
raise RuntimeError("无效的HTTP响应格式")
# 解析JSON
try:
result = json.loads(body)
return result
except Exception as e:
print("JSON解析错误: " + str(e))
print("响应内容: " + body)
return {
"type": "status_detection_response",
"success": False,
"error": "JSON解析失败: " + str(e)
}
def disconnect(self):
"""断开网络连接"""
if self._sta and self._sta.isconnected():
self._sta.disconnect()
print("WiFi已断开")
self._is_connected = False
@property
def is_connected(self):
"""是否已连接网络"""
return self._is_connected
@property
def server_host(self):
"""服务器主机地址"""
return self._server_host
@property
def server_port(self):
"""服务器端口"""
return self._server_port