376 lines
11 KiB
Python
376 lines
11 KiB
Python
# state_detection.py
|
||
# 状态检测模块 - 通过HTTP请求发送图像进行心情和皮肤状态检测
|
||
# 适用于庐山派 K230-CanMV 开发板
|
||
|
||
import network
|
||
import socket
|
||
import time
|
||
import os
|
||
|
||
try:
|
||
import ubinascii as binascii
|
||
except ImportError:
|
||
import binascii
|
||
|
||
try:
|
||
import ujson as json
|
||
except ImportError:
|
||
import json
|
||
|
||
try:
|
||
import ussl as ssl
|
||
except ImportError:
|
||
try:
|
||
import ssl
|
||
except ImportError:
|
||
ssl = None
|
||
|
||
|
||
class StateDetector:
|
||
"""
|
||
状态检测类 - 通过HTTP POST请求发送图像数据到服务器进行心情和皮肤状态检测
|
||
|
||
功能:
|
||
- 连接WiFi网络
|
||
- 发送图像数据到检测服务器 (POST /api/detection/analyze)
|
||
- 接收并解析心情和皮肤状态检测结果
|
||
|
||
参数:
|
||
server_host: 检测服务器主机地址
|
||
server_port: 检测服务器端口 (默认80)
|
||
api_path: API路径 (默认 "/api/detection/analyze")
|
||
|
||
使用示例:
|
||
detector = StateDetector(server_host="192.168.0.21", server_port=8081)
|
||
detector.connect_wifi("SSID", "password")
|
||
result = detector.detect(image_data)
|
||
print(result["emotion"])
|
||
detector.disconnect()
|
||
"""
|
||
|
||
# 情绪类型定义
|
||
EMOTIONS = ["angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"]
|
||
|
||
def __init__(self, server_host, server_port=80, api_path="/api/detection/analyze"):
|
||
"""
|
||
初始化状态检测器
|
||
|
||
参数:
|
||
server_host: 检测服务器主机地址
|
||
server_port: 检测服务器端口
|
||
api_path: API路径
|
||
"""
|
||
self._server_host = server_host
|
||
self._server_port = server_port
|
||
self._api_path = api_path
|
||
|
||
self._sta = None
|
||
self._is_connected = False
|
||
|
||
def connect_wifi(self, ssid, password=None, timeout=15):
|
||
"""
|
||
连接WiFi网络
|
||
|
||
参数:
|
||
ssid: WiFi SSID
|
||
password: WiFi密码 (无密码时可为None)
|
||
timeout: 连接超时时间(秒)
|
||
|
||
返回:
|
||
str: 获取的IP地址
|
||
"""
|
||
self._sta = network.WLAN(network.STA_IF)
|
||
|
||
if password:
|
||
self._sta.connect(ssid, password)
|
||
else:
|
||
self._sta.connect(ssid)
|
||
|
||
print("正在连接WiFi: " + ssid + "...")
|
||
|
||
start_time = time.time()
|
||
while not self._sta.isconnected():
|
||
if time.time() - start_time > timeout:
|
||
raise RuntimeError("WiFi连接超时: " + ssid)
|
||
time.sleep(1)
|
||
os.exitpoint()
|
||
|
||
ip = self._sta.ifconfig()[0]
|
||
print("WiFi连接成功! IP: " + ip)
|
||
self._is_connected = True
|
||
return ip
|
||
|
||
def detect(self, image_data, width=640, height=480, timestamp=None):
|
||
"""
|
||
发送图像数据进行状态检测
|
||
|
||
参数:
|
||
image_data: 图像数据 (bytes类型,RGB888格式)
|
||
width: 图像宽度 (默认640)
|
||
height: 图像高度 (默认480)
|
||
timestamp: 图像采集时间 (可选,格式: "YYYY-MM-DD HH:MM:SS")
|
||
|
||
返回:
|
||
dict: 检测结果
|
||
"""
|
||
if not self._is_connected:
|
||
raise RuntimeError("网络未连接,请先调用 connect_wifi() 或 connect_lan()")
|
||
|
||
# 添加BMP头,将原始RGB数据包装为BMP图片
|
||
bmp_image = self._add_bmp_header(image_data, width, height)
|
||
|
||
# 将BMP图像数据编码为base64
|
||
image_base64 = binascii.b2a_base64(bmp_image).decode('utf-8').strip()
|
||
|
||
# 构建请求体
|
||
request_body = {
|
||
"type": "status_detection",
|
||
"image": image_base64
|
||
}
|
||
|
||
if timestamp:
|
||
request_body["timestamp"] = timestamp
|
||
else:
|
||
# 使用当前时间
|
||
t = time.localtime()
|
||
ts = "%04d-%02d-%02d %02d:%02d:%02d" % (t[0], t[1], t[2], t[3], t[4], t[5])
|
||
request_body["timestamp"] = ts
|
||
|
||
body_json = json.dumps(request_body)
|
||
|
||
# 创建socket并发送HTTP请求
|
||
response = self._send_http_request(body_json)
|
||
|
||
# 解析响应
|
||
return self._parse_response(response)
|
||
|
||
def _add_bmp_header(self, rgb_data, width, height):
|
||
"""
|
||
为RGB888数据添加BMP文件头
|
||
"""
|
||
# BMP文件头 (14 bytes)
|
||
# 0x42, 0x4D ("BM")
|
||
# File size (4 bytes)
|
||
# Reserved (4 bytes)
|
||
# Data offset (4 bytes) -> 54
|
||
|
||
# DIB Header (40 bytes)
|
||
# Header size (4 bytes) -> 40
|
||
# Width (4 bytes)
|
||
# Height (4 bytes) -> -height for top-down
|
||
# Planes (2 bytes) -> 1
|
||
# BPP (2 bytes) -> 24
|
||
# Compression (4 bytes) -> 0
|
||
# Image size (4 bytes)
|
||
# X ppm (4 bytes)
|
||
# Y ppm (4 bytes)
|
||
# Colors used (4 bytes)
|
||
# Colors important (4 bytes)
|
||
|
||
file_size = 54 + len(rgb_data)
|
||
|
||
# 构建头部
|
||
header = bytearray(54)
|
||
|
||
# BM
|
||
header[0], header[1] = 0x42, 0x4D
|
||
|
||
# File Size
|
||
header[2] = file_size & 0xFF
|
||
header[3] = (file_size >> 8) & 0xFF
|
||
header[4] = (file_size >> 16) & 0xFF
|
||
header[5] = (file_size >> 24) & 0xFF
|
||
|
||
# Data Offset (54)
|
||
header[10] = 54
|
||
|
||
# DIB Header Size (40)
|
||
header[14] = 40
|
||
|
||
# Width
|
||
header[18] = width & 0xFF
|
||
header[19] = (width >> 8) & 0xFF
|
||
header[20] = (width >> 16) & 0xFF
|
||
header[21] = (width >> 24) & 0xFF
|
||
|
||
# Height (Use negative for top-down RGB)
|
||
# 2's complement for negative number
|
||
h = -height
|
||
header[22] = h & 0xFF
|
||
header[23] = (h >> 8) & 0xFF
|
||
header[24] = (h >> 16) & 0xFF
|
||
header[25] = (h >> 24) & 0xFF
|
||
|
||
# Planes (1)
|
||
header[26] = 1
|
||
|
||
# BPP (24)
|
||
header[28] = 24
|
||
|
||
# Image Size
|
||
data_len = len(rgb_data)
|
||
header[34] = data_len & 0xFF
|
||
header[35] = (data_len >> 8) & 0xFF
|
||
header[36] = (data_len >> 16) & 0xFF
|
||
header[37] = (data_len >> 24) & 0xFF
|
||
|
||
return header + rgb_data
|
||
|
||
def _send_http_request(self, body, max_retries=3):
|
||
"""
|
||
发送HTTP POST请求
|
||
|
||
参数:
|
||
body: 请求体(JSON字符串)
|
||
max_retries: 最大重试次数
|
||
|
||
返回:
|
||
str: HTTP响应内容
|
||
"""
|
||
# 获取服务器地址
|
||
addr_info = None
|
||
for attempt in range(max_retries):
|
||
try:
|
||
addr_info = socket.getaddrinfo(self._server_host, self._server_port)
|
||
break
|
||
except BaseException:
|
||
print("DNS解析重试 (" + str(attempt + 1) + "/" + str(max_retries) + ")")
|
||
time.sleep(1)
|
||
|
||
if not addr_info:
|
||
raise RuntimeError("无法解析服务器地址: " + self._server_host)
|
||
|
||
addr = addr_info[0][-1]
|
||
print("连接服务器: " + str(addr))
|
||
|
||
# 确保body是bytes类型
|
||
if isinstance(body, str):
|
||
body_bytes = body.encode('utf-8')
|
||
else:
|
||
body_bytes = body
|
||
|
||
# 创建socket并连接
|
||
s = socket.socket()
|
||
|
||
try:
|
||
s.connect(addr)
|
||
|
||
# 构建HTTP请求头
|
||
host_header = self._server_host
|
||
if self._server_port != 80:
|
||
host_header = self._server_host + ":" + str(self._server_port)
|
||
|
||
# 构建头部
|
||
header = "POST " + self._api_path + " HTTP/1.1\r\n"
|
||
header += "Host: " + host_header + "\r\n"
|
||
header += "Content-Type: application/json\r\n"
|
||
header += "Content-Length: " + str(len(body_bytes)) + "\r\n"
|
||
header += "Connection: close\r\n"
|
||
header += "\r\n"
|
||
|
||
# 发送头部 (确保全部发送)
|
||
self._send_all(s, header.encode('utf-8'))
|
||
|
||
# 发送body (确保全部发送)
|
||
self._send_all(s, body_bytes)
|
||
|
||
# 接收响应
|
||
response = b""
|
||
while True:
|
||
chunk = s.recv(4096)
|
||
if not chunk:
|
||
break
|
||
response += chunk
|
||
|
||
return response.decode('utf-8')
|
||
|
||
finally:
|
||
s.close()
|
||
|
||
def _send_all(self, s, data):
|
||
"""
|
||
辅助函数:确保数据全部发送
|
||
使用memoryview避免内存复制,分块发送
|
||
"""
|
||
# 使用memoryview避免切片时的内存复制
|
||
mv = memoryview(data)
|
||
total_len = len(data)
|
||
total_sent = 0
|
||
chunk_size = 4096 # 每次发送4KB
|
||
|
||
while total_sent < total_len:
|
||
try:
|
||
# 计算本次发送的切片
|
||
remaining = total_len - total_sent
|
||
to_send = min(chunk_size, remaining)
|
||
|
||
# 发送数据
|
||
# mv[start:end] 创建新的memoryview切片,不复制数据
|
||
sent = s.send(mv[total_sent : total_sent + to_send])
|
||
|
||
if sent == 0:
|
||
raise RuntimeError("Socket连接断开 (sent=0)")
|
||
|
||
total_sent += sent
|
||
|
||
except OSError as e:
|
||
# EAGAIN (11) or EWOULDBLOCK
|
||
if e.args[0] == 11:
|
||
time.sleep(0.01)
|
||
continue
|
||
# 重新抛出其他错误 (如 ECONNRESET)
|
||
raise e
|
||
|
||
def _parse_response(self, response):
|
||
"""
|
||
解析HTTP响应
|
||
|
||
参数:
|
||
response: HTTP响应字符串
|
||
|
||
返回:
|
||
dict: 解析后的JSON响应体
|
||
"""
|
||
# 分离HTTP头部和body
|
||
if "\r\n\r\n" in response:
|
||
header, body = response.split("\r\n\r\n", 1)
|
||
elif "\n\n" in response:
|
||
header, body = response.split("\n\n", 1)
|
||
else:
|
||
raise RuntimeError("无效的HTTP响应格式")
|
||
|
||
# 解析JSON
|
||
try:
|
||
result = json.loads(body)
|
||
return result
|
||
except Exception as e:
|
||
print("JSON解析错误: " + str(e))
|
||
print("响应内容: " + body)
|
||
return {
|
||
"type": "status_detection_response",
|
||
"success": False,
|
||
"error": "JSON解析失败: " + str(e)
|
||
}
|
||
|
||
def disconnect(self):
|
||
"""断开网络连接"""
|
||
if self._sta and self._sta.isconnected():
|
||
self._sta.disconnect()
|
||
print("WiFi已断开")
|
||
self._is_connected = False
|
||
|
||
@property
|
||
def is_connected(self):
|
||
"""是否已连接网络"""
|
||
return self._is_connected
|
||
|
||
@property
|
||
def server_host(self):
|
||
"""服务器主机地址"""
|
||
return self._server_host
|
||
|
||
@property
|
||
def server_port(self):
|
||
"""服务器端口"""
|
||
return self._server_port
|