modify minibatch iterator end condition so that the last batch can be of variable size

This commit is contained in:
RexYing 2017-11-27 09:24:19 -08:00
parent 811b50945e
commit 75794f8a09

View File

@ -108,7 +108,7 @@ class EdgeMinibatchIterator(object):
return adj
def end(self):
return self.batch_num * self.batch_size > len(self.train_edges) - self.batch_size + 1
return self.batch_num * self.batch_size >= len(self.train_edges)
def batch_feed_dict(self, batch_edges):
batch1 = []
@ -125,9 +125,10 @@ class EdgeMinibatchIterator(object):
return feed_dict
def next_minibatch_feed_dict(self):
start = self.batch_num * self.batch_size
start_idx = self.batch_num * self.batch_size
self.batch_num += 1
batch_edges = self.train_edges[start : start + self.batch_size]
end_idx = min(start_idx + self.batch_size, len(self.train_edges))
batch_edges = self.train_edges[start_idx : end_idx]
return self.batch_feed_dict(batch_edges)
def num_training_batches(self):
@ -258,7 +259,7 @@ class NodeMinibatchIterator(object):
return adj
def end(self):
return self.batch_num * self.batch_size > len(self.train_nodes) - self.batch_size
return self.batch_num * self.batch_size >= len(self.train_nodes)
def batch_feed_dict(self, batch_nodes, val=False):
batch1id = batch_nodes
@ -299,9 +300,10 @@ class NodeMinibatchIterator(object):
return len(self.train_nodes) // self.batch_size + 1
def next_minibatch_feed_dict(self):
start = self.batch_num * self.batch_size
start_idx = self.batch_num * self.batch_size
self.batch_num += 1
batch_nodes = self.train_nodes[start : start + self.batch_size]
end_idx = min(start_idx + self.batch_size, len(self.train_nodes))
batch_nodes = self.train_nodes[start_idx : end_idx]
return self.batch_feed_dict(batch_nodes)
def incremental_embed_feed_dict(self, size, iter_num):