32 lines
979 B
Python
32 lines
979 B
Python
|
import tensorflow as tf
|
||
|
import numpy as np
|
||
|
|
||
|
# DISCLAIMER:
|
||
|
# Parts of this code file are derived from
|
||
|
# https://github.com/tkipf/gcn
|
||
|
# (A full license with proper attributions will be provided in the
|
||
|
# public repo of this code base)
|
||
|
|
||
|
def uniform(shape, scale=0.05, name=None):
|
||
|
"""Uniform init."""
|
||
|
initial = tf.random_uniform(shape, minval=-scale, maxval=scale, dtype=tf.float32)
|
||
|
return tf.Variable(initial, name=name)
|
||
|
|
||
|
|
||
|
def glorot(shape, name=None):
|
||
|
"""Glorot & Bengio (AISTATS 2010) init."""
|
||
|
init_range = np.sqrt(6.0/(shape[0]+shape[1]))
|
||
|
initial = tf.random_uniform(shape, minval=-init_range, maxval=init_range, dtype=tf.float32)
|
||
|
return tf.Variable(initial, name=name)
|
||
|
|
||
|
|
||
|
def zeros(shape, name=None):
|
||
|
"""All zeros."""
|
||
|
initial = tf.zeros(shape, dtype=tf.float32)
|
||
|
return tf.Variable(initial, name=name)
|
||
|
|
||
|
def ones(shape, name=None):
|
||
|
"""All ones."""
|
||
|
initial = tf.ones(shape, dtype=tf.float32)
|
||
|
return tf.Variable(initial, name=name)
|