about summary refs log tree commit diff stats
path: root/miasm2/expression
diff options
context:
space:
mode:
Diffstat (limited to 'miasm2/expression')
-rw-r--r--miasm2/expression/expression.py143
-rw-r--r--miasm2/expression/expression_helper.py6
-rw-r--r--miasm2/expression/expression_reduce.py10
-rw-r--r--miasm2/expression/parser.py54
-rw-r--r--miasm2/expression/simplifications_common.py20
5 files changed, 196 insertions, 37 deletions
diff --git a/miasm2/expression/expression.py b/miasm2/expression/expression.py
index 54cd5a2d..8e63e6a2 100644
--- a/miasm2/expression/expression.py
+++ b/miasm2/expression/expression.py
@@ -19,6 +19,7 @@
 # IR components are :
 #  - ExprInt
 #  - ExprId
+#  - ExprLoc
 #  - ExprAff
 #  - ExprCond
 #  - ExprMem
@@ -48,12 +49,13 @@ TOK_POS_STRICT = "Spos"
 # Hashing constants
 EXPRINT = 1
 EXPRID = 2
-EXPRAFF = 3
-EXPRCOND = 4
-EXPRMEM = 5
-EXPROP = 6
-EXPRSLICE = 7
-EXPRCOMPOSE = 8
+EXPRLOC = 3
+EXPRAFF = 4
+EXPRCOND = 5
+EXPRMEM = 6
+EXPROP = 7
+EXPRSLICE = 8
+EXPRCOMPOSE = 9
 
 
 priorities_list = [
@@ -115,6 +117,8 @@ class DiGraphExpr(DiGraph):
             return node.op
         elif isinstance(node, ExprId):
             return node.name
+        elif isinstance(node, ExprLoc):
+            return "%s" % node.loc_key
         elif isinstance(node, ExprMem):
             return "@%d" % node.size
         elif isinstance(node, ExprCompose):
@@ -141,6 +145,32 @@ class DiGraphExpr(DiGraph):
         return ""
 
 
+
+class LocKey(object):
+    def __init__(self, key):
+        self._key = key
+
+    key = property(lambda self: self._key)
+
+    def __hash__(self):
+        return hash(self._key)
+
+    def __eq__(self, other):
+        if self is other:
+            return True
+        if self.__class__ is not other.__class__:
+            return False
+        return self.key == other.key
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __repr__(self):
+        return "<%s %d>" % (self.__class__.__name__, self._key)
+
+    def __str__(self):
+        return "loc_key_%d" % self.key
+
 # IR definitions
 
 class Expr(object):
@@ -383,6 +413,9 @@ class Expr(object):
     def is_id(self, name=None):
         return False
 
+    def is_loc(self, label=None):
+        return False
+
     def is_aff(self):
         return False
 
@@ -532,6 +565,7 @@ class ExprId(Expr):
         if size is None:
             warnings.warn('DEPRECATION WARNING: size is a mandatory argument: use ExprId(name, SIZE)')
             size = 32
+        assert isinstance(name, str)
         super(ExprId, self).__init__(size)
         self._name = name
 
@@ -584,6 +618,68 @@ class ExprId(Expr):
         return True
 
 
+class ExprLoc(Expr):
+
+    """An ExprLoc represent a Label in Miasm IR.
+    """
+
+    __slots__ = Expr.__slots__ + ["_loc_key"]
+
+    def __init__(self, loc_key, size):
+        """Create an identifier
+        @loc_key: int, label loc_key
+        @size: int, identifier's size
+        """
+        assert isinstance(loc_key, LocKey)
+        super(ExprLoc, self).__init__(size)
+        self._loc_key = loc_key
+
+    loc_key= property(lambda self: self._loc_key)
+
+    def __reduce__(self):
+        state = self._loc_key, self._size
+        return self.__class__, state
+
+    def __new__(cls, loc_key, size):
+        return Expr.get_object(cls, (loc_key, size))
+
+    def __str__(self):
+        return str(self._loc_key)
+
+    def get_r(self, mem_read=False, cst_read=False):
+        return set()
+
+    def get_w(self):
+        return set()
+
+    def _exprhash(self):
+        return hash((EXPRLOC, self._loc_key, self._size))
+
+    def _exprrepr(self):
+        return "%s(%r, %d)" % (self.__class__.__name__, self._loc_key, self._size)
+
+    def __contains__(self, expr):
+        return self == expr
+
+    @visit_chk
+    def visit(self, callback, test_visit=None):
+        return self
+
+    def copy(self):
+        return ExprLoc(self._loc_key, self._size)
+
+    def depth(self):
+        return 1
+
+    def graph_recursive(self, graph):
+        graph.add_node(self)
+
+    def is_loc(self, loc_key=None):
+        if loc_key is not None and self._loc_key != loc_key:
+            return False
+        return True
+
+
 class ExprAff(Expr):
 
     """An ExprAff represent an affection from an Expression to another one.
@@ -1226,10 +1322,11 @@ class ExprCompose(Expr):
 
 # Expression order for comparaison
 EXPR_ORDER_DICT = {ExprId: 1,
-                   ExprCond: 2,
-                   ExprMem: 3,
-                   ExprOp: 4,
-                   ExprSlice: 5,
+                   ExprLoc: 2,
+                   ExprCond: 3,
+                   ExprMem: 4,
+                   ExprOp: 5,
+                   ExprSlice: 6,
                    ExprCompose: 7,
                    ExprInt: 8,
                   }
@@ -1289,6 +1386,11 @@ def compare_exprs(expr1, expr2):
         if ret:
             return ret
         return cmp(expr1.size, expr2.size)
+    elif cls1 == ExprLoc:
+        ret = cmp(expr1.loc_key, expr2.loc_key)
+        if ret:
+            return ret
+        return cmp(expr1.size, expr2.size)
     elif cls1 == ExprAff:
         raise NotImplementedError(
             "Comparaison from an ExprAff not yet implemented")
@@ -1379,11 +1481,19 @@ def ExprInt_from(expr, i):
 def get_expr_ids_visit(expr, ids):
     """Visitor to retrieve ExprId in @expr
     @expr: Expr"""
-    if isinstance(expr, ExprId):
+    if expr.is_id():
         ids.add(expr)
     return expr
 
 
+def get_expr_locs_visit(expr, locs):
+    """Visitor to retrieve ExprLoc in @expr
+    @expr: Expr"""
+    if expr.is_loc():
+        locs.add(expr)
+    return expr
+
+
 def get_expr_ids(expr):
     """Retrieve ExprId in @expr
     @expr: Expr"""
@@ -1392,6 +1502,14 @@ def get_expr_ids(expr):
     return ids
 
 
+def get_expr_locs(expr):
+    """Retrieve ExprLoc in @expr
+    @expr: Expr"""
+    locs = set()
+    expr.visit(lambda x: get_expr_locs_visit(x, locs))
+    return locs
+
+
 def test_set(expr, pattern, tks, result):
     """Test if v can correspond to e. If so, update the context in result.
     Otherwise, return False
@@ -1431,6 +1549,9 @@ def match_expr(expr, pattern, tks, result=None):
     elif expr.is_id():
         return test_set(expr, pattern, tks, result)
 
+    elif expr.is_loc():
+        return test_set(expr, pattern, tks, result)
+
     elif expr.is_op():
 
         # expr need to be the same operation than pattern
diff --git a/miasm2/expression/expression_helper.py b/miasm2/expression/expression_helper.py
index 722d169d..2fe5e26d 100644
--- a/miasm2/expression/expression_helper.py
+++ b/miasm2/expression/expression_helper.py
@@ -268,6 +268,9 @@ class Variables_Identifier(object):
         elif isinstance(expr, m2_expr.ExprId):
             pass
 
+        elif isinstance(expr, m2_expr.ExprLoc):
+            pass
+
         elif isinstance(expr, m2_expr.ExprMem):
             self.find_variables_rec(expr.arg)
 
@@ -552,7 +555,8 @@ def possible_values(expr):
 
     # Terminal expression
     if (isinstance(expr, m2_expr.ExprInt) or
-            isinstance(expr, m2_expr.ExprId)):
+        isinstance(expr, m2_expr.ExprId) or
+        isinstance(expr, m2_expr.ExprLoc)):
         consvals.add(ConstrainedValue(frozenset(), expr))
     # Unary expression
     elif isinstance(expr, m2_expr.ExprSlice):
diff --git a/miasm2/expression/expression_reduce.py b/miasm2/expression/expression_reduce.py
index 45386ca2..22ac8d8d 100644
--- a/miasm2/expression/expression_reduce.py
+++ b/miasm2/expression/expression_reduce.py
@@ -4,8 +4,8 @@ Apply reduction rules to an Expression ast
 """
 
 import logging
-from miasm2.expression.expression import ExprInt, ExprId, ExprOp, ExprSlice,\
-    ExprCompose, ExprMem, ExprCond
+from miasm2.expression.expression import ExprInt, ExprId, ExprLoc, ExprOp, \
+    ExprSlice, ExprCompose, ExprMem, ExprCond
 
 log_reduce = logging.getLogger("expr_reduce")
 console_handler = logging.StreamHandler()
@@ -29,7 +29,7 @@ class ExprNode(object):
         expr = self.expr
         if self.info is not None:
             out = repr(self.info)
-        elif expr.is_int() or expr.is_id():
+        elif expr.is_int() or expr.is_id() or expr.is_loc():
             out = str(expr)
         elif expr.is_mem():
             out = "@%d[%r]" % (self.expr.size, self.arg)
@@ -76,7 +76,7 @@ class ExprReducer(object):
         @expr: Expression to analyze
         """
 
-        if isinstance(expr, (ExprId, ExprInt)):
+        if isinstance(expr, (ExprId, ExprLoc, ExprInt)):
             node = ExprNode(expr)
         elif isinstance(expr, (ExprMem, ExprSlice)):
             son = self.expr2node(expr.arg)
@@ -118,7 +118,7 @@ class ExprReducer(object):
 
         expr = node.expr
         log_reduce.debug("\t" * lvl + "Reduce...: %s", node.expr)
-        if isinstance(expr, (ExprId, ExprInt)):
+        if isinstance(expr, (ExprId, ExprInt, ExprLoc)):
             pass
         elif isinstance(expr, ExprMem):
             arg = self.categorize(node.arg, lvl=lvl + 1, **kwargs)
diff --git a/miasm2/expression/parser.py b/miasm2/expression/parser.py
index b3f3af1c..cbfd58d0 100644
--- a/miasm2/expression/parser.py
+++ b/miasm2/expression/parser.py
@@ -1,6 +1,6 @@
 import pyparsing
-from miasm2.expression.expression import ExprInt, ExprId, ExprSlice, ExprMem, \
-    ExprCond, ExprCompose, ExprOp, ExprAff
+from miasm2.expression.expression import ExprInt, ExprId, ExprLoc, ExprSlice, \
+    ExprMem, ExprCond, ExprCompose, ExprOp, ExprAff, LocKey
 
 integer = pyparsing.Word(pyparsing.nums).setParseAction(lambda t:
                                                         int(t[0]))
@@ -16,6 +16,7 @@ str_int = str_int_pos | str_int_neg
 
 STR_EXPRINT = pyparsing.Suppress("ExprInt")
 STR_EXPRID = pyparsing.Suppress("ExprId")
+STR_EXPRLOC = pyparsing.Suppress("ExprLoc")
 STR_EXPRSLICE = pyparsing.Suppress("ExprSlice")
 STR_EXPRMEM = pyparsing.Suppress("ExprMem")
 STR_EXPRCOND = pyparsing.Suppress("ExprCond")
@@ -23,11 +24,17 @@ STR_EXPRCOMPOSE = pyparsing.Suppress("ExprCompose")
 STR_EXPROP = pyparsing.Suppress("ExprOp")
 STR_EXPRAFF = pyparsing.Suppress("ExprAff")
 
+LOCKEY = pyparsing.Suppress("LocKey")
+
 STR_COMMA = pyparsing.Suppress(",")
 LPARENTHESIS = pyparsing.Suppress("(")
 RPARENTHESIS = pyparsing.Suppress(")")
 
 
+T_INF = pyparsing.Suppress("<")
+T_SUP = pyparsing.Suppress(">")
+
+
 string_quote = pyparsing.QuotedString(quoteChar="'", escChar='\\', escQuote='\\')
 string_dquote = pyparsing.QuotedString(quoteChar='"', escChar='\\', escQuote='\\')
 
@@ -36,26 +43,33 @@ string = string_quote | string_dquote
 
 expr = pyparsing.Forward()
 
-expr_int = pyparsing.Group(STR_EXPRINT + LPARENTHESIS + str_int + STR_COMMA + str_int + RPARENTHESIS)
-expr_id = pyparsing.Group(STR_EXPRID + LPARENTHESIS + string + STR_COMMA + str_int + RPARENTHESIS)
-expr_slice = pyparsing.Group(STR_EXPRSLICE + LPARENTHESIS + expr + STR_COMMA + str_int + STR_COMMA + str_int + RPARENTHESIS)
-expr_mem = pyparsing.Group(STR_EXPRMEM + LPARENTHESIS + expr + STR_COMMA + str_int + RPARENTHESIS)
-expr_cond = pyparsing.Group(STR_EXPRCOND + LPARENTHESIS + expr + STR_COMMA + expr + STR_COMMA + expr + RPARENTHESIS)
-expr_compose = pyparsing.Group(STR_EXPRCOMPOSE + LPARENTHESIS + pyparsing.delimitedList(expr, delim=',') + RPARENTHESIS)
-expr_op = pyparsing.Group(STR_EXPROP + LPARENTHESIS + string + STR_COMMA + pyparsing.delimitedList(expr, delim=',') + RPARENTHESIS)
-expr_aff = pyparsing.Group(STR_EXPRAFF + LPARENTHESIS + expr + STR_COMMA + expr + RPARENTHESIS)
-
-expr << (expr_int | expr_id | expr_slice | expr_mem | expr_cond | \
+expr_int = STR_EXPRINT + LPARENTHESIS + str_int + STR_COMMA + str_int + RPARENTHESIS
+expr_id = STR_EXPRID + LPARENTHESIS + string + STR_COMMA + str_int + RPARENTHESIS
+expr_loc = STR_EXPRLOC + LPARENTHESIS + T_INF + LOCKEY + str_int + T_SUP + STR_COMMA + str_int + RPARENTHESIS
+expr_slice = STR_EXPRSLICE + LPARENTHESIS + expr + STR_COMMA + str_int + STR_COMMA + str_int + RPARENTHESIS
+expr_mem = STR_EXPRMEM + LPARENTHESIS + expr + STR_COMMA + str_int + RPARENTHESIS
+expr_cond = STR_EXPRCOND + LPARENTHESIS + expr + STR_COMMA + expr + STR_COMMA + expr + RPARENTHESIS
+expr_compose = STR_EXPRCOMPOSE + LPARENTHESIS + pyparsing.delimitedList(expr, delim=',') + RPARENTHESIS
+expr_op = STR_EXPROP + LPARENTHESIS + string + STR_COMMA + pyparsing.delimitedList(expr, delim=',') + RPARENTHESIS
+expr_aff = STR_EXPRAFF + LPARENTHESIS + expr + STR_COMMA + expr + RPARENTHESIS
+
+expr << (expr_int | expr_id | expr_loc | expr_slice | expr_mem | expr_cond | \
          expr_compose | expr_op | expr_aff)
 
-expr_int.setParseAction(lambda t: ExprInt(*t[0]))
-expr_id.setParseAction(lambda t: ExprId(*t[0]))
-expr_slice.setParseAction(lambda t: ExprSlice(*t[0]))
-expr_mem.setParseAction(lambda t: ExprMem(*t[0]))
-expr_cond.setParseAction(lambda t: ExprCond(*t[0]))
-expr_compose.setParseAction(lambda t: ExprCompose(*t[0]))
-expr_op.setParseAction(lambda t: ExprOp(*t[0]))
-expr_aff.setParseAction(lambda t: ExprAff(*t[0]))
+def parse_loc_key(t):
+    assert len(t) == 2
+    loc_key, size = LocKey(t[0]), t[1]
+    return ExprLoc(loc_key, size)
+
+expr_int.setParseAction(lambda t: ExprInt(*t))
+expr_id.setParseAction(lambda t: ExprId(*t))
+expr_loc.setParseAction(parse_loc_key)
+expr_slice.setParseAction(lambda t: ExprSlice(*t))
+expr_mem.setParseAction(lambda t: ExprMem(*t))
+expr_cond.setParseAction(lambda t: ExprCond(*t))
+expr_compose.setParseAction(lambda t: ExprCompose(*t))
+expr_op.setParseAction(lambda t: ExprOp(*t))
+expr_aff.setParseAction(lambda t: ExprAff(*t))
 
 
 def str_to_expr(str_in):
diff --git a/miasm2/expression/simplifications_common.py b/miasm2/expression/simplifications_common.py
index 13b25ce2..149c5b8d 100644
--- a/miasm2/expression/simplifications_common.py
+++ b/miasm2/expression/simplifications_common.py
@@ -250,6 +250,26 @@ def simp_cst_propagation(e_s, expr):
             e_s(Y.msb()) == ExprInt(0, 1)):
             args = [args[0].args[0], X + Y]
 
+    # ((var >> int1) << int1) => var & mask
+    # ((var << int1) >> int1) => var & mask
+    if (op_name in ['<<', '>>'] and
+        args[0].is_op() and
+        args[0].op in ['<<', '>>'] and
+        op_name != args[0]):
+        var = args[0].args[0]
+        int1 = args[0].args[1]
+        int2 = args[1]
+        if int1 == int2 and int1.is_int() and int(int1) < expr.size:
+            if op_name == '>>':
+                mask = ExprInt((1 << (expr.size - int(int1))) - 1, expr.size)
+            else:
+                mask = ExprInt(
+                    ((1 << int(int1)) - 1) ^ ((1 << expr.size) - 1),
+                    expr.size
+                )
+            ret = var & mask
+            return ret
+
     # ((A & A.mask)
     if op_name == "&" and args[-1] == expr.mask:
         return ExprOp('&', *args[:-1])