graphsage-tf/graphsage/utils.py

105 lines
3.7 KiB
Python
Raw Normal View History

from __future__ import print_function
2017-05-29 23:35:30 +08:00
import numpy as np
import random
import json
import sys
2017-09-17 05:17:14 +08:00
import os
2017-05-29 23:35:30 +08:00
import networkx as nx
2017-05-29 23:35:30 +08:00
from networkx.readwrite import json_graph
2017-11-16 08:10:43 +08:00
version_info = map(int, nx.__version__.split('.'))
major = version_info[0]
minor = version_info[1]
assert (major <= 1) and (minor <= 11), "networkx major version > 1.11"
2017-05-29 23:35:30 +08:00
WALK_LEN=5
N_WALKS=50
2017-09-17 05:17:14 +08:00
def load_data(prefix, normalize=True, load_walks=False):
2017-05-29 23:35:30 +08:00
G_data = json.load(open(prefix + "-G.json"))
G = json_graph.node_link_graph(G_data)
if isinstance(G.nodes()[0], int):
conversion = lambda n : int(n)
else:
conversion = lambda n : n
2017-09-17 05:17:14 +08:00
if os.path.exists(prefix + "-feats.npy"):
feats = np.load(prefix + "-feats.npy")
else:
print("No features present.. Only identity features will be used.")
feats = None
2017-05-29 23:35:30 +08:00
id_map = json.load(open(prefix + "-id_map.json"))
2017-10-12 05:05:36 +08:00
id_map = {conversion(k):int(v) for k,v in id_map.items()}
2017-05-29 23:35:30 +08:00
walks = []
class_map = json.load(open(prefix + "-class_map.json"))
2017-10-12 05:05:36 +08:00
if isinstance(list(class_map.values())[0], list):
2017-05-29 23:35:30 +08:00
lab_conversion = lambda n : n
else:
lab_conversion = lambda n : int(n)
2017-10-12 05:05:36 +08:00
class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}
2017-05-29 23:35:30 +08:00
2017-11-16 08:10:43 +08:00
## Remove all nodes that do not have val/test annotations
## (necessary because of networkx weirdness with the Reddit data)
broken_count = 0
for node in G.nodes():
if not 'val' in G.node[node] or not 'test' in G.node[node]:
G.remove_node(node)
broken_count += 1
print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))
2017-05-29 23:35:30 +08:00
## Make sure the graph has edge train_removed annotations
## (some datasets might already have this..)
print("Loaded data.. now preprocessing..")
2017-10-12 05:05:36 +08:00
for edge in G.edges():
2017-05-29 23:35:30 +08:00
if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
G[edge[0]][edge[1]]['train_removed'] = True
else:
G[edge[0]][edge[1]]['train_removed'] = False
2017-09-17 05:17:14 +08:00
if normalize and not feats is None:
2017-05-29 23:35:30 +08:00
from sklearn.preprocessing import StandardScaler
train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])
train_feats = feats[train_ids]
scaler = StandardScaler()
scaler.fit(train_feats)
feats = scaler.transform(feats)
2017-09-17 05:17:14 +08:00
if load_walks:
with open(prefix + "-walks.txt") as fp:
for line in fp:
walks.append(map(conversion, line.split()))
2017-05-29 23:35:30 +08:00
return G, feats, id_map, walks, class_map
def run_random_walks(G, nodes, num_walks=N_WALKS):
pairs = []
for count, node in enumerate(nodes):
if G.degree(node) == 0:
continue
for i in range(num_walks):
curr_node = node
for j in range(WALK_LEN):
next_node = random.choice(G.neighbors(curr_node))
# self co-occurrences are useless
if curr_node != node:
pairs.append((node,curr_node))
curr_node = next_node
if count % 1000 == 0:
print("Done walks for", count, "nodes")
2017-05-29 23:35:30 +08:00
return pairs
if __name__ == "__main__":
""" Run random walks """
graph_file = sys.argv[1]
out_file = sys.argv[2]
G_data = json.load(open(graph_file))
G = json_graph.node_link_graph(G_data)
nodes = [n for n in G.nodes() if not G.node[n]["val"] and not G.node[n]["test"]]
G = G.subgraph(nodes)
pairs = run_random_walks(G, nodes)
with open(out_file, "w") as fp:
2017-10-14 04:29:31 +08:00
fp.write("\n".join([str(p[0]) + "\t" + str(p[1]) for p in pairs]))