PyTorch 针对 IDX 类型数据文件创建 Dataset

众所周知的 MNIST 数据集的源文件就是以 IDX 数据形式存放的,其文件名为:

  • train-images-idx3-ubyte
  • train-labels-idx1-ubyte
  • t10k-images-idx3-ubyte
  • t10k-labels-idx1-ubyte

MINST 数据集的官方网页有写明其 IDX3 文件和 IDX1 文件的具体内部结构。但是大家一般使用 MNIST 数据集都是直接用 PyTorch 提供的 MNIST Class,似乎没有很多人研究直接从 IDX3 和 IDX1 文件中读取数据。

从源文件读数据的方法在一定情况下是需要的,毕竟也有很多数据集是以 IDX 类型文件提供的。这里构建了一个(个人感觉)较为完备和易用的方法来读数据。

N.B. 太长不看可以直接拉到最后。

正文

IDX1 和 IDX3 文件结构

内容结构在上面提到的官方网站写的很清楚了:

The IDX file format is a simple format for vectors and multidimensional matrices of various numerical types.

The basic format is:

magic number

size in dimension 0

size in dimension 1

size in dimension 2

size in dimension N

data

The magic number is an integer (MSB first). The first 2 bytes are always 0.

The third byte codes the type of the data:

0x08: unsigned byte

0x09: signed byte

0x0B: short (2 bytes)

0x0C: int (4 bytes)

0x0D: float (4 bytes)

0x0E: double (8 bytes)

The 4-th byte codes the number of dimensions of the vector/matrix: 1 for vectors, 2 for matrices….

The sizes in each dimension are 4-byte integers (MSB first, high endian, like in most non-Intel processors).

The data is stored like in a C array, i.e. the index in the last dimension changes the fastest.

总结下来就是一个 IDX 文件头部的一些二进制数据说明了这个文件的数据类型数据量每个数据的大小,之后每个数据就一个挨一个以二进制数据排列。

例子:具体到 MNIST 数据集:

[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  60000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel

0000 到 0004 位置上的数据是 magic number,也就是说明本数据集的数据类型;

0005 到 0008 位置上的数据值是这个数据集的数据量,在 MNIST 中的值为 60000,说明这个 MNIST 数据集有 60000 个数据;

随后 0009 到 0012 和 0013 到 0016 的数据分别指图像的高和宽.

再往后每个二进制表示一个像素,中间没有隔断,需要自己按照头部数据裁剪。比如这里就可以从 0016 这个二进制数据开始,每 28 * 28 个二进制数据可以转换为一个图片,一直进行下去。

了解源数据结构之后就容易下手了。

初始化 Dataset 类

毕竟数据要自己处理成数据集,所以 Dataset 类要继承着 torch.utils.dataDataset 写一个:

N.B. 这个数据集的类姑且命名为 MNIST

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# author: Chuanbo HUA
# date: 2020-05-06
# update: 2020-05-10
'''
Data Processing.

Decode IDX files to dataset. 
- IDX3: images.
- IDX1: labels.
'''

from torch.utils import data


class MNIST(data.Dataset):


    def __init__(self):
        pass
    
    
    def __getitem__(self, index):
        pass
    
    
    def __len__(self):
        pass

之后定义头部数据的 Offset,可以看到 IDX3 的头部是 4 个 int,IDX1 的头部是 2 个 int。这里涉及 structcalcsize 函数,之后裁剪图片数据的时候也会用到,可参考 Python 3.7 Doc 中对 struct 的介绍。这里的含义即为 4 个 int 的二进制大小。

import struct

self.img_offset = struct.calcsize('>iiii')
self.label_offset = struct.calcsize('>ii')

然后以二进制读取数据集文件,并获取头部数据。这里 'rb' 指 read & binary。

bin_data = open(img_path, 'rb').read()
_, self.num_img, self.num_rows, self.num_cols = struct.unpack_from('>iiii', bin_data, offset=0)

之后就可以开始读图片了。网上很多关于读取 IDX 文件的教程都是包含头部文件一口气读完并写入内存。这样很占用内存和低效。

根据 PyTorch 的 Dataset 优化规则,我们最好把读文件这样较为消耗算力的操作放在 __getitem__ 函数中,而不是在 __init__ 函数里面一口气加载完。而且 IDX 二进制读取是调用底层的 C,读的速度非常的快。

所以说初始化的时候只需要获得一些关键的数据保留在成员变量中即可,主要操作都放在 __getitem__ 函数中。

除了基本的头部数据之外,初始化还储存了几个数据,方便之后 get item 函数调用。这两个 csize 分别是图片和标签的二进制长度,之后读数就按照这个长度依次读就可以。

self.img_fmt = '>' + str(num_rows * num_cols) + 'B'
self.img_fmt_csize = struct.calcsize(self.img_fmt)

self.label_fmt = '>B'
self.label_fmt_csize = struct.calcsize(self.label_fmt)

Get Item 函数

Offset 有了,每个数据的长度有了,那么读第几个数据只要算一算起始位置就可以。

例子:比如要读第 10 个数据,每个数据的二进制长度是 4,算上头部占了 16 个数,那么第 10 个数的位置开始于 16 + 9 * 4 = 52,即 0052 到 0056 这四个二进制数表示第 10 个数据。

img_data = open(self.img_path, 'rb').read()
label_data = open(self.label_path, 'rb').read()

# self.img_offset 指的是头文件长度
img_offset = self.img_offset + self.img_fmt_csize * index
label_offset = self.label_offset + self.label_fmt_csize * index

img = np.array(struct.unpack_from(self.img_fmt, img_data, img_offset)).reshape((1, 	self.img_rows, self.img_cols))
label = struct.unpack_from(self.label_fmt, label_data, label_offset)[0]

整合

把上面的整合起来,加上 input 和 output 就是一个效率不错的 Dataset 类了。

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# author: Chuanbo HUA
# date: 2020-05-06
# update: 2020-05-10
'''
Data Processing.
Decode IDX files to dataset. 
- IDX3: images.
- IDX1: labels.
'''

import numpy as np
import struct
from torch.utils import data
from torchvision import transforms


class NetFlowTrain(data.Dataset):
    '''
    Dataset of train set.
    The percentage of train set is 70%.
    '''


    def __init__(self, img_path, label_path):
        '''
        Args:
            img_path: str=...
            label_path: str=...
        '''

        self.img_path = img_path
        self.label_path = label_path

        self.img_offset = struct.calcsize('>iiii')
        self.label_offset = struct.calcsize('>ii')

        bin_data = open(img_path, 'rb').read()
        _, self.len, self.img_rows, self.img_cols = \
            struct.unpack_from('>iiii', bin_data, 0)

        self.img_fmt = '>' + str(num_rows * num_cols) + 'B'
        self.img_fmt_csize = struct.calcsize(self.img_fmt)

        self.label_fmt = '>B'
        self.label_fmt_csize = struct.calcsize(self.label_fmt)

        print("Create train dataset successfully. \n    -> Number of image: %d \n    -> Size of image: %d * %d\n" % (self.len, num_cols, num_rows))


    def __getitem__(self, index):
        '''
        Returns:
            Turple: (image: array, label: int).
        '''
        
        img_data = open(self.img_path, 'rb').read()
        label_data = open(self.label_path, 'rb').read()

        img_offset = self.img_offset + self.img_fmt_csize * index
        label_offset = self.label_offset + self.label_fmt_csize * index

        img = np.array(struct.unpack_from\
            (self.img_fmt, img_data, img_offset)).\
            reshape((1, self.img_rows, self.img_cols))
        label = struct.unpack_from(self.label_fmt, label_data, label_offset)[0]

        return img, label


    def __len__(self):
        '''
        Return:
            Int: the length of dataset.
        '''

        return self.len

其他的操作和普通数据集处理类似,如归一化等等。