Source code for EduNLP.ModelZoo.disenqnet.disenqnet

# -*- coding: utf-8 -*-

import logging
import torch
from torch import nn
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import StepLR
import os
import json
from gensim.models import KeyedVectors

from .modules import TextEncoder, AttnModel, ConceptEstimator, MIEstimator, DisenEstimator
from .utils import get_mask
from ..utils import set_device
from ..base_model import BaseModel
from transformers.modeling_outputs import ModelOutput
from transformers import PretrainedConfig


class DisenQNetOutput(ModelOutput):
    """
    Output type of [`DisenQNet`]

    Parameters
    ----------
    embed: Tensor of (batch_size, seq_len, hidden_size), word embedding
    k_hidden: Tensor of (batch_size, hidden_size) or None, concept representation of question
    i_hidden: Tensor of (batch_size, hidden_size) or None, individual representation of question
    """
    embeded: torch.FloatTensor = None
    k_hidden: torch.FloatTensor = None
    i_hidden: torch.FloatTensor = None


[docs]class DisenQNet(BaseModel): base_model_prefix = 'disenq' """ DisenQNet question representation model Parameters ---------- vocab_size: int size of vocabulary hidden_size: int size of word and question embedding dropout_rate: float dropout rate wv: torch.Tensor Tensor of (vocab_size, hidden_size) or None, initial word embedding, default = None """ def __init__(self, vocab_size: int, hidden_size: int, dropout_rate: float, wv=None, **kwargs): super(DisenQNet, self).__init__() self.hidden_size = hidden_size self.encoder = TextEncoder(vocab_size, hidden_size, dropout_rate, wv=wv) self.k_model = AttnModel(hidden_size, dropout_rate) self.i_model = AttnModel(hidden_size, dropout_rate) self.dropout = nn.Dropout(p=dropout_rate) # config self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "kwargs", 'wv']} self.config.update(kwargs) self.config['architecture'] = 'DisenQNet' self.config = PretrainedConfig.from_dict(self.config)
[docs] def forward(self, seq_idx=None, seq_len=None, get_vk=True, get_vi=True) -> ModelOutput: """ Parameters ---------- seq_idx: Tensor of (batch_size, seq_len) word index seq_len: Tensor of (batch_size) valid sequence length of each batch get_vk: bool whether to return vk get_vi: bool whether to return vi Returns ------- DisenQNetOutput - embed: Tensor of (batch_size, seq_len, hidden_size), word embedding - k_hidden: Tensor of (batch_size, hidden_size) or None, concept representation of question - i_hidden: Tensor of (batch_size, hidden_size) or None, individual representation of question """ # embed: batch_size * seq_len * hidden_size # q_hidden: batch_size * hidden_size embed, q_hidden = self.encoder(seq_idx) # batch_size * seq_len, 0 means valid, 1 means pad mask = get_mask(seq_idx.size(1), seq_len) embed.masked_fill_(mask.unsqueeze(-1), 0) k_hidden, i_hidden = None, None q_hidden_dp = self.dropout(q_hidden) embed_dp = self.dropout(embed) # batch_size * hidden_size if get_vk: k_hidden, _ = self.k_model(q_hidden_dp, embed_dp, embed_dp, mask) if get_vi: i_hidden, _ = self.i_model(q_hidden_dp, embed_dp, embed_dp, mask) return DisenQNetOutput( embeded=embed, k_hidden=k_hidden, i_hidden=i_hidden )
[docs] @classmethod def from_config(cls, config_path, **kwargs): with open(config_path, "r", encoding="utf-8") as rf: model_config = json.load(rf) model_config.update(kwargs) return cls( vocab_size=model_config['vocab_size'], hidden_size=model_config['hidden_size'], dropout_rate=model_config['dropout_rate'], )
class DisenQNetForPreTrainingOutput(ModelOutput): """ Output type of [`DisenQNetForPreTraining`] Parameters ---------- loss embed: Tensor of (batch_size, seq_len, hidden_size), word embedding k_hidden: Tensor of (batch_size, hidden_size) or None, concept representation of question i_hidden: Tensor of (batch_size, hidden_size) or None, individual representation of question """ loss: torch.FloatTensor = None embeded: torch.FloatTensor = None k_hidden: torch.FloatTensor = None i_hidden: torch.FloatTensor = None
[docs]class DisenQNetForPreTraining(BaseModel): base_model_prefix = 'disenq' def __init__(self, vocab_size, concept_size, hidden_size, dropout_rate, pos_weight, w_cp, w_mi, w_dis, warmup, n_adversarial, wv=None, **kwargs): super(DisenQNetForPreTraining, self).__init__() self.disenq = DisenQNet( vocab_size=vocab_size, hidden_size=hidden_size, dropout_rate=dropout_rate, wv=wv, **kwargs) self.mi_estimator = MIEstimator(hidden_size, hidden_size * 2, dropout_rate) self.concept_estimator = ConceptEstimator(hidden_size, concept_size, pos_weight, dropout_rate) self.disen_estimator = DisenEstimator(hidden_size, dropout_rate) self.w_cp = w_cp self.w_mi = w_mi self.w_dis = w_dis self.hidden_size = hidden_size self.warming_up = False self.params = { "vocab_size": vocab_size, "concept_size": concept_size, "hidden_size": hidden_size, "dropout": dropout_rate, "pos_weight": pos_weight, "w_cp": w_cp, "w_mi": w_mi, "w_dis": w_dis, 'warmup': warmup, 'n_adversarial': n_adversarial, } self.modules = (self.disenq, self.mi_estimator, self.concept_estimator, self.disen_estimator) self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "kwargs", 'wv']} self.config.update(kwargs) self.config['architecture'] = 'DisenQNetForPreTraining' self.config = PretrainedConfig.from_dict(self.config) model_params = list() for params in [list(self.disenq.parameters()), list(self.mi_estimator.parameters()), list(self.concept_estimator.parameters())]: model_params.extend(params) self.model_params = model_params self.adv_params = list(self.disen_estimator.parameters())
[docs] def forward(self, seq_idx=None, seq_len=None, concept=None) -> ModelOutput: # train enc outputs = self.disenq(seq_idx, seq_len) embed = outputs.embeded k_hidden = outputs.k_hidden i_hidden = outputs.i_hidden hidden = torch.cat((k_hidden, i_hidden), dim=-1) # max mi mi_loss = - self.mi_estimator(embed, hidden, seq_len) # min concept_loss cp_loss = self.concept_estimator(k_hidden, concept) if self.warming_up: loss = self.w_mi * mi_loss + self.w_cp * cp_loss else: # min dis dis_loss = self.disen_estimator(k_hidden, i_hidden) loss = self.w_mi * mi_loss + self.w_cp * cp_loss + self.w_dis * dis_loss return DisenQNetForPreTrainingOutput( loss=loss, embeded=embed, k_hidden=k_hidden, i_hidden=i_hidden )
[docs] @classmethod def from_config(cls, config_path, **kwargs): with open(config_path, "r", encoding="utf-8") as rf: model_config = json.load(rf) model_config.update(kwargs) return cls( vocab_size=model_config['vocab_size'], concept_size=model_config['concept_size'], hidden_size=model_config['hidden_size'], dropout_rate=model_config['dropout_rate'], pos_weight=model_config['pos_weight'], w_cp=model_config['w_cp'], w_mi=model_config['w_mi'], w_dis=model_config['w_dis'], warmup=model_config['warmup'], n_adversarial=model_config['n_adversarial'], )