Bagaimana cara menggunakan Scikit-Learn Label Propagation pada data terstruktur grafik?


11

Sebagai bagian dari penelitian saya, saya tertarik untuk melakukan propagasi label pada grafik. Saya terutama tertarik pada dua metode ini:

Saya melihat bahwa scikit-belajar menawarkan model untuk melakukan itu. Namun, model ini seharusnya diterapkan pada data terstruktur vektor ( yaitu titik data).

Model membangun matriks afinitas dari titik data menggunakan kernel, dan kemudian menjalankan algoritma pada matriks yang dikonstruksi. Saya ingin dapat langsung memasukkan matriks adjacency dari grafik saya di tempat matriks kesamaan.

Adakah gagasan tentang bagaimana mencapainya? Atau apakah Anda tahu pustaka Python yang akan memungkinkan untuk menjalankan propagasi label secara langsung pada data terstruktur grafik untuk dua metode yang disebutkan di atas?

Terima kasih sebelumnya atas bantuan Anda!


Sudahkah Anda memeriksa kode sumber Scikit-learn untuk melihat apa yang dilakukannya setelah menghitung matriks afinitas? Mungkin bisa "menyalin" kode setelah bagian itu untuk menerapkannya langsung ke matriks adjacency Anda.
Tasos

Terima kasih atas komentar Anda! Jadi, sebenarnya, inilah yang saya lakukan saat ini, tetapi beberapa bagian dari kode yang perlu saya modifikasi agar sesuai dengan kebutuhan saya agak samar. Saya takut menulis ulang bagian-bagian itu akan menyebabkan kesalahan. Saya berharap ada metode yang lebih mudah.
Thibaud Martinez

1
Kode sumber di github.com/scikit-learn/scikit-learn/blob/7389dba/sklearn/… - mengatakan bahwa implementasi harus mengganti metode _build_graph. Jadi secara alami Anda harus mencoba membuat kelas turunan yang menerima matriks yang dikomputasi.
mikalai

Jawaban:


2

Menjawab pertanyaan saya sendiri di sini, semoga bermanfaat bagi sebagian pembaca.

Scikit-belajar terutama dirancang untuk menangani data terstruktur vektor. Oleh karena itu, jika Anda ingin melakukan propagasi label / label menyebar pada data terstruktur grafik, Anda mungkin lebih baik menerapkan kembali metode sendiri daripada menggunakan antarmuka Scikit.

Berikut ini adalah implementasi dari Propagasi Label dan Penyebaran Label di PyTorch.

Kedua metode secara keseluruhan mengikuti langkah-langkah algoritmik yang sama, dengan variasi tentang bagaimana matriks adjacency dinormalisasi dan bagaimana label disebarkan pada setiap langkah. Karena itu, mari kita buat kelas dasar untuk dua model kami.

from abc import abstractmethod
import torch

class BaseLabelPropagation:
    """Base class for label propagation models.

    Parameters
    ----------
    adj_matrix: torch.FloatTensor
        Adjacency matrix of the graph.
    """
    def __init__(self, adj_matrix):
        self.norm_adj_matrix = self._normalize(adj_matrix)
        self.n_nodes = adj_matrix.size(0)
        self.one_hot_labels = None 
        self.n_classes = None
        self.labeled_mask = None
        self.predictions = None

    @staticmethod
    @abstractmethod
    def _normalize(adj_matrix):
        raise NotImplementedError("_normalize must be implemented")

    @abstractmethod
    def _propagate(self):
        raise NotImplementedError("_propagate must be implemented")

    def _one_hot_encode(self, labels):
        # Get the number of classes
        classes = torch.unique(labels)
        classes = classes[classes != -1]
        self.n_classes = classes.size(0)

        # One-hot encode labeled data instances and zero rows corresponding to unlabeled instances
        unlabeled_mask = (labels == -1)
        labels = labels.clone()  # defensive copying
        labels[unlabeled_mask] = 0
        self.one_hot_labels = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)
        self.one_hot_labels = self.one_hot_labels.scatter(1, labels.unsqueeze(1), 1)
        self.one_hot_labels[unlabeled_mask, 0] = 0

        self.labeled_mask = ~unlabeled_mask

    def fit(self, labels, max_iter, tol):
        """Fits a semi-supervised learning label propagation model.

        labels: torch.LongTensor
            Tensor of size n_nodes indicating the class number of each node.
            Unlabeled nodes are denoted with -1.
        max_iter: int
            Maximum number of iterations allowed.
        tol: float
            Convergence tolerance: threshold to consider the system at steady state.
        """
        self._one_hot_encode(labels)

        self.predictions = self.one_hot_labels.clone()
        prev_predictions = torch.zeros((self.n_nodes, self.n_classes), dtype=torch.float)

        for i in range(max_iter):
            # Stop iterations if the system is considered at a steady state
            variation = torch.abs(self.predictions - prev_predictions).sum().item()

            if variation < tol:
                print(f"The method stopped after {i} iterations, variation={variation:.4f}.")
                break

            prev_predictions = self.predictions
            self._propagate()

    def predict(self):
        return self.predictions

    def predict_classes(self):
        return self.predictions.max(dim=1).indices

Model mengambil input matriks adjacency dari grafik serta label dari node. Label tersebut dalam bentuk vektor bilangan bulat yang menunjukkan nomor kelas setiap node dengan -1 pada posisi node yang tidak berlabel.

Algoritma Propagasi Label disajikan di bawah ini.

W: matriks adjacency dari grafik Hitung matriks derajat diagonal D oleh DsayasayajWsayaj Inisialisasi Y^(0)(y1,...,yl,0,0,...,0) Pengulangan  1. Y^(t+1)D-1WY^(t) 2. Y^l(t+1)Yl sampai konvergensi ke Y^() Titik label xsaya dengan tanda y^saya()

Dari Xiaojin Zhu dan Zoubin Ghahramani. Belajar dari data berlabel dan tidak berlabel dengan propagasi label. Laporan Teknis CMU-CALD-02-107, Universitas Carnegie Mellon, 2002

Kami mendapatkan implementasi berikut.

class LabelPropagation(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1 * W"""
        degs = adj_matrix.sum(dim=1)
        degs[degs == 0] = 1  # avoid division by 0 error
        return adj_matrix / degs[:, None]

    def _propagate(self):
        self.predictions = torch.matmul(self.norm_adj_matrix, self.predictions)

        # Put back already known labels
        self.predictions[self.labeled_mask] = self.one_hot_labels[self.labeled_mask]

    def fit(self, labels, max_iter=1000, tol=1e-3):
        super().fit(labels, max_iter, tol)

Algoritma Penyebaran Label adalah:

W: matriks adjacency dari grafik Hitung matriks derajat diagonal D oleh DsayasayajWsayaj Hitung grafik Laplacian yang dinormalisasi L.D-1/2WD-1/2 Inisialisasi Y^(0)(y1,...,yl,0,0,...,0) Pilih parameter α[0,1) Pengulangan Y^(t+1)αL.Y^(t)+(1-α)Y^(0) sampai konvergensi ke Y^() Titik label xsaya dengan tanda y^saya()

Dari Dengyong Zhou, Olivier Bousquet, Thomas Navin Lal, Jason Weston, Bernhard Schoelkopf. Belajar dengan konsistensi lokal dan global (2004)

Karena itu, implementasinya adalah sebagai berikut.

class LabelSpreading(BaseLabelPropagation):
    def __init__(self, adj_matrix):
        super().__init__(adj_matrix)
        self.alpha = None

    @staticmethod
    def _normalize(adj_matrix):
        """Computes D^-1/2 * W * D^-1/2"""
        degs = adj_matrix.sum(dim=1)
        norm = torch.pow(degs, -0.5)
        norm[torch.isinf(norm)] = 1
        return adj_matrix * norm[:, None] * norm[None, :]

    def _propagate(self):
        self.predictions = (
            self.alpha * torch.matmul(self.norm_adj_matrix, self.predictions)
            + (1 - self.alpha) * self.one_hot_labels
        )

    def fit(self, labels, max_iter=1000, tol=1e-3, alpha=0.5):
        """
        Parameters
        ----------
        alpha: float
            Clamping factor.
        """
        self.alpha = alpha
        super().fit(labels, max_iter, tol)

Sekarang mari kita menguji model propagasi kita pada data sintetis. Untuk melakukannya, kami memilih untuk menggunakan grafik gua .

import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

# Create caveman graph
n_cliques = 4
size_cliques = 10
caveman_graph = nx.connected_caveman_graph(n_cliques, size_cliques)
adj_matrix = nx.adjacency_matrix(caveman_graph).toarray()

# Create labels
labels = np.full(n_cliques * size_cliques, -1.)

# Only one node per clique is labeled. Each clique belongs to a different class.
labels[0] = 0
labels[size_cliques] = 1
labels[size_cliques * 2] = 2
labels[size_cliques * 3] = 3

# Create input tensors
adj_matrix_t = torch.FloatTensor(adj_matrix)
labels_t = torch.LongTensor(labels)

# Learn with Label Propagation
label_propagation = LabelPropagation(adj_matrix_t)
label_propagation.fit(labels_t)
label_propagation_output_labels = label_propagation.predict_classes()

# Learn with Label Spreading
label_spreading = LabelSpreading(adj_matrix_t)
label_spreading.fit(labels_t, alpha=0.8)
label_spreading_output_labels = label_spreading.predict_classes()

# Plot graphs
color_map = {-1: "orange", 0: "blue", 1: "green", 2: "red", 3: "cyan"}
input_labels_colors = [color_map[l] for l in labels]
lprop_labels_colors = [color_map[l] for l in label_propagation_output_labels.numpy()]
lspread_labels_colors = [color_map[l] for l in label_spreading_output_labels.numpy()]

plt.figure(figsize=(14, 6))
ax1 = plt.subplot(1, 4, 1)
ax2 = plt.subplot(1, 4, 2)
ax3 = plt.subplot(1, 4, 3)

ax1.title.set_text("Raw data (4 classes)")
ax2.title.set_text("Label Propagation")
ax3.title.set_text("Label Spreading")

pos = nx.spring_layout(caveman_graph)
nx.draw(caveman_graph, ax=ax1, pos=pos, node_color=input_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax2, pos=pos, node_color=lprop_labels_colors, node_size=50)
nx.draw(caveman_graph, ax=ax3, pos=pos, node_color=lspread_labels_colors, node_size=50)

# Legend
ax4 = plt.subplot(1, 4, 4)
ax4.axis("off")
legend_colors = ["orange", "blue", "green", "red", "cyan"]
legend_labels = ["unlabeled", "class 0", "class 1", "class 2", "class 3"]
dummy_legend = [ax4.plot([], [], ls='-', c=c)[0] for c in legend_colors]
plt.legend(dummy_legend, legend_labels)

plt.show()

Model yang diimplementasikan bekerja dengan benar dan memungkinkan untuk mendeteksi komunitas dalam grafik.

Label propagasi dan implementasi label menyebar diuji pada grafik gua

Catatan: Metode propagasi yang disajikan dimaksudkan untuk digunakan pada grafik yang tidak diarahkan.

Kode tersedia sebagai notebook Jupyter interaktif di sini .

Map

Dengan menggunakan situs kami, Anda mengakui telah membaca dan memahami Kebijakan Cookie dan Kebijakan Privasi kami.
Licensed under cc by-sa 3.0 with attribution required.