Hotfix for networkx weirdness.
This commit is contained in:
parent
da7b91464a
commit
811b50945e
@ -8,7 +8,9 @@ import os
|
|||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
from networkx.readwrite import json_graph
|
from networkx.readwrite import json_graph
|
||||||
major, minor = map(int, nx.__version__.split('.'))
|
version_info = map(int, nx.__version__.split('.'))
|
||||||
|
major = version_info[0]
|
||||||
|
minor = version_info[1]
|
||||||
assert (major <= 1) and (minor <= 11), "networkx major version > 1.11"
|
assert (major <= 1) and (minor <= 11), "networkx major version > 1.11"
|
||||||
|
|
||||||
WALK_LEN=5
|
WALK_LEN=5
|
||||||
@ -38,6 +40,15 @@ def load_data(prefix, normalize=True, load_walks=False):
|
|||||||
|
|
||||||
class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}
|
class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}
|
||||||
|
|
||||||
|
## Remove all nodes that do not have val/test annotations
|
||||||
|
## (necessary because of networkx weirdness with the Reddit data)
|
||||||
|
broken_count = 0
|
||||||
|
for node in G.nodes():
|
||||||
|
if not 'val' in G.node[node] or not 'test' in G.node[node]:
|
||||||
|
G.remove_node(node)
|
||||||
|
broken_count += 1
|
||||||
|
print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))
|
||||||
|
|
||||||
## Make sure the graph has edge train_removed annotations
|
## Make sure the graph has edge train_removed annotations
|
||||||
## (some datasets might already have this..)
|
## (some datasets might already have this..)
|
||||||
print("Loaded data.. now preprocessing..")
|
print("Loaded data.. now preprocessing..")
|
||||||
|
Loading…
Reference in New Issue
Block a user