make adjacency matrix a placeholder to make saved graphdef smaller
This commit is contained in:
parent
48f72045d3
commit
050895f671
@ -144,7 +144,8 @@ def train(train_data, test_data=None):
|
|||||||
batch_size=FLAGS.batch_size,
|
batch_size=FLAGS.batch_size,
|
||||||
max_degree=FLAGS.max_degree,
|
max_degree=FLAGS.max_degree,
|
||||||
context_pairs = context_pairs)
|
context_pairs = context_pairs)
|
||||||
adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info")
|
adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
|
||||||
|
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
|
||||||
|
|
||||||
if FLAGS.model == 'graphsage_mean':
|
if FLAGS.model == 'graphsage_mean':
|
||||||
# Create model
|
# Create model
|
||||||
@ -248,7 +249,7 @@ def train(train_data, test_data=None):
|
|||||||
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
|
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
|
||||||
|
|
||||||
# Init variables
|
# Init variables
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
|
|
||||||
|
@ -146,7 +146,8 @@ def train(train_data, test_data=None):
|
|||||||
max_degree=FLAGS.max_degree,
|
max_degree=FLAGS.max_degree,
|
||||||
num_neg_samples=FLAGS.neg_sample_size,
|
num_neg_samples=FLAGS.neg_sample_size,
|
||||||
context_pairs = context_pairs)
|
context_pairs = context_pairs)
|
||||||
adj_info = tf.Variable(tf.constant(minibatch.adj, dtype=tf.int32), trainable=False, name="adj_info")
|
adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)
|
||||||
|
adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")
|
||||||
|
|
||||||
if FLAGS.model == 'graphsage_mean':
|
if FLAGS.model == 'graphsage_mean':
|
||||||
# Create model
|
# Create model
|
||||||
@ -243,7 +244,7 @@ def train(train_data, test_data=None):
|
|||||||
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
|
summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)
|
||||||
|
|
||||||
# Init variables
|
# Init variables
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user