Python 3 support.

This commit is contained in:
williamleif 2017-10-11 14:05:36 -07:00
parent 05c72baeb7
commit c29361aa7d
3 changed files with 12 additions and 12 deletions

View File

@ -42,15 +42,15 @@ class EdgeMinibatchIterator(object):
self.train_edges = self.edges = np.random.permutation(edges) self.train_edges = self.edges = np.random.permutation(edges)
if not n2v_retrain: if not n2v_retrain:
self.train_edges = self._remove_isolated(self.train_edges) self.train_edges = self._remove_isolated(self.train_edges)
self.val_edges = [e for e in G.edges_iter() if G[e[0]][e[1]]['train_removed']] self.val_edges = [e for e in G.edges() if G[e[0]][e[1]]['train_removed']]
else: else:
if fixed_n2v: if fixed_n2v:
self.train_edges = self.val_edges = self._n2v_prune(self.edges) self.train_edges = self.val_edges = self._n2v_prune(self.edges)
else: else:
self.train_edges = self.val_edges = self.edges self.train_edges = self.val_edges = self.edges
print(len([n for n in G.nodes_iter() if not G.node[n]['test'] and not G.node[n]['val']]), 'train nodes') print(len([n for n in G.nodes() if not G.node[n]['test'] and not G.node[n]['val']]), 'train nodes')
print(len([n for n in G.nodes_iter() if G.node[n]['test'] or G.node[n]['val']]), 'test nodes') print(len([n for n in G.nodes() if G.node[n]['test'] or G.node[n]['val']]), 'test nodes')
self.val_set_size = len(self.val_edges) self.val_set_size = len(self.val_edges)
def _n2v_prune(self, edges): def _n2v_prune(self, edges):
@ -150,7 +150,7 @@ class EdgeMinibatchIterator(object):
def label_val(self): def label_val(self):
train_edges = [] train_edges = []
val_edges = [] val_edges = []
for n1, n2 in self.G.edges_iter(): for n1, n2 in self.G.edges():
if (self.G.node[n1]['val'] or self.G.node[n1]['test'] if (self.G.node[n1]['val'] or self.G.node[n1]['test']
or self.G.node[n2]['val'] or self.G.node[n2]['test']): or self.G.node[n2]['val'] or self.G.node[n2]['test']):
val_edges.append((n1,n2)) val_edges.append((n1,n2))
@ -197,8 +197,8 @@ class NodeMinibatchIterator(object):
self.adj, self.deg = self.construct_adj() self.adj, self.deg = self.construct_adj()
self.test_adj = self.construct_test_adj() self.test_adj = self.construct_test_adj()
self.val_nodes = [n for n in self.G.nodes_iter() if self.G.node[n]['val']] self.val_nodes = [n for n in self.G.nodes() if self.G.node[n]['val']]
self.test_nodes = [n for n in self.G.nodes_iter() if self.G.node[n]['test']] self.test_nodes = [n for n in self.G.nodes() if self.G.node[n]['test']]
self.no_train_nodes_set = set(self.val_nodes + self.test_nodes) self.no_train_nodes_set = set(self.val_nodes + self.test_nodes)
self.train_nodes = set(G.nodes()).difference(self.no_train_nodes_set) self.train_nodes = set(G.nodes()).difference(self.no_train_nodes_set)

View File

@ -125,8 +125,8 @@ def train(train_data, test_data=None):
features = train_data[1] features = train_data[1]
id_map = train_data[2] id_map = train_data[2]
class_map = train_data[4] class_map = train_data[4]
if isinstance(class_map.values()[0], list): if isinstance(list(class_map.values())[0], list):
num_classes = len(class_map.values()[0]) num_classes = len(list(class_map.values())[0])
else: else:
num_classes = len(set(class_map.values())) num_classes = len(set(class_map.values()))

View File

@ -25,20 +25,20 @@ def load_data(prefix, normalize=True, load_walks=False):
print("No features present.. Only identity features will be used.") print("No features present.. Only identity features will be used.")
feats = None feats = None
id_map = json.load(open(prefix + "-id_map.json")) id_map = json.load(open(prefix + "-id_map.json"))
id_map = {conversion(k):int(v) for k,v in id_map.iteritems()} id_map = {conversion(k):int(v) for k,v in id_map.items()}
walks = [] walks = []
class_map = json.load(open(prefix + "-class_map.json")) class_map = json.load(open(prefix + "-class_map.json"))
if isinstance(class_map.values()[0], list): if isinstance(list(class_map.values())[0], list):
lab_conversion = lambda n : n lab_conversion = lambda n : n
else: else:
lab_conversion = lambda n : int(n) lab_conversion = lambda n : int(n)
class_map = {conversion(k):lab_conversion(v) for k,v in class_map.iteritems()} class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}
## Make sure the graph has edge train_removed annotations ## Make sure the graph has edge train_removed annotations
## (some datasets might already have this..) ## (some datasets might already have this..)
print("Loaded data.. now preprocessing..") print("Loaded data.. now preprocessing..")
for edge in G.edges_iter(): for edge in G.edges():
if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or
G.node[edge[0]]['test'] or G.node[edge[1]]['test']): G.node[edge[0]]['test'] or G.node[edge[1]]['test']):
G[edge[0]][edge[1]]['train_removed'] = True G[edge[0]][edge[1]]['train_removed'] = True