about summary refs log tree commit diff stats
diff options
context:
space:
mode:
authorAjax <commial@gmail.com>2016-01-25 10:50:59 +0100
committerAjax <commial@gmail.com>2016-01-25 14:28:41 +0100
commit418d2f1600dd1b7dbe81537f9b24c1d4e2d450cb (patch)
treeaaaf354d02efc3463beed062f5ad55397be91879
parentb4d46020aff6aac57b79cfcb64011bdc608854d9 (diff)
downloadmiasm-418d2f1600dd1b7dbe81537f9b24c1d4e2d450cb.tar.gz
miasm-418d2f1600dd1b7dbe81537f9b24c1d4e2d450cb.zip
Graph; introduce copy and merge
-rw-r--r--miasm2/core/graph.py19
-rw-r--r--test/core/graph.py15
2 files changed, 34 insertions, 0 deletions
diff --git a/miasm2/core/graph.py b/miasm2/core/graph.py
index f38f71d6..ee5dc418 100644
--- a/miasm2/core/graph.py
+++ b/miasm2/core/graph.py
@@ -26,6 +26,25 @@ class DiGraph(object):
     def edges(self):
         return self._edges
 
+    def merge(self, graph):
+        """Merge the current graph with @graph
+        @graph: DiGraph instance
+        """
+        for node in graph._nodes:
+            self.add_node(node)
+        for edge in graph._edges:
+            self.add_edge(*edge)
+
+    def __add__(self, graph):
+        """Wrapper on `.merge`"""
+        self.merge(graph)
+        return self
+
+    def copy(self):
+        """Copy the current graph instance"""
+        graph = self.__class__()
+        return graph + self
+
     def __eq__(self, graph):
         if not isinstance(graph, self.__class__):
             return False
diff --git a/test/core/graph.py b/test/core/graph.py
index 269b721b..33a2fc6f 100644
--- a/test/core/graph.py
+++ b/test/core/graph.py
@@ -202,3 +202,18 @@ graph2.add_edge(2, 3)
 graph2.add_edge(1, 2)
 assert graph == graph2
 
+# Copy
+graph4 = graph.copy()
+assert graph == graph4
+
+# Merge
+graph3 = DiGraph()
+graph3.add_edge(3, 1)
+graph3.add_edge(1, 4)
+graph4 += graph3
+for node in graph3.nodes():
+    assert node in graph4.nodes()
+for edge in graph3.edges():
+    assert edge in graph4.edges()
+assert graph4.nodes() == graph.nodes().union(graph3.nodes())
+assert sorted(graph4.edges()) == sorted(graph.edges() + graph3.edges())