Merge pull request #23 from laurimi/master

make adjacency matrix a placeholder to make saved graphdef smaller
This commit is contained in:
Rex Ying 2017-12-21 15:32:22 -08:00 committed by GitHub
commit c2aa90ee63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 4 deletions

View File

@ -144,7 +144,8 @@ def train(train_data, test_data=None):
batch_size=FLAGS.batch_size, batch_size=FLAGS.batch_size,
max_degree=FLAGS.max_degree, max_degree=FLAGS.max_degree,
context_pairs = context_pairs) context_pairs = context_pairs)
adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info") adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
if FLAGS.model == 'graphsage_mean': if FLAGS.model == 'graphsage_mean':
# Create model # Create model
@ -248,7 +249,7 @@ def train(train_data, test_data=None):
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
# Init variables # Init variables
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
# Train model # Train model

View File

@ -146,7 +146,8 @@ def train(train_data, test_data=None):
max_degree=FLAGS.max_degree, max_degree=FLAGS.max_degree,
num_neg_samples=FLAGS.neg_sample_size, num_neg_samples=FLAGS.neg_sample_size,
context_pairs = context_pairs) context_pairs = context_pairs)
adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info") adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
if FLAGS.model == 'graphsage_mean': if FLAGS.model == 'graphsage_mean':
# Create model # Create model
@ -243,7 +244,7 @@ def train(train_data, test_data=None):
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph) summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
# Init variables # Init variables
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
# Train model # Train model