|
提速PyTorch数据读取的几种土办法
数据读得慢,训练就得干等。这破事谁遇上了都头疼。下面几个法子亲测有效,照着搞能让数据读取快不少。
一、用DataLoader别蛮干
DataLoader里几个参数调对了,速度能上来:
python
from torch.utils.data import DataLoader
关键在这几个参数
loader = DataLoader(
dataset,
batch_size=32,
num_workers=4, 多开几个进程读数据
pin_memory=True, 数据拷到GPU显存旁边,省时间
prefetch_factor=2, 提前准备几批数据
persistent_workers=True 别老重启进程
)
num_workers别设太大,一般设成CPU核数比较合适。设大了反而可能拖慢速度。
二、数据预处理能提前就提前
训练时现做数据增强太费事。能提前处理的就别留着:
python
别在训练时干这些
class BadDataset:
def __init__(self, images):
self.images = images
self.transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def __getitem__(self, idx):
img = self.images[idx]
每次都要重新处理,慢!
return self.transform(img)
好的做法:提前处理好常见的变换
class SmartDataset:
def __init__(self, preprocessed_images):
已经预处理好了,只剩随机增强
self.images = preprocessed_images
self.augment = transforms.RandomHorizontalFlip(p=0.5)
def __getitem__(self, idx):
img = self.images[idx]
return self.augment(img) 只剩轻量操作
图片解码、尺寸调整这些固定操作,提前做成缓存文件,训练时直接读。
三、数据格式选对路
不同数据格式速度差得远:
小文件太多 → 合并成几个大文件
图片数据 → 转成LMDB、HDF5或TFRecord
读磁盘太慢 → 有条件就上NVMe固态硬盘
python
用LMDB存图片,读得快
import lmdb
import pickle
class LMDBDataset:
def __init__(self, lmdb_path):
self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
def __getitem__(self, idx):
with self.env.begin() as txn:
直接按索引取,速度快
data = txn.get(f'{idx:08d}'.encode())
return pickle.loads(data)
四、数据放对地方
数据放哪里也有讲究:
数据放本地SSD > 放机械硬盘
放内存盘(/dev/shm) > 放SSD
数据别和代码放一个盘,省得打架
要是数据量大,考虑整个数据服务器,用万兆网连起来。
五、自定义数据读取逻辑
PyTorch自带的Sampler不一定最合适,自己写可能更高效:
python
from torch.utils.data import Sampler
class FastSampler(Sampler):
def __init__(self, data_source, batch_size):
self.data_source = data_source
self.batch_size = batch_size
def __iter__(self):
让相邻的数据尽量一起读,减少磁盘寻道
indices = list(range(len(self.data_source)))
按数据在磁盘上的顺序排序
indices.sort(key=lambda i: self.data_source.get_disk_location(i))
for i in range(0, len(indices), self.batch_size):
yield indices[i:i + self.batch_size]
六、监控瓶颈在哪里
先搞清楚慢在哪:
python
from torch.utils.data import DataLoader
import time
测试数据读取时间
dataset = YourDataset()
loader = DataLoader(dataset, batch_size=32, num_workers=4)
start = time.time()
for batch_idx, batch in enumerate(loader):
if batch_idx >= 10: 测10批就行
break
total_time = time.time() - start
print(f"平均每批读取时间: {total_time/10:.3f}秒")
关掉数据增强再测一次,就知道预处理花了多少时间
七、终极狠招:全放内存
如果数据不大,直接全读到内存里:
python
class MemoryDataset:
def __init__(self, data_path):
一开始就把所有数据读进来
self.all_data = []
for path in data_path:
self.all_data.append(self._load_and_preprocess(path))
def __getitem__(self, idx):
直接从内存拿,飞快
return self.all_data[idx]
八、其他零碎建议
数据别混类型:float32和float64混用可能拖慢速度
数据对齐:让数据在内存里连续存放
定期清缓存:训练久了系统缓存可能满了
用最新版PyTorch:新版本通常优化更好
简单总结
数据读取优化就几句话:能提前处理的别留着,能放内存的别放硬盘,能多进程的别单干,数据排好序再读。具体用哪个法子得看实际情况,一般先加num_workers和pin_memory,效果不明显再试其他方法。
实在不行就堆硬件,换固态硬盘、加内存,钱能解决的问题就不是问题。不过先试试上面的办法,多半能省下这笔钱。
|