180417pytorch网络可视化

找到一个可以方便实现pytorch网络可视化的资源,十分方便。github地址

使用方法如下:

  1. Install graphviz:

    1
    pip install graphviz
  2. torchviz工具

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    # dot.py
    from graphviz import Digraph
    import torch
    from torch.autograd import Variable
    from collections import namedtuple


    Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op'))


    def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph.

    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function

    Args:
    var: output Variable
    params: dict of (name, Variable) to add names to node that
    require grad (TODO: make optional)
    """
    if params is not None:
    assert all(isinstance(p, Variable) for p in params.values())
    param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
    shape='box',
    align='left',
    fontsize='12',
    ranksep='0.1',
    height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def size_to_str(size):
    return '(' + (', ').join(['%d' % v for v in size]) + ')'

    def add_nodes(var):
    if var not in seen:
    if torch.is_tensor(var):
    dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
    elif hasattr(var, 'variable'):
    u = var.variable
    name = param_map[id(u)] if params is not None else ''
    node_name = '%s\n %s' % (name, size_to_str(u.size()))
    dot.node(str(id(var)), node_name, fillcolor='lightblue')
    else:
    dot.node(str(id(var)), str(type(var).__name__))
    seen.add(var)
    if hasattr(var, 'next_functions'):
    for u in var.next_functions:
    if u[0] is not None:
    dot.edge(str(id(u[0])), str(id(var)))
    add_nodes(u[0])
    if hasattr(var, 'saved_tensors'):
    for t in var.saved_tensors:
    dot.edge(str(id(t)), str(id(var)))
    add_nodes(t)
    add_nodes(var.grad_fn)
    return dot


    # For traces

    def replace(name, scope):
    return '/'.join([scope[name], name])


    def parse(graph):
    scope = {}
    for n in graph.nodes():
    inputs = [i.uniqueName() for i in n.inputs()]
    for i in range(1, len(inputs)):
    scope[inputs[i]] = n.scopeName()

    uname = next(n.outputs()).uniqueName()
    assert n.scopeName() != '', '{} has empty scope name'.format(n)
    scope[uname] = n.scopeName()
    scope['0'] = 'input'

    nodes = []
    for n in graph.nodes():
    attrs = {k: n[k] for k in n.attributeNames()}
    attrs = str(attrs).replace("'", ' ')
    inputs = [replace(i.uniqueName(), scope) for i in n.inputs()]
    uname = next(n.outputs()).uniqueName()
    nodes.append(Node(**{'name': replace(uname, scope),
    'op': n.kind(),
    'inputs': inputs,
    'attr': attrs}))

    for n in graph.inputs():
    uname = n.uniqueName()
    if uname not in scope.keys():
    scope[uname] = 'unused'
    nodes.append(Node(**{'name': replace(uname, scope),
    'op': 'Parameter',
    'inputs': [],
    'attr': str(n.type())}))

    return nodes


    def make_dot_from_trace(trace):
    """ Produces graphs of torch.jit.trace outputs

    Example:
    >>> trace, = torch.jit.trace(model, args=(x,))
    >>> dot = make_dot_from_trace(trace)
    """
    torch.onnx._optimize_trace(trace, False)
    graph = trace.graph()
    list_of_nodes = parse(graph)

    node_attr = dict(style='filled',
    shape='box',
    align='left',
    fontsize='12',
    ranksep='0.1',
    height='0.2')

    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

    for node in list_of_nodes:
    dot.node(node.name, label=node.name.replace('/', '\n'))
    if node.inputs:
    for inp in node.inputs:
    dot.edge(inp, node.name)
    return dot

  3. demo

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import torch
    from torch.autograd import Variable
    import torchvision
    import torch.nn as nn

    from dot import make_dot, make_dot_from_trace

    '''
    model = nn.Sequential()
    model.add_module('W0', nn.Linear(8, 16))
    model.add_module('tanh', nn.Tanh())
    model.add_module('W1', nn.Linear(16, 1))

    x = Variable(torch.randn(1,8))
    y = model(x)

    dot = make_dot(y.mean(), params=dict(model.named_parameters()))
    dot.render("graph.gv", view=True)
    '''

    model = torchvision.models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 100)
    x = Variable(torch.randn(10,3,224,224))
    y = model(x)
    dot = make_dot(y.mean(), params=dict(model.named_parameters()))
    dot.render("resnet18.gv", view=True)

    插图: