Source code for hetseq.data.h5pyDataset

import os
import bisect
from functools import lru_cache

import h5py

import numpy as np
import torch
import torch.utils.data
from torch.utils.data.dataloader import default_collate


[docs]class BertH5pyData(torch.utils.data.Dataset): # # don't know whether support multiprocess loading? def __init__(self, path, max_pred_length=512): super(BertH5pyData, self).__init__() self.keys = ('input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions', 'masked_lm_ids', 'next_sentence_labels') self.max_pred_length = max_pred_length self.data_file = None self.path = path self.read_data(path) def read_data(self, path): with h5py.File(path, "r", libver='latest', swmr=True) as data_file: self._len = len(data_file[self.keys[0]]) def check_index(self, i): if i < 0 or i >= self._len: raise IndexError('index out of range') @lru_cache(maxsize=8) def __getitem__(self, index): with h5py.File(self.path, "r", libver='latest', swmr=True) as data_file: self.check_index(index) inputs = [data_file[key][index] for key in self.keys] [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [ torch.from_numpy(input.astype(np.int64)) if indice < 5 else torch.from_numpy( np.asarray(input.astype(np.int64))) for indice, input in enumerate(inputs)] masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -1 index = self.max_pred_length # store number of masked tokens in index padded_mask_indices = (masked_lm_positions == 0).nonzero() if len(padded_mask_indices) != 0: index = padded_mask_indices[0].item() masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] def __del__(self): if self.data_file: self.data_file.flush() self.data_file.close() #encounter bug, don't know how to fix it def __len__(self): return self._len # debug # return 11 def size(self, idx: int): """ Return an example's size as a float or tuple. """ return self.max_pred_length # in our BERT preparation, the length is always 512 def set_epoch(self, epoch): pass
[docs]class ConBertH5pyData(torch.utils.data.Dataset): @staticmethod def cumsum(sequence, sample_ratios): r, s = [], 0 for e, ratio in zip(sequence, sample_ratios): curr_len = int(ratio * len(e)) r.append(curr_len + s) s += curr_len return r def __init__(self, datasets, sample_ratios=1): super(ConBertH5pyData, self).__init__() assert len(datasets) > 0, "datasets should not be an empty iterable" self.datasets = list(datasets) if isinstance(sample_ratios, int): sample_ratios = [sample_ratios] * len(self.datasets) self.sample_ratios = sample_ratios self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios) self.real_sizes = [len(d) for d in self.datasets] def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) return self.datasets[dataset_idx][sample_idx] def _get_dataset_and_sample_index(self, idx: int): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] sample_idx = sample_idx % self.real_sizes[dataset_idx] return dataset_idx, sample_idx def collater(self, samples): # For now only supports datasets with same underlying collater implementations # print("samples", type(samples)) if len(samples) == 0: return None if hasattr(self.datasets[0], 'collater'): return self.datasets[0].collater(samples) else: return default_collate(samples) def ordered_indices(self): """Return an ordered list of indices. Batches will be constructed based on this order.""" return np.arange(len(self)) def num_tokens(self, index: int): return np.max(self.size(index)) def size(self, idx: int): """ Return an example's size as a float or tuple. """ dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx) return self.datasets[dataset_idx].size(sample_idx) def set_epoch(self, epoch): pass