diff --git a/graphsage/utils.py b/graphsage/utils.py index 4b73c35..4143d46 100644 --- a/graphsage/utils.py +++ b/graphsage/utils.py @@ -8,7 +8,9 @@ import os import networkx as nx from networkx.readwrite import json_graph -major, minor = map(int, nx.__version__.split('.')) +version_info = map(int, nx.__version__.split('.')) +major = version_info[0] +minor = version_info[1] assert (major <= 1) and (minor <= 11), "networkx major version > 1.11" WALK_LEN=5 @@ -38,6 +40,15 @@ def load_data(prefix, normalize=True, load_walks=False): class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()} + ## Remove all nodes that do not have val/test annotations + ## (necessary because of networkx weirdness with the Reddit data) + broken_count = 0 + for node in G.nodes(): + if not 'val' in G.node[node] or not 'test' in G.node[node]: + G.remove_node(node) + broken_count += 1 + print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count)) + ## Make sure the graph has edge train_removed annotations ## (some datasets might already have this..) print("Loaded data.. now preprocessing..")