Source code for hetseq.tasks.tasks

import os
import collections

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from hetseq import distributed_utils
from hetseq.data import (
    BertNerDataset,
    MNISTDataset,
    BertH5pyData,
    ConBertH5pyData,
    data_utils,
    iterators,
)



[docs]class Task(object): """ Tasks store dictionaries and provide helpers for loading/iterating over Datasets, initializing the Model/Criterion and calculating the loss. """ def __init__(self, args): self.args = args self.datasets = {} self.dataset_to_epoch_iter = {} def load_dictionary(self, vocab_file): """Loads a vocabulary file into a dictionary.""" vocab = collections.OrderedDict() index = 0 with open(vocab_file, "r", encoding="utf-8") as reader: while True: token = reader.readline() if not token: break token = token.strip() vocab[token] = index index += 1 print('| loaded dictionary with {} subwords from: {}'.format(index, vocab_file)) return vocab def load_dataset(self, split, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ raise NotImplementedError def dataset(self, split): """ Return a loaded dataset split. Args: split (str): name of the split (e.g., train, valid, test) Returns: a :class:`~torch.utils.data.Dataset` corresponding to *split* """ if split not in self.datasets: raise KeyError('Dataset not loaded: ' + split) if not isinstance(self.datasets[split], torch.utils.data.Dataset): raise TypeError('Datasets are expected to be of type torch.utils.data.Dataset') return self.datasets[split] def get_batch_iterator( self, dataset, max_tokens=None, max_sentences=None, max_positions=None, ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=0, ): """ Get an iterator that yields batches of data from the given dataset. Args: dataset (~torch.utils.data.Dataset): dataset to batch max_tokens (int, optional): max number of tokens in each batch (default: None). max_sentences (int, optional): max number of sentences in each batch (default: None). max_positions (optional): max sentence length supported by the model (default: None). ignore_invalid_inputs (bool, optional): don't raise Exception for sentences that are too long (default: False). required_batch_size_multiple (int, optional): require batch size to be a multiple of N (default: 1). seed (int, optional): seed for random number generator for reproducibility (default: 1). 19940802 by options default num_shards (int, optional): shard the data iterator into N shards (default: 1) = num of gpus shard_id (int, optional): which shard of the data iterator to return (default: 0). = index of gpu num_workers (int, optional): how many subprocesses to use for data loading. 0 means the data will be loaded in the main process (default: 0). 16 by options default epoch (int, optional): the epoch to start the iterator from (default: 0). Returns: ~iterators.EpochBatchIterator: a batched iterator over the given dataset split """ # For default fairseq task, return same iterator across epochs # as datasets are not dynamic, can be overridden in task specific # setting. if dataset in self.dataset_to_epoch_iter: return self.dataset_to_epoch_iter[dataset] # initialize the dataset with the correct starting epoch #dataset.set_epoch(epoch) # get indices ordered by example size with data_utils.numpy_seed(seed): indices = dataset.ordered_indices() # create mini-batches with given size constraints print('| build batch sampler') batch_sampler = data_utils.batch_by_size( indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, required_batch_size_multiple=required_batch_size_multiple, ) print('| finish building batch sampler') # return a reusable, sharded iterator epoch_iter = iterators.EpochBatchIterator( dataset=dataset, collate_fn=dataset.collater, batch_sampler=batch_sampler, seed=seed, num_shards=num_shards, shard_id=shard_id, num_workers=num_workers, epoch=epoch, ) self.dataset_to_epoch_iter[dataset] = epoch_iter return epoch_iter def build_model(self, args): """ Build the :class:`~torch.nn.Module` instance for this task. Args: args (argparse.Namespace): parsed command-line arguments Returns: a :class:`~torch.nn.Module` instance """ raise NotImplementedError def train_step(self, sample, model, optimizer, ignore_grad=False): """ Do forward and backward, and return the loss as computed by *criterion* for the given *model* and *sample*. Args: sample (dict): the mini-batch. The format is defined by the :class:`~torch.utils.data.Dataset`. model (~torch.nn.Module): the model optimizer (~optim._Optimizer): the optimizer ignore_grad (bool): multiply loss by 0 if this is set to True Returns: tuple: - the loss - the sample size, which is used as the denominator for the gradient - logging outputs to display while training """ model.train() loss = model(*sample) if ignore_grad: loss *= 0 if sample is None or len(sample) == 0 or len(sample[0][0]) == 0: sample_size = 0 else: sample_size = len(sample[0][0]) nsentences = sample_size logging_output = { 'nsentences': nsentences, 'loss': loss.data, 'nll_loss': loss.data, 'ntokens': 0, 'sample_size': sample_size, } optimizer.backward(loss) return loss, sample_size, logging_output def update_step(self, num_updates): """Task level update when number of update increases. This is called after optimization step and learning rate update of each step""" pass
[docs]class LanguageModelingTask(Task): """ Train a language model, currently support BERT. Args: args: parsed from command line dictionary: the BPE dictionary for the input of the language model """ def __init__(self, args, dictionary): super(LanguageModelingTask, self).__init__(args) self.dictionary = dictionary ''' self.output_dictionary = output_dictionary or dictionary if targets is None: targets = ["future"] self.targets = targets ''' @classmethod def setup_task(cls, args, **kwargs): """Setup the task (e.g., load dictionaries). Args: args (argparse.Namespace): parsed command-line arguments """ dictionary = cls.load_dictionary(cls, args.dict) return cls(args, dictionary) def build_model(self, args): if args.task == 'bert': from hetseq.bert_modeling import BertForPreTraining, BertConfig config = BertConfig.from_json_file(args.config_file) model = BertForPreTraining(config) else: raise ValueError( "Unsupported language modeling task: {}".format(args.task) ) return model def load_dataset(self, split, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ path = self.args.data if not os.path.exists(path): raise FileNotFoundError( "Dataset not found: ({})".format(path) ) files = [os.path.join(path, f) for f in os.listdir(path)] if os.path.isdir(path) else [path] files = sorted([f for f in files if split in f]) # # debug if self.args.num_file > 0: files = files[0:self.args.num_file] assert len(files) > 0, "no suitable file in split ***{}***".format(split) datasets = [] for i, f in enumerate(files): datasets.append(BertH5pyData(f)) dataset = ConBertH5pyData(datasets) print('| loaded {} sentences from: {}'.format(len(dataset), path), flush=True) self.datasets[split] = dataset print('| loading finished')
class MNISTTask(Task): def __init__(self, args): super(MNISTTask, self).__init__(args) @classmethod def setup_task(cls, args, **kwargs): """Setup the task (e.g., load dictionaries). Args: args (argparse.Namespace): parsed command-line arguments """ # return cls(args, dictionary, output_dictionary, targets=targets) return cls(args) def build_model(self, args): model = MNISTNet() return model def load_dataset(self, split, **kwargs): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) """ path = self.args.data if not os.path.exists(path): os.makedirs(path) raise FileNotFoundError( "Dataset not found: ({})".format(path) ) if os.path.isdir(path): if os.path.exists(os.path.join(path, 'MNIST/processed/')): path = os.path.join(path, 'MNIST/processed/') elif os.path.basename(os.path.normpath(path)) != 'processed': torchvision.datasets.MNIST(path, train=True, download=True) path = os.path.join(path, 'MNIST/processed/') files = [os.path.join(path, f) for f in os.listdir(path)] if os.path.isdir(path) else [path] files = sorted([f for f in files if split in f]) assert len(files) == 1, "no suitable file in split ***{}***".format(split) dataset = MNISTDataset(files[0]) print('| loaded {} sentences from: {}'.format(len(dataset), path), flush=True) self.datasets[split] = dataset print('| loading finished') class MNISTNet(nn.Module): def __init__(self): super(MNISTNet, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x, target, eval=False): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) loss = F.nll_loss(output, target) return loss