diff --git a/graphsage/supervised_train.py b/graphsage/supervised_train.py index 9990cf9..c5bff00 100644 --- a/graphsage/supervised_train.py +++ b/graphsage/supervised_train.py @@ -144,7 +144,8 @@ def train(train_data, test_data=None): batch_size=FLAGS.batch_size, max_degree=FLAGS.max_degree, context_pairs = context_pairs) - adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info") + adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) + adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") if FLAGS.model == 'graphsage_mean': # Create model @@ -248,7 +249,7 @@ def train(train_data, test_data=None): summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) # Init variables - sess.run(tf.global_variables_initializer()) + sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) # Train model diff --git a/graphsage/unsupervised_train.py b/graphsage/unsupervised_train.py index b3162fd..44ef609 100644 --- a/graphsage/unsupervised_train.py +++ b/graphsage/unsupervised_train.py @@ -146,7 +146,8 @@ def train(train_data, test_data=None): max_degree=FLAGS.max_degree, num_neg_samples=FLAGS.neg_sample_size, context_pairs = context_pairs) - adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info") + adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) + adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") if FLAGS.model == 'graphsage_mean': # Create model @@ -243,7 +244,7 @@ def train(train_data, test_data=None): summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) # Init variables - sess.run(tf.global_variables_initializer()) + sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) # Train model