import torch
[docs]class ArgClass(object):
def __init__(self):
pass
[docs]def build_args_from_dict(dic):
args = ArgClass()
for key, value in dic.items():
args.__setattr__(key, value)
return args
[docs]def add_remaining_self_loops(edge_index, edge_weight, fill_value, num_nodes):
N = num_nodes
row, col = edge_index[0], edge_index[1]
mask = row != col
loop_index = torch.arange(0, N, dtype=edge_index.dtype, device=edge_index.device)
loop_index = loop_index.unsqueeze(0).repeat(2, 1)
edge_index = torch.cat([edge_index[:, mask], loop_index], dim=1)
inv_mask = ~mask
loop_weight = torch.full((N, ), fill_value, dtype=edge_weight.dtype,
device=edge_weight.device)
remaining_edge_weight = edge_weight[inv_mask]
if remaining_edge_weight.numel() > 0:
loop_weight[row[inv_mask]] = remaining_edge_weight
edge_weight = torch.cat([edge_weight[mask], loop_weight], dim=0)
return edge_index, edge_weight
if __name__ == "__main__":
[docs] args = build_args_from_dict({'a': 1, 'b': 2})
print(args.a, args.b)