dgl tricks

dgl tricks

社蕙 71 2023-04-12

dgl.num_nodes()

(method) def num_nodes(ntype: Any | None = None) -> Any Return the number of nodes in the graph.

Parameters ntype : str, optional

 The node type name. If given, it returns the number of nodes of the type. If not given > (default), it returns the total number of nodes of all types.

Returns

int

 The number of nodes.

Examples

# The following example uses PyTorch backend.
>>> import dgl
>>> import torch

# Create a graph with two node types -- 'user' and 'game'.
>>> g = dgl.heterograph({
...     ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])),
...     ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6]))
... })

# Query for the number of nodes.
>>> g.num_nodes('user')
5
>>> g.num_nodes('game')
7
>>> g.num_nodes()
12

这个解释困扰了我好一会:game节点的数量怎么会是7呢?特别是听信了gpt的谗言,我也以为是user0 follow user1之类的,后来认真找了一下发现不是这样的:

图片来自https://docs.dgl.ai/en/0.8.x/guide_cn/graph-heterogeneous.html,是一个dgl的中文文档 代码来自https://blog.csdn.net/weixin_41650348/article/details/117334408

异构图

它对应的代码是:

g = dgl.heterograph({
     ('user', 'follows', 'user'): (torch.tensor([0]), torch.tensor([1])),
     ('user', 'plays', 'game'): (torch.tensor([0, 0, 1, 1]), torch.tensor([0, 1, 1, 2]))
})

这样一来就清楚了,代码实际上是两个tensor相互对应的,比如user0 play game0、user0 play game1……那么就是说在一开始的代码里面,game的最大标号是6,user的最大标号是4,那么由于dgl为每种节点从零开始标号,user就是5个,game就是6个。