找回密码
 立即注册
搜索

如何优化PyTorch数据读取?

[复制链接]
xinwen.mobi 发表于 2026-1-4 21:21:51 | 显示全部楼层 |阅读模式

提速PyTorch数据读取的几种土办法
数据读得慢,训练就得干等。这破事谁遇上了都头疼。下面几个法子亲测有效,照着搞能让数据读取快不少。

一、用DataLoader别蛮干
DataLoader里几个参数调对了,速度能上来:

python
from torch.utils.data import DataLoader

关键在这几个参数
loader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=4,  多开几个进程读数据
    pin_memory=True,  数据拷到GPU显存旁边,省时间
    prefetch_factor=2,  提前准备几批数据
    persistent_workers=True  别老重启进程
)
num_workers别设太大,一般设成CPU核数比较合适。设大了反而可能拖慢速度。

二、数据预处理能提前就提前
训练时现做数据增强太费事。能提前处理的就别留着:

python
别在训练时干这些
class BadDataset:
    def __init__(self, images):
        self.images = images
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])
   
    def __getitem__(self, idx):
        img = self.images[idx]
        每次都要重新处理,慢!
        return self.transform(img)

好的做法:提前处理好常见的变换
class SmartDataset:
    def __init__(self, preprocessed_images):
        已经预处理好了,只剩随机增强
        self.images = preprocessed_images
        self.augment = transforms.RandomHorizontalFlip(p=0.5)
   
    def __getitem__(self, idx):
        img = self.images[idx]
        return self.augment(img)  只剩轻量操作
图片解码、尺寸调整这些固定操作,提前做成缓存文件,训练时直接读。

三、数据格式选对路
不同数据格式速度差得远:

小文件太多 → 合并成几个大文件

图片数据 → 转成LMDB、HDF5或TFRecord

读磁盘太慢 → 有条件就上NVMe固态硬盘

python
用LMDB存图片,读得快
import lmdb
import pickle

class LMDBDataset:
    def __init__(self, lmdb_path):
        self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
        
    def __getitem__(self, idx):
        with self.env.begin() as txn:
            直接按索引取,速度快
            data = txn.get(f'{idx:08d}'.encode())
        return pickle.loads(data)
四、数据放对地方
数据放哪里也有讲究:

数据放本地SSD > 放机械硬盘

放内存盘(/dev/shm) > 放SSD

数据别和代码放一个盘,省得打架

要是数据量大,考虑整个数据服务器,用万兆网连起来。

五、自定义数据读取逻辑
PyTorch自带的Sampler不一定最合适,自己写可能更高效:

python
from torch.utils.data import Sampler

class FastSampler(Sampler):
    def __init__(self, data_source, batch_size):
        self.data_source = data_source
        self.batch_size = batch_size
        
    def __iter__(self):
        让相邻的数据尽量一起读,减少磁盘寻道
        indices = list(range(len(self.data_source)))
        按数据在磁盘上的顺序排序
        indices.sort(key=lambda i: self.data_source.get_disk_location(i))
        
        for i in range(0, len(indices), self.batch_size):
            yield indices[i:i + self.batch_size]
六、监控瓶颈在哪里
先搞清楚慢在哪:

python
from torch.utils.data import DataLoader
import time

测试数据读取时间
dataset = YourDataset()
loader = DataLoader(dataset, batch_size=32, num_workers=4)

start = time.time()
for batch_idx, batch in enumerate(loader):
    if batch_idx >= 10:  测10批就行
        break
        
total_time = time.time() - start
print(f"平均每批读取时间: {total_time/10:.3f}秒")

关掉数据增强再测一次,就知道预处理花了多少时间
七、终极狠招:全放内存
如果数据不大,直接全读到内存里:

python
class MemoryDataset:
    def __init__(self, data_path):
        一开始就把所有数据读进来
        self.all_data = []
        for path in data_path:
            self.all_data.append(self._load_and_preprocess(path))
        
    def __getitem__(self, idx):
        直接从内存拿,飞快
        return self.all_data[idx]
八、其他零碎建议
数据别混类型:float32和float64混用可能拖慢速度

数据对齐:让数据在内存里连续存放

定期清缓存:训练久了系统缓存可能满了

用最新版PyTorch:新版本通常优化更好

简单总结
数据读取优化就几句话:能提前处理的别留着,能放内存的别放硬盘,能多进程的别单干,数据排好序再读。具体用哪个法子得看实际情况,一般先加num_workers和pin_memory,效果不明显再试其他方法。

实在不行就堆硬件,换固态硬盘、加内存,钱能解决的问题就不是问题。不过先试试上面的办法,多半能省下这笔钱。

回复

使用道具 举报

QQ|周边二手车|手机版|标签|xml|txt|新闻魔笔科技XinWen.MoBi - 海量语音新闻! ( 粤ICP备2024355322号-1|粤公网安备44090202001230号 )|网站地图

GMT+8, 2026-1-14 23:02 , Processed in 0.057756 second(s), 19 queries .

Powered by Discuz! X3.5

© 2001-2026 Discuz! Team.

快速回复 返回顶部 返回列表