Comment tracer (visualiser) un réseau de neurones artificiels en python avec GraphViz ?

Pour illustrer un projet de recherche utilisant un réseau de neurones, j'avais besoin d'un outil de visualisation simple. Vous trouverez ici quelques résultats basés sur la bibliothèque Graphviz:

Installer Graphviz pour python

Pour installer par exemple Graphviz avec anaconda, il faut entrer les deux commandes suivantes:

conda install -c anaconda graphviz

et

conda install -c conda-forge python-graphviz

Tracer un simple graphe avec graphviz

On peut maintenant tracer un simple graphe avec graphviz( voir User Guide)

>>> from graphviz import Digraph
>>> dot = Digraph(comment='A simple Graph')
>>> dot.node('A', 'Cloudy')
>>> dot.node('B', 'Sunny')
>>> dot.node('C', 'Rainy')
>>> dot.edges(['AB', 'AC'])
>>> dot.edge('B', 'C', constraint='false')
>>> dot.format = 'png'
>>> dot.render('my_graph', view=False) 
'my_graph.png'

donne ici

Example of a simple graph with graphviz
Example of a simple graph with graphviz

Note: dot.source donne l'ensemble des balises pour générer le graphe:

>>> print(dot.source) 
// A simple Graph
digraph {
    A [label=Cloudy]
    B [label=Sunny]
    C [label=Rainy]
    A -> B
    A -> C
    B -> C [constraint=false]
}

Tracer un réseau de neurones artificiels avec graphviz

Pour tracer un réseau de neurones artificiels avec graphviz on peut par exemple utiliser le template proposé par Zeyuan Hu (voir) pour tracer un simple réseau de neurones artificiels:

>>> graph = temp = '''
... digraph G {
... 
...      graph[ fontname = "Helvetica-Oblique",
...             fontsize = 12,
...             label = "",
...             size = "7.75,10.25" ];
... 
...     rankdir = LR;
...     splines=false;
...     edge[style=invis];
...     ranksep= 1.4;
...     {
...     node [shape=circle, color=chartreuse, style=filled, fillcolor=chartreuse];
...     x1 [label=<x1>];
...     x2 [label=<x2>]; 
... }
... {
...     node [shape=circle, color=dodgerblue, style=filled, fillcolor=dodgerblue];
...     a12 [label=<a<sub>1</sub><sup>(2)</sup>>];
...     a22 [label=<a<sub>2</sub><sup>(2)</sup>>];
...     a32 [label=<a<sub>3</sub><sup>(2)</sup>>];
...     a42 [label=<a<sub>4</sub><sup>(2)</sup>>];
...     a52 [label=<a<sub>5</sub><sup>(2)</sup>>];
...     a13 [label=<a<sub>1</sub><sup>(3)</sup>>];
...     a23 [label=<a<sub>2</sub><sup>(3)</sup>>];
...     a33 [label=<a<sub>3</sub><sup>(3)</sup>>];
...     a43 [label=<a<sub>4</sub><sup>(3)</sup>>];
...     a53 [label=<a<sub>5</sub><sup>(3)</sup>>];
... }
... {
...     node [shape=circle, color=coral1, style=filled, fillcolor=coral1];
...     O1 [label=<y1>];
...     O2 [label=<y2>]; 
...     O3 [label=<y3>]; 
... }
...     {
...         rank=same;
...         x1->x2;
...     }
...     {
...         rank=same;
...         a12->a22->a32->a42->a52;
...     }
...     {
...         rank=same;
...         a13->a23->a33->a43->a53;
...     }
...     {
...         rank=same;
...         O1->O2->O3;
...     }
...     l0 [shape=plaintext, label="layer 1 (input layer)"];
...     l0->x1;
...     {rank=same; l0;x1};
...     l1 [shape=plaintext, label="layer 2 (hidden layer)"];
...     l1->a12;
...     {rank=same; l1;a12};
...     l2 [shape=plaintext, label="layer 3 (hidden layer)"];
...     l2->a13;
...     {rank=same; l2;a13};
...     l3 [shape=plaintext, label="layer 4 (output layer)"];
...     l3->O1;
...     {rank=same; l3;O1};
...     edge[style=solid, tailport=e, headport=w];
...     {x1; x2} -> {a12;a22;a32;a42;a52};
...     {a12;a22;a32;a42;a52} -> {a13;a23;a33;a43;a53};
...     {a13;a23;a33;a43;a53} -> {O1,O2,O3};
... }'''
>>> from graphviz import Source
>>> dot = Source(graph)
>>> dot.format = 'png'
>>> dot.render('neural_network_01', view=False) 
'neural_network_01.png'

donne

Comment tracer (visualiser) un réseau de neurones artificiels en python avec GraphViz ?
Comment tracer (visualiser) un réseau de neurones artificiels en python avec GraphViz ?

Note 1: pour changer la taille de l'image ou la taille du texte, il faut éditer les lignes suivantes:

...      graph[ fontname = "Helvetica-Oblique",
...             fontsize = 12,
...             label = "",
...             size = "7.75,10.25" ];

Note 2: marche aussi dans un jupyter notebook:

Comment tracer (visualiser) un réseau de neurones artificiels en python avec GraphViz dans un jupyter notebook ?
Comment tracer (visualiser) un réseau de neurones artificiels en python avec GraphViz dans un jupyter notebook ?

Références

Image

of