抽空测试了一下Lightning新出的LitData,并和WebDataset进行了简单比较,目前LitData还处于初步阶段。

项目地址#

https://github.com/Lightning-AI/litdata
https://github.com/webdataset/webdataset

当前版本#

litdata=0.2.18
webdataset=0.2.90

LitData处理SA-1B数据集(10TB)#

https://lightning.ai/lightning-ai/studios/prepare-large-image-segmentation-datasets-with-litdata

LitData修改optimize cache路径#

默认cache路径在 /tmp/chunks/,通过设置环境变量修改路径:

import os
# Set your desired cache directory
os.environ["DATA_OPTIMIZER_CACHE_FOLDER"] = "/path/to/your/cache_dir"

不使用CombinedStreamingDataset#

CombinedStreamingDataset在读取时,会出现某个dataset读完了,训练就会卡死的情况。StreamingDataset更简单,但也会出现validation卡死。

比较LitData与WebDataset#

LitData非常不稳定,在训练VQGAN过程中,从training切换到validation就会失败。LitData在4张A800、batch_size=16、num_workers=12时,处理速度为1.69 it/s.

WebDataset通过resample和with_epoch/with_length,先将shards分到4个显卡(进程),再开num_workers=12个进程读取数据,最后在WebLoader进程中重新打乱数据。batch_size=16时,每个epoch的steps数为dataset_size // (batch_size * world_size),处理速度为1.52 it/s,比LitData稍慢。

LitData提供的map并行函数比WebDataset方便。

WebDataset unbatched IndexError#

# webdataset/filters.py:515-521,报错IndexError
def _unbatched(data):
    """Turn batched data back into unbatched data."""
    for sample in data:
        assert isinstance(sample, (tuple, list)), sample
        assert len(sample) > 0
        for i in range(len(sample[0])):
            yield tuple(x[i] for x in sample)

# 修改后,丢弃不等长的样本
def _unbatched(data):
    """Turn batched data back into unbatched data."""
    for sample in data:
        assert isinstance(sample, (tuple, list)), sample
        assert len(sample) > 0
        min_length = min(len(x) for x in sample)
        max_length = max(len(x) for x in sample)
        if max_length != min_length:
            print(f"{max_length} != {min_length}, some data will be discarded.")
        for i in range(min_length):
            yield tuple(x[i] for x in sample)