graphsage-tf/graphsage/minibatch.py
2017-05-29 08:35:30 -07:00

292 lines
12 KiB
Python

from __future__ import division
from __future__ import print_function
import numpy as np
np.random.seed(123)
class EdgeMinibatchIterator(object):
""" This minibatch iterator iterates over batches of sampled edges or
random pairs of co-occuring edges.
"""
def __init__(self, G, id2idx,
placeholders, context_pairs=None,batch_size=100, max_degree=25, num_neg_samples=20,
n2v_retrain=False, fixed_n2v=False,
**kwargs):
self.G = G
self.nodes = G.nodes()
self.id2idx = id2idx
self.placeholders = placeholders
self.batch_size = batch_size
self.max_degree = max_degree
self.num_neg_samples = num_neg_samples
self.batch_num = 0
self.nodes = np.random.permutation(G.nodes())
self.adj, self.deg = self.construct_adj()
self.test_adj = self.construct_test_adj()
if context_pairs is None:
edges = G.edges()
else:
edges = context_pairs
self.train_edges = self.edges = np.random.permutation(edges)
if not n2v_retrain:
self.train_edges = self._remove_isolated(self.train_edges)
self.val_edges = [e for e in G.edges_iter() if G[e[0]][e[1]]['train_removed']]
else:
if fixed_n2v:
self.train_edges = self.val_edges = self._n2v_prune(self.edges)
else:
self.train_edges = self.val_edges = self.edges
print(len([n for n in G.nodes_iter() if not G.node[n]['test'] and not G.node[n]['val']]), 'train nodes')
print(len([n for n in G.nodes_iter() if G.node[n]['test'] or G.node[n]['val']]), 'test nodes')
self.val_set_size = len(self.val_edges)
def _n2v_prune(self, edges):
is_val = lambda n : self.G.node[n]["val"] or self.G.node[n]["test"]
return [e for e in edges if not is_val(e[1])]
def _remove_isolated(self, edge_list):
new_edge_list = []
for n1, n2 in edge_list:
if (self.deg[self.id2idx[n1]] == 0 or self.deg[self.id2idx[n2]] == 0) \
and (not self.G.node[n1]['test'] or self.G.node[n1]['val']) \
and (not self.G.node[n2]['test'] or self.G.node[n2]['val']):
continue
else:
new_edge_list.append((n1,n2))
return new_edge_list
def construct_adj(self):
adj = len(self.id2idx)*np.ones((len(self.id2idx)+1, self.max_degree))
deg = np.zeros((len(self.id2idx),))
for nodeid in self.G.nodes():
if self.G.node[nodeid]['test'] or self.G.node[nodeid]['val']:
continue
neighbors = np.array([self.id2idx[neighbor]
for neighbor in self.G.neighbors(nodeid)
if (not self.G[nodeid][neighbor]['train_removed'])])
deg[self.id2idx[nodeid]] = len(neighbors)
if len(neighbors) == 0:
continue
if len(neighbors) > self.max_degree:
neighbors = np.random.choice(neighbors, self.max_degree, replace=False)
elif len(neighbors) < self.max_degree:
neighbors = np.random.choice(neighbors, self.max_degree, replace=True)
adj[self.id2idx[nodeid], :] = neighbors
return adj, deg
def construct_test_adj(self):
adj = len(self.id2idx)*np.ones((len(self.id2idx)+1, self.max_degree))
for nodeid in self.G.nodes():
neighbors = np.array([self.id2idx[neighbor]
for neighbor in self.G.neighbors(nodeid)])
if len(neighbors) == 0:
continue
if len(neighbors) > self.max_degree:
neighbors = np.random.choice(neighbors, self.max_degree, replace=False)
elif len(neighbors) < self.max_degree:
neighbors = np.random.choice(neighbors, self.max_degree, replace=True)
adj[self.id2idx[nodeid], :] = neighbors
return adj
def end(self):
return self.batch_num * self.batch_size > len(self.train_edges) - self.batch_size + 1
def batch_feed_dict(self, batch_edges):
batch1 = []
batch2 = []
for node1, node2 in batch_edges:
batch1.append(self.id2idx[node1])
batch2.append(self.id2idx[node2])
feed_dict = dict()
feed_dict.update({self.placeholders['batch_size'] : len(batch_edges)})
feed_dict.update({self.placeholders['batch1']: batch1})
feed_dict.update({self.placeholders['batch2']: batch2})
return feed_dict
def next_minibatch_feed_dict(self):
start = self.batch_num * self.batch_size
self.batch_num += 1
batch_edges = self.train_edges[start : start + self.batch_size]
return self.batch_feed_dict(batch_edges)
def val_feed_dict(self, size=None):
edge_list = self.val_edges
if size is None:
return self.batch_feed_dict(edge_list)
else:
ind = np.random.permutation(len(edge_list))
val_edges = [edge_list[i] for i in ind[:min(size, len(ind))]]
return self.batch_feed_dict(val_edges)
def incremental_val_feed_dict(self, size, iter_num):
edge_list = self.val_edges
val_edges = edge_list[iter_num*size:min((iter_num+1)*size,
len(edge_list))]
return self.batch_feed_dict(val_edges), (iter_num+1)*size >= len(self.val_edges), val_edges
def incremental_embed_feed_dict(self, size, iter_num):
node_list = self.nodes
val_nodes = node_list[iter_num*size:min((iter_num+1)*size,
len(node_list))]
val_edges = [(n,n) for n in val_nodes]
return self.batch_feed_dict(val_edges), (iter_num+1)*size >= len(node_list), val_edges
def label_val(self):
train_edges = []
val_edges = []
for n1, n2 in self.G.edges_iter():
if (self.G.node[n1]['val'] or self.G.node[n1]['test']
or self.G.node[n2]['val'] or self.G.node[n2]['test']):
val_edges.append((n1,n2))
else:
train_edges.append((n1,n2))
return train_edges, val_edges
def shuffle(self):
""" Re-shuffle the training set.
Also reset the batch number.
"""
self.train_edges = np.random.permutation(self.train_edges)
self.nodes = np.random.permutation(self.nodes)
self.batch_num = 0
class NodeMinibatchIterator(object):
"""
This minibatch iterator iterates over nodes for supervised learning.
"""
def __init__(self, G, id2idx,
placeholders, label_map, num_classes, context_pairs=None,
batch_size=100, max_degree=25,
**kwargs):
self.G = G
self.nodes = G.nodes()
self.id2idx = id2idx
self.placeholders = placeholders
self.batch_size = batch_size
self.max_degree = max_degree
self.batch_num = 0
self.label_map = label_map
self.num_classes = num_classes
self.adj, self.deg = self.construct_adj()
self.test_adj = self.construct_test_adj()
self.val_nodes = [n for n in self.G.nodes_iter() if self.G.node[n]['val']]
self.test_nodes = [n for n in self.G.nodes_iter() if self.G.node[n]['test']]
self.no_train_nodes_set = set(self.val_nodes + self.test_nodes)
self.train_nodes = set(G.nodes()).difference(self.no_train_nodes_set)
# don't train on nodes that only have edges to test set
self.train_nodes = [n for n in self.train_nodes if self.deg[id2idx[n]] > 0]
def _make_label_vec(self, node):
label = self.label_map[node]
if isinstance(label, list):
label_vec = np.array(label)
else:
label_vec = np.zeros((self.num_classes))
class_ind = self.label_map[node]
label_vec[class_ind] = 1
return label_vec
def construct_adj(self):
adj = len(self.id2idx)*np.ones((len(self.id2idx)+1, self.max_degree))
deg = np.zeros((len(self.id2idx),))
for nodeid in self.G.nodes():
if self.G.node[nodeid]['test'] or self.G.node[nodeid]['val']:
continue
neighbors = np.array([self.id2idx[neighbor]
for neighbor in self.G.neighbors(nodeid)
if (not self.G[nodeid][neighbor]['train_removed'])])
deg[self.id2idx[nodeid]] = len(neighbors)
if len(neighbors) == 0:
continue
if len(neighbors) > self.max_degree:
neighbors = np.random.choice(neighbors, self.max_degree, replace=False)
elif len(neighbors) < self.max_degree:
neighbors = np.random.choice(neighbors, self.max_degree, replace=True)
adj[self.id2idx[nodeid], :] = neighbors
return adj, deg
def construct_test_adj(self):
adj = len(self.id2idx)*np.ones((len(self.id2idx)+1, self.max_degree))
for nodeid in self.G.nodes():
neighbors = np.array([self.id2idx[neighbor]
for neighbor in self.G.neighbors(nodeid)])
if len(neighbors) == 0:
continue
if len(neighbors) > self.max_degree:
neighbors = np.random.choice(neighbors, self.max_degree, replace=False)
elif len(neighbors) < self.max_degree:
neighbors = np.random.choice(neighbors, self.max_degree, replace=True)
adj[self.id2idx[nodeid], :] = neighbors
return adj
def end(self):
return self.batch_num * self.batch_size > len(self.train_nodes) - self.batch_size
def batch_feed_dict(self, batch_nodes, val=False):
batch1id = batch_nodes
batch1 = [self.id2idx[n] for n in batch1id]
labels = np.vstack([self._make_label_vec(node) for node in batch1id])
feed_dict = dict()
feed_dict.update({self.placeholders['batch_size'] : len(batch1)})
feed_dict.update({self.placeholders['batch']: batch1})
feed_dict.update({self.placeholders['labels']: labels})
return feed_dict, labels
def node_val_feed_dict(self, size=None, test=False):
if test:
val_nodes = self.test_nodes
else:
val_nodes = self.val_nodes
if not size is None:
val_nodes = np.random.choice(val_nodes, size, replace=True)
# add a dummy neighbor
ret_val = self.batch_feed_dict(val_nodes)
return ret_val[0], ret_val[1]
def incremental_node_val_feed_dict(self, size, iter_num, test=False):
if test:
val_nodes = self.test_nodes
else:
val_nodes = self.val_nodes
val_node_subset = val_nodes[iter_num*size:min((iter_num+1)*size,
len(val_nodes))]
# add a dummy neighbor
ret_val = self.batch_feed_dict(val_node_subset)
return ret_val[0], ret_val[1], (iter_num+1)*size >= len(val_nodes), val_node_subset
def next_minibatch_feed_dict(self):
start = self.batch_num * self.batch_size
self.batch_num += 1
batch_nodes = self.train_nodes[start : start + self.batch_size]
return self.batch_feed_dict(batch_nodes)
def incremental_embed_feed_dict(self, size, iter_num):
node_list = self.nodes
val_nodes = node_list[iter_num*size:min((iter_num+1)*size,
len(node_list))]
return self.batch_feed_dict(val_nodes), (iter_num+1)*size >= len(node_list), val_nodes
def shuffle(self):
""" Re-shuffle the training set.
Also reset the batch number.
"""
self.train_nodes = np.random.permutation(self.train_nodes)
self.batch_num = 0