Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoConfig, RobertaModel | |
| class SchemaItemClassifier(nn.Module): | |
| def __init__(self, model_name_or_path, mode): | |
| super(SchemaItemClassifier, self).__init__() | |
| if mode in ["eval", "test"]: | |
| # load config | |
| config = AutoConfig.from_pretrained(model_name_or_path) | |
| # randomly initialize model's parameters according to the config | |
| self.plm_encoder = RobertaModel(config) | |
| elif mode == "train": | |
| self.plm_encoder = RobertaModel.from_pretrained(model_name_or_path) | |
| else: | |
| raise ValueError() | |
| self.plm_hidden_size = self.plm_encoder.config.hidden_size | |
| # column cls head | |
| self.column_info_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) | |
| self.column_info_cls_head_linear2 = nn.Linear(256, 2) | |
| # column bi-lstm layer | |
| self.column_info_bilstm = nn.LSTM( | |
| input_size = self.plm_hidden_size, | |
| hidden_size = int(self.plm_hidden_size/2), | |
| num_layers = 2, | |
| dropout = 0, | |
| bidirectional = True | |
| ) | |
| # linear layer after column bi-lstm layer | |
| self.column_info_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) | |
| # table cls head | |
| self.table_name_cls_head_linear1 = nn.Linear(self.plm_hidden_size, 256) | |
| self.table_name_cls_head_linear2 = nn.Linear(256, 2) | |
| # table bi-lstm pooling layer | |
| self.table_name_bilstm = nn.LSTM( | |
| input_size = self.plm_hidden_size, | |
| hidden_size = int(self.plm_hidden_size/2), | |
| num_layers = 2, | |
| dropout = 0, | |
| bidirectional = True | |
| ) | |
| # linear layer after table bi-lstm layer | |
| self.table_name_linear_after_pooling = nn.Linear(self.plm_hidden_size, self.plm_hidden_size) | |
| # activation function | |
| self.leakyrelu = nn.LeakyReLU() | |
| self.tanh = nn.Tanh() | |
| # table-column cross-attention layer | |
| self.table_column_cross_attention_layer = nn.MultiheadAttention(embed_dim = self.plm_hidden_size, num_heads = 8) | |
| # dropout function, p=0.2 means randomly set 20% neurons to 0 | |
| self.dropout = nn.Dropout(p = 0.2) | |
| def table_column_cross_attention( | |
| self, | |
| table_name_embeddings_in_one_db, | |
| column_info_embeddings_in_one_db, | |
| column_number_in_each_table | |
| ): | |
| table_num = table_name_embeddings_in_one_db.shape[0] | |
| table_name_embedding_attn_list = [] | |
| for table_id in range(table_num): | |
| table_name_embedding = table_name_embeddings_in_one_db[[table_id], :] | |
| column_info_embeddings_in_one_table = column_info_embeddings_in_one_db[ | |
| sum(column_number_in_each_table[:table_id]) : sum(column_number_in_each_table[:table_id+1]), :] | |
| table_name_embedding_attn, _ = self.table_column_cross_attention_layer( | |
| table_name_embedding, | |
| column_info_embeddings_in_one_table, | |
| column_info_embeddings_in_one_table | |
| ) | |
| table_name_embedding_attn_list.append(table_name_embedding_attn) | |
| # residual connection | |
| table_name_embeddings_in_one_db = table_name_embeddings_in_one_db + torch.cat(table_name_embedding_attn_list, dim = 0) | |
| # row-wise L2 norm | |
| table_name_embeddings_in_one_db = torch.nn.functional.normalize(table_name_embeddings_in_one_db, p=2.0, dim=1) | |
| return table_name_embeddings_in_one_db | |
| def table_column_cls( | |
| self, | |
| encoder_input_ids, | |
| encoder_input_attention_mask, | |
| batch_aligned_column_info_ids, | |
| batch_aligned_table_name_ids, | |
| batch_column_number_in_each_table | |
| ): | |
| batch_size = encoder_input_ids.shape[0] | |
| encoder_output = self.plm_encoder( | |
| input_ids = encoder_input_ids, | |
| attention_mask = encoder_input_attention_mask, | |
| return_dict = True | |
| ) # encoder_output["last_hidden_state"].shape = (batch_size x seq_length x hidden_size) | |
| batch_table_name_cls_logits, batch_column_info_cls_logits = [], [] | |
| # handle each data in current batch | |
| for batch_id in range(batch_size): | |
| column_number_in_each_table = batch_column_number_in_each_table[batch_id] | |
| sequence_embeddings = encoder_output["last_hidden_state"][batch_id, :, :] # (seq_length x hidden_size) | |
| # obtain table ids for each table | |
| aligned_table_name_ids = batch_aligned_table_name_ids[batch_id] | |
| # obtain column ids for each column | |
| aligned_column_info_ids = batch_aligned_column_info_ids[batch_id] | |
| table_name_embedding_list, column_info_embedding_list = [], [] | |
| # obtain table embedding via bi-lstm pooling + a non-linear layer | |
| for table_name_ids in aligned_table_name_ids: | |
| table_name_embeddings = sequence_embeddings[table_name_ids, :] | |
| # BiLSTM pooling | |
| output_t, (hidden_state_t, cell_state_t) = self.table_name_bilstm(table_name_embeddings) | |
| table_name_embedding = hidden_state_t[-2:, :].view(1, self.plm_hidden_size) | |
| table_name_embedding_list.append(table_name_embedding) | |
| table_name_embeddings_in_one_db = torch.cat(table_name_embedding_list, dim = 0) | |
| # non-linear mlp layer | |
| table_name_embeddings_in_one_db = self.leakyrelu(self.table_name_linear_after_pooling(table_name_embeddings_in_one_db)) | |
| # obtain column embedding via bi-lstm pooling + a non-linear layer | |
| for column_info_ids in aligned_column_info_ids: | |
| column_info_embeddings = sequence_embeddings[column_info_ids, :] | |
| # BiLSTM pooling | |
| output_c, (hidden_state_c, cell_state_c) = self.column_info_bilstm(column_info_embeddings) | |
| column_info_embedding = hidden_state_c[-2:, :].view(1, self.plm_hidden_size) | |
| column_info_embedding_list.append(column_info_embedding) | |
| column_info_embeddings_in_one_db = torch.cat(column_info_embedding_list, dim = 0) | |
| # non-linear mlp layer | |
| column_info_embeddings_in_one_db = self.leakyrelu(self.column_info_linear_after_pooling(column_info_embeddings_in_one_db)) | |
| # table-column (tc) cross-attention | |
| table_name_embeddings_in_one_db = self.table_column_cross_attention( | |
| table_name_embeddings_in_one_db, | |
| column_info_embeddings_in_one_db, | |
| column_number_in_each_table | |
| ) | |
| # calculate table 0-1 logits | |
| table_name_embeddings_in_one_db = self.table_name_cls_head_linear1(table_name_embeddings_in_one_db) | |
| table_name_embeddings_in_one_db = self.dropout(self.leakyrelu(table_name_embeddings_in_one_db)) | |
| table_name_cls_logits = self.table_name_cls_head_linear2(table_name_embeddings_in_one_db) | |
| # calculate column 0-1 logits | |
| column_info_embeddings_in_one_db = self.column_info_cls_head_linear1(column_info_embeddings_in_one_db) | |
| column_info_embeddings_in_one_db = self.dropout(self.leakyrelu(column_info_embeddings_in_one_db)) | |
| column_info_cls_logits = self.column_info_cls_head_linear2(column_info_embeddings_in_one_db) | |
| batch_table_name_cls_logits.append(table_name_cls_logits) | |
| batch_column_info_cls_logits.append(column_info_cls_logits) | |
| return batch_table_name_cls_logits, batch_column_info_cls_logits | |
| def forward( | |
| self, | |
| encoder_input_ids, | |
| encoder_attention_mask, | |
| batch_aligned_column_info_ids, | |
| batch_aligned_table_name_ids, | |
| batch_column_number_in_each_table, | |
| ): | |
| batch_table_name_cls_logits, batch_column_info_cls_logits \ | |
| = self.table_column_cls( | |
| encoder_input_ids, | |
| encoder_attention_mask, | |
| batch_aligned_column_info_ids, | |
| batch_aligned_table_name_ids, | |
| batch_column_number_in_each_table | |
| ) | |
| return { | |
| "batch_table_name_cls_logits" : batch_table_name_cls_logits, | |
| "batch_column_info_cls_logits": batch_column_info_cls_logits | |
| } |