about summary refs log tree commit diff stats
diff options
context:
space:
mode:
-rw-r--r--miasm/expression/expression_eval_abstract.py126
-rw-r--r--miasm/expression/expression_helper.py96
2 files changed, 113 insertions, 109 deletions
diff --git a/miasm/expression/expression_eval_abstract.py b/miasm/expression/expression_eval_abstract.py
index a3f9451d..809c72fd 100644
--- a/miasm/expression/expression_eval_abstract.py
+++ b/miasm/expression/expression_eval_abstract.py
@@ -17,7 +17,7 @@
 #
 from miasm.expression.expression import *
 import struct
-import logging 
+import logging
 import cPickle
 import numpy
 from miasm.expression.expression_helper import *
@@ -29,7 +29,7 @@ numpy.seterr(over='ignore', under='ignore')
 mymaxuint = {8:0xFFL,
              16:0xFFFFL,
              32:0xFFFFFFFFL,
-             64:0xFFFFFFFFFFFFFFFFL    
+             64:0xFFFFFFFFFFFFFFFFL
              }
 
 
@@ -100,7 +100,11 @@ class mpool():
     def keys(self):
         k = self.pool_id.keys() + [x[0] for x in self.pool_mem.values()]
         return k
-
+    def copy(self):
+        p = mpool()
+        p.pool_id = dict(self.pool_id)
+        p.pool_mem = dict(self.pool_mem)
+        return p
 
 class eval_abs:
     dict_size = {
@@ -116,23 +120,23 @@ class eval_abs:
             cpt^=tmp&1
             tmp>>=1
         return cpt
-        
+
     def my_bsf(self, a, default_val=0):
         tmp = 0
         for i in xrange(32):
             if a & (1<<i):
                 return i
-        
+
         return default_val
     def my_bsr(self, a, op_size, default_val = 0):
         tmp = 0
         for i in xrange(op_size-1, -1, -1):
             if a & (1<<i):
                 return i
-        
+
         return default_val
-            
-        
+
+
     def __init__(self, vars, func_read = None, func_write = None, log = None):
         self.pool = mpool()
         for v in vars:
@@ -152,7 +156,7 @@ class eval_abs:
             f = open(f,"w")
         self.log = None
         cPickle.dump(self, f)
-    
+
     @staticmethod
 
     def from_file(f, g):
@@ -167,21 +171,21 @@ class eval_abs:
         m.log = log
         new_pool = mpool()
         for x in m.pool:
-            
+
             if not str(x) in g:
                 xx = ExprId(str(x))
                 g[str(xx)] = xx
             else:
                 xx = x
-            
+
             xx = x
             print repr(g[str(xx)]), g[str(xx)]
-            
+
             if isinstance(m.pool[x], Expr):
                 new_pool[g[str(xx)]] = m.pool[x].reload_expr(g)
             else:
                 new_pool[g[str(xx)]] = m.pool[x]
-                
+
         m.pool = new_pool
         return m
 
@@ -221,17 +225,17 @@ class eval_abs:
         if ptr_diff <0:
             #    [a     ]
             #[b      ]XXX
-            
+
             sub_size = b.size + ptr_diff*8
             if sub_size >= a.size:
                 pass
             else:
                 ex = ExprOp('+', a.arg, ExprInt(uint32(sub_size/8)))
                 ex = expr_simp(self.eval_expr(ex, {}))
-                
+
                 rest_ptr = ex
                 rest_size = a.size - sub_size
-    
+
                 val = self.pool[a][sub_size:a.size]
                 out = [(ExprMem(rest_ptr, rest_size), val)]
         else:
@@ -248,7 +252,7 @@ class eval_abs:
                 out.append((ExprMem(a.arg, ptr_diff*8), val))
             #part Y
             if ptr_diff*8+b.size <a.size:
-                
+
                 ex = ExprOp('+', b.arg, ExprInt(uint32(b.size/8)))
                 ex = expr_simp(self.eval_expr(ex, {}))
 
@@ -256,9 +260,9 @@ class eval_abs:
                 rest_size = a.size - (ptr_diff*8 + b.size)
                 val = self.pool[a][ptr_diff*8 + b.size:a.size]
                 out.append((ExprMem(ex, val.get_size()), val))
-            
-            
-        return out    
+
+
+        return out
 
     #give mem stored overlapping requested mem ptr
     def get_mem_overlapping(self, e):
@@ -315,7 +319,7 @@ class eval_abs:
         ret = self.eval_expr_no_cache(e, eval_cache)
         ret.is_eval = True
         return ret
-        
+
 
 
     def eval_op_plus(self, args, op_size, cast_int):
@@ -325,7 +329,7 @@ class eval_abs:
     def eval_op_minus(self, args, op_size, cast_int):
         ret_value = args[0] - args[1]
         return ret_value
-    
+
     def eval_op_mult(self, args, op_size, cast_int):
         ret_value = (args[0] * args[1])
         return ret_value
@@ -348,8 +352,8 @@ class eval_abs:
         if c == 0:
             raise ValueError('div by 0')
         big = (a<<uint64(op_size))+b
-        ret_value =  big-c*(big/c) 
-        if ret_value>mymaxuint[op_size]:raise ValueError('Divide Error')        
+        ret_value =  big-c*(big/c)
+        if ret_value>mymaxuint[op_size]:raise ValueError('Divide Error')
         return ret_value
 
     def eval_op_idiv(self, args, op_size, cast_int):
@@ -453,7 +457,7 @@ class eval_abs:
         tmpa = uint64((args[0]<<1) | args[2])
         rez = (tmpa>>r)  | (tmpa << (op_size+uint64(1)-r))
         return rez
-    
+
     def eval_op_rotr_wflag_rez(self, args, op_size, cast_int):
         return self.eval_op_rotr_wflag(args, op_size, cast_int)>>1
     def eval_op_rotr_wflag_cf(self, args, op_size, cast_int):
@@ -463,7 +467,7 @@ class eval_abs:
         r = args[1]#&0x1F
         ret_value = ((args[0] &mymaxuint[op_size])<<r)
         return ret_value
-    
+
     def eval_op_rshift(self, args, op_size, cast_int):
         r = args[1]#&0x1F
         ret_value = ((args[0]&mymaxuint[op_size])>>r)
@@ -505,7 +509,7 @@ class eval_abs:
         return ExprOp("objbyid_default0", ExprInt(cast_int(args[0])))
 
 
-    
+
     deal_op = {'+':eval_op_plus,
                '-':eval_op_minus,
                '*':eval_op_mult,
@@ -536,9 +540,9 @@ class eval_abs:
                #XXX
                'objbyid_default0':objbyid_default0,
                }
-    
+
     op_size_no_check = ['<<<', '>>>', 'a<<', '>>', '<<',
-                        '<<<c_rez', '<<<c_cf', 
+                        '<<<c_rez', '<<<c_cf',
                         '>>>c_rez', '>>>c_cf',]
 
 
@@ -654,26 +658,26 @@ class eval_abs:
         for a in args:
             if isinstance(a, ExprTop):
                 return ExprTop()
-        
+
         for a in args:
             if not isinstance(a, ExprInt):
                 return ExprOp(e.op, *args)
-        
+
         args = [a.arg for a in args]
-        
+
         types_tab = [type(a) for a  in args]
         if types_tab.count(types_tab[0]) != len(args) and not e.op in self.op_size_no_check:
             raise ValueError('invalid cast %r %r'%(types_tab, args))
-        
+
         cast_int = types_tab[0]
         op_size = tab_int_size[types_tab[0]]
-        
+
 
         ret_value = self.deal_op[e.op](self, args, op_size, cast_int)
         if isinstance(ret_value, Expr):
             return ret_value
         return ExprInt(cast_int(ret_value))
-                   
+
     def eval_ExprCond(self, e, eval_cache = {}):
         cond = self.eval_expr(e.cond, eval_cache)
         src1 = self.eval_expr(e.src1, eval_cache)
@@ -681,23 +685,23 @@ class eval_abs:
 
         if isinstance(cond, ExprTop):
             return ExprCond(e.cond, src1, src2)
-        
+
         if isinstance(cond, ExprInt):
             if cond.arg == 0:
                 return src2
             else:
                 return src1
         return ExprCond(cond, src1, src2)
-       
+
     def eval_ExprSlice(self, e, eval_cache = {}):
         arg = expr_simp(self.eval_expr(e.arg, eval_cache))
         if isinstance(arg, ExprTop):
             return ExprTop()
-        
+
         if isinstance(arg, ExprMem):
             if e.start == 0 and e.stop == arg.size:
                 return arg
-                
+
             return ExprSlice(arg, e.start, e.stop)
         if isinstance(arg, ExprTop):
             return ExprTop()
@@ -709,7 +713,7 @@ class eval_abs:
             to_add = []
             return ExprSlice(arg, e.start, e.stop)
         return ExprSlice(arg, e.start, e.stop)
-            
+
     def eval_ExprCompose(self, e, eval_cache = {}):
         args = []
         for a in e.args:
@@ -731,8 +735,8 @@ class eval_abs:
                 is_int_cond+=3
                 continue
             is_int_cond+=1
-                
-        
+
+
         if not is_int and is_int_cond!=1:
             uu = ExprCompose([ExprSliceTo(a, e.args[i].start, e.args[i].stop) for i, a in enumerate(args)])
             return uu
@@ -740,7 +744,7 @@ class eval_abs:
         if not is_int:
             rez = 0L
             total_bit = 0
-            
+
             for i in xrange(len(e.args)):
                 if isinstance(args[i], ExprInt):
                     a = args[i].arg
@@ -756,19 +760,19 @@ class eval_abs:
                     total_bit+=e.args[i].stop-e.args[i].start
                     mycond, mysrc1, mysrc2 = a.cond, a.src1.arg&mask, a.src2.arg&mask
                     cond_i = i
-                    
+
             mysrc1|=rez
             mysrc2|=rez
-            
-            
-            
+
+
+
             if total_bit in tab_uintsize:
                 return self.eval_expr(ExprCond(mycond, ExprInt(tab_uintsize[total_bit](mysrc1)), ExprInt(tab_uintsize[total_bit](mysrc2))), eval_cache)
             else:
                 raise 'cannot return non rounb bytes rez! %X %X'%(total_bit, rez)
-                    
-                
-        
+
+
+
         rez = 0L
         total_bit = 0
         for i in xrange(len(e.args)):
@@ -782,10 +786,10 @@ class eval_abs:
             return ExprInt(tab_uintsize[total_bit](rez))
         else:
             raise 'cannot return non rounb bytes rez! %X %X'%(total_bit, rez)
-        
+
     def eval_ExprTop(self, e, eval_cache = {}):
         return e
-    
+
     def eval_expr_no_cache(self, e, eval_cache = {}):
         c = e.__class__
         deal_class = {ExprId: self.eval_ExprId,
@@ -801,13 +805,13 @@ class eval_abs:
 
     def get_instr_mod(self, exprs):
         pool_out = {}
-        
+
         eval_cache = {}
-        
+
         for e in exprs:
             if not isinstance(e, ExprAff):
                 raise TypeError('not affect', str(e))
-            
+
             src = self.eval_expr(e.src, eval_cache)
             if isinstance(e.dst, ExprMem):
                 a = self.eval_expr(e.dst.arg, eval_cache)
@@ -821,16 +825,16 @@ class eval_abs:
                     self.func_write(self, dst, src, pool_out)
                 else:
                     pool_out[dst] = src
-                
+
             elif isinstance(e.dst, ExprId):
                 pool_out[e.dst] = src
             elif isinstance(e.dst, ExprTop):
                 raise ValueError("affect in ExprTop")
             else:
                 raise ValueError("affected zarb", str(e.dst))
-                
 
-        return pool_out    
+
+        return pool_out
 
     def eval_instr(self, exprs):
         tmp_ops = self.get_instr_mod(exprs)
@@ -845,7 +849,7 @@ class eval_abs:
                     for xx, yy in diff_mem:
                         self.pool[xx] = yy
                 tmp = expr_simp(tmp_ops[op])
-                    
+
                 if isinstance(expr_simp(op.arg), ExprTop):
                     raise ValueError('xx')
                     continue
@@ -859,12 +863,12 @@ class eval_abs:
             if isinstance(op, ExprMem):
                 mem_dst.append(op)
 
-            
+
         return mem_dst
 
     def get_reg(self, r):
         return self.eval_expr(self.pool[r], {})
-        
+
 
 
 
diff --git a/miasm/expression/expression_helper.py b/miasm/expression/expression_helper.py
index 912250c3..324c1ca8 100644
--- a/miasm/expression/expression_helper.py
+++ b/miasm/expression/expression_helper.py
@@ -57,11 +57,11 @@ def merge_sliceto_slice(args):
     for a in args:
         if max_size == None or max_size < a.stop:
             max_size = a.stop
-        
-    
+
+
 
     #first simplify all num slices
-    
+
     final_sources = []
     sorted_s = []
     for x in sources_int.values():
@@ -79,7 +79,7 @@ def merge_sliceto_slice(args):
         while sorted_s:
             if sorted_s[-1][1].stop != start:
                 break
-            
+
             start = sorted_s[-1][1].start
 
             a = uint64((int(out.arg.arg) << (out.start - start )) + sorted_s[-1][1].arg.arg)
@@ -90,9 +90,9 @@ def merge_sliceto_slice(args):
         out_type = tab_size_int[max_size]
         out.arg.arg = out_type(out.arg.arg)
         final_sources.append((start, out))
-        
+
     final_sources_int = final_sources
-    
+
     #check if same sources have corresponding start/stop
     #is slice AND is sliceto
     simp_sources = []
@@ -110,27 +110,27 @@ def merge_sliceto_slice(args):
                     break
                 if sorted_s[-1][1].arg.stop != out.arg.start:
                     break
-                
+
                 start = sorted_s[-1][1].start
                 out.arg.start = sorted_s[-1][1].arg.start
                 sorted_s.pop()
             out.start = start
 
             final_sources.append((start, out))
-        
+
         simp_sources+=final_sources
 
     simp_sources+= final_sources_int
 
     for i, v in non_slice.items():
         simp_sources.append((i, v))
-    
+
     simp_sources.sort()
-    
+
     simp_sources = [x[1] for x in simp_sources]
     return simp_sources
-    
-    
+
+
 def expr_simp(e):
     if e.is_simp:
         return e
@@ -156,7 +156,7 @@ def expr_simp_w(e):
                 return expr_simp(e.src2)
             else:
                 return expr_simp(e.src1)
-                
+
         return ExprCond(expr_simp(e.cond), expr_simp(e.src1), expr_simp(e.src2))
     elif isinstance(e, ExprMem):
         if isinstance(e.arg, ExprTop):
@@ -209,12 +209,12 @@ def expr_simp_w(e):
             if isinstance(args[1], ExprInt) and args[1].arg == 0:
                 return args[1]
 
-        #A-(-123) =>A+123    
+        #A-(-123) =>A+123
         if op == '-' and isinstance(args[1], ExprInt) and int32(args[1].arg)<0 :
             op = '+'
             args[1] = ExprInt(-args[1].arg)
 
-        #A+(-123) =>A-123    
+        #A+(-123) =>A-123
         if op == '+' and isinstance(args[1], ExprInt) and int32(args[1].arg)<0 :
             op = '-'
             args[1] = ExprInt(-args[1].arg)
@@ -229,8 +229,8 @@ def expr_simp_w(e):
             else:
                 op = op2
                 args1 = args[0].args[1].arg - args[1].arg
-                    
-                
+
+
                 #if op == '-':
                 #    args1 = -args1
             args0 = args[0].args[0]
@@ -245,7 +245,7 @@ def expr_simp_w(e):
         #0 - (a-b) => b-a
         if op == '-' and isinstance(args[0], ExprInt) and args[0].arg == 0 and isinstance(args[1], ExprOp) and args[1].op == "-":
             return expr_simp(args[1].args[1] - args[1].args[0])
-            
+
         #a<<< x <<< y => a <<< (x+y) (ou <<< >>>)
         if op in ['<<<', '>>>'] and isinstance(args[1], ExprInt) and isinstance(args[0], ExprOp) and args[0].op in ['<<<', '>>>'] and isinstance(args[0].args[1], ExprInt):
             op1 = op
@@ -256,7 +256,7 @@ def expr_simp_w(e):
             else:
                 op = op2
                 args1 = args[0].args[1].arg - args[1].arg
-                    
+
             args0 = args[0].args[0]
             args = [args0, ExprInt(args1)]
 
@@ -270,10 +270,10 @@ def expr_simp_w(e):
         if op in ['<<<', '>>>'] and isinstance(args[0], ExprOp) and args[0].op in ['<<<', '>>>'] and args[1] == args[0].args[1]:
             oo = op, args[0].op
             if oo in [('<<<', '>>>'), ('>>>', '<<<')]:
-                
+
                 e = expr_simp(args[0].args[0])
                 return e
-                
+
 
         #( a + int1 ) - (b+int2) => a - (b+ (int1-int2))
         if op in ['+', '-'] and isinstance(args[0], ExprOp) and args[0].op in ['+', '-'] and isinstance(args[1], ExprOp) and args[1].op in ['+', '-'] and isinstance(args[0].args[1], ExprInt) and isinstance(args[1].args[1], ExprInt):
@@ -296,7 +296,7 @@ def expr_simp_w(e):
                               )
                        )
             e = expr_simp(e)
-            
+
             return e
 
         #(a - (a + XXX)) => 0-XXX
@@ -311,7 +311,7 @@ def expr_simp_w(e):
                        z,
                        args[1].args[1])
             e = expr_simp(e)
-            
+
             return e
 
 
@@ -324,7 +324,7 @@ def expr_simp_w(e):
                        z,
                        args[0].args[1])
             e = expr_simp(e)
-            
+
             return e
 
         #  ((a ^ b) ^ a) => b (or commut)
@@ -364,11 +364,11 @@ def expr_simp_w(e):
                 rest_a = args[0].args[0]
                 e = expr_simp(rest_a)
                 return e
-        
+
         # a<<< a.size => a
         if op in ['<<<', '>>>'] and isinstance(args[1], ExprInt) and args[1].arg == args[0].get_size():
             return expr_simp(args[0])
-        
+
         #!!a => a
         if op == '!' and isinstance(args[0], ExprOp) and args[0].op == '!':
             new_e = args[0].args[0]
@@ -393,11 +393,11 @@ def expr_simp_w(e):
                        ,
                        args[0].args[1])
             return expr_simp(e)
-        
-        
+
+
         if op == "&" and isinstance(args[0], ExprOp) and args[0].op == '!' and isinstance(args[1], ExprOp) and args[1].op == '!' and isinstance(args[0].args[0], ExprOp) and args[0].args[0].op == '&' and isinstance(args[1].args[0], ExprOp) and args[1].args[0].op == '&':
 
-            ##############1 
+            ##############1
             a1 = args[0].args[0].args[0]
             if isinstance(a1, ExprOp) and a1.op == '!':
                 a1 = a1.args[0]
@@ -413,7 +413,7 @@ def expr_simp_w(e):
                 b1 = ExprInt(~b1.arg)
             else:
                 b1 = None
-    
+
 
             a2 = args[1].args[0].args[0]
             b2 = args[1].args[0].args[1]
@@ -439,7 +439,7 @@ def expr_simp_w(e):
                 b1 = ExprInt(~b1.arg)
             else:
                 b1 = None
-    
+
 
             a2 = args[0].args[0].args[0]
             b2 = args[0].args[0].args[1]
@@ -448,7 +448,7 @@ def expr_simp_w(e):
             if a1 != None and b1 != None and a1 == a2 and b1 == b2:
                 new_e = ExprOp('^', a1, b1)
                 return expr_simp(new_e)
-        
+
 
         # (x & mask) >> shift whith mask < 2**shift => 0
         if op == ">>" and isinstance(args[1], ExprInt) and isinstance(args[0], ExprOp) and args[0].op == "&":
@@ -465,13 +465,13 @@ def expr_simp_w(e):
             new_e = ExprSlice(ExprOp('!', args[0].arg), args[0].start, args[0].stop)
             return expr_simp(new_e)
 
-        
+
         #! int
         if op == '!' and isinstance(args[0], ExprInt):
             a = args[0]
             e = ExprInt(tab_max_uint[a.get_size()]^a.arg)
             return e
-        
+
         #a^a=>0 | a-a =>0
         if op in ['^', '-'] and args[0] == args[1]:
             tmp =  ExprInt(tab_size_int[args[0].get_size()](0))
@@ -491,11 +491,11 @@ def expr_simp_w(e):
             if isinstance(args[0], ExprOp) and args[0].op == '|' and isinstance(args[0].args[1], ExprInt) and \
                args[0].args[1].arg != 0:
                 return ExprInt(tab_size_int[args[0].get_size()](0))
-                                     
+
 
         if op == 'parity' and isinstance(args[0], ExprInt):
             return ExprInt(tab_size_int[args[0].get_size()](parity(args[0].arg)))
-        
+
         new_e = ExprOp(op, *[expr_simp(x) for x in args])
         if new_e == e:
             return new_e
@@ -521,7 +521,7 @@ def expr_simp_w(e):
         elif isinstance(arg, ExprSlice):
             if e.stop-e.start > arg.stop-arg.start:
                 raise ValueError('slice in slice: getting more val', str(e))
-            
+
             new_e = ExprSlice(expr_simp(arg.arg), e.start + arg.start, e.start + arg.start + (e.stop - e.start))
             return expr_simp(new_e)
         elif isinstance(arg, ExprCompose):
@@ -545,7 +545,7 @@ def expr_simp_w(e):
 
 
 
-        
+
         return ExprSlice(arg, e.start, e.stop)
     elif isinstance(e, ExprSliceTo):
         if isinstance(e.arg, ExprTop):
@@ -561,7 +561,7 @@ def expr_simp_w(e):
                     return expr_simp(ExprSliceTo(ExprCompose([a]), e.start, e.stop))
 
 
-            
+
         return ExprSliceTo(expr_simp(e.arg), e.start, e.stop)
     elif isinstance(e, ExprCompose):
         #(.., a_to[x:y], a[:]_to[y:z], ..) => (.., a[x:z], ..)
@@ -591,9 +591,9 @@ def expr_simp_w(e):
 
         if simp:
             return expr_simp(ExprCompose(args))
-            
-            
-        
+
+
+
         all_top = True
         for a in e.args:
             if not isinstance(a, ExprTop):
@@ -605,7 +605,7 @@ def expr_simp_w(e):
         if ExprTop() in e.args:
             return ExprTop()
         """
-        
+
         args = merge_sliceto_slice(e.args)
         if len(args) == 1:
             a = args[0]
@@ -614,14 +614,14 @@ def expr_simp_w(e):
                     print a, a.arg.get_size(), a.stop
                     raise ValueError("cast in compose!", e)
                 return a.arg
-            
+
             uu = expr_simp(a.arg)
             return uu
         if len(args) != len(e.args):
             return expr_simp(ExprCompose(args))
         else:
             return ExprCompose(args)
-    
+
     else:
         raise 'bad expr'
 
@@ -653,6 +653,6 @@ def expr_replace(e, repl):
         return ExprCompose([expr_replace(x, repl) for x in e.args])
     else:
         raise 'bad expr'
-    
-    
-    
+
+
+