88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
class LinearEnn(nn.Module):
|
|
in_dim: int
|
|
out_dim: int
|
|
alpha_kl: float
|
|
def __init__(self, in_dim: int, out_dim: int, focal: int, alpha_kl: float):
|
|
super().__init__()
|
|
self.in_dim = in_dim
|
|
self.out_dim = out_dim
|
|
self.alpha_kl = alpha_kl
|
|
self.focal = focal
|
|
self.classifier = nn.Sequential(
|
|
nn.Linear(in_dim, out_dim),
|
|
nn.ELU(),
|
|
)
|
|
|
|
def forward(self, inputs: torch.FloatTensor) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
logits = self.classifier(inputs)
|
|
evidence = torch.exp(logits)
|
|
prob = F.normalize(evidence + 1, p=1, dim=1)
|
|
return evidence, prob
|
|
|
|
def criterion(self, evidence: torch.FloatTensor, label: torch.LongTensor) -> torch.FloatTensor:
|
|
if len(label.shape) == 1:
|
|
label = F.one_hot(label, self.out_dim)
|
|
alpha = evidence + 1
|
|
alpha_0 = alpha.sum(1).unsqueeze(-1).repeat(1, self.out_dim)
|
|
loss_ece = torch.sum(label * (torch.digamma(alpha_0) - torch.digamma(alpha)), dim=1)
|
|
loss_ece = torch.mean(loss_ece)
|
|
if self.alpha_kl > 0:
|
|
tilde_alpha = label + (1 - label) * alpha
|
|
uncertainty_alpha = torch.ones_like(tilde_alpha).cuda()
|
|
estimate_dirichlet = torch.distributions.Dirichlet(tilde_alpha)
|
|
uncertainty_dirichlet = torch.distributions.Dirichlet(uncertainty_alpha)
|
|
kl = torch.distributions.kl_divergence(estimate_dirichlet, uncertainty_dirichlet)
|
|
loss_kl = torch.mean(kl)
|
|
else:
|
|
loss_kl = 0
|
|
return loss_ece + self.alpha_kl * loss_kl
|
|
|
|
def predict(self, inputs: torch.FloatTensor | list[list[float]] | np.ndarray) -> tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
"""
|
|
返回每个类别的预测概率 和 当前的预测的不确定度 uncertainty
|
|
"""
|
|
if not isinstance(inputs, torch.FloatTensor):
|
|
inputs = torch.FloatTensor(inputs)
|
|
with torch.no_grad():
|
|
evidence, prob = self.forward(inputs)
|
|
alpha = evidence + 1
|
|
S = alpha.sum(dim=1)
|
|
u = self.out_dim / S
|
|
|
|
return prob, u
|
|
|
|
|
|
def train_enn(enn_model: LinearEnn, embedding: np.ndarray | torch.FloatTensor, labels: np.ndarray | torch.LongTensor, bs: int = 64, lr: float = 1e-3, epoch: int = 100):
|
|
optimizer = torch.optim.AdamW(enn_model.parameters(), lr=lr)
|
|
|
|
sample_num = len(embedding)
|
|
sample_indice = np.arange(sample_num)
|
|
bs_num = int(np.ceil(sample_num / bs))
|
|
|
|
training_losses = []
|
|
|
|
for i in range(epoch):
|
|
alpha_kl = min(0.9, i / 10)
|
|
np.random.shuffle(sample_indice)
|
|
train_loss = 0
|
|
for bs_i in range(bs_num):
|
|
start = bs_i * bs
|
|
end = min(sample_num, start + bs)
|
|
data_indice = sample_indice[start: end]
|
|
data = torch.FloatTensor(embedding[data_indice])
|
|
label = torch.LongTensor(labels[data_indice])
|
|
evidence, prob = enn_model(data)
|
|
loss = enn_model.criterion(evidence, label)
|
|
train_loss += loss.item()
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
training_losses.append(train_loss / bs_num) |