Hotfix for networkx weirdness.
This commit is contained in:
parent
da7b91464a
commit
811b50945e
@ -8,7 +8,9 @@ import os
|
||||
|
||||
import networkx as nx
|
||||
from networkx.readwrite import json_graph
|
||||
major, minor = map(int, nx.__version__.split('.'))
|
||||
version_info = map(int, nx.__version__.split('.'))
|
||||
major = version_info[0]
|
||||
minor = version_info[1]
|
||||
assert (major <= 1) and (minor <= 11), "networkx major version > 1.11"
|
||||
|
||||
WALK_LEN=5
|
||||
@ -38,6 +40,15 @@ def load_data(prefix, normalize=True, load_walks=False):
|
||||
|
||||
class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}
|
||||
|
||||
## Remove all nodes that do not have val/test annotations
|
||||
## (necessary because of networkx weirdness with the Reddit data)
|
||||
broken_count = 0
|
||||
for node in G.nodes():
|
||||
if not 'val' in G.node[node] or not 'test' in G.node[node]:
|
||||
G.remove_node(node)
|
||||
broken_count += 1
|
||||
print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))
|
||||
|
||||
## Make sure the graph has edge train_removed annotations
|
||||
## (some datasets might already have this..)
|
||||
print("Loaded data.. now preprocessing..")
|
||||
|
Loading…
Reference in New Issue
Block a user