from __future__ import print_function import numpy as np import random import json import sys from networkx.readwrite import json_graph WALK_LEN=5 N_WALKS=50 def load_data(prefix, normalize=True): G_data = json.load(open(prefix + "-G.json")) G = json_graph.node_link_graph(G_data) if isinstance(G.nodes()[0], int): conversion = lambda n : int(n) else: conversion = lambda n : n feats = np.load(prefix + "-feats.npy") id_map = json.load(open(prefix + "-id_map.json")) id_map = {conversion(k):int(v) for k,v in id_map.iteritems()} walks = [] class_map = json.load(open(prefix + "-class_map.json")) if isinstance(class_map.values()[0], list): lab_conversion = lambda n : n else: lab_conversion = lambda n : int(n) class_map = {conversion(k):lab_conversion(v) for k,v in class_map.iteritems()} ## Make sure the graph has edge train_removed annotations ## (some datasets might already have this..) print("Loaded data.. now preprocessing..") for edge in G.edges_iter(): if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or G.node[edge[0]]['test'] or G.node[edge[1]]['test']): G[edge[0]][edge[1]]['train_removed'] = True else: G[edge[0]][edge[1]]['train_removed'] = False if normalize: from sklearn.preprocessing import StandardScaler train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]) train_feats = feats[train_ids] scaler = StandardScaler() scaler.fit(train_feats) feats = scaler.transform(feats) with open(prefix + "-walks.txt") as fp: for line in fp: walks.append(map(conversion, line.split())) return G, feats, id_map, walks, class_map def run_random_walks(G, nodes, num_walks=N_WALKS): pairs = [] for count, node in enumerate(nodes): if G.degree(node) == 0: continue for i in range(num_walks): curr_node = node for j in range(WALK_LEN): next_node = random.choice(G.neighbors(curr_node)) # self co-occurrences are useless if curr_node != node: pairs.append((node,curr_node)) curr_node = next_node if count % 1000 == 0: print("Done walks for", count, "nodes") return pairs if __name__ == "__main__": """ Run random walks """ graph_file = sys.argv[1] out_file = sys.argv[2] G_data = json.load(open(graph_file)) G = json_graph.node_link_graph(G_data) nodes = [n for n in G.nodes() if not G.node[n]["val"] and not G.node[n]["test"]] G = G.subgraph(nodes) pairs = run_random_walks(G, nodes) with open(out_file, "w") as fp: fp.write("\n".join([p[0] + "\t" + p[1] for p in pairs]))