Hotfix for networkx weirdness.

This commit is contained in:
William Hamilton 2017-11-15 16:10:43 -08:00
parent da7b91464a
commit 811b50945e

View File

@ -8,7 +8,9 @@ import os
import networkx as nx
from networkx.readwrite import json_graph
major, minor = map(int, nx.__version__.split('.'))
version_info = map(int, nx.__version__.split('.'))
major = version_info[0]
minor = version_info[1]
assert (major <= 1) and (minor <= 11), "networkx major version > 1.11"
WALK_LEN=5
@ -38,6 +40,15 @@ def load_data(prefix, normalize=True, load_walks=False):
class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}
## Remove all nodes that do not have val/test annotations
## (necessary because of networkx weirdness with the Reddit data)
broken_count = 0
for node in G.nodes():
if not 'val' in G.node[node] or not 'test' in G.node[node]:
G.remove_node(node)
broken_count += 1
print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))
## Make sure the graph has edge train_removed annotations
## (some datasets might already have this..)
print("Loaded data.. now preprocessing..")