diff --git a/graphsage/minibatch.py b/graphsage/minibatch.py index 1f53784..0a1cd96 100644 --- a/graphsage/minibatch.py +++ b/graphsage/minibatch.py @@ -108,7 +108,7 @@ class EdgeMinibatchIterator(object): return adj def end(self): - return self.batch_num * self.batch_size > len(self.train_edges) - self.batch_size + 1 + return self.batch_num * self.batch_size >= len(self.train_edges) def batch_feed_dict(self, batch_edges): batch1 = [] @@ -125,9 +125,10 @@ class EdgeMinibatchIterator(object): return feed_dict def next_minibatch_feed_dict(self): - start = self.batch_num * self.batch_size + start_idx = self.batch_num * self.batch_size self.batch_num += 1 - batch_edges = self.train_edges[start : start + self.batch_size] + end_idx = min(start_idx + self.batch_size, len(self.train_edges)) + batch_edges = self.train_edges[start_idx : end_idx] return self.batch_feed_dict(batch_edges) def num_training_batches(self): @@ -258,7 +259,7 @@ class NodeMinibatchIterator(object): return adj def end(self): - return self.batch_num * self.batch_size > len(self.train_nodes) - self.batch_size + return self.batch_num * self.batch_size >= len(self.train_nodes) def batch_feed_dict(self, batch_nodes, val=False): batch1id = batch_nodes @@ -299,9 +300,10 @@ class NodeMinibatchIterator(object): return len(self.train_nodes) // self.batch_size + 1 def next_minibatch_feed_dict(self): - start = self.batch_num * self.batch_size + start_idx = self.batch_num * self.batch_size self.batch_num += 1 - batch_nodes = self.train_nodes[start : start + self.batch_size] + end_idx = min(start_idx + self.batch_size, len(self.train_nodes)) + batch_nodes = self.train_nodes[start_idx : end_idx] return self.batch_feed_dict(batch_nodes) def incremental_embed_feed_dict(self, size, iter_num):