Menjawab pertanyaan saya sendiri di sini, semoga bermanfaat bagi sebagian pembaca.
Scikit-belajar terutama dirancang untuk menangani data terstruktur vektor. Oleh karena itu, jika Anda ingin melakukan propagasi label / label menyebar pada data terstruktur grafik, Anda mungkin lebih baik menerapkan kembali metode sendiri daripada menggunakan antarmuka Scikit.
Berikut ini adalah implementasi dari Propagasi Label dan Penyebaran Label di PyTorch.
Kedua metode secara keseluruhan mengikuti langkah-langkah algoritmik yang sama, dengan variasi tentang bagaimana matriks adjacency dinormalisasi dan bagaimana label disebarkan pada setiap langkah. Karena itu, mari kita buat kelas dasar untuk dua model kami.
from abc import abstractmethod
import torch
class BaseLabelPropagation:
"""Base class for label propagation models.
Parameters
----------
adj_matrix: torch.FloatTensor
Adjacency matrix of the graph.
"""
def __init__(self, adj_matrix):
self.norm_adj_matrix = self._normalize(adj_matrix)
self.n_nodes = adj_matrix.size(0)
self.one_hot_labels = None
self.n_classes = None
self.labeled_mask = None
self.predictions = None
@staticmethod
@abstractmethod
def _normalize(adj_matrix):
raise NotImplementedError("_normalize must be implemented")
@abstractmethod
def _propagate(self):
raise NotImplementedError("_propagate must be implemented")
def _one_hot_encode(self, labels):
# Get the number of classes
classes = torch.unique(labels)
classes = classes[classes != -1]
self.n_classes = classes.size(0)
# One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
unlabeled_mask = (labels == -1)
labels = labels.clone() # defensive copying
labels[unlabeled_mask] = 0
self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
self.one_hot_labels[unlabeled_mask, 0] = 0
self.labeled_mask = ~unlabeled_mask
def fit(self, labels, max_iter, tol):
"""Fits a semi-supervised learning label propagation model.
labels: torch.LongTensor
Tensor of size n_nodes indicating the class number of each node.
Unlabeled nodes are denoted with -1.
max_iter: int
Maximum number of iterations allowed.
tol: float
Convergence tolerance: threshold to consider the system at steady state.
"""
self._one_hot_encode(labels)
self.predictions = self.one_hot_labels.clone()
prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
for i in range(max_iter):
# Stop iterations if the system is considered at a steady state
variation = torch.abs(self.predictions - prev_predictions).sum().item()
if variation < tol:
print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
break
prev_predictions = self.predictions
self._propagate()
def predict(self):
return self.predictions
def predict_classes(self):
return self.predictions.max(dim=1).indices
Model mengambil input matriks adjacency dari grafik serta label dari node. Label tersebut dalam bentuk vektor bilangan bulat yang menunjukkan nomor kelas setiap node dengan -1 pada posisi node yang tidak berlabel.
Algoritma Propagasi Label disajikan di bawah ini.
W : matriks adjacency dari grafik Hitung matriks derajat diagonal D oleh Di i← ∑jWsaya j Inisialisasi Y^( 0 )← ( y1, ... , yl, 0 , 0 , ... , 0 ) Pengulangan 1. Y^( t + 1 )← D- 1W Y^( t ) 2. Y^( t + 1 )l← Yl sampai konvergensi ke Y^( ∞ ) Titik label xsaya dengan tanda y^( ∞)saya
Dari Xiaojin Zhu dan Zoubin Ghahramani. Belajar dari data berlabel dan tidak berlabel dengan propagasi label. Laporan Teknis CMU-CALD-02-107, Universitas Carnegie Mellon, 2002
Kami mendapatkan implementasi berikut.
class LabelPropagation(BaseLabelPropagation):
def __init__(self, adj_matrix):
super().__init__(adj_matrix)
@staticmethod
def _normalize(adj_matrix):
"""Computes D^-1 * W"""
degs = adj_matrix.sum(dim=1)
degs[degs == 0] = 1 # avoid division by 0 error
return adj_matrix / degs[:, None]
def _propagate(self):
self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)
# Put back already known labels
self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]
def fit(self, labels, max_iter=1000, tol=1e-3):
super().fit(labels, max_iter, tol)
Algoritma Penyebaran Label adalah:
W : matriks adjacency dari grafik Hitung matriks derajat diagonal D oleh Di i← ∑jWsaya j Hitung grafik Laplacian yang dinormalisasi L ← D- 1 / 2W D- 1 / 2 Inisialisasi Y^( 0 )← ( y1, ... , yl, 0 , 0 , ... , 0 ) Pilih parameter α ∈ [ 0 , 1 ) Iterate Y^( t + 1 ) ← α L Y^( t )+ ( 1 - α ) Y^( 0 ) sampai konvergensi ke Y^( ∞ ) Titik label xsaya dengan tanda y^( ∞ )saya
Dari Dengyong Zhou, Olivier Bousquet, Thomas Navin Lal, Jason Weston, Bernhard Schoelkopf. Belajar dengan konsistensi lokal dan global (2004)
Karena itu, implementasinya adalah sebagai berikut.
class LabelSpreading(BaseLabelPropagation):
def __init__(self, adj_matrix):
super().__init__(adj_matrix)
self.alpha = None
@staticmethod
def _normalize(adj_matrix):
"""Computes D^-1/2 * W * D^-1/2"""
degs = adj_matrix.sum(dim=1)
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 1
return adj_matrix * norm[:, None] * norm[None, :]
def _propagate(self):
self.predictions = (
self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
+ (1 - self.alpha) * self.one_hot_labels
)
def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
"""
Parameters
----------
alpha: float
Clamping factor.
"""
self.alpha = alpha
super().fit(labels, max_iter, tol)
Sekarang mari kita menguji model propagasi kita pada data sintetis. Untuk melakukannya, kami memilih untuk menggunakan grafik gua .
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()
# Create labels
labels = np.full(n_cliques * size_cliques, -1.)
# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3
# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)
# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()
# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()
# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]
plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)
ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")
pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)
# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)
plt.show()
Model yang diimplementasikan bekerja dengan benar dan memungkinkan untuk mendeteksi komunitas dalam grafik.
Catatan: Metode propagasi yang disajikan dimaksudkan untuk digunakan pada grafik yang tidak diarahkan.
Kode tersedia sebagai notebook Jupyter interaktif di sini .