(** This module provides an interpreter for the RTL language. *) let error_prefix = "RTL interpret" let error s = Error.global_error error_prefix s module Mem = Driver.RTLMemory module Val = Mem.Value let chunk = Driver.RTLMemory.int_size module Eval = I8051.Eval (Val) type memory = RTL.function_def Mem.memory (* Local environments. They associate a value to the registers of the function being executed. *) type local_env = Val.t Register.Map.t (* Call frames. The execution state has a call stack, each element of the stack being composed of the return registers to store the result of the callee, the graph, the node, the local environment and the value of the carry to resume the execution of the caller. *) type stack_frame = { ret_regs : Register.t list ; graph : RTL.graph ; pc : Label.t ; sp : Val.address ; lenv : local_env ; carry : Val.t } (* Execution states. There are three possible states : - The constructor [State] represents a state when executing a function - The constructor [CallState] represents a state when calling a function - The constructor [ReturnState] represents a state when leaving a function *) type state = | State of stack_frame list * RTL.graph * Label.t * Val.address (* sp *) * local_env * Val.t (* carry *) * memory * CostLabel.t list | CallState of stack_frame list * RTL.function_def * Val.t list (* args *) * memory * CostLabel.t list | ReturnState of stack_frame list * Val.t list (* return values *) * memory * CostLabel.t list let string_of_local_env lenv = let f x v s = s ^ (if Val.eq v Val.undef then "" else (Register.print x) ^ " = " ^ (Val.to_string v) ^ " ") in Register.Map.fold f lenv "" let string_of_args args = let f s v = s ^ " " ^ (Val.to_string v) in List.fold_left f "" args let print_state = function | State (_, _, lbl, sp, lenv, carry, mem, _) -> Printf.printf "Stack pointer: %s\n\nCarry: %s\n\nLocal environment:\n%s\n\nMemory:%s\nRegular state: %s\n\n%!" (Val.string_of_address sp) (Val.to_string carry) (string_of_local_env lenv) (Mem.to_string mem) lbl | CallState (_, _, args, mem, _) -> Printf.printf "Memory:%s\nCall state: %s\n\n%!" (Mem.to_string mem) (string_of_args args) | ReturnState (_, vs, mem, _) -> Printf.printf "Memory:%s\nReturn state: %s\n\n%!" (Mem.to_string mem) (string_of_args vs) let find_function mem f = let addr = Mem.find_global mem f in Mem.find_fun_def mem addr let get_local_value (lenv : local_env) (r : Register.t) : Val.t = if Register.Map.mem r lenv then Register.Map.find r lenv else error ("Unknown local register \"" ^ (Register.print r) ^ "\".") let get_arg_values lenv args = List.map (get_local_value lenv) args let get_local_addr lenv f1 f2 = List.map (get_local_value lenv) [f1 ; f2] let adds rs vs lenv = let f lenv r v = Register.Map.add r v lenv in List.fold_left2 f lenv rs vs (* Assign a value to some destinations registers. *) let assign_state sfrs graph lbl sp lenv carry mem trace destrs vs = let lenv = adds destrs vs lenv in State (sfrs, graph, lbl, sp, lenv, carry, mem, trace) (* Branch on a value. *) let branch_state sfrs graph lbl_true lbl_false sp lenv carry mem trace v = let next_lbl = if Val.is_true v then lbl_true else if Val.is_false v then lbl_false else error "Undefined conditional value." in State (sfrs, graph, next_lbl, sp, lenv, carry, mem, trace) (* Interpret statements. *) let interpret_statement (sfrs : stack_frame list) (graph : RTL.graph) (sp : Val.address) (lenv : local_env) (carry : Val.t) (mem : memory) (stmt : RTL.statement) (trace : CostLabel.t list) : state = match stmt with | RTL.St_skip lbl -> State (sfrs, graph, lbl, sp, lenv, carry, mem, trace) | RTL.St_cost (cost_lbl, lbl) -> State (sfrs, graph, lbl, sp, lenv, carry, mem, cost_lbl :: trace) | RTL.St_addr (r1, r2, x, lbl) -> assign_state sfrs graph lbl sp lenv carry mem trace [r1 ; r2] (Mem.find_global mem x) | RTL.St_stackaddr (r1, r2, lbl) -> assign_state sfrs graph lbl sp lenv carry mem trace [r1 ; r2] sp | RTL.St_int (r, i, lbl) -> assign_state sfrs graph lbl sp lenv carry mem trace [r] [Val.of_int i] | RTL.St_move (destr, srcr, lbl) -> assign_state sfrs graph lbl sp lenv carry mem trace [destr] [get_local_value lenv srcr] | RTL.St_opaccs (opaccs, destr1, destr2, srcr1, srcr2, lbl) -> let (v1, v2) = Eval.opaccs opaccs (get_local_value lenv srcr1) (get_local_value lenv srcr2) in assign_state sfrs graph lbl sp lenv carry mem trace [destr1 ; destr2] [v1 ; v2] | RTL.St_op1 (op1, destr, srcr, lbl) -> let v = Eval.op1 op1 (get_local_value lenv srcr) in assign_state sfrs graph lbl sp lenv carry mem trace [destr] [v] | RTL.St_op2 (op2, destr, srcr1, srcr2, lbl) -> let (v, carry) = Eval.op2 carry op2 (get_local_value lenv srcr1) (get_local_value lenv srcr2) in assign_state sfrs graph lbl sp lenv carry mem trace [destr] [v] | RTL.St_clear_carry lbl -> State (sfrs, graph, lbl, sp, lenv, Val.zero, mem, trace) | RTL.St_set_carry lbl -> State (sfrs, graph, lbl, sp, lenv, Val.of_int 1, mem, trace) | RTL.St_load (destr, addr1, addr2, lbl) -> let addr = get_local_addr lenv addr1 addr2 in let v = Mem.load mem chunk addr in assign_state sfrs graph lbl sp lenv carry mem trace [destr] [v] | RTL.St_store (addr1, addr2, srcr, lbl) -> let addr = get_local_addr lenv addr1 addr2 in let mem = Mem.store mem chunk addr (get_local_value lenv srcr) in State (sfrs, graph, lbl, sp, lenv, carry, mem, trace) | RTL.St_call_id (f, args, ret_regs, lbl) -> let f_def = find_function mem f in let args = get_arg_values lenv args in let sf = { ret_regs = ret_regs ; graph = graph ; pc = lbl ; sp = sp ; lenv = lenv ; carry = carry } in CallState (sf :: sfrs, f_def, args, mem, trace) | RTL.St_call_ptr (f1, f2, args, ret_regs, lbl) -> let addr = get_local_addr lenv f1 f2 in let f_def = Mem.find_fun_def mem addr in let args = get_arg_values lenv args in let sf = { ret_regs = ret_regs ; graph = graph ; pc = lbl ; sp = sp ; lenv = lenv ; carry = carry } in CallState (sf :: sfrs, f_def, args, mem, trace) | RTL.St_tailcall_id (f, args) -> let f_def = find_function mem f in let args = get_arg_values lenv args in let mem = Mem.free mem sp in CallState (sfrs, f_def, args, mem, trace) | RTL.St_tailcall_ptr (f1, f2, args) -> let addr = get_local_addr lenv f1 f2 in let f_def = Mem.find_fun_def mem addr in let args = get_arg_values lenv args in let mem = Mem.free mem sp in CallState (sfrs, f_def, args, mem, trace) | RTL.St_cond (srcr, lbl_true, lbl_false) -> let v = get_local_value lenv srcr in branch_state sfrs graph lbl_true lbl_false sp lenv carry mem trace v | RTL.St_return rl -> let vl = List.map (get_local_value lenv) rl in let mem = Mem.free mem sp in ReturnState (sfrs, vl, mem, trace) module InterpretExternal = Primitive.Interpret (Mem) let interpret_external mem f args = match InterpretExternal.t mem f args with | (mem', InterpretExternal.V vs) -> (mem', vs) | (mem', InterpretExternal.A addr) -> (mem', addr) let init_locals (locals : Register.Set.t) (params : Register.t list) (args : Val.t list) : local_env = let f r lenv = Register.Map.add r Val.undef lenv in let lenv = Register.Set.fold f locals Register.Map.empty in let f lenv r v = Register.Map.add r v lenv in List.fold_left2 f lenv params args let state_after_call (sfrs : stack_frame list) (f_def : RTL.function_def) (args : Val.t list) (mem : memory) (trace : CostLabel.t list) : state = match f_def with | RTL.F_int def -> let (mem', sp) = Mem.alloc mem def.RTL.f_stacksize in State (sfrs, def.RTL.f_graph, def.RTL.f_entry, sp, init_locals def.RTL.f_locals def.RTL.f_params args, Val.undef, mem', trace) | RTL.F_ext def -> let (mem', vs) = interpret_external mem def.AST.ef_tag args in ReturnState (sfrs, vs, mem', trace) let state_after_return (sf : stack_frame) (sfrs : stack_frame list) (ret_vals : Val.t list) (mem : memory) (trace : CostLabel.t list) : state = let f i lenv r = Register.Map.add r (List.nth ret_vals i) lenv in let lenv = MiscPottier.foldi f sf.lenv sf.ret_regs in State (sfrs, sf.graph, sf.pc, sf.sp, lenv, sf.carry, mem, trace) let small_step (st : state) : state = match st with | State (sfrs, graph, pc, sp, lenv, carry, mem, trace) -> let stmt = Label.Map.find pc graph in interpret_statement sfrs graph sp lenv carry mem stmt trace | CallState (sfrs, f_def, args, mem, trace) -> state_after_call sfrs f_def args mem trace | ReturnState ([], ret_vals, mem, trace) -> assert false (* End of execution; handled in iter_small_step. *) | ReturnState (sf :: sfrs, ret_vals, mem, trace) -> state_after_return sf sfrs ret_vals mem trace let compute_result vs = let f res v = res && (Val.is_int v) in let is_int vs = (List.length vs > 0) && (List.fold_left f true vs) in if is_int vs then let chunks = List.map (fun v -> IntValue.Int32.cast (Val.to_int_repr v)) vs in IntValue.Int32.merge chunks else IntValue.Int32.zero let rec iter_small_step debug st = let print_and_return_result (res, cost_labels) = if debug then Printf.printf "Result = %s\n%!" (IntValue.Int32.to_string res) ; (res, cost_labels) in if debug then print_state st ; match small_step st with | ReturnState ([], vs, mem, trace) -> print_and_return_result (compute_result vs, List.rev trace) | st' -> iter_small_step debug st' let add_global_vars = List.fold_left (fun mem (id, size) -> Mem.add_var mem id (AST.SQ (AST.QInt size)) None) let add_fun_defs = List.fold_left (fun mem (f_id, f_def) -> Mem.add_fun_def mem f_id f_def) (* The memory is initialized by loading the code into it, and by reserving space for the global variables. *) let init_mem (p : RTL.program) : memory = add_global_vars (add_fun_defs Mem.empty p.RTL.functs) p.RTL.vars (* Interpret the program only if it has a main. *) let interpret debug p = Printf.printf "*** RTL interpret ***\n%!" ; match p.RTL.main with | None -> (IntValue.Int32.zero, []) | Some main -> let mem = init_mem p in let main_def = find_function mem main in let st = CallState ([], main_def, [], mem, []) in iter_small_step debug st