Source code for EduNLP.ModelZoo.rnn.harnn

import torch
import torch.nn as nn
from torch.autograd import Variable
from transformers.modeling_outputs import ModelOutput
from typing import List, Dict, Optional

__all__ = ["HAM"]


class HAMOutput(ModelOutput):
    """
        Output type of [`HAM`]

        Parameters
        ----------
        scores: of shape (batch_size, sequence_length)
        first_att_weight: of shape
        second_visual: of shape (batch_size, num_classes)
        third_visual: of shape (batch_size, num_classes)
        first_logits: of shape (batch_size, num_classes)
        second_logits: of shape (batch_size, num_classes)
        third_logits: of shape (batch_size, num_classes)
        global_logits: of shape (batch_size, num_classes)
        first_scores: of shape (batch_size, num_classes)
        second_scores: of shape (batch_size, num_classes)
    """
    scores: torch.FloatTensor = None
    first_att_weight: torch.FloatTensor = None
    second_visual: torch.FloatTensor = None
    third_visual: torch.FloatTensor = None
    first_logits: torch.FloatTensor = None
    second_logits: torch.FloatTensor = None
    third_logits: torch.FloatTensor = None
    global_logits: torch.FloatTensor = None
    first_scores: torch.FloatTensor = None
    second_scores: torch.FloatTensor = None


def truncated_normal_(tensor, mean=0, std=0.1):
    size = tensor.shape
    tmp = tensor.new_empty(size + (4,)).normal_()
    valid = (tmp < 2) & (tmp > -2)
    ind = valid.max(-1, keepdim=True)[1]
    tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
    tensor.data.mul_(std).add_(mean)
    return tensor


class AttentionLayer(nn.Module):
    def __init__(self, num_units, attention_unit_size, num_classes):
        super(AttentionLayer, self).__init__()
        self.fc1 = nn.Linear(num_units, attention_unit_size, bias=False)
        self.fc2 = nn.Linear(attention_unit_size, num_classes, bias=False)

    def forward(self, input_x):
        attention_matrix = self.fc2(torch.tanh(self.fc1(input_x))).transpose(1, 2)
        attention_weight = torch.softmax(attention_matrix, dim=-1)
        attention_out = torch.matmul(attention_weight, input_x)
        return attention_weight, torch.mean(attention_out, dim=1)


class LocalLayer(nn.Module):
    def __init__(self, num_units, num_classes):
        super(LocalLayer, self).__init__()
        self.fc = nn.Linear(num_units, num_classes)

    def forward(self, input_x, input_att_weight):
        logits = self.fc(input_x)
        scores = torch.sigmoid(logits)
        visual = torch.mul(input_att_weight, scores.unsqueeze(-1))
        visual = torch.softmax(visual, dim=-1)
        visual = torch.mean(visual, dim=1)
        return logits, scores, visual


[docs]class HAM(nn.Module): def __init__( self, num_classes_list: List[int], num_total_classes: int, sequence_model_hidden_size: int, attention_unit_size: Optional[int] = 256, fc_hidden_size: Optional[int] = 512, beta: Optional[float] = 0.5, dropout_rate=None): super(HAM, self).__init__() self.beta = beta assert num_total_classes == sum(num_classes_list) # First Level self.first_attention = AttentionLayer(sequence_model_hidden_size, attention_unit_size, num_classes_list[0]) self.first_fc = nn.Linear(sequence_model_hidden_size * 2, fc_hidden_size) self.first_local = LocalLayer(fc_hidden_size, num_classes_list[0]) # Second Level self.second_attention = AttentionLayer(sequence_model_hidden_size, attention_unit_size, num_classes_list[1]) self.second_fc = nn.Linear(sequence_model_hidden_size * 2, fc_hidden_size) self.second_local = LocalLayer(fc_hidden_size, num_classes_list[1]) # Third Level self.third_attention = AttentionLayer(sequence_model_hidden_size, attention_unit_size, num_classes_list[2]) self.third_fc = nn.Linear(sequence_model_hidden_size * 2, fc_hidden_size) self.third_local = LocalLayer(fc_hidden_size, num_classes_list[2]) # Fully Connected Layer self.fc = nn.Linear(fc_hidden_size * 3, fc_hidden_size) # Highway Layer self.highway_lin = nn.Linear(fc_hidden_size, fc_hidden_size) self.highway_gate = nn.Linear(fc_hidden_size, fc_hidden_size) # Add dropout self.dropout = nn.Dropout(dropout_rate) # Global scores self.global_scores_fc = nn.Linear(fc_hidden_size, num_total_classes) for name, param in self.named_parameters(): if 'embedding' not in name and 'weight' in name: truncated_normal_(param.data, mean=0, std=0.1) else: nn.init.constant_(param.data, 0.1)
[docs] def forward(self, sequential_embeddings): sequential_embeddings_pool = torch.mean(sequential_embeddings, dim=1) # First Level first_att_weight, first_att_out = self.first_attention(sequential_embeddings) first_local_input = torch.cat((sequential_embeddings_pool, first_att_out), dim=1) first_local_fc_out = self.first_fc(first_local_input) first_logits, first_scores, first_visual = self.first_local(first_local_fc_out, first_att_weight) # Second Level second_att_input = torch.mul(sequential_embeddings, first_visual.unsqueeze(-1)) second_att_weight, second_att_out = self.second_attention(second_att_input) second_local_input = torch.cat((sequential_embeddings_pool, second_att_out), dim=1) second_local_fc_out = self.second_fc(second_local_input) second_logits, second_scores, second_visual = self.second_local(second_local_fc_out, second_att_weight) # Third Level third_att_input = torch.mul(sequential_embeddings, second_visual.unsqueeze(-1)) third_att_weight, third_att_out = self.third_attention(third_att_input) third_local_input = torch.cat((sequential_embeddings_pool, third_att_out), dim=1) third_local_fc_out = self.third_fc(third_local_input) third_logits, third_scores, third_visual = self.third_local(third_local_fc_out, third_att_weight) # Concat # shape of ham_out: [batch_size, fc_hidden_size * 3] ham_out = torch.cat((first_local_fc_out, second_local_fc_out, third_local_fc_out), dim=1) # Fully Connected Layer fc_out = self.fc(ham_out) # Highway Layer and Dropout highway_g = torch.relu(self.highway_lin(fc_out)) highway_t = torch.sigmoid(self.highway_gate(fc_out)) highway_output = torch.mul(highway_g, highway_t) + torch.mul((1 - highway_t), fc_out) h_drop = self.dropout(highway_output) # Global scores global_logits = self.global_scores_fc(h_drop) global_scores = torch.sigmoid(global_logits) local_scores = torch.cat((first_scores, second_scores, third_scores), dim=1) scores = self.beta * global_scores + (1 - self.beta) * local_scores return HAMOutput( scores=scores, first_att_weight=first_att_weight, first_visual=first_visual, second_visual=second_visual, third_visual=third_visual, first_logits=first_logits, second_logits=second_logits, third_logits=third_logits, global_logits=global_logits, first_scores=first_scores, second_scores=second_scores )