bak
This commit is contained in:
parent
ff389389dc
commit
bdb9b90827
|
|
@ -2,4 +2,5 @@
|
||||||
.idea
|
.idea
|
||||||
yolo
|
yolo
|
||||||
runs
|
runs
|
||||||
models/
|
models/
|
||||||
|
*.pt
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['ULTRALYTICS_PLOTS'] = 'False' # 禁用绘图环境变量
|
||||||
|
|
||||||
|
|
||||||
|
def safe_training():
|
||||||
|
"""安全训练,彻底避免NumPy问题"""
|
||||||
|
|
||||||
|
model = YOLO('yolov8n-pose.pt')
|
||||||
|
|
||||||
|
results = model.train(
|
||||||
|
data='./dataset1/train.yaml',
|
||||||
|
epochs=100,
|
||||||
|
imgsz=320,
|
||||||
|
batch=2,
|
||||||
|
device='mps',
|
||||||
|
workers=0, # 设置为0避免多进程问题
|
||||||
|
|
||||||
|
# 关键:彻底禁用所有可能触发NumPy bug的功能
|
||||||
|
plots=False, # 禁用绘图
|
||||||
|
save_json=False, # 禁用JSON保存
|
||||||
|
verbose=True,
|
||||||
|
|
||||||
|
# 简化所有参数
|
||||||
|
lr0=0.001,
|
||||||
|
pose=2.0,
|
||||||
|
kobj=1.5,
|
||||||
|
|
||||||
|
# 关闭数据增强
|
||||||
|
augment=False,
|
||||||
|
hsv_h=0.0,
|
||||||
|
hsv_s=0.0,
|
||||||
|
hsv_v=0.0,
|
||||||
|
degrees=0.0,
|
||||||
|
translate=0.0,
|
||||||
|
scale=0.0,
|
||||||
|
fliplr=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
print("开始安全训练...")
|
||||||
|
safe_training()
|
||||||
|
|
@ -22,5 +22,5 @@ model = YOLO('models/yolov8n-pose.pt')
|
||||||
# save_dir: 训练结果保存路径
|
# save_dir: 训练结果保存路径
|
||||||
# weights: 预训练模型路径
|
# weights: 预训练模型路径
|
||||||
model.info()
|
model.info()
|
||||||
model.train(data='./dataset1/train.yaml', epochs=300, imgsz=640, batch=32, device=0)
|
model.train(data='./dataset1/train.yaml', epochs=300, imgsz=640, batch=32, device='cpu')
|
||||||
print("训练完成")
|
print("训练完成")
|
||||||
|
|
|
||||||
|
|
@ -18,28 +18,29 @@ model.train(
|
||||||
device=0, # 使用GPU 0
|
device=0, # 使用GPU 0
|
||||||
workers=8, # 充分利用CPU核心
|
workers=8, # 充分利用CPU核心
|
||||||
patience=50,
|
patience=50,
|
||||||
# lr0=0.01,
|
lr0=0.01,
|
||||||
# lrf=0.01,
|
lrf=0.01,
|
||||||
# momentum=0.937,
|
momentum=0.937,
|
||||||
# weight_decay=0.0005,
|
weight_decay=0.0005,
|
||||||
# warmup_epochs=3.0,
|
warmup_epochs=3.0,
|
||||||
#
|
|
||||||
# # 损失权重调整(重点解决姿态问题)
|
# 重点:大幅调整损失权重
|
||||||
# box=7.5,
|
pose=10.0, # 大幅提高姿态损失权重
|
||||||
# pose=1.5, # 提高姿态损失权重
|
kobj=5.0, # 提高关键点目标权重
|
||||||
# kobj=2.0, # 提高关键点目标权重
|
box=1.0, # 降低检测权重(因为检测已经很好)
|
||||||
# cls=1.0,
|
cls=0.3, # 降低分类权重
|
||||||
#
|
|
||||||
# # 性能优化
|
|
||||||
# amp=True, # 自动混合精度训练
|
# 性能优化
|
||||||
# cos_lr=True, # 余弦学习率调度
|
amp=True, # 自动混合精度训练
|
||||||
# close_mosaic=10, # 最后10epoch关闭马赛克增强
|
cos_lr=True, # 余弦学习率调度
|
||||||
#
|
close_mosaic=10, # 最后10epoch关闭马赛克增强
|
||||||
# # 数据增强
|
|
||||||
# hsv_h=0.015,
|
# 数据增强
|
||||||
# hsv_s=0.7,
|
hsv_h=0.015,
|
||||||
# hsv_v=0.4,
|
hsv_s=0.7,
|
||||||
# fliplr=0.5, # 水平翻转对姿态很重要
|
hsv_v=0.4,
|
||||||
|
fliplr=0.5, # 水平翻转对姿态很重要
|
||||||
|
|
||||||
save=True,
|
save=True,
|
||||||
exist_ok=True,
|
exist_ok=True,
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,9 @@ val: valid/images # 验证集目录名
|
||||||
task: pose #任务类型为 pose
|
task: pose #任务类型为 pose
|
||||||
nc: 2 # 类别数
|
nc: 2 # 类别数
|
||||||
loss: CIou
|
loss: CIou
|
||||||
names: ['left_hand', 'right_hand'] # 类别名称
|
names:
|
||||||
|
0: 'left_hand'
|
||||||
|
1: 'right_hand'
|
||||||
kpt_shape: [21, 3] # 关键点数量, 维度数量 (2 表示 x,y 或 3 表示 x,y,可见性)
|
kpt_shape: [21, 3] # 关键点数量, 维度数量 (2 表示 x,y 或 3 表示 x,y,可见性)
|
||||||
# 关键点
|
# 关键点
|
||||||
flip_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
|
flip_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
|
||||||
|
|
@ -3,4 +3,5 @@ opencv-python~=4.12.0.88
|
||||||
numpy~=2.0.2
|
numpy~=2.0.2
|
||||||
pyautogui~=0.9.54
|
pyautogui~=0.9.54
|
||||||
moviepy~=2.2.1
|
moviepy~=2.2.1
|
||||||
torch~=2.8.0
|
torch~=2.8.0
|
||||||
|
PyYAML~=6.0.2
|
||||||
Loading…
Reference in New Issue