81 lines
4.9 KiB
Python
81 lines
4.9 KiB
Python
|
class HierarchicalGraphNeuralNetwork(nn.Module):
|
||
|
def __init__(self, external_vocab: Vocab):
|
||
|
super(HierarchicalGraphNeuralNetwork, self).__init__()
|
||
|
self.pool = 'global_max_pool'
|
||
|
# Hierarchical 1: Control Flow Graph (CFG) embedding and pooling
|
||
|
cfg_filter_list =[200, 200]
|
||
|
cfg_filter_list.insert(0, 11)
|
||
|
self.cfg_filter_length = len(cfg_filter_list)
|
||
|
cfg_graphsage_params = [dict(in_channels=cfg_filter_list[i], out_channels=cfg_filter_list[i + 1], bias=True) for
|
||
|
i in range(self.cfg_filter_length - 1)]
|
||
|
cfg_conv = dict(constructor=torch_geometric.nn.conv.SAGEConv, kwargs=cfg_graphsage_params)
|
||
|
cfg_constructor = cfg_conv['constructor']
|
||
|
for i in range(self.cfg_filter_length - 1):
|
||
|
setattr(self, 'CFG_gnn_{}'.format(i + 1), cfg_constructor(**cfg_conv['kwargs'][i]))
|
||
|
self.dropout = nn.Dropout(p=0.2)
|
||
|
# Hierarchical 2: Function Call Graph (FCG) embedding and pooling
|
||
|
self.external_embedding_layer = nn.Embedding(num_embeddings=external_vocab.max_vocab_size + 2,
|
||
|
embedding_dim=cfg_filter_list[-1],
|
||
|
padding_idx=external_vocab.pad_idx)
|
||
|
fcg_filter_list = [200, 200]
|
||
|
fcg_filter_list.insert(0, cfg_filter_list[-1])
|
||
|
self.fcg_filter_length = len(fcg_filter_list)
|
||
|
fcg_graphsage_params = [dict(in_channels=fcg_filter_list[i], out_channels=fcg_filter_list[i + 1], bias=True) for
|
||
|
i in range(self.fcg_filter_length - 1)]
|
||
|
fcg_conv = dict(constructor=torch_geometric.nn.conv.SAGEConv, kwargs=fcg_graphsage_params)
|
||
|
fcg_constructor = fcg_conv['constructor']
|
||
|
for i in range(self.fcg_filter_length - 1):
|
||
|
setattr(self, 'FCG_gnn_{}'.format(i + 1), fcg_constructor(**fcg_conv['kwargs'][i]))
|
||
|
# Last Projection Function: gradually project with more linear layers
|
||
|
self.pj1 = torch.nn.Linear(in_features=fcg_filter_list[-1], out_features=int(fcg_filter_list[-1] / 2))
|
||
|
self.pj2 = torch.nn.Linear(in_features=int(fcg_filter_list[-1] / 2), out_features=int(fcg_filter_list[-1] / 4))
|
||
|
self.pj3 = torch.nn.Linear(in_features=int(fcg_filter_list[-1] / 4), out_features=6)
|
||
|
self.last_activation = nn.Softmax(dim=1)
|
||
|
|
||
|
def forward(self, real_local_batch: Batch, real_bt_positions: list, bt_external_names: list,
|
||
|
bt_all_function_edges: list):
|
||
|
rtn_local_batch = self.forward_cfg_gnn(local_batch=real_local_batch)
|
||
|
x_cfg_pool = torch_geometric.nn.glob.global_max_pool(x=rtn_local_batch.x, batch=rtn_local_batch.batch)
|
||
|
fcg_list = []
|
||
|
fcg_internal_list = []
|
||
|
for idx_batch in range(len(real_bt_positions) - 1):
|
||
|
start_pos, end_pos = real_bt_positions[idx_batch: idx_batch + 2]
|
||
|
idx_x_cfg = x_cfg_pool[start_pos: end_pos]
|
||
|
fcg_internal_list.append(idx_x_cfg)
|
||
|
idx_x_external = self.external_embedding_layer(
|
||
|
torch.tensor([bt_external_names[idx_batch]], dtype=torch.long))
|
||
|
idx_x_external = idx_x_external.squeeze(dim=0)
|
||
|
idx_x_total = torch.cat([idx_x_cfg, idx_x_external], dim=0)
|
||
|
idx_function_edge = torch.tensor(bt_all_function_edges[idx_batch], dtype=torch.long)
|
||
|
idx_graph_data = Data(x=idx_x_total, edge_index=idx_function_edge)
|
||
|
idx_graph_data.validate()
|
||
|
fcg_list.append(idx_graph_data)
|
||
|
fcg_batch = Batch.from_data_list(fcg_list)
|
||
|
# Hierarchical 2: Function Call Graph (FCG) embedding and pooling
|
||
|
rtn_fcg_batch = self.forward_fcg_gnn(function_batch=fcg_batch) # [batch_size, max_node_size, dim]
|
||
|
x_fcg_pool = torch_geometric.nn.glob.global_max_pool(x=rtn_fcg_batch.x, batch=rtn_fcg_batch.batch)
|
||
|
batch_final = x_fcg_pool
|
||
|
# step last project to the number_of_classes (multiclass)
|
||
|
bt_final_embed = self.pj3(self.pj2(self.pj1(batch_final)))
|
||
|
bt_pred = self.last_activation(bt_final_embed)
|
||
|
return bt_pred
|
||
|
|
||
|
def forward_cfg_gnn(self, local_batch: Batch):
|
||
|
in_x, edge_index = local_batch.x, local_batch.edge_index
|
||
|
for i in range(self.cfg_filter_length - 1):
|
||
|
out_x = getattr(self, 'CFG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||
|
out_x = torch.nn.functional.relu(out_x, inplace=True)
|
||
|
out_x = self.dropout(out_x)
|
||
|
in_x = out_x
|
||
|
local_batch.x = in_x
|
||
|
return local_batch
|
||
|
|
||
|
def forward_fcg_gnn(self, function_batch: Batch):
|
||
|
in_x, edge_index = function_batch.x, function_batch.edge_index
|
||
|
for i in range(self.fcg_filter_length - 1):
|
||
|
out_x = getattr(self, 'FCG_gnn_{}'.format(i + 1))(x=in_x, edge_index=edge_index)
|
||
|
out_x = torch.nn.functional.relu(out_x, inplace=True)
|
||
|
out_x = self.dropout(out_x)
|
||
|
in_x = out_x
|
||
|
function_batch.x = in_x
|
||
|
return function_batch
|