85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import random
|
|
import json
|
|
import sys
|
|
|
|
from networkx.readwrite import json_graph
|
|
|
|
WALK_LEN=5
|
|
N_WALKS=50
|
|
|
|
def load_data(prefix, normalize=True):
|
|
G_data = json.load(open(prefix + "-G.json"))
|
|
G = json_graph.node_link_graph(G_data)
|
|
if isinstance(G.nodes()[0], int):
|
|
conversion = lambda n : int(n)
|
|
else:
|
|
conversion = lambda n : n
|
|
|
|
feats = np.load(prefix + "-feats.npy")
|
|
id_map = json.load(open(prefix + "-id_map.json"))
|
|
id_map = {conversion(k):int(v) for k,v in id_map.iteritems()}
|
|
walks = []
|
|
class_map = json.load(open(prefix + "-class_map.json"))
|
|
if isinstance(class_map.values()[0], list):
|
|
lab_conversion = lambda n : n
|
|
else:
|
|
lab_conversion = lambda n : int(n)
|
|
|
|
class_map = {conversion(k):lab_conversion(v) for k,v in class_map.iteritems()}
|
|
|
|
## Make sure the graph has edge train_removed annotations
|
|
## (some datasets might already have this..)
|
|
print("Loaded data.. now preprocessing..")
|
|
for edge in G.edges_iter():
|
|
if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
|
|
G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
|
|
G[edge[0]][edge[1]]['train_removed'] = True
|
|
else:
|
|
G[edge[0]][edge[1]]['train_removed'] = False
|
|
|
|
if normalize:
|
|
from sklearn.preprocessing import StandardScaler
|
|
train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])
|
|
train_feats = feats[train_ids]
|
|
scaler = StandardScaler()
|
|
scaler.fit(train_feats)
|
|
feats = scaler.transform(feats)
|
|
|
|
with open(prefix + "-walks.txt") as fp:
|
|
for line in fp:
|
|
walks.append(map(conversion, line.split()))
|
|
|
|
return G, feats, id_map, walks, class_map
|
|
|
|
def run_random_walks(G, nodes, num_walks=N_WALKS):
|
|
pairs = []
|
|
for count, node in enumerate(nodes):
|
|
if G.degree(node) == 0:
|
|
continue
|
|
for i in range(num_walks):
|
|
curr_node = node
|
|
for j in range(WALK_LEN):
|
|
next_node = random.choice(G.neighbors(curr_node))
|
|
# self co-occurrences are useless
|
|
if curr_node != node:
|
|
pairs.append((node,curr_node))
|
|
curr_node = next_node
|
|
if count % 1000 == 0:
|
|
print("Done walks for", count, "nodes")
|
|
return pairs
|
|
|
|
if __name__ == "__main__":
|
|
""" Run random walks """
|
|
graph_file = sys.argv[1]
|
|
out_file = sys.argv[2]
|
|
G_data = json.load(open(graph_file))
|
|
G = json_graph.node_link_graph(G_data)
|
|
nodes = [n for n in G.nodes() if not G.node[n]["val"] and not G.node[n]["test"]]
|
|
G = G.subgraph(nodes)
|
|
pairs = run_random_walks(G, nodes)
|
|
with open(out_file, "w") as fp:
|
|
fp.write("\n".join([str(p[0]) + "\t" + str(p[1]) for p in pairs]))
|