graphsage-tf/graphsage/utils.py
2017-05-29 08:35:30 -07:00

71 lines
2.3 KiB
Python

import numpy as np
import random
import json
from networkx.readwrite import json_graph
WALK_LEN=5
N_WALKS=50
def load_data(prefix, normalize=True):
G_data = json.load(open(prefix + "-G.json"))
G = json_graph.node_link_graph(G_data)
if isinstance(G.nodes()[0], int):
conversion = lambda n : int(n)
else:
conversion = lambda n : n
feats = np.load(prefix + "-feats.npy")
id_map = json.load(open(prefix + "-id_map.json"))
id_map = {conversion(k):int(v) for k,v in id_map.iteritems()}
walks = []
class_map = json.load(open(prefix + "-class_map.json"))
if isinstance(class_map.values()[0], list):
lab_conversion = lambda n : n
else:
lab_conversion = lambda n : int(n)
class_map = {conversion(k):lab_conversion(v) for k,v in class_map.iteritems()}
## Make sure the graph has edge train_removed annotations
## (some datasets might already have this..)
print("Loaded data.. now preprocessing..")
for edge in G.edges_iter():
if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
G[edge[0]][edge[1]]['train_removed'] = True
else:
G[edge[0]][edge[1]]['train_removed'] = False
if normalize:
from sklearn.preprocessing import StandardScaler
train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])
train_feats = feats[train_ids]
scaler = StandardScaler()
scaler.fit(train_feats)
feats = scaler.transform(feats)
with open(prefix + "-walks.txt") as fp:
for line in fp:
walks.append(map(conversion, line.split()))
return G, feats, id_map, walks, class_map
def run_random_walks(G, nodes, num_walks=N_WALKS):
print("Subgraph for walks is of size", len(G))
pairs = []
for count, node in enumerate(nodes):
if G.degree(node) == 0:
continue
for i in range(num_walks):
curr_node = node
for j in range(WALK_LEN):
next_node = random.choice(G.neighbors(curr_node))
# self co-occurrences are useless
if curr_node != node:
pairs.append((node,curr_node))
curr_node = next_node
if count % 1000 == 0:
print(count)
return pairs