
在之前的专栏中,介绍了交叉熵训练阶段的运算逻辑,分割数据集、封装batch、加载数据进行训练可以通过下面三行代码来实现。
(资料图)
对于强化学习训练,却需要通过下面的代码展开训练。
其中加载数据展开训练的train_scst没什么好说的,下面具体记录一下训练开始前的数据准备过程。
首先来看一下train_dataset对象中的image_dictionary方法,如下:
可以看到其实际上是实例化了一个名为DictionaryDataset的类,并实例化后的的对象作为返回值。即在强化学习阶段使用的数据集不再是PairedDataset类的对象,而是DictionaryDataset类的对象。接下来看一下DictionaryDataset类的初始化函数。PS:此处的self.examples是交叉熵处理过程中提到的样本对象列表。
【1】基于样本对象构建两个字典,然后封装成对象。字典是{'image':'this_is_a_filename.jpg'}和{'text':‘this is a caption’},封装后的对象名分别为key_example和value_example,记为图片对象和描述对象.
【2】样本对象列表examples是基于描述语句构建的,即10w张图片对应50w条描述,则examples的长度为50w,其中每个样本由一条描述语句和一张图片组成。显然,通过这样的方式构建的examples中,连续的五条样本实际上对应着相同的图片。
所以,在这里需要进行判断,如果当前样本的图片对象第一次出现,则选择添加,否则跳过。当选择添加时,首先构建一个字典 { 图片对象:已有图片对象列表的长度 },通过该字典可以很方便的判断当前样本的图片对象是否已添加;然后将当前样本的图片对象添加到一个列表中。
【3】描述对象不会存在重复添加的问题,所以将每个样本的描述对象都添加到列表中即可。
【4】构建一个列表来确定图片对象和描述对象的对应关系(一个图片对应五条描述),具体来说就是构建一个以图片对象编号为key的字典,其value为样本编号)等价于描述对象编号)。
上述变量之间的关系如图所示:
在初始化函数的最后,利用Dataset类和ValueDataset类分别实例化图片对象列表和描述对象列表。需要注意的时此处使用Dataset类的方式与交叉熵训练阶段不同,此处的第二个传入参数是仅包含一个元素的字典,所以在Dataset类的getitem和collate_fn中也只会进行一次循环,即仅调用ImageDetectionsField中的preprocess和preproces对视觉特征进行处理,而不关注描述文本部分。
由于在强化阶段使用数据集是DictionaryDataset类的对象,所以在调用DataLoader封装batch时使用的也将会是DictionaryDataset类中自定义的collate_fn和getitem方法。
在上述getitem代码中,直接通过下标的方式从两个对象中取出数据,self.key_dataset是Dataset类的对象,自不必多说,而self.value_dataset是ValueDataset类的对象,该类还没有进行过介绍。在collate_fn代码中,可以看到首先利用zip将图片对象和描述对象分离,然后分别调用Dataset类和ValueDataset类中定义的collate_fn方法处理。视觉特的处理部分和交叉熵训练中相似,故不多讲,接下来重点介绍文本处理部分。
ValueDataset类的初始化方法中没有什么值得在意的,只不过是将传入参数赋值给了类中的变量。接下来看getitem方法。
self.dictionary是一个字典,self.dictionary[i]是将其中key为i个元素的value取出,即一个包含描述样本编号的列表。接下来根据这个列表建立训练,在循环中的super(ValueDataset,self)代表其父类,调用父类中的getitem方法来利用描述样本的编号取出描述文本,随后将描述样本组成列表,作为返回值。
最后是ValueDataset类的collate_fn,没有仔细研究,等看明白了再补充。