modify minibatch iterator end condition so that the last batch can be of variable size
This commit is contained in:
parent
811b50945e
commit
75794f8a09
@ -108,7 +108,7 @@ class EdgeMinibatchIterator(object):
|
|||||||
return adj
|
return adj
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
return self.batch_num * self.batch_size > len(self.train_edges) - self.batch_size + 1
|
return self.batch_num * self.batch_size >= len(self.train_edges)
|
||||||
|
|
||||||
def batch_feed_dict(self, batch_edges):
|
def batch_feed_dict(self, batch_edges):
|
||||||
batch1 = []
|
batch1 = []
|
||||||
@ -125,9 +125,10 @@ class EdgeMinibatchIterator(object):
|
|||||||
return feed_dict
|
return feed_dict
|
||||||
|
|
||||||
def next_minibatch_feed_dict(self):
|
def next_minibatch_feed_dict(self):
|
||||||
start = self.batch_num * self.batch_size
|
start_idx = self.batch_num * self.batch_size
|
||||||
self.batch_num += 1
|
self.batch_num += 1
|
||||||
batch_edges = self.train_edges[start : start + self.batch_size]
|
end_idx = min(start_idx + self.batch_size, len(self.train_edges))
|
||||||
|
batch_edges = self.train_edges[start_idx : end_idx]
|
||||||
return self.batch_feed_dict(batch_edges)
|
return self.batch_feed_dict(batch_edges)
|
||||||
|
|
||||||
def num_training_batches(self):
|
def num_training_batches(self):
|
||||||
@ -258,7 +259,7 @@ class NodeMinibatchIterator(object):
|
|||||||
return adj
|
return adj
|
||||||
|
|
||||||
def end(self):
|
def end(self):
|
||||||
return self.batch_num * self.batch_size > len(self.train_nodes) - self.batch_size
|
return self.batch_num * self.batch_size >= len(self.train_nodes)
|
||||||
|
|
||||||
def batch_feed_dict(self, batch_nodes, val=False):
|
def batch_feed_dict(self, batch_nodes, val=False):
|
||||||
batch1id = batch_nodes
|
batch1id = batch_nodes
|
||||||
@ -299,9 +300,10 @@ class NodeMinibatchIterator(object):
|
|||||||
return len(self.train_nodes) // self.batch_size + 1
|
return len(self.train_nodes) // self.batch_size + 1
|
||||||
|
|
||||||
def next_minibatch_feed_dict(self):
|
def next_minibatch_feed_dict(self):
|
||||||
start = self.batch_num * self.batch_size
|
start_idx = self.batch_num * self.batch_size
|
||||||
self.batch_num += 1
|
self.batch_num += 1
|
||||||
batch_nodes = self.train_nodes[start : start + self.batch_size]
|
end_idx = min(start_idx + self.batch_size, len(self.train_nodes))
|
||||||
|
batch_nodes = self.train_nodes[start_idx : end_idx]
|
||||||
return self.batch_feed_dict(batch_nodes)
|
return self.batch_feed_dict(batch_nodes)
|
||||||
|
|
||||||
def incremental_embed_feed_dict(self, size, iter_num):
|
def incremental_embed_feed_dict(self, size, iter_num):
|
||||||
|
Loading…
Reference in New Issue
Block a user