diff --git a/graphsage/models.py b/graphsage/models.py index 861e382..695cbc5 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -231,7 +231,7 @@ class SampleAndAggregate(GeneralizedModel): else: self.embeds = None if features is None: - if identity_dim is None: + if identity_dim == 0: raise Exception("Must have a positive value for identity feature dimension if no input features given.") self.features = self.embeds else: diff --git a/graphsage/supervised_models.py b/graphsage/supervised_models.py index 4bff401..08fc01e 100644 --- a/graphsage/supervised_models.py +++ b/graphsage/supervised_models.py @@ -53,7 +53,7 @@ class SupervisedGraphsage(models.SampleAndAggregate): else: self.embeds = None if features is None: - if identity_dim is None: + if identity_dim == 0: raise Exception("Must have a positive value for identity feature dimension if no input features given.") self.features = self.embeds else: