Model¶
-
class
hetseq.bert_modeling.
BertConfig
(vocab_size_or_config_json_file, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act='gelu', hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02)[source]¶ Configuration class to store the configuration of a BertModel.
-
class
hetseq.bert_modeling.
BertForPreTraining
(*args: Any, **kwargs: Any)[source]¶ BERT model with pre-training heads. This module comprises the BERT model followed by the two pre-training heads:
the masked language modeling head, and
the next sentence classification head.
- Params:
config: a BertConfig class instance with the configuration to build a new model.
- Inputs:
- input_ids: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts extract_features.py, run_classifier.py and run_squad.py)
- token_type_ids: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a sentence A and type 1 corresponds to a sentence B token (see BERT paper for more details).
- attention_mask: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It’s a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It’s the mask that we typically use for attention when a batch has varying length sentences.
- masked_lm_labels: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, …, vocab_size]. All labels set to -1 are ignored (masked), the loss is only computed for the labels set in [0, …, vocab_size]
- next_sentence_label: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1]. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
- Outputs:
- if masked_lm_labels and next_sentence_label are not None:
Outputs the total_loss which is the sum of the masked language modeling loss and the next sentence classification loss.
- if masked_lm_labels or next_sentence_label is None:
Outputs a tuple comprising - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and - the next sentence classification logits of shape [batch_size, 2].
Example usage: ```python # Already been converted into WordPiece token ids input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
- config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = BertForPreTraining(config) masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) ```