【PyTorch与深度学习】5、深入剖析PyTorch DataLoader源码

课程地址
最近做实验发现自己还是基础框架上掌握得不好,于是开始重学一遍PyTorch框架,这个是课程笔记,此节课很详细,笔记记的比较粗

1. DataLoader

1.1 DataLoader类实现

1.1.1 构造函数__init__实现

构造函数有如下参数:

  • dataset:传入自己定义好的数据集类Dataset
  • batch_size:默认值为1,它代表着每批次训练的样本的个数
  • shuffle:布尔类型,True为打乱数据集,False为不打乱数据集
  • sampler:决定以何种方式对数据进行采样,可以不用shuffle随机打乱样本,可以用自己编写的函数去决定如何取样本,比如:你想让你的样本以一种有序的方式来组织成mini-batch,比如把长度比较接近的样本放入到一个mini-batch中,这个时候就不能用shuffle,因为一打乱,这些样本的长度就是乱的。如果传入该参数,则shuffle就没有意义。
  • batch_sampler:可以用自己编写的函数成批次地取样本。如果传入该参数,则shuffle就没有意义。
  • num_workers:默认值为0,它是指数据加载的子进程数量,以加快数据加载的速度,提高训练效率。一般数值设定取决于CPU的核心数,通常数字大到一定程度,其加载速度也不会再提高了。
  • collate_fn:聚集函数,它是对一个批次batch进行后处理,比如:我们通过shuffle打乱后得到一个批次batch,然后对这个batch我们希望对它进行一个pad,但是这个pad的长度只能通过batch去算出来,而不是预先能计算出长度,这个时候我们就要用到collate_fn参数,对之前的shuffle后的mini-batch再处理一下,把这个批次batch给它pad成一样的长度,然后再返回一个新的批次batch。

【注】在深度学习和自然语言处理(NLP)等领域中,pad(填充)是一个常见的预处理步骤,特别是在处理变长序列(如文本、时间序列等)时。当使用DataLoader从数据集中批量提取数据时,如果每个数据项(例如,句子或时间序列)的长度不同,那么为了能够在同一批次中进行高效计算(例如,通过矩阵运算),我们通常需要将这些数据项填充(或截断)到相同的长度。
这就是collate_fn参数发挥作用的地方。默认情况下,DataLoader使用了一个内置的collate_fn来将一批数据项组合成一个张量(tensor),但这个默认函数并不进行填充。为了进行填充,你需要提供一个自定义的collate_fn。

  • pin_memory:布尔类型,默认值为False,用于指定是否将数据加载到固定的内存区域(pinned memory)中。固定内存区域是指一块被操作系统锁定的内存,这样可以防止它被移动,从而提高数据传输的效率。当pin_memory参数设置为True时,PyTorch会尝试将从数据集加载的数据存储在固定的内存中,这对于GPU加速的情况下可以提高数据传输效率,因为GPU可以直接从固定内存中访问数据,而不需要进行额外的内存拷贝操作。需要注意的是,只有当你使用GPU进行训练时,才会考虑使用pin_memory参数。对于CPU训练来说,pin_memory参数的影响通常不太明显。而且这个东西对训练速度的影响还有待考究。
  • drop_last:布尔类型,默认为False,如果你的总样本数目不是每个批次batch的整数倍的话,这时候我们可以将drop_last设置为True,让最后那个小批次(样本数没达到batch-size的批次)丢掉。

构造函数的具体代码和注释如下:

    def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                 shuffle: bool = False, sampler: Union[Sampler, Iterable, None] = None,
                 batch_sampler: Union[Sampler[Sequence], Iterable[Sequence], None] = None,
                 num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                 pin_memory: bool = False, drop_last: bool = False,
                 timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                 multiprocessing_context=None, generator=None,
                 *, prefetch_factor: int = 2,
                 persistent_workers: bool = False):
        torch._C._log_api_usage_once("python.data_loader")

        if num_workers < 0:
            raise ValueError('num_workers option should be non-negative; '
                             'use num_workers=0 to disable multiprocessing.')

        if timeout < 0:
            raise ValueError('timeout option should be non-negative')

        if num_workers == 0 and prefetch_factor != 2:
            raise ValueError('prefetch_factor option could only be specified in multiprocessing.'
                             'let num_workers > 0 to enable multiprocessing.')
        assert prefetch_factor > 0

        if persistent_workers and num_workers == 0:
            raise ValueError('persistent_workers option needs num_workers > 0')
		# 设置成员函数
        self.dataset = dataset
        self.num_workers = num_workers
        self.prefetch_factor = prefetch_factor
        self.pin_memory = pin_memory
        self.timeout = timeout
        self.worker_init_fn = worker_init_fn
        self.multiprocessing_context = multiprocessing_context

        # 这里不用看,一般我们都是用Dataset类,而不是IterableDataset,所以直接看这个if条件后面对应的else条件
        if isinstance(dataset, IterableDataset):
            self._dataset_kind = _DatasetKind.Iterable
            
            if isinstance(dataset, IterDataPipe):
                torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle)
            elif shuffle is not False:
                raise ValueError(
                    "DataLoader with IterableDataset: expected unspecified "
                    "shuffle option, but got shuffle={}".format(shuffle))

            if sampler is not None:
                # See NOTE [ Custom Samplers and IterableDataset ]
                raise ValueError(
                    "DataLoader with IterableDataset: expected unspecified "
                    "sampler option, but got sampler={}".format(sampler))
            elif batch_sampler is not None:
                # See NOTE [ Custom Samplers and IterableDataset ]
                raise ValueError(
                    "DataLoader with IterableDataset: expected unspecified "
                    "batch_sampler option, but got batch_sampler={}".format(batch_sampler))
        # 直接跳到else条件
        else:
        	# 设置数据集的种类是DatasetKind.Map类型
            self._dataset_kind = _DatasetKind.Map


		# 如果你设置了sampler(默认为None),如果你传入了自定义的sampler且shuffle设置为True的话,这种情况是没有意义的,shuffle是官方提供的一种随机采用党的sampler,你都自定义sampler了,就不需要shuffle来随机打乱。所以shuffle和sampler是互斥的,不能同时去设置
        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')
        # batch_sampler是批次级别的采样,sampler是样本级的采样,
        if batch_sampler is not None:
            # 如果你设置了batch_size不是1,或者你设置了shuffle或者你设置了sampler,或者你设置了drop_last,这些都与batch_sampler是互斥的,总结一句话就是:你只要设置了batch_sampler就不需要设置batch_size了,因为你设置了batch_sampler就已经告诉PyTorch框架你的batch_size和以什么样的方式去构成mini-batch
            if batch_size != 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            batch_size = None
            drop_last = False
        # 如果batch_size是None,同时如果有drop_last,这时候会报错
        elif batch_size is None:
            # no auto_collation
            if drop_last:
                raise ValueError('batch_size=None option disables auto-batching '
                                 'and is mutually exclusive with drop_last')
		# 如果你没有设置sampler的话
        if sampler is None:  # give default samplers
            if self._dataset_kind == _DatasetKind.Iterable:
                # See NOTE [ Custom Samplers and IterableDataset ]
                sampler = _InfiniteConstantSampler()
            else:  # map-style(常用的),如果你设置了shuffle的话,它就会用内置的一个叫random sample的类来去对我们这个Dataset进行一个随机的打乱。具体实现在下面的章节
                if shuffle:
                    sampler = RandomSampler(dataset, generator=generator)  # type: ignore[arg-type]
                # 如果没有设置shuffle为True的话,它就用SequentialSampler即按原本的顺序来采样
                else:
                    sampler = SequentialSampler(dataset)  # type: ignore[arg-type]
		# 如果你的batch_size不是None并且batch_sampler也不是None
		# 它就默认给你构造一个batch_sampler
		# BatchSampler源码实现见下面的章节
        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.batch_size = batch_size
        self.drop_last = drop_last
        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.generator = generator
		# 如果collate_fn参数为None,则如果设置了auto_collatoion,就调用默认的default_collate
        if collate_fn is None:
        	# _auto_collation是根据batch_sampler是否为None来去设置的,如果batch_sampler不是None,_auto_collation设置为True,如果batch_sampler是None的话,它就会调用_utils.collate.default_convert这个函数,否则调用_utils.collate.default_collate函数。
        	# _utils.collate.default_collate函数是以batch作为输入,它相当于什么都没做,最后返回了个batch,如果自己要实现这个collate_fn,要以batch做输入,然后再做处理。
            if self._auto_collation:
                collate_fn = _utils.collate.default_collate
            else:
                collate_fn = _utils.collate.default_convert

        self.collate_fn = collate_fn
        self.persistent_workers = persistent_workers

        self.__initialized = True
        self._IterableDataset_len_called = None  # See NOTE [ IterableDataset and __len__ ]

        self._iterator = None

        self.check_worker_number_rationality()

        torch.set_vital('Dataloader', 'enabled', 'True')  # type: ignore[attr-defined]

1.1.2 _get_iterator函数

    def _get_iterator(self) -> '_BaseDataLoaderIter':
    	# 如果设置num_workers为0的话,它就走单个样本处理过程
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
        # 如果num_workers不为0,说明是多进程读取样本
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

一般迭代用,是在__iter__方法中实现的,使得DataLoader能变成一个可迭代的对象。

1.2 RandomSampler 类的实现

重点看中文注释

class RandomSampler(Sampler[int]):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify :attr:`num_samples` to draw.

    Args:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`.
        generator (Generator): Generator used in sampling.
    """
    data_source: Sized
    replacement: bool

    def __init__(self, data_source: Sized, replacement: bool = False,
                 num_samples: Optional[int] = None, generator=None) -> None:
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples
        self.generator = generator

        if not isinstance(self.replacement, bool):
            raise TypeError("replacement should be a boolean value, but got "
                            "replacement={}".format(self.replacement))

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    @property
    def num_samples(self) -> int:
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples
	
	# 首先看__iter__方法
    def __iter__(self) -> Iterator[int]:
    	# 获取数据集的大小
        n = len(self.data_source)
        # 如果没有传入generator的话,他就会随机生成一个种子,去构建一个生成器generator
        if self.generator is None:
       		# 设置随机数的种子
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            generator = torch.Generator()
            generator.manual_seed(seed)
        else:
            generator = self.generator

        if self.replacement:
            for _ in range(self.num_samples // 32):
                yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
            # 返回0到n-1的列表的随机组合,n是数据集长度
            yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
        else:
            for _ in range(self.num_samples // n):
                yield from torch.randperm(n, generator=generator).tolist()
            yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]

    def __len__(self) -> int:
        return self.num_samples

1.3 SequentialSampler类的实现

class SequentialSampler(Sampler[int]):
    r"""Samples elements sequentially, always in the same order.

    Args:
        data_source (Dataset): dataset to sample from
    """
    data_source: Sized

    def __init__(self, data_source: Sized) -> None:
        self.data_source = data_source
	# 如果迭代它,返回的是有序的索引
    def __iter__(self) -> Iterator[int]:
        return iter(range(len(self.data_source)))

    def __len__(self) -> int:
        return len(self.data_source)

1.4 BatchSampler类的实现

也是直接看__iter__函数

class BatchSampler(Sampler[List[int]]):

    def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool) -> None:
        # Since collections.abc.Iterable does not check for `__getitem__`, which
        # is one way for an object to be an iterable, we don't do an `isinstance`
        # check here.
        if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last
	# 先看iter函数
    def __iter__(self) -> Iterator[List[int]]:
    	# 先创建一个空列表batch
        batch = []
        # 对sampler进行一个迭代,去元素的索引
        for idx in self.sampler:
        	# 将其索引添加到列表中
            batch.append(idx)
            # 如果列表长度等于batch_size,这时候就返回列表,相当于返回一个批次batch,然后把batch置为空
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        # 如果drop_last(是否丢弃最后的不够一个批次数量的元素)设置为False,那我们就把最后这个不够数量的批次也返回
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self) -> int:
        # Can only be called if self.sampler has __len__ implemented
        # We cannot enforce this condition, so we turn off typechecking for the
        # implementation below.
        # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
        if self.drop_last:
            return len(self.sampler) // self.batch_size  # type: ignore[arg-type]
        else:
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size  # type: ignore[arg-type]

1.5 其他

这个UP讲的太详细了,没全记录,部分细节可以看看视频

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/592144.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL基础_1.MySQL概述

文章目录 一、关系型数据库和非关系型数据库1.1 关系型&#xff08;RDBMS&#xff09;1.2 非关系型&#xff08;非RDBMS&#xff09; 二、常用的基础语句2.1 查看表的创建信息2.2 编码问题 一、关系型数据库和非关系型数据库 1.1 关系型&#xff08;RDBMS&#xff09; 是最古…

都上3D数字孪生了,2D的WEB组态和大屏可视化未来的发展在哪里?趋势是基于页面嵌套、蓝图连线等新技术,与功能业务应用融合

首先回顾下组态工具的发展史&#xff1a; 回顾发展史&#xff0c;WEB组态终于可以搭建业务系统了&#xff01;&#xff08;页面嵌套 节点编辑 WEB组态 上位机 大屏可视化 无代码 0代码 iframe nodered 蓝图&#xff09;-CSDN博客文章浏览阅读624次&#xff0c;点赞12次&#x…

ThreeJS:纹理的颜色空间

色彩空间Color Space 在ThreeJS中&#xff0c;纹理的colorSpace属性用于定义文里的颜色空间。 颜色空间是一个用于描述颜色的数学模型&#xff0c;在现实生活中&#xff0c;人眼可以观察到无数种颜色&#xff0c;而颜色空间就是用来描述这些颜色的一个方法&#xff0c;不同的颜…

C语言-自定义类型:结构体,枚举,联合

目录 一、结构体1.1 结构体变量的定义和初始化1.2 结构体内存对齐1.3 修改默认对齐数1.4 结构体传参 二、位段2.1 什么是位段2.2 位段的内存分配2.3 位段的跨平台问题2.4 位段的应用 三、枚举3.1 枚举类型的定义3.2 枚举的优点 四、联合&#xff08;共用体&#xff09;4.1 联合…

c#数据库: 9.删除和添加新字段/数据更新

先把原来数据表的sexy字段删除,然后重新在添加字段sexy,如果添加成功,sexy列的随机内容会更新.原数据表如下: using System; using System.Collections.Generic; using System.Data; using System.Data.Common; using System.Data.SqlClient; using System.Linq; using System.…

Linux理解文件操作 文件描述符fd 理解重定向 dup2 缓冲区 C语言实现自己的shell

文章目录 前言一、文件相关概念与操作1.1 open()1.2 close()1.3 write()1.4 read()1.4 写入的时候先清空文件内容再写入1.5 追加&#xff08;a && a&#xff09; 二、文件描述符2.1 文件描述符 fd 0 1 2 的理解2.2 FILE结构体&#xff1a;的源代码 三、深入理解文件描述…

jupyter notebook 设置密码报错ModuleNotFoundError: No module named ‘notebook.auth‘

jupyter notebook 设置密码报错ModuleNotFoundError: No module named ‘notebook.auth‘ 原因是notebook新版本没有notebook.auth 直接输入以下命令即可设置密码 jupyter notebook password

k8s调度原理以及自定义调度器

kube-scheduler 是 kubernetes 的核心组件之一&#xff0c;主要负责整个集群资源的调度功能&#xff0c;根据特定的调度算法和策略&#xff0c;将 Pod 调度到最优的工作节点上面去&#xff0c;从而更加合理、更加充分的利用集群的资源&#xff0c;这也是我们选择使用 kubernete…

Linux---软硬链接

软链接 我们先学习一下怎样创建软链接文件&#xff0c;指令格式为&#xff1a;ln -s 被链接的文件 生成的链接文件名 我们可以这样记忆&#xff1a;ln是link的简称&#xff0c;s是soft的简称。 我们在下面的图片中就是给test文件生成了一个软链接mytest&#xff1a; 我们来解…

数据结构篇其四---栈:后进先出的魔法世界

前言 栈的学习难度非常简单&#xff0c;前提是如果你学过顺序表和单链表的话&#xff0c;我直接说我的观点了&#xff0c;栈就是有限制的顺序表和单链表。 栈只允许一端进行插入删除。栈去除了各种情况复杂的插入删除&#xff0c;只允许一端插入删除的特性&#xff0c;这一种数…

5月4(信息差)

&#x1f384; HDMI ARC国产双精度浮点dsp杜比数码7.1声道解码AC3/dts/AAC环绕声光纤、同轴、USB输入解码板KC33C &#x1f30d; 国铁集团回应高铁票价将上涨 https://finance.eastmoney.com/a/202405043066422773.html ✨ 源代码管理平台GitLab发布人工智能编程助手DuoCha…

【数据结构】您有一份KMP算法教学已到账,请注意查收!!!

KMP算法 导读一、KMP算法1.1 重要术语1.2 部分匹配值1.3 部分匹配值的作用 二、KMP算法原理2.1 从指针的角度理解KMP算法2.2 从匹配的角度理解KMP算法2.3 小结 三、KMP算法的实现3.1 next数组3.2 next数组的计算3.2.1 通过PM值计算next数组3.2.2 通过移位模拟计算next数组3.2.3…

Web Storage 笔记11 网页数据存储

相关内容&#xff1a;Web Storage基本概念、localStorage、sessionStorage、登录注销实例、…… 在制作网页时会希望记录一些信息&#xff0c;例如用户登录状态、计数器或者小游戏等&#xff0c;但是又不希望用到数据库&#xff0c;就可以利用WebStorage技术将数据存储在用户浏…

Kubelet containerd 管理命令 ctr常用操作

镜像常用操作 1. 拉取镜像 ctr images pull docker.io/library/nginx:alpine 指定平台 --all-platforms&#xff1a;所有平台&#xff08;amd64 、arm、386 、ppc64le 等&#xff09;&#xff0c;不加的话下载当前平台架构 --platform&#xff1a;指定linux/amd64平台 ctr …

鸿蒙开发仿咸鱼TabBar

鸿蒙开发自定义TabBar&#xff0c;实现tabBar 上中间按钮凸起效果 第一步、定义数据模型 export default class TabItemData{defaultIcon: ResourceselectedIcon: Resourcetitle: stringisMiddle: booleanconstructor(defaultIcon:Resource, selectedIcon:Resource, title:st…

并发-启动线程的正确姿势

目录 启动线程的正确姿势 Start方法原理解读 Run方法原理解读 常见问题 启动线程的正确姿势 start()与run()方法的比较测试结果可以看出&#xff0c;runnable.run()方法是由main线程执行的&#xff0c;而要子线程执行就一定要先调用start()启动新线程去执行run方法并不能成…

【数据结构】第四讲:双向链表

目录 一、链表的分类 二、双向链表的结构及实现 1.带头双向链表的结构 2.创建节点 3.初始化 4.尾插 5.打印 6.头插 7.尾删 8.头删 9.在pos位置之后插入数据 10.删除pos节点 11.查找 12.销毁 个人主页&#xff1a;深情秋刀鱼-CSDN博客 数据结构专栏&#xff1a;数…

Mac M2 本地下载 Xinference

想要在Mac M2 上部署一个本地的模型。看到了Xinference 这个工具 一、Xorbits Inference 是什么 Xorbits Inference&#xff08;Xinference&#xff09;是一个性能强大且功能全面的分布式推理框架。可用于大语言模型&#xff08;LLM&#xff09;&#xff0c;语音识别模型&…

激动,五四青年节,拿下YashanDB认证YCP

&#x1f4e2;&#x1f4e2;&#x1f4e2;&#x1f4e3;&#x1f4e3;&#x1f4e3; 作者&#xff1a;IT邦德 中国DBA联盟(ACDU)成员&#xff0c;10余年DBA工作经验&#xff0c; Oracle、PostgreSQL ACE CSDN博客专家及B站知名UP主&#xff0c;全网粉丝10万 擅长主流Oracle、My…

中间件之搜索和数据分析组件Elasticsearch

一、概述 1.1介绍 The Elastic Stack, 包括 Elasticsearch、Kibana、Beats 和 Logstash&#xff08;也称为 ELK Stack&#xff09;。 能够安全可靠地获取任何来源、任何格式的数据&#xff0c;然后实时地对数据进行搜索、分析和可视 化。Elaticsearch&#xff0c;简称为 ES&a…
最新文章