bak
This commit is contained in:
parent
ff389389dc
commit
bdb9b90827
|
|
@ -2,4 +2,5 @@
|
|||
.idea
|
||||
yolo
|
||||
runs
|
||||
models/
|
||||
models/
|
||||
*.pt
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from ultralytics import YOLO
|
||||
import os
|
||||
|
||||
os.environ['ULTRALYTICS_PLOTS'] = 'False' # 禁用绘图环境变量
|
||||
|
||||
|
||||
def safe_training():
|
||||
"""安全训练,彻底避免NumPy问题"""
|
||||
|
||||
model = YOLO('yolov8n-pose.pt')
|
||||
|
||||
results = model.train(
|
||||
data='./dataset1/train.yaml',
|
||||
epochs=100,
|
||||
imgsz=320,
|
||||
batch=2,
|
||||
device='mps',
|
||||
workers=0, # 设置为0避免多进程问题
|
||||
|
||||
# 关键:彻底禁用所有可能触发NumPy bug的功能
|
||||
plots=False, # 禁用绘图
|
||||
save_json=False, # 禁用JSON保存
|
||||
verbose=True,
|
||||
|
||||
# 简化所有参数
|
||||
lr0=0.001,
|
||||
pose=2.0,
|
||||
kobj=1.5,
|
||||
|
||||
# 关闭数据增强
|
||||
augment=False,
|
||||
hsv_h=0.0,
|
||||
hsv_s=0.0,
|
||||
hsv_v=0.0,
|
||||
degrees=0.0,
|
||||
translate=0.0,
|
||||
scale=0.0,
|
||||
fliplr=0.0,
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
print("开始安全训练...")
|
||||
safe_training()
|
||||
|
|
@ -22,5 +22,5 @@ model = YOLO('models/yolov8n-pose.pt')
|
|||
# save_dir: 训练结果保存路径
|
||||
# weights: 预训练模型路径
|
||||
model.info()
|
||||
model.train(data='./dataset1/train.yaml', epochs=300, imgsz=640, batch=32, device=0)
|
||||
model.train(data='./dataset1/train.yaml', epochs=300, imgsz=640, batch=32, device='cpu')
|
||||
print("训练完成")
|
||||
|
|
|
|||
|
|
@ -18,28 +18,29 @@ model.train(
|
|||
device=0, # 使用GPU 0
|
||||
workers=8, # 充分利用CPU核心
|
||||
patience=50,
|
||||
# lr0=0.01,
|
||||
# lrf=0.01,
|
||||
# momentum=0.937,
|
||||
# weight_decay=0.0005,
|
||||
# warmup_epochs=3.0,
|
||||
#
|
||||
# # 损失权重调整(重点解决姿态问题)
|
||||
# box=7.5,
|
||||
# pose=1.5, # 提高姿态损失权重
|
||||
# kobj=2.0, # 提高关键点目标权重
|
||||
# cls=1.0,
|
||||
#
|
||||
# # 性能优化
|
||||
# amp=True, # 自动混合精度训练
|
||||
# cos_lr=True, # 余弦学习率调度
|
||||
# close_mosaic=10, # 最后10epoch关闭马赛克增强
|
||||
#
|
||||
# # 数据增强
|
||||
# hsv_h=0.015,
|
||||
# hsv_s=0.7,
|
||||
# hsv_v=0.4,
|
||||
# fliplr=0.5, # 水平翻转对姿态很重要
|
||||
lr0=0.01,
|
||||
lrf=0.01,
|
||||
momentum=0.937,
|
||||
weight_decay=0.0005,
|
||||
warmup_epochs=3.0,
|
||||
|
||||
# 重点:大幅调整损失权重
|
||||
pose=10.0, # 大幅提高姿态损失权重
|
||||
kobj=5.0, # 提高关键点目标权重
|
||||
box=1.0, # 降低检测权重(因为检测已经很好)
|
||||
cls=0.3, # 降低分类权重
|
||||
|
||||
|
||||
# 性能优化
|
||||
amp=True, # 自动混合精度训练
|
||||
cos_lr=True, # 余弦学习率调度
|
||||
close_mosaic=10, # 最后10epoch关闭马赛克增强
|
||||
|
||||
# 数据增强
|
||||
hsv_h=0.015,
|
||||
hsv_s=0.7,
|
||||
hsv_v=0.4,
|
||||
fliplr=0.5, # 水平翻转对姿态很重要
|
||||
|
||||
save=True,
|
||||
exist_ok=True,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ val: valid/images # 验证集目录名
|
|||
task: pose #任务类型为 pose
|
||||
nc: 2 # 类别数
|
||||
loss: CIou
|
||||
names: ['left_hand', 'right_hand'] # 类别名称
|
||||
names:
|
||||
0: 'left_hand'
|
||||
1: 'right_hand'
|
||||
kpt_shape: [21, 3] # 关键点数量, 维度数量 (2 表示 x,y 或 3 表示 x,y,可见性)
|
||||
# 关键点
|
||||
flip_idx: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
|
||||
|
|
@ -3,4 +3,5 @@ opencv-python~=4.12.0.88
|
|||
numpy~=2.0.2
|
||||
pyautogui~=0.9.54
|
||||
moviepy~=2.2.1
|
||||
torch~=2.8.0
|
||||
torch~=2.8.0
|
||||
PyYAML~=6.0.2
|
||||
Loading…
Reference in New Issue