graphsage-tf/graphsage/neigh_samplers.py
2017-05-31 06:39:04 -07:00

30 lines
818 B
Python

from __future__ import division
from __future__ import print_function
from graphsage.layers import Layer
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
"""
Classes that are used to sample node neighborhoods
"""
class UniformNeighborSampler(Layer):
"""
Uniformly samples neighbors.
Assumes that adj lists are padded with random re-sampling
"""
def __init__(self, adj_info, **kwargs):
super(UniformNeighborSampler, self).__init__(**kwargs)
self.adj_info = adj_info
def _call(self, inputs):
ids, num_samples = inputs
adj_lists = tf.nn.embedding_lookup(self.adj_info, ids)
adj_lists = tf.transpose(tf.random_shuffle(tf.transpose(adj_lists)))
adj_lists = tf.slice(adj_lists, [0,0], [-1, num_samples])
return adj_lists