Chinaunix首页 | 论坛 | 博客
  • 博客访问: 3648372
  • 博文数量: 365
  • 博客积分: 0
  • 博客等级: 民兵
  • 技术积分: 2522
  • 用 户 组: 普通用户
  • 注册时间: 2019-10-28 13:40
文章分类

全部博文(365)

文章存档

2023年(8)

2022年(130)

2021年(155)

2020年(50)

2019年(22)

我的朋友

分类: Python/Ruby

2021-04-23 16:58:53

BiLSTM+CRF

import torch

import torch.nn as nn

from modelgraph.BILSTM import BiLSTM

from itertools import zip_longest

class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, emb_size, hidden_size, out_size):

        super(BiLSTM_CRF, self).__init__()

        self.bilstm = BiLSTM(vocab_size, emb_size, hidden_size, out_size)

        self.transition = nn.Parameter(torch.ones(out_size, out_size) * 1 / out_size)

    def forward(self, sents_tensor, lengths):

        emission = self.bilstm(sents_tensor, lengths)

        batch_size, max_len, out_size = emission.size()

        crf_scores = emission.unsqueeze(2).expand(-1, -1, out_size, -1) + self.transition.unsqueeze(0)

        return crf_scores

    def test(self, test_sents_tensor, lengths, tag2id):

        start_id = tag2id['']

        end_id = tag2id['']

        pad = tag2id['']

        tagset_size = len(tag2id)

        crf_scores =self.forward(test_sents_tensor, lengths)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        B , L , T, _ =crf_scores.size()

        viterbi = torch.zeros(B, L, T).to(device)

        backpointer = (torch.zeros(B, L, T).long() * end_id).to(device)

        lengths = torch.LongTensor(lengths).to(device)

        for step in range(L):

            batch_size_t =(lengths > step).sum().item()

            if step == 0:

                viterbi[:batch_size_t, step, :] = crf_scores[: batch_size_t, step, start_id, :]

                backpointer[:batch_size_t, step, :] = start_id

            else:

                max_scores, prev_tags = torch.max(viterbi[:batch_size_t, step-1, :].unsqueeze(2) + crf_scores[:batch_size_t, step, :, :], dim=1)

                viterbi[:batch_size_t, step, :] = max_scores

                backpointer[:batch_size_t, step, :] = prev_tags

        backpointer = backpointer.view(B, -1)

        tagids = []

        tags_t = None

        for step in range(L-1, 0, -1):

            batch_size_t = (lengths > step).sum().item()

            if step == L-1:

                index = torch.ones(batch_size_t).long() * (step * tagset_size)

                index = index.to(device)

                index += end_id

            else:

                prev_batch_size_t = len(tags_t)

                new_in_batch = torch.LongTensor([end_id] * (batch_size_t - prev_batch_size_t)).to(device)

                offset = torch.cat([tags_t, new_in_batch], dim=0)

                index = torch.ones(batch_size_t).long() * (step *tagset_size)

                index = index.to(device)

                index += offset.long()

            try:

                tags_t = backpointer[:batch_size_t].gather(dim=1, index=index.unsqueeze(1).long())

            except RuntimeError:

                import pdb

                pdb.set_trace()

            tags_t = tags_t.squeeze(1)

            tagids.append(tags_t.tolist())

        tagids = list(zip_longest(*reversed(tagids), fillvalue=pad))

        tagids = torch.Tensor(tagids).long()

        return tagids

def cal_lstm_crf_loss(crf_scores, targets, tag2id):

    pad_id = tag2id.get('')

    start_id = tag2id.get('')

    end_id = tag2id.get('')

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    batch_size, max_len = targets.size()

    target_size = len(tag2id)

    mask = (targets != pad_id)

    lengths = mask.sum(dim=1)

    targets = indexed(targets, target_size, start_id)

    targets = targets.masked_select(mask)

    flatten_scores = crf_scores.masked_select(

        mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores)

    ).view(-1, target_size*target_size).contiguous()

    golden_scores = flatten_scores.gather(

        dim=1, index=targets.unsqueeze(1)).sum()

    scores_upto_t = torch.zeros(batch_size, target_size).to(device)

    for t in range(max_len):

        batch_size_t = (lengths > t).sum().item()

        if t == 0:

            scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t,

                                           t, start_id, :]

        else:

            scores_upto_t[:batch_size_t] = torch.logsumexp(

                crf_scores[:batch_size_t, t, :, :] +

                scores_upto_t[:batch_size_t].unsqueeze(2),

                dim=1

            )

    all_path_scores = scores_upto_t[:, end_id].sum()

    loss = (all_path_scores - golden_scores) / batch_size

    return loss

def indexed(targets, tagset_size, start_id):

    batch_size, max_len = targets.size()

    for col in range(max_len-1, 0, -1):

        targets[:, col] += (targets[:, col-1] * tagset_size)

    targets[:, 0] += (start_id * tagset_size)

    return targets

阅读(2010) | 评论(0) | 转发(0) |
给主人留下些什么吧!~~