Compare commits
No commits in common. "cf85cf78429e23e1c1c65a690646b2966f592f8f" and "127147fab1a84dd57a170c3cdc273faf95a64667" have entirely different histories.
cf85cf7842
...
127147fab1
19
README.md
19
README.md
@ -1,19 +0,0 @@
|
||||
# Reference PyTorch GraphSAGE Implementation
|
||||
### Author: William L. Hamilton
|
||||
|
||||
|
||||
Basic reference PyTorch implementation of [GraphSAGE](https://github.com/williamleif/GraphSAGE).
|
||||
This reference implementation is not as fast as the TensorFlow version for large graphs, but the code is easier to read and it performs better (in terms of speed) on small-graph benchmarks.
|
||||
The code is also intended to be simpler, more extensible, and easier to work with than the TensorFlow version.
|
||||
|
||||
Currently, only supervised versions of GraphSAGE-mean and GraphSAGE-GCN are implemented.
|
||||
|
||||
#### Requirements
|
||||
|
||||
pytorch >0.2 is required.
|
||||
|
||||
#### Running examples
|
||||
|
||||
Execute `python -m graphsage.model` to run the Cora example.
|
||||
It assumes that CUDA is not being used, but modifying the run functions in `model.py` in the obvious way can change this.
|
||||
There is also a pubmed example (called via the `run_pubmed` function in model.py).
|
@ -2,11 +2,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
|
||||
import random
|
||||
|
||||
"""
|
||||
Set of modules for aggregating embeddings of neighbors.
|
||||
These modules take as input embeddings of neighbors.
|
||||
"""
|
||||
|
||||
class MeanAggregator(nn.Module):
|
||||
@ -17,9 +17,8 @@ class MeanAggregator(nn.Module):
|
||||
"""
|
||||
Initializes the aggregator for a specific graph.
|
||||
|
||||
features -- function mapping LongTensor of node ids to FloatTensor of feature values.
|
||||
cuda -- whether to use GPU
|
||||
gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
|
||||
features -- function mapping (node_list, features, offset) to feature values
|
||||
see torch.nn.EmbeddingBag and forward function below docs for offset meaning.
|
||||
"""
|
||||
|
||||
super(MeanAggregator, self).__init__()
|
||||
@ -29,11 +28,6 @@ class MeanAggregator(nn.Module):
|
||||
self.gcn = gcn
|
||||
|
||||
def forward(self, nodes, to_neighs, num_sample=10):
|
||||
"""
|
||||
nodes --- list of nodes in a batch
|
||||
to_neighs --- list of sets, each set is the set of neighbors for node in batch
|
||||
num_sample --- number of neighbors to sample. No sampling if None.
|
||||
"""
|
||||
# Local pointers to functions (speed hack)
|
||||
_set = set
|
||||
if not num_sample is None:
|
||||
|
@ -35,8 +35,9 @@ class Encoder(nn.Module):
|
||||
Generates embeddings for a batch of nodes.
|
||||
|
||||
nodes -- list of nodes
|
||||
mode -- string desiginating the mode of the nodes
|
||||
"""
|
||||
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes],
|
||||
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[node] for node in nodes],
|
||||
self.num_sample)
|
||||
if not self.gcn:
|
||||
if self.cuda:
|
||||
|
@ -12,11 +12,6 @@ from collections import defaultdict
|
||||
from graphsage.encoders import Encoder
|
||||
from graphsage.aggregators import MeanAggregator
|
||||
|
||||
"""
|
||||
Simple supervised GraphSAGE model as well as examples running the model
|
||||
on the Cora and Pubmed datasets.
|
||||
"""
|
||||
|
||||
class SupervisedGraphSage(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, enc):
|
||||
|
Loading…
Reference in New Issue
Block a user