This commit is contained in:
mshe 2025-09-27 18:55:42 +08:00
parent ff389389dc
commit bdb9b90827
6 changed files with 76 additions and 26 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@
yolo
runs
models/
*.pt

45
00.test.py Normal file
View File

@ -0,0 +1,45 @@
from ultralytics import YOLO
import os
os.environ['ULTRALYTICS_PLOTS'] = 'False' # 禁用绘图环境变量
def safe_training():
"""安全训练彻底避免NumPy问题"""
model = YOLO('yolov8n-pose.pt')
results = model.train(
data='./dataset1/train.yaml',
epochs=100,
imgsz=320,
batch=2,
device='mps',
workers=0, # 设置为0避免多进程问题
# 关键彻底禁用所有可能触发NumPy bug的功能
plots=False, # 禁用绘图
save_json=False, # 禁用JSON保存
verbose=True,
# 简化所有参数
lr0=0.001,
pose=2.0,
kobj=1.5,
# 关闭数据增强
augment=False,
hsv_h=0.0,
hsv_s=0.0,
hsv_v=0.0,
degrees=0.0,
translate=0.0,
scale=0.0,
fliplr=0.0,
)
return results
print("开始安全训练...")
safe_training()

View File

@ -22,5 +22,5 @@ model = YOLO('models/yolov8n-pose.pt')
# save_dir: 训练结果保存路径
# weights: 预训练模型路径
model.info()
model.train(data='./dataset1/train.yaml', epochs=300, imgsz=640, batch=32, device=0)
model.train(data='./dataset1/train.yaml', epochs=300, imgsz=640, batch=32, device='cpu')
print("训练完成")

View File

@ -18,28 +18,29 @@ model.train(
device=0, # 使用GPU 0
workers=8, # 充分利用CPU核心
patience=50,
# lr0=0.01,
# lrf=0.01,
# momentum=0.937,
# weight_decay=0.0005,
# warmup_epochs=3.0,
#
# # 损失权重调整(重点解决姿态问题)
# box=7.5,
# pose=1.5, # 提高姿态损失权重
# kobj=2.0, # 提高关键点目标权重
# cls=1.0,
#
# # 性能优化
# amp=True, # 自动混合精度训练
# cos_lr=True, # 余弦学习率调度
# close_mosaic=10, # 最后10epoch关闭马赛克增强
#
# # 数据增强
# hsv_h=0.015,
# hsv_s=0.7,
# hsv_v=0.4,
# fliplr=0.5, # 水平翻转对姿态很重要
lr0=0.01,
lrf=0.01,
momentum=0.937,
weight_decay=0.0005,
warmup_epochs=3.0,
# 重点:大幅调整损失权重
pose=10.0, # 大幅提高姿态损失权重
kobj=5.0, # 提高关键点目标权重
box=1.0, # 降低检测权重(因为检测已经很好)
cls=0.3, # 降低分类权重
# 性能优化
amp=True, # 自动混合精度训练
cos_lr=True, # 余弦学习率调度
close_mosaic=10, # 最后10epoch关闭马赛克增强
# 数据增强
hsv_h=0.015,
hsv_s=0.7,
hsv_v=0.4,
fliplr=0.5, # 水平翻转对姿态很重要
save=True,
exist_ok=True,

View File

@ -4,7 +4,9 @@ val: valid/images # 验证集目录名
task: pose #任务类型为 pose
nc: 2 # 类别数
loss: CIou
names: ['left_hand', 'right_hand'] # 类别名称
names:
0: 'left_hand'
1: 'right_hand'
kpt_shape: [21, 3] # 关键点数量, 维度数量 (2 表示 x,y 或 3 表示 x,y,可见性)
# 关键点
flip_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]

View File

@ -4,3 +4,4 @@ numpy~=2.0.2
pyautogui~=0.9.54
moviepy~=2.2.1
torch~=2.8.0
PyYAML~=6.0.2