import networkx as nx
import EoN
import matplotlib.pyplot as plt
import random
import numpy as np
from matplotlib.animation import FuncAnimation
from collections import defaultdict
class Simulation_Investigation():
r'''Simulation_Display is a class which is used for creating a particular
type of plot or an animation.
The plot shows an image of the network at a snapshot in time. In addition
to the right of these plots it can show various timeseries from the simulation
or from some other calculation.
A longer term goal is to have the *_from_graph methods be directly callable and
read in the IC and then get the appropriate time series.
'''
#I want to improve how the labels, colors, linetypes etc are passed through
#here.
#within the Simulation_Investigation class there sits a _time_series_ class
class _time_series_():
def __init__(self, ts_data, color_dict, label=None, tex = True,
**kwargs):
r'''
:Arguments:
**ts_data** a pair (t, D)
where ``t`` is a numpy array of times and ``D`` is a dict such that
``D[status]`` is a numpy array giving the number of individuals
of given status at corresponding time in ``t``.
**color_dict** a dict
``color_dict[status]`` is the color to be used for status in plots
**label** a string
The label to be used for these plots
**tex** A boolean
tells whether the status should be rendered as tex math mode
or not in the labels.
**kwargs** key word arguments to be passed along to the plotting command
'''
self._t_ = ts_data[0]
self._D_ = ts_data[1]
self._tex_ = tex
self.color_dict = color_dict
self.label=label
self.plt_kwargs = kwargs
def _plot_(self, ax, statuses_to_plot):
if self.label:
if self._tex_:
for status in statuses_to_plot:
if status in self._D_:
ax.plot(self._t_, self._D_[status], color = self.color_dict[status], label=self.label+': ${}$'.format(status), **self.plt_kwargs)
else:
for status in statuses_to_plot:
if status in self._D_:
ax.plot(self._t_, self._D_[status], color = self.color_dict[status], label=self.label+': {}'.format(status), **self.plt_kwargs)
else:
for status in statuses_to_plot:
if status in self._D_:
ax.plot(self._t_, self._D_[status], color = self.color_dict[status], **self.plt_kwargs)
def update_kwargs(self, **kwargs):
self.plt_kwargs.update(kwargs)
def __init__(self, G, node_history, transmissions = None,
possible_statuses = None, pos = None, color_dict=None,
tex = True):
r'''
:Arguments:
**G** The graph
**node_history** (dict)
``node_history[node]`` is a tuple (times, statuses) where
- ``times`` is a list of the times at which ``node`` changes status.
the first entry is the initial time.
- ``statuses`` is a list giving the status the new status of the node
at each corresponding time.
**transmissions** (list)
Each event which is induced by a neighbor appears (in order) in
``transmissions``. It appears as a triple ``(time, source, target)``
where
- ``time`` is the time of the event
- ``source`` is the neighbor inducing the transition.
- ``target`` is the node undergoing the transition.
**possible_statuses** list (default None)
a list of the statuses to be considered.
If not given, then defaults to the values in node_history.
**pos** (dict - default None)
The ``pos`` argument to be given to the networkx plotting commands.
**color_dict** (dict - default None)
A dictionary stating for each status what color is to be used when
plotting.
If not given and your statuses are ``'S'``, and ``'I'`` or they are
``'S'``, ``'I'``, and ``'R'``, it will attempt to use a greenish color for
``'S'`` and a reddish color for ``'I'`` and gray for ``'R'``. These
should be color-blind friendly, despite appearing green/red to
me.
Otherwise if not given, it will cycle through a set of 7 colors
which I believe are color-blind friendly. If you have more than
7 statuses, you probably want to set your own color_dict.
**tex** Boolean (default `True`)
If 'True`, then labels for statuses will be in tex's math mode
If ``False``, just plain text.
'''
if possible_statuses is None:
ps = set()
for node in node_history:
ps = ps.union(set(node_history[node]))
possible_statuses = list(ps)
if color_dict is None:
if set(possible_statuses) == set(['S', 'I', 'R']):
color_dict = {'S':'#009a80','I':'#ff2000', 'R':'gray'}
elif set(possible_statuses) == set(['S', 'I']):
color_dict = {'S':'#009a80','I':'#ff2000'}
else:
colors = ['#FF2000', '#009A80', '#5AB3E6', '#E69A00', '#CD9AB3', '#0073B3','#F0E442']
color_dict = {status:colors[index%len(colors)] for index, status in enumerate(possible_statuses)}
self.G = G
self._node_history_ = node_history
self._transmissions_ = transmissions
self._tex_ = tex
self._possible_statuses_ = possible_statuses
self.sim_color_dict = color_dict
self.pos = pos #don't go through the effort to define this until a plot
#is made
self.summary() #defines self._t_, self._D_
self._time_series_list_ = []
self._simulation_time_series_ = self._time_series_(self._summary_,
color_dict=self.sim_color_dict,
label = 'Simulation', tex = tex
)
self._time_series_list_.append(self._simulation_time_series_)
[docs] def node_history(self, node):
r'''
returns the history of a node.
:Arguments:
**node**
the node
:Returns:
**timelist, statuslist** lists
the times at which the node changes status and what the new status is at each time.
'''
return self._node_history_[node]
[docs] def node_status(self, node, time):
r'''
returns the status of a given node at a given time.
:Arguments:
**node**
the node
**time** float
the time of interest.
:Returns:
**status** string (such as `'S'`, `'I'`, or `'R'`)
status of node at time.
'''
changetimes = self._node_history_[node][0]
number_swaps = len([changetime for changetime in changetimes if changetime<= time])
status = self._node_history_[node][1][number_swaps-1]
return status
[docs] def get_statuses(self, nodelist=None, time=None):
r'''
returns the status of nodes at a given time.
:Arguments:
**nodelist** iterable (default None):
Some sort of iterable of nodes.
If default value, then returns statuses of all nodes.
**time** float (default None)
the time of interest.
if default value, then returns initial time
:Returns:
**status** dict
A dict whose keys are the nodes in nodelist giving their status at time.
'''
if nodelist is None:
nodelist = self.G
if time is None:
time = self._t_[0]
status = {}
for node in nodelist:
changetimes = self._node_history_[node][0]
number_swaps = len([changetime for changetime in changetimes if changetime<= time])
status[node] = self._node_history_[node][1][number_swaps-1]
return status
[docs] def summary(self, nodelist = None):
r'''
Provides the population-scale summary of the dynamics. It returns
a numpy array t as well as numpy arrays for each of the ``possible_statuses``
giving how many nodes had that status at the corresponding time.
Assumes that all entries in node_history start with same tmin
:Arguments:
**nodelist** (default None)
The nodes that we want to focus on. By default this is all nodes.
If you want all nodes, the most efficient thing to do is to
not include ``'nodelist'``. Otherwise it will recalculate everything.
:Returns:
**summary** tuple
a pair (t, D) where
- t is a numpy array of times and
- D is a dict whose keys are the possible statuses and whose values
are numpy arrays giving the count of each status at the specific
times.
If nodelist is empty, this is for the entire graph. Otherwise
it is just for the node in nodelist.
'''
if nodelist is None: #calculate everything.
nodelist =self.G
if nodelist is self.G:
try:
self._summary_ #after first time through, don't recalculate.
return self._summary_
except AttributeError: #hey, it's the first time through, let's calculate
pass
times = set()
delta = {status:defaultdict(int) for status in self._possible_statuses_}
for node in nodelist:
node_times = self._node_history_[node][0]
node_statuses = self._node_history_[node][1]
tmin = node_times[0] #should be the same for each node, but hard to choose a single node at start.
times.add(tmin)
delta[node_statuses[0]][tmin]+=1
for new_status, old_status, time in zip(node_statuses[1:], node_statuses[:-1], node_times[1:]):
delta[new_status][time] = delta[new_status][time]+1
delta[old_status][time] = delta[old_status][time]-1
times.add(time)
t = np.array(sorted(list(times)))
tmin = t[0]
mysummary = (t, {status:[delta[status][tmin]] for status in self._possible_statuses_})
for time in t[1:]:
for status in self._possible_statuses_:
mysummary[1][status].append(mysummary[1][status][-1]+delta[status][time])
for status in self._possible_statuses_:
mysummary[1][status] = np.array(mysummary[1][status])
if nodelist == self.G: #
self._summary_ = mysummary
self._t_ = t
self._D_ = mysummary[1]
return mysummary
[docs] def t(self):
r''' Returns the times of events
Generally better to get these all through summary()'''
return self._summary_[0]
[docs] def S(self):
r'''
If ``'S'`` is a state, then this will return the number susceptible at each time.
Else it raises an error
Generally better to get these all through ``summary()`` '''
if 'S' in self._possible_statuses_:
return self._summary_[1]['S']
else:
raise EoN.EoNError("'S' is not a possible status")
[docs] def I(self):
r'''
See notes for S
Returns the number infected at each time
Generally better to get these all through summary()'''
if 'I' in self._possible_statuses_:
return self._summary_[1]['I']
else:
raise EoN.EoNError("'I' is not a possible status")
[docs] def R(self):
r'''
See notes for S
Returns the number recovered at each time
Generally better to get these all through summary()'''
if 'R' in self._possible_statuses_:
return self._summary_[1]['R']
else:
raise EoN.EoNError("'R' is not a possible status")
[docs] def transmissions(self):
r'''Returns a list of tuples (t,u,v) stating that node u infected node
v at time t. In the standard code, if v was already infected at tmin, then
the source is None
Note - this only includes successful transmissions. So if u tries
to infect v, but fails because v is already infected this is not
recorded.'''
if self._transmissions_ is None:
raise EoN.EoNError("transmissions were not provided when created")
return self._transmissions_
[docs] def transmission_tree(self):
r'''
Produces a MultiDigraph whose edges correspond to transmission events.
If SIR, then this is a tree (or a forest).
:Returns:
**T** a directed Multi graph
T has all the information in ``transmissions``.
An edge from u to v with time t means u transmitted to v at time t.
:Warning:
Although we refer to this as a "tree", if the disease is SIS, there
are likely to be cycles and/or repeated edges. If the disease is SIR
but there are multiple initial infections, then this will be a "forest".
If it's an SIR, then this is a tree (or forest).
The graph contains only those nodes that are infected at some point.
'''
if self._transmissions_ is None:
raise EoN.EoNError("transmissions were not provided when created")
T = nx.MultiDiGraph()
for t, u, v in self._transmissions_:
if u is not None:
T.add_edge(u, v, time=t)
return T
[docs] def add_timeseries(self, ts_data, color_dict = None, label = None, tex = None,
**kwargs):
r'''
This allows us to include some additional timeseries for comparision
with the simulation. So for example, if we perform a simulation and
want to plot the simulation but also a prediction, this is what we
would use.
:Arguments:
**ts_data** a pair (t, D)
where t is a numpy array of times
and D is a dict
where D[status] is the number of individuals of given status at
corresponding time.
**color_dict** dict (default None)
a dictionary mapping statuses to the color
desired for their plots. Defaults to the same as the simulation
**label** (string)
The label to be used for these plots in the legend.
**tex** (boolean)
Tells whether status should be rendered in tex's math mode in
labels. Defaults to whatever was done for creation of this
simulation_investigation object.
****kwargs**
any matplotlib key word args to affect how the curve is shown.
:Returns:
**ts** timeseries object
:Modifies:
This adds the timeseries object ``ts`` to the internal ``_time_series_list_``
'''
if color_dict is None:
color_dict = self.color_dict
ts = self._time_series_(ts_data, color_dict = color_dict, label=label,
tex=self._tex_, **kwargs)
self._time_series_list_.append(ts)
return ts
[docs] def update_ts_kwargs(self, ts, **kwargs):
r'''Allows us to change some of the matplotlib key word arguments
for a timeseries object
:Arguments:
**ts** (timeseries object)
the timeseries object whose key word args we are updating.
****kwargs**
the new matplotlib key word arguments
'''
ts.update_kwargs(**kwargs)
[docs] def update_ts_tex(self, ts, tex):
r'''updates the tex flag for time series plots
:Arguments:
**ts** (timeseries object)
the timeseries object whose key word args we are updating.
**tex**
the new value for ``tex``
'''
ts._tex_=tex
[docs] def update_ts_label(self, ts, label):
r'''updates the label for time series plots
:Arguments:
**ts** timeseries object
the timeseries object whose key word args we are updating.
**label** string
the new label
'''
ts.label=label
[docs] def update_ts_color_dict(self, ts, color_dict):
r'''
updates the color_dict for time series plots
:Arguments:
**ts** timeseries object
the timeseries object whose key word args we are updating.
**color_dict** dict
the new color_dict
'''
for status in ts._D_:
if status not in color_dict:
raise EoN.EoNError("Status {} is not in color_dict".format(status))
ts.color_dict=color_dict
[docs] def sim_update_kwargs(self, **kwargs):
r'''Allows us to change some of the matplotlib key word arguments
for the simulation. This is identical to update_ts_kwargs except
we don't need to tell it which time series to use.
:Arguments:
****kwargs**
the new matplotlib key word arguments
'''
self._simulation_time_series_.update_kwargs(**kwargs)
[docs] def sim_update_tex(self, tex):
r'''updates the tex flag for the simulation in the time series plots
and in the network plots
:Arguments:
**tex** string
the new value of ``tex``
'''
self._tex_=tex
self._simulation_time_series_._tex_=tex
[docs] def sim_update_label(self, label):
r'''updates the label for the simulation in the time series plots
:Arguments:
**label** string
the new ``label``
'''
self.label=label
[docs] def sim_update_color_dict(self, color_dict):
r'''
updates the color_dict for the simulation
:Arguments:
**color_dict** dict
the new color_dict
'''
for status in self._possible_statuses_:
if status not in color_dict:
raise EoN.EoNError("Status {} is not in color_dict".format(status))
self.sim_color_dict = color_dict
self._simulation_time_series_.color_dict=color_dict
# def add_timeseries_from_analytic_model(self, tau, gamma, model = EoN.EBCM_from_graph, tmin = 0, tmax = 10, tcount = 1001, SIR = True, color_dict={'S':'#009a80','I':'#ff2000', 'R':'gray'}, label = None):
# r''' Uses one of the analytic models to predict the curve.
# The analytic model needs to be one of the *_from_graph models.
# (currently the pref mixing EBCM models do not work with this either)
# only works for cts time models.
#
# Arguments:
# tau (float)
# transmission rate
# gamma (float)
# recovery rate
# model (function)
# A function like the *_from_graph models
#
#
# '''
# node = self.G.nodes()[0]
# tmin = self._node_history_[node][0][0]
# initial_status = self.get_statuses(tmin)
#
# initial_infecteds = [node for node in self.G.nodes() if initial_status[node]=='I']
#
# if SIR:
# initial_recovereds = [node for node in self.G.nodes() if initial_status[node]=='R']
#
# t, S, I, R = model(self.G, tau, gamma, initial_infecteds=initial_infecteds,
# initial_recovereds = initial_recovereds, tmin = tmin, tmax=tmax,
# tcount=tcount, return_full_data=False)
#
# else:
# t, S, I = model(self.G, tau, gamma, initial_infecteds=initial_infecteds,
# tmin = tmin, tmax=tmax,
# tcount=tcount, return_full_data=False)
# R=None
#
# self.add_timeseries(t, S=S, I=I, R=R, color_dict=color_dict, label=label)
# def calculate_approximate_time_series(self, function, rho = None, **params):
# r'''calls function to estimate time series. If using one of the
# networkx functions, it will use one of the X_from_graph functions'''
# print("this function isn't up yet")
# if rho is None:
# if self.SIR:
# initialS, initialI, initialR = get_IC(self.G, node_history, SIR=self.SIR)
# else:
# initialS, initialI = get_IC(self.G, node_history, SIR=self.SIR)
#
[docs] def set_pos(self, pos):
r'''Set the position of the nodes.
:Arguments:
**pos** (dict)
as in ``nx.draw_networkx``
'''
self.pos = pos
def _display_graph_(self, pos, nodestatus, nodelist, status_order, statuses_to_plot, ax, **nx_kwargs):
'''
:Arguments:
**pos** (dict)
position as for networkx
**nodestatus** (dict)
status of all nodes at given time
**nodelist** (list)
a list of the nodes to plot. This partially determines which
nodes appear on top
**status_order** list of statuses
Each status will appear on top of all later statuses. If list
empty or ``False``, will ignore.
Any statuses not appearing in list will simply be below those on the
list and will not have priority by status.
**statuses_to_plot** list of statuses to plot.
If given, then the other nodes will be left invisible when plotting
but I think this requires networkx v2.3 or later.
**ax** axis
**nx_kwargs**
'''
if nodelist:
nodeset = set(nodelist) #containment test in next line is quicker with set
edgelist = [edge for edge in self.G.edges() if edge[0] in nodeset and edge[1] in nodeset]
else:
nodelist = list(self.G.nodes())
random.shuffle(nodelist)#assume no order desired unless sent in
edgelist = list(self.G.edges())
if status_order: #redefine nodelist order so that particular status on top
nodes_by_status = [[node for node in nodelist if nodestatus[node] == status] for status in reversed(status_order)]
# I_nodes = [node for node in nodelist if nodestatus[node] == 'I']
other_nodes = [node for node in nodelist if nodestatus[node] not in status_order]
nodelist = other_nodes
for L in nodes_by_status:
nodelist.extend(L)
color_list = [self.sim_color_dict[nodestatus[node]] if nodestatus[node] in statuses_to_plot else "None" for node in nodelist]
nx.draw_networkx_edges(self.G, pos, edgelist=edgelist, ax=ax, **nx_kwargs)
drawn_nodes = nx.draw_networkx_nodes(self.G, pos, nodelist = nodelist, node_color=color_list, ax=ax, **nx_kwargs)
if "with_labels" in nx_kwargs and nx_kwargs['with_labels']==True:
nx.draw_networkx_labels(self.G, pos)
if "with_edge_labels" in nx_kwargs and nx_kwargs['with_edge_labels'] == True:
nx.draw_networkx_edge_labels(self.G, pos)
ax.set_xticks([])
ax.set_yticks([])
fakelines = []
for status in statuses_to_plot:
fakelines.append(plt.Line2D([0,0],[0,1], color=self.sim_color_dict[status], marker = 'o', linestyle = ''))
if self._tex_:
ax.legend(fakelines, ['${}$'.format(status) for status in statuses_to_plot])
else:
ax.legend(fakelines, statuses_to_plot)
return drawn_nodes
def _display_time_series_(self, fig, t, ts_plots, ts_list, timelabel):
'''
:ARGUMENTS:
**fig** a matplotlib figure
**t** float
the time for the snapshot of the network.
**ts_plots** (list of lists or list of strings)
lists such as ``[['S'], ['I'], ['R']]`` or ``[['S', 'I'], ['R']]``
equivalently ``['S', 'I', 'R']`` and ``['SI', 'R']`` will do the same
but is problematic if a status has a string longer than 1.
denotes what should appear in the timeseries plots. The
length of the list determines how many plots there are. If
entry i is ``['A', 'B']`` then plot i has both ``'A'`` and ``'B'`` plotted.
.
So ``[['S'], ['I'], ['R']]`` or ``['SIR']`` will result in
3 plots, one with just ``'S'``, one with just ``'I'`` and one with just ``'R'``
while ``[['S', 'I'], ['R']]`` or ``['SI', 'R']`` will result in
2 plots, one with both ``'S'`` and ``'I'`` and one with just ``'R'``.
**ts_list** (list of timeseries objects - default ``None``)
If multiple time series have been added, we might want to plot
only some of them. This says which ones to plot.
The simulation is always included.
**timelabel** (string, default ``'$t$'``)
the horizontal label to be used on the time series plots
'''
#the handling of the final element separately is ugly.
#should figure out how to put it all into a single loop.
if ts_list is None:
ts_list = self._time_series_list_
elif self._simulation_time_series_ not in ts_list:
ts_list.append(self._simulation_time_series_)
ts_axes = []
time_markers = []
ts_plot_count = len(ts_plots)
for cnt, ts_plot in enumerate(ts_plots[:-1]):
ax = fig.add_subplot(ts_plot_count, 2, 2*(cnt+1))
ax.set_xticks([])
for ts in reversed(ts_list):
ts._plot_(ax, ts_plot)
ax.legend()
if self._tex_:
ax.set_title(", ".join(['${}$'.format(status) for status in ts_plot]))
else:
ax.set_title(", ".join(ts_plot))
tm = ax.axvline(x=t, linestyle='--', color='k')
ts_axes.append(ax)
time_markers.append(tm)
ax = fig.add_subplot(ts_plot_count, 2, 2*ts_plot_count)
ax.set_xlabel(timelabel)
ts_plot = ts_plots[-1]
for ts in reversed(ts_list):
ts._plot_(ax, ts_plot)
ax.legend()
if self._tex_:
ax.set_title(", ".join(['${}$'.format(status) for status in ts_plot]))
else:
ax.set_title(", ".join(ts_plot))
tm = ax.axvline(x=t, linestyle='--', color='k')
ts_axes.append(ax)
time_markers.append(tm)
return ts_axes, time_markers
[docs] def display(self, time, ts_plots = None, ts_list = None, nodelist=None,
status_order=False, timelabel=r'$t$', pos=None, statuses_to_plot = None,
**nx_kwargs):
r'''
Provides a plot of the network at a specific time and (optionally)
some of the time series
By default it plots the network and all time series. The time series
are plotted in 3 (for SIR) or 2 (for SIS) different plots to the right
of the network. There are options to control how many plots appear
and which time series objects are plotted in it.
We can make the number of time series plots to the right be zero
by setting ts_plots to be an empty list.
:Arguments:
**time** float
the time for the snapshot of the network.
**ts_plots** (list of strings, defaults to ``statuses_to_plot``, which defaults
to ``self._possible_statuses_``)
if ``[]`` or ``False`` then the display only shows the network.
lists such as ``[['S'], ['I'], ['R']]`` or ``[['S', 'I'], ['R']]``
equivalently ``['S', 'I', 'R']`` and ``['SI', 'R']`` will do the same
but is problematic if a status has a string longer than 1.
denotes what should appear in the timeseries plots. The
length of the list determines how many plots there are. If
entry i is ``['A', 'B']`` then plot i has both ``'A'`` and ``'B'`` plotted.
.
So ``[['S'], ['I'], ['R']]`` or ``['SIR']`` will result in
3 plots, one with just ``'S'``, one with just ``'I'`` and one with just ``'R'``
while ``[['S', 'I'], ['R']]`` or ``['SI', 'R']`` will result in
2 plots, one with both ``'S'`` and ``'I'`` and one with just ``'R'``.
Defaults to the possible_statuses
**ts_list** (list of timeseries objects - default None)
If multiple time series have been added, we might want to plot
only some of them. This says which ones to plot.
The simulation is always included.
**nodelist** (list, default None)
which nodes should be included in the network plot. By default
this is the entire network.
This also determines which nodes are on top of each other
(particularly if ``status_order`` is ``False``).
**status_order** list of statuses default ``False``
Each status will appear on top of all later statuses. If list
empty or ``False``, will ignore.
Any statuses not appearing in list will simply be below those on the
list and will not have priority by status.
**timelabel** (string, default ``'$t$'``)
the horizontal label to be used on the time series plots
**pos**
overrides self.pos for this display (but does not overwrite
self.pos. Use set_pos if you want to do this)
**statuses_to_plot** list of statuses to plot.
If given, then the other nodes will be left invisible when plotting
but I think this requires networkx v2.3 or later.
****nx_kwargs**
any networkx keyword arguments to go into the network plot.
:Returns:
**network_ax, ts_ax_list** (axis, list of axises)
The axes for the network plot and a list of all the axes for the
timeseries plots
Notes :
If you only want to plot the graph, set ts_plots equal to [].
If you want S, I, and R on a single plot, set ts_plots equal to ['SIR']
If you only want some of the timeseries objects, set ts_list to be those
(the simulation time series will always be plotted).
Examples :
To show a plot where sim is the Simulation_Investigation object
simply do
::
sim.display()
plt.show()
To save it,
::
sim.display()
plt.savefig(filename).
If you want to do more detailed modifications of the plots, this
returns the axes:
::
network_ax, timeseries_axes = sim.display()
'''
if statuses_to_plot is None:
statuses_to_plot = self._possible_statuses_
if ts_plots is None:
ts_plots = [[x] for x in statuses_to_plot]
if ts_plots:
fig = plt.figure(figsize=(10,4))
graph_ax = fig.add_subplot(121)
else:
fig = plt.figure()
graph_ax = fig.add_subplot(111)
nodestatus = self.get_statuses(self.G, time)
if pos is None:
if self.pos is None:
pos = nx.spring_layout(self.G)
else:
pos = self.pos
self._display_graph_(pos, nodestatus, nodelist, status_order, statuses_to_plot, graph_ax, **nx_kwargs)
if ts_plots:
ts_ax_list, time_markers = self._display_time_series_(fig, time, ts_plots, ts_list, timelabel)
else:
ts_ax_list, time_markers = [], []
plt.tight_layout()
return graph_ax, ts_ax_list
def _draw_specific_status(self, pos, nodes, status, ax, **nx_kwargs):
drawn = nx.draw_networkx_nodes(self.G, pos, nodelist = nodes, node_color = self.sim_color_dict[status], **nx_kwargs)
return drawn
def _update_ani_(self, time, pos, nodelist, drawn_nodes, drawn_elevated, status_order, graph_ax, ts_axes, time_markers, nx_kwargs):
'''
'''
nodestatus = self.get_statuses(self.G, time)
drawn_nodes.set_color([self.sim_color_dict[nodestatus[node]] for node in nodelist])
for status in reversed(status_order):
nodes_with_status = [node for node in nodelist if nodestatus[node] == status]
drawn_elevated[status][0].remove()
drawn_elevated[status][0] = nx.draw_networkx_nodes(self.G, pos, nodelist=nodes_with_status, color = self.sim_color_dict[status], ax = graph_ax, **nx_kwargs)
for index, ax in enumerate(ts_axes):
time_markers[index].remove()
time_markers[index] = ax.axvline(x=time, linestyle='--', color='k')
return
[docs] def animate(self, frame_times=None, ts_plots = None,
ts_list = None, nodelist=None, status_order=False, timelabel=r'$t$',
pos = None, statuses_to_plot = None, **nx_kwargs):
r'''
As in display, but this produces an animation.
To display an animation where sim is the Simulation_Investigation object
simply do
::
sim.animate()
plt.show()
To save an animation [on a mac with appropriate additional libraries
installed], you can do
::
ani = sim.animate()
ani.save(filename, fps=5, extra_args=['-vcodec', 'libx264'])
here ``ani`` is a matplotlib animation.
See
https://matplotlib.org/api/_as_gen/matplotlib.animation.Animation.save.html
for more about the save command for matplotlib animations.
:Arguments:
The same as in display, except that time is replaced by frame_times
**frame_times** (list/numpy array)
The times for animation frames. If nothing is given, then it
uses 101 times between 0 and t[-1]
**ts_plots** (list of strings, defaults to ``statuses_to_plot``, which defaults
to ``self._possible_statuses_``)
if ``[]`` or ``False`` then the display only shows the network.
lists such as ``[['S'], ['I'], ['R']]`` or ``[['S', 'I'], ['R']]``
equivalently ``['S', 'I', 'R']`` and ``['SI', 'R']`` will do the same
but is problematic if a status has a string longer than 1.
denotes what should appear in the timeseries plots. The
length of the list determines how many plots there are. If
entry i is ``['A', 'B']`` then plot i has both ``'A'`` and ``'B'`` plotted.
.
So ``[['S'], ['I'], ['R']]`` or ``['SIR']`` will result in
3 plots, one with just ``'S'``, one with just ``'I'`` and one with just ``'R'``
while ``[['S', 'I'], ['R']]`` or ``['SI', 'R']`` will result in
2 plots, one with both ``'S'`` and ``'I'`` and one with just ``'R'``.
Defaults to the possible_statuses
**ts_list** list of timeseries objects (default None)
If multiple time series have been added, we might want to plot
only some of them. This says which ones to plot.
The simulation is always included.
**nodelist** list (default None)
which nodes should be included in the network plot. By default
this is the entire network.
This also determines which nodes are on top of each other
(particularly if status_order is ``False``).
**status_order** list of statuses default ``False``
Each status will appear on top of all later statuses. If list
empty or ``False``, will ignore.
Any statuses not appearing in list will simply be below those on the
list and will not have priority by status.
**timelabel** string (default '$t$')
the horizontal label to be used on the time series plots
**pos** dict (default None)
overrides self.pos for this display (but does not overwrite
self.pos. Use set_pos if you want to do this)
**statuses_to_plot** list of statuses to plot.
If given, then the other nodes will be left invisible when plotting
but I think this requires networkx v2.3 or later.
****nx_kwargs**
any networkx keyword arguments to go into the network plot.
'''
# if not self.SIR and ts_plots:
# ts_plots = [x for x in ts_plots if x != 'R']
if frame_times is None:
frame_times = np.linspace(0,self._t_[-1], 101)
if statuses_to_plot is None:
statuses_to_plot = self._possible_statuses_
if ts_plots is None:
ts_plots = statuses_to_plot
if ts_plots:
fig = plt.figure(figsize=(10,4))
graph_ax = fig.add_subplot(121)
else:
fig = plt.figure()
graph_ax = fig.add_subplot(111)
initial_status = self.get_statuses(self.G, frame_times[0])
if pos is None:
if self.pos is None:
pos = nx.spring_layout(self.G)
else:
pos = self.pos
if nodelist is None:
nodelist = list(self.G.nodes())
random.shuffle(nodelist)
if status_order is False:
status_order = []
#First we draw all of the nodes with their original status, and without
#putting particular status on top. All nodes are in place, and their color
#can be updated at a later time.
#
#Then we select the nodes whose status puts them on top initially
#
#For each status that goes on top, we draw it in a way that we'll be
#able to redraw that status at a later time.
drawn_nodes = self._display_graph_(pos, initial_status, nodelist, False, statuses_to_plot, graph_ax, **nx_kwargs)
elevated = {status: [node for node in self.G if initial_status[node] == status] for status in status_order}
drawn_elevated = {}
for status in reversed(status_order):
drawn_elevated[status]=[self._draw_specific_status_(pos, elevated[status], status, graph_ax, **nx_kwargs)] #making each a list so that I can change the entry in the list while still passing the same object
#WARNING I'm defining a dict and while that definition is happening
#it's drawing things
if ts_plots:
ts_axes, time_markers = self._display_time_series_(fig, frame_times[0], ts_plots, ts_list, timelabel)
else:
ts_axes, time_markers = [], []
plt.tight_layout()
fargs = (pos, nodelist, drawn_nodes, drawn_elevated, status_order, graph_ax, ts_axes, time_markers, nx_kwargs)
ani = FuncAnimation(fig, self._update_ani_, frames = frame_times, fargs = fargs, repeat=False)
return ani
#to show, do
#simulation.animation()
#plt.show()
#to save do
#ani = simulation.animation()
#ani.save(filename, fps=5, extra_args=['-vcodec', 'libx264'])