First commit.
This commit is contained in:
commit
65cb65c9c0
110
.gitignore
vendored
Normal file
110
.gitignore
vendored
Normal file
@ -0,0 +1,110 @@
|
||||
# Custom
|
||||
*.csv
|
||||
*.idea
|
||||
*.png
|
||||
*.pdf
|
||||
tmp/
|
||||
*.txt
|
||||
*swp*
|
||||
*.sw?
|
||||
gcn_back
|
||||
.DS_STORE
|
||||
*.aux
|
||||
*.log
|
||||
*.out
|
||||
*.bbl
|
||||
*.synctex.gz
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
env/
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*,cover
|
||||
.hypothesis/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# IPython Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# celery beat schedule file
|
||||
celerybeat-schedule
|
||||
|
||||
# dotenv
|
||||
.env
|
||||
|
||||
# virtualenv
|
||||
venv/
|
||||
ENV/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
*.pickle
|
||||
*.pkl
|
||||
|
0
graphsage/__init__.py
Normal file
0
graphsage/__init__.py
Normal file
58
graphsage/aggregators.py
Normal file
58
graphsage/aggregators.py
Normal file
@ -0,0 +1,58 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Variable
|
||||
|
||||
import random
|
||||
|
||||
"""
|
||||
Set of modules for aggregating embeddings of neighbors.
|
||||
These modules take as input embeddings of neighbors.
|
||||
"""
|
||||
|
||||
class MeanAggregator(nn.Module):
|
||||
"""
|
||||
Aggregates a node's embeddings using mean of neighbors' embeddings
|
||||
"""
|
||||
def __init__(self, features, cuda=False, gcn=False):
|
||||
"""
|
||||
Initializes the aggregator for a specific graph.
|
||||
|
||||
features -- function mapping (node_list, features, offset) to feature values
|
||||
see torch.nn.EmbeddingBag and forward function below docs for offset meaning.
|
||||
"""
|
||||
|
||||
super(MeanAggregator, self).__init__()
|
||||
|
||||
self.features = features
|
||||
self.cuda = cuda
|
||||
self.gcn = gcn
|
||||
|
||||
def forward(self, nodes, to_neighs, num_sample=10):
|
||||
# Local pointers to functions (speed hack)
|
||||
_set = set
|
||||
if not num_sample is None:
|
||||
_sample = random.sample
|
||||
samp_neighs = [_set(_sample(to_neigh,
|
||||
num_sample,
|
||||
)) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
|
||||
else:
|
||||
samp_neighs = to_neighs
|
||||
|
||||
if self.gcn:
|
||||
samp_neighs = [samp_neigh + set([nodes[i]]) for i, samp_neigh in enumerate(samp_neighs)]
|
||||
unique_nodes_list = list(set.union(*samp_neighs))
|
||||
unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)}
|
||||
mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
|
||||
column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
|
||||
row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
|
||||
mask[row_indices, column_indices] = 1
|
||||
if self.cuda:
|
||||
mask = mask.cuda()
|
||||
num_neigh = mask.sum(1, keepdim=True)
|
||||
mask = mask.div(num_neigh)
|
||||
if self.cuda:
|
||||
embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
|
||||
else:
|
||||
embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
|
||||
to_feats = mask.mm(embed_matrix)
|
||||
return to_feats
|
51
graphsage/encoders.py
Normal file
51
graphsage/encoders.py
Normal file
@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
import torch.nn.functional as F
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""
|
||||
Encodes a node's using 'convolutional' GraphSage approach
|
||||
"""
|
||||
def __init__(self, features, feature_dim,
|
||||
embed_dim, adj_lists, aggregator,
|
||||
num_sample=10,
|
||||
base_model=None, gcn=False, cuda=False,
|
||||
feature_transform=False):
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
self.features = features
|
||||
self.feat_dim = feature_dim
|
||||
self.adj_lists = adj_lists
|
||||
self.aggregator = aggregator
|
||||
self.num_sample = num_sample
|
||||
if base_model != None:
|
||||
self.base_model = base_model
|
||||
|
||||
self.gcn = gcn
|
||||
self.embed_dim = embed_dim
|
||||
self.cuda = cuda
|
||||
self.aggregator.cuda = cuda
|
||||
self.weight = nn.Parameter(
|
||||
torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim))
|
||||
init.xavier_uniform(self.weight)
|
||||
|
||||
def forward(self, nodes):
|
||||
"""
|
||||
Generates embeddings for a batch of nodes.
|
||||
|
||||
nodes -- list of nodes
|
||||
mode -- string desiginating the mode of the nodes
|
||||
"""
|
||||
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[node] for node in nodes],
|
||||
self.num_sample)
|
||||
if not self.gcn:
|
||||
if self.cuda:
|
||||
self_feats = self.features(torch.LongTensor(nodes).cuda())
|
||||
else:
|
||||
self_feats = self.features(torch.LongTensor(nodes))
|
||||
combined = torch.cat([self_feats, neigh_feats], dim=1)
|
||||
else:
|
||||
combined = neigh_feats
|
||||
combined = F.relu(self.weight.mm(combined.t()))
|
||||
return combined
|
105
graphsage/model.py
Normal file
105
graphsage/model.py
Normal file
@ -0,0 +1,105 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import init
|
||||
from torch.autograd import Variable
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
import random
|
||||
from sklearn.metrics import f1_score
|
||||
from collections import defaultdict
|
||||
|
||||
from graphsage.encoders import Encoder
|
||||
from graphsage.aggregators import MeanAggregator
|
||||
|
||||
class SupervisedGraphSage(nn.Module):
|
||||
|
||||
def __init__(self, num_classes, enc):
|
||||
super(SupervisedGraphSage, self).__init__()
|
||||
self.enc = enc
|
||||
self.xent = nn.CrossEntropyLoss()
|
||||
|
||||
self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
|
||||
init.xavier_uniform(self.weight)
|
||||
|
||||
def forward(self, nodes):
|
||||
embeds = self.enc(nodes)
|
||||
scores = self.weight.mm(embeds)
|
||||
return scores.t()
|
||||
|
||||
def loss(self, nodes, labels):
|
||||
scores = self.forward(nodes)
|
||||
return self.xent(scores, labels.squeeze())
|
||||
|
||||
def load_pubmed():
|
||||
#hardcoded for simplicity...
|
||||
num_nodes = 19717
|
||||
num_feats = 500
|
||||
feat_data = np.zeros((num_nodes, num_feats))
|
||||
labels = np.empty((num_nodes, 1), dtype=np.int64)
|
||||
node_map = {}
|
||||
with open("pubmed-data/Pubmed-Diabetes.NODE.paper.tab") as fp:
|
||||
fp.readline()
|
||||
feat_map = {entry.split(":")[1]:i-1 for i,entry in enumerate(fp.readline().split("\t"))}
|
||||
for i, line in enumerate(fp):
|
||||
info = line.split("\t")
|
||||
node_map[info[0]] = i
|
||||
labels[i] = int(info[1].split("=")[1])-1
|
||||
for word_info in info[2:-1]:
|
||||
word_info = word_info.split("=")
|
||||
feat_data[i][feat_map[word_info[0]]] = float(word_info[1])
|
||||
adj_lists = defaultdict(set)
|
||||
with open("pubmed-data/Pubmed-Diabetes.DIRECTED.cites.tab") as fp:
|
||||
fp.readline()
|
||||
fp.readline()
|
||||
for line in fp:
|
||||
info = line.strip().split("\t")
|
||||
paper1 = node_map[info[1].split(":")[1]]
|
||||
paper2 = node_map[info[-1].split(":")[1]]
|
||||
adj_lists[paper1].add(paper2)
|
||||
adj_lists[paper2].add(paper1)
|
||||
return feat_data, labels, adj_lists
|
||||
|
||||
if __name__ == "__main__":
|
||||
np.random.seed(1)
|
||||
random.seed(1)
|
||||
num_nodes = 19717
|
||||
num_feats = 500
|
||||
feat_data, labels, adj_lists = load_pubmed()
|
||||
features = nn.Embedding(19717, 500)
|
||||
features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=False)
|
||||
# features.cuda()
|
||||
|
||||
agg1 = MeanAggregator(features, cuda=True)
|
||||
enc1 = Encoder(features, 500, 128, adj_lists, agg1, gcn=True, cuda=False)
|
||||
agg2 = MeanAggregator(lambda nodes : enc1(nodes).t(), cuda=False)
|
||||
enc2 = Encoder(lambda nodes : enc1(nodes).t(), enc1.embed_dim, 128, adj_lists, agg2,
|
||||
base_model=enc1, gcn=True, cuda=False)
|
||||
enc1.num_samples = 10
|
||||
enc2.num_samples = 25
|
||||
|
||||
graphsage = SupervisedGraphSage(3, enc2)
|
||||
# graphsage.cuda()
|
||||
rand_indices = np.random.permutation(num_nodes)
|
||||
test = rand_indices[:1000]
|
||||
val = rand_indices[1000:1500]
|
||||
train = list(rand_indices[1500:])
|
||||
|
||||
optimizer = torch.optim.SGD(filter(lambda p : p.requires_grad, graphsage.parameters()), lr=0.7)
|
||||
times = []
|
||||
for batch in range(200):
|
||||
batch_nodes = train[:1024]
|
||||
random.shuffle(train)
|
||||
start_time = time.time()
|
||||
optimizer.zero_grad()
|
||||
loss = graphsage.loss(batch_nodes,
|
||||
Variable(torch.LongTensor(labels[np.array(batch_nodes)])))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
end_time = time.time()
|
||||
times.append(end_time-start_time)
|
||||
print batch, loss.data[0]
|
||||
|
||||
val_output = graphsage.forward(val)
|
||||
print "Validation F1:", f1_score(labels[val], val_output.data.numpy().argmax(axis=1), average="micro")
|
||||
print "Average batch time:", np.mean(times)
|
44340
pubmed-data/Pubmed-Diabetes.DIRECTED.cites.tab
Normal file
44340
pubmed-data/Pubmed-Diabetes.DIRECTED.cites.tab
Normal file
File diff suppressed because it is too large
Load Diff
3
pubmed-data/Pubmed-Diabetes.GRAPH.pubmed.tab
Normal file
3
pubmed-data/Pubmed-Diabetes.GRAPH.pubmed.tab
Normal file
@ -0,0 +1,3 @@
|
||||
GRAPH pubmed
|
||||
NO_FEATURES
|
||||
Pubmed-Diabetes
|
19719
pubmed-data/Pubmed-Diabetes.NODE.paper.tab
Normal file
19719
pubmed-data/Pubmed-Diabetes.NODE.paper.tab
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user