modify minibatch iterator end condition so that the last batch can be of variable size
This commit is contained in:
parent
811b50945e
commit
75794f8a09
@ -108,7 +108,7 @@ class EdgeMinibatchIterator(object):
|
||||
return adj
|
||||
|
||||
def end(self):
|
||||
return self.batch_num * self.batch_size > len(self.train_edges) - self.batch_size + 1
|
||||
return self.batch_num * self.batch_size >= len(self.train_edges)
|
||||
|
||||
def batch_feed_dict(self, batch_edges):
|
||||
batch1 = []
|
||||
@ -125,9 +125,10 @@ class EdgeMinibatchIterator(object):
|
||||
return feed_dict
|
||||
|
||||
def next_minibatch_feed_dict(self):
|
||||
start = self.batch_num * self.batch_size
|
||||
start_idx = self.batch_num * self.batch_size
|
||||
self.batch_num += 1
|
||||
batch_edges = self.train_edges[start : start + self.batch_size]
|
||||
end_idx = min(start_idx + self.batch_size, len(self.train_edges))
|
||||
batch_edges = self.train_edges[start_idx : end_idx]
|
||||
return self.batch_feed_dict(batch_edges)
|
||||
|
||||
def num_training_batches(self):
|
||||
@ -258,7 +259,7 @@ class NodeMinibatchIterator(object):
|
||||
return adj
|
||||
|
||||
def end(self):
|
||||
return self.batch_num * self.batch_size > len(self.train_nodes) - self.batch_size
|
||||
return self.batch_num * self.batch_size >= len(self.train_nodes)
|
||||
|
||||
def batch_feed_dict(self, batch_nodes, val=False):
|
||||
batch1id = batch_nodes
|
||||
@ -299,9 +300,10 @@ class NodeMinibatchIterator(object):
|
||||
return len(self.train_nodes) // self.batch_size + 1
|
||||
|
||||
def next_minibatch_feed_dict(self):
|
||||
start = self.batch_num * self.batch_size
|
||||
start_idx = self.batch_num * self.batch_size
|
||||
self.batch_num += 1
|
||||
batch_nodes = self.train_nodes[start : start + self.batch_size]
|
||||
end_idx = min(start_idx + self.batch_size, len(self.train_nodes))
|
||||
batch_nodes = self.train_nodes[start_idx : end_idx]
|
||||
return self.batch_feed_dict(batch_nodes)
|
||||
|
||||
def incremental_embed_feed_dict(self, size, iter_num):
|
||||
|
Loading…
Reference in New Issue
Block a user