|
咱直接捞干的,聊聊怎么让PyTorch读数据更快,别让数据加载成了训练速度的绊脚石。
首先得明白为啥会慢。一般卡住就几个地方:硬盘太磨叽、数据预处理太复杂、CPU和GPU配合不好。
第一招:上DataLoader的多进程
别用单线程慢慢读,把num_workers调起来。一般设成CPU核数左右。比如8核的就试试设成4或者8。
python
loader = DataLoader(dataset, num_workers=4, batch_size=32, shuffle=True)
不过注意,设太高可能内存爆炸,得自己摸着石头过河试试。
第二招:玩命缓存
要是数据不大,直接一股脑全塞内存里。在__init__里读完,在__getitem__里直接取,硬盘就彻底歇了。
要是数据太大,把那些经常要做的转换结果缓存下来,下次直接用。
第三招:数据预处理能提前就提前
别在训练的时候现场做旋转缩放裁剪这些。能提前存好处理完的数据最好。如果非要做,试试torchvision.transforms.functional,比普通的transforms快一点。
第四招:PIN MEMORY
如果用的是GPU,把pin_memory=True打开。这样数据从CPU到GPU能跑得更快,尤其是数据量小的时候效果明显。
python
loader = DataLoader(dataset, num_workers=4, batch_size=32, shuffle=True, pin_memory=True)
第五招:调大batch size,但别太大
一次多读点,效率肯定高。但得看显存脸色,别给撑爆了。找个速度和内存都能接受的平衡点。
第六招:换个快点的硬盘
最笨但最好用的方法。把数据放到SSD硬盘上,速度能起飞。如果搞分布式训练,弄个高速网络存储也行。
第七招:自定义DataLoader
如果数据特别奇葩,PyTorch自带的搞不定,那就自己写个数据读取的逻辑。比如把一堆小文件提前打包成一个大文件,读的时候一次读一大块,能减少硬盘寻道时间。
第八招:数据格式用对
别用一堆小图片,换成像HDF5或者TFRecord这种专门为快速读取设计的格式。或者用WebDataset,直接把数据打包成tar文件,读起来嗖嗖的。
简单说,就是让CPU忙起来,别闲着等数据;让数据离计算芯片近一点,少折腾;硬盘能少读就少读。具体用哪招,得看自己的数据、机器和任务,多试试,找个最适合的组合。
|
|