about summary refs log tree commit diff stats
path: root/miasm2/expression/expression.py
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/expression/expression.py')
-rw-r--r--miasm2/expression/expression.py62
1 files changed, 46 insertions, 16 deletions
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py
index 591dc024..d31c509c 100644
--- a/miasm2/expression/expression.py
+++ b/miasm2/expression/expression.py
@@ -56,6 +56,38 @@ EXPRSLICE = 7
 EXPRCOMPOSE = 8
 
 
+priorities_list = [
+    [ '+' ],
+    [ '*', '/', '%'  ],
+    [ '**' ],
+    [ '-' ],	# Unary '-', associativity with + not handled
+]
+
+# dictionary from 'op' to priority, derived from above
+priorities = dict((op, prio)
+                  for prio, l in enumerate(priorities_list)
+                  for op in l)
+PRIORITY_MAX = len(priorities_list) - 1
+
+def should_parenthesize_child(child, parent):
+    if (isinstance(child, ExprId) or isinstance(child, ExprInt) or
+        isinstance(child, ExprCompose) or isinstance(child, ExprMem) or
+        isinstance(child, ExprSlice)):
+        return False
+    elif isinstance(child, ExprOp) and not child.is_infix():
+        return False
+    elif (isinstance(child, ExprCond) or isinstance(parent, ExprSlice)):
+        return True
+    elif (isinstance(child, ExprOp) and isinstance(parent, ExprOp)):
+        pri_child = priorities.get(child.op, -1)
+        pri_parent = priorities.get(parent.op, PRIORITY_MAX + 1)
+        return pri_child < pri_parent
+    else:
+        return True
+
+def str_protected_child(child, parent):
+    return ("(%s)" % child) if should_parenthesize_child(child, parent) else str(child)
+
 def visit_chk(visitor):
     "Function decorator launching callback on Expression visit"
     def wrapped(expr, callback, test_visit=lambda x: True):
@@ -687,7 +719,7 @@ class ExprCond(Expr):
         return Expr.get_object(cls, (cond, src1, src2))
 
     def __str__(self):
-        return "(%s?(%s,%s))" % (str(self._cond), str(self._src1), str(self._src2))
+        return "%s?(%s,%s)" % (str_protected_child(self._cond, self), str(self._src1), str(self._src2))
 
     def get_r(self, mem_read=False, cst_read=False):
         out_src1 = self.src1.get_r(mem_read, cst_read)
@@ -919,20 +951,13 @@ class ExprOp(Expr):
         return Expr.get_object(cls, (op, args))
 
     def __str__(self):
-        if self.is_associative():
-            return '(' + self._op.join([str(arg) for arg in self._args]) + ')'
-        if (self._op.startswith('call_func_') or
-            self._op == 'cpuid' or
-            len(self._args) > 2 or
-                self._op in ['parity', 'segm']):
-            return self._op + '(' + ', '.join([str(arg) for arg in self._args]) + ')'
-        if len(self._args) == 2:
-            return ('(' + str(self._args[0]) +
-                    ' ' + self.op + ' ' + str(self._args[1]) + ')')
-        else:
-            return reduce(lambda x, y: x + ' ' + str(y),
-                          self._args,
-                          '(' + str(self._op)) + ')'
+        if self._op == '-':		# Unary minus
+            return '-' + str_protected_child(self._args[0], self)
+        if self.is_associative() or self.is_infix():
+            return (' ' + self._op + ' ').join([str_protected_child(arg, self)
+                                                for arg in self._args])
+        return (self._op + '(' +
+                ', '.join([str(arg) for arg in self._args]) + ')')
 
     def get_r(self, mem_read=False, cst_read=False):
         return reduce(lambda elements, arg:
@@ -960,6 +985,11 @@ class ExprOp(Expr):
     def is_function_call(self):
         return self._op.startswith('call')
 
+    def is_infix(self):
+        return self._op in [ '-', '+', '*', '^', '&', '|', '>>', '<<',
+                             'a>>', '>>>', '<<<', '/', '%', '**',
+                             '<u', '<s', '<=u', '<=s', '==' ]
+
     def is_associative(self):
         "Return True iff current operation is associative"
         return (self._op in ['+', '*', '^', '&', '|'])
@@ -1026,7 +1056,7 @@ class ExprSlice(Expr):
         return Expr.get_object(cls, (arg, start, stop))
 
     def __str__(self):
-        return "%s[%d:%d]" % (str(self._arg), self._start, self._stop)
+        return "%s[%d:%d]" % (str_protected_child(self._arg, self), self._start, self._stop)
 
     def get_r(self, mem_read=False, cst_read=False):
         return self._arg.get_r(mem_read, cst_read)