make adjacency matrix a placeholder to make saved graphdef smaller
This commit is contained in:
parent
48f72045d3
commit
050895f671
@ -144,7 +144,8 @@ def train(train_data, test_data=None):
|
||||
batch_size=FLAGS.batch_size,
|
||||
max_degree=FLAGS.max_degree,
|
||||
context_pairs = context_pairs)
|
||||
adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info")
|
||||
adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
|
||||
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
|
||||
|
||||
if FLAGS.model == 'graphsage_mean':
|
||||
# Create model
|
||||
@ -248,7 +249,7 @@ def train(train_data, test_data=None):
|
||||
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
|
||||
|
||||
# Init variables
|
||||
sess.run(tf.global_variables_initializer())
|
||||
sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
|
||||
|
||||
# Train model
|
||||
|
||||
|
@ -146,7 +146,8 @@ def train(train_data, test_data=None):
|
||||
max_degree=FLAGS.max_degree,
|
||||
num_neg_samples=FLAGS.neg_sample_size,
|
||||
context_pairs = context_pairs)
|
||||
adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info")
|
||||
adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
|
||||
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
|
||||
|
||||
if FLAGS.model == 'graphsage_mean':
|
||||
# Create model
|
||||
@ -243,7 +244,7 @@ def train(train_data, test_data=None):
|
||||
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
|
||||
|
||||
# Init variables
|
||||
sess.run(tf.global_variables_initializer())
|
||||
sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
|
||||
|
||||
# Train model
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user