From 050895f671160d20e7951829894377c912902ad4 Mon Sep 17 00:00:00 2001 From: Mikko Lauri Date: Wed, 20 Dec 2017 10:48:36 +0100 Subject: [PATCH] make adjacency matrix a placeholder to make saved graphdef smaller --- graphsage/supervised_train.py | 5 +++-- graphsage/unsupervised_train.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/graphsage/supervised_train.py b/graphsage/supervised_train.py index 9990cf9..c5bff00 100644 --- a/graphsage/supervised_train.py +++ b/graphsage/supervised_train.py @@ -144,7 +144,8 @@ def train(train_data, test_data=None): batch_size=FLAGS.batch_size, max_degree=FLAGS.max_degree, context_pairs = context_pairs) - adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info") + adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) + adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") if FLAGS.model == 'graphsage_mean': # Create model @@ -248,7 +249,7 @@ def train(train_data, test_data=None): summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) # Init variables - sess.run(tf.global_variables_initializer()) + sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) # Train model diff --git a/graphsage/unsupervised_train.py b/graphsage/unsupervised_train.py index b3162fd..44ef609 100644 --- a/graphsage/unsupervised_train.py +++ b/graphsage/unsupervised_train.py @@ -146,7 +146,8 @@ def train(train_data, test_data=None): max_degree=FLAGS.max_degree, num_neg_samples=FLAGS.neg_sample_size, context_pairs = context_pairs) - adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info") + adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape) + adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info") if FLAGS.model == 'graphsage_mean': # Create model @@ -243,7 +244,7 @@ def train(train_data, test_data=None): summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) # Init variables - sess.run(tf.global_variables_initializer()) + sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj}) # Train model