make adjacency matrix a placeholder to make saved graphdef smaller

This commit is contained in:
Mikko Lauri 2017-12-20 10:48:36 +01:00
parent 48f72045d3
commit 050895f671
2 changed files with 6 additions and 4 deletions

View File

@ -144,7 +144,8 @@ def train(train_data, test_data=None):
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
max_degree=FLAGS.max_degree, max_degree=FLAGS.max_degree,
context_pairs = context_pairs) context_pairs = context_pairs)
adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info") adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
if FLAGS.model == 'graphsage_mean': if FLAGS.model == 'graphsage_mean':
# Create model # Create model
@ -248,7 +249,7 @@ def train(train_data, test_data=None):
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
# Init variables # Init variables
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
# Train model # Train model

View File

@ -146,7 +146,8 @@ def train(train_data, test_data=None):
max_degree=FLAGS.max_degree, max_degree=FLAGS.max_degree,
num_neg_samples=FLAGS.neg_sample_size, num_neg_samples=FLAGS.neg_sample_size,
context_pairs = context_pairs) context_pairs = context_pairs)
adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info") adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
if FLAGS.model == 'graphsage_mean': if FLAGS.model == 'graphsage_mean':
# Create model # Create model
@ -243,7 +244,7 @@ def train(train_data, test_data=None):
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
# Init variables # Init variables
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
# Train model # Train model