Compare commits

..

No commits in common. "cf85cf78429e23e1c1c65a690646b2966f592f8f" and "127147fab1a84dd57a170c3cdc273faf95a64667" have entirely different histories.

4 changed files with 5 additions and 34 deletions

View File

@ -1,19 +0,0 @@
# Reference PyTorch GraphSAGE Implementation
### Author: William L. Hamilton
Basic reference PyTorch implementation of [GraphSAGE](https://github.com/williamleif/GraphSAGE).
This reference implementation is not as fast as the TensorFlow version for large graphs, but the code is easier to read and it performs better (in terms of speed) on small-graph benchmarks.
The code is also intended to be simpler, more extensible, and easier to work with than the TensorFlow version.
Currently, only supervised versions of GraphSAGE-mean and GraphSAGE-GCN are implemented.
#### Requirements
pytorch >0.2 is required.
#### Running examples
Execute `python -m graphsage.model` to run the Cora example.
It assumes that CUDA is not being used, but modifying the run functions in `model.py` in the obvious way can change this.
There is also a pubmed example (called via the `run_pubmed` function in model.py).

View File

@ -2,11 +2,11 @@ import torch
import torch.nn as nn
from torch.autograd import Variable
import random
"""
Set of modules for aggregating embeddings of neighbors.
These modules take as input embeddings of neighbors.
"""
class MeanAggregator(nn.Module):
@ -17,9 +17,8 @@ class MeanAggregator(nn.Module):
"""
Initializes the aggregator for a specific graph.
features -- function mapping LongTensor of node ids to FloatTensor of feature values.
cuda -- whether to use GPU
gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
features -- function mapping (node_list, features, offset) to feature values
see torch.nn.EmbeddingBag and forward function below docs for offset meaning.
"""
super(MeanAggregator, self).__init__()
@ -29,11 +28,6 @@ class MeanAggregator(nn.Module):
self.gcn = gcn
def forward(self, nodes, to_neighs, num_sample=10):
"""
nodes --- list of nodes in a batch
to_neighs --- list of sets, each set is the set of neighbors for node in batch
num_sample --- number of neighbors to sample. No sampling if None.
"""
# Local pointers to functions (speed hack)
_set = set
if not num_sample is None:

View File

@ -35,8 +35,9 @@ class Encoder(nn.Module):
Generates embeddings for a batch of nodes.
nodes -- list of nodes
mode -- string desiginating the mode of the nodes
"""
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes],
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[node] for node in nodes],
self.num_sample)
if not self.gcn:
if self.cuda:

View File

@ -12,11 +12,6 @@ from collections import defaultdict
from graphsage.encoders import Encoder
from graphsage.aggregators import MeanAggregator
"""
Simple supervised GraphSAGE model as well as examples running the model
on the Cora and Pubmed datasets.
"""
class SupervisedGraphSage(nn.Module):
def __init__(self, num_classes, enc):