classModel(nn.Module): def__init__(self): super().__init__() self.fc1 = torch.nn.Linear(768,512) self.fc2 = torch.nn.Linear(512,256) self.fc3 = torch.nn.Linear(256,13) self.bert_model = BertModel.from_pretrained('../bert-base-chinese/') for param in self.bert_model.parameters(): param.requires_grad_(True)
defforward(self,input_ids,attention_mask,token_type_ids): out = self.bert_model( input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids ) out = F.dropout(out.last_hidden_state[:,0],p =0.2) out = self.fc1(out) out = F.relu(out) out = self.fc2(out) out = F.relu(out) out = self.fc3(out) out = out.softmax(dim = 1) return (out)