graphsage/graphsage/encoders.py
2017-12-19 13:42:06 -08:00

52 lines
1.7 KiB
Python

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
class Encoder(nn.Module):
"""
Encodes a node's using 'convolutional' GraphSage approach
"""
def __init__(self, features, feature_dim,
embed_dim, adj_lists, aggregator,
num_sample=10,
base_model=None, gcn=False, cuda=False,
feature_transform=False):
super(Encoder, self).__init__()
self.features = features
self.feat_dim = feature_dim
self.adj_lists = adj_lists
self.aggregator = aggregator
self.num_sample = num_sample
if base_model != None:
self.base_model = base_model
self.gcn = gcn
self.embed_dim = embed_dim
self.cuda = cuda
self.aggregator.cuda = cuda
self.weight = nn.Parameter(
torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim))
init.xavier_uniform(self.weight)
def forward(self, nodes):
"""
Generates embeddings for a batch of nodes.
nodes -- list of nodes
mode -- string desiginating the mode of the nodes
"""
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[node] for node in nodes],
self.num_sample)
if not self.gcn:
if self.cuda:
self_feats = self.features(torch.LongTensor(nodes).cuda())
else:
self_feats = self.features(torch.LongTensor(nodes))
combined = torch.cat([self_feats, neigh_feats], dim=1)
else:
combined = neigh_feats
combined = F.relu(self.weight.mm(combined.t()))
return combined