diff --git a/graphsage/model.py b/graphsage/model.py index 9a08a1c..aeca282 100644 --- a/graphsage/model.py +++ b/graphsage/model.py @@ -12,6 +12,11 @@ from collections import defaultdict from graphsage.encoders import Encoder from graphsage.aggregators import MeanAggregator +""" +Simple supervised GraphSAGE model as well as examples running the model +on the Cora and Pubmed datasets. +""" + class SupervisedGraphSage(nn.Module): def __init__(self, num_classes, enc):