open Ast module StringMap = Map.Make (String) exception Type_error of string type func_sig = { fparams : typ list; fret : typ; } type tc_ctx = { structs : (typ StringMap.t) StringMap.t; funcs : func_sig StringMap.t; globals : (typ * bool) StringMap.t; } let fail_type msg = raise (Type_error msg) let expect_type got want = if not (equal_typ got want) then fail_type (Printf.sprintf "type mismatch: got %s, expected %s" (string_of_typ got) (string_of_typ want)) let rec validate_type (structs : (typ StringMap.t) StringMap.t) (allow_void : bool) = function | TVoid when allow_void -> () | TVoid -> fail_type "void is not a valid variable type" | TStruct n -> if not (StringMap.mem n structs) then fail_type ("unknown struct type: " ^ n) | TArray t -> if equal_typ t TVoid then fail_type "array element type cannot be void"; validate_type structs false t | TInt | TBool -> () | TComptime t -> validate_type structs false t | TCombo (t1, t2) -> validate_type structs false t1; validate_type structs false t2 let lookup_var env x = match StringMap.find_opt x env with | Some (t, _) -> t | None -> fail_type ("unknown variable: " ^ x) let is_var_mutable env x = match StringMap.find_opt x env with | Some (_, m) -> m | None -> false let lookup_struct_field ctx sname fname = match StringMap.find_opt sname ctx.structs with | None -> fail_type ("unknown struct: " ^ sname) | Some fields -> (match StringMap.find_opt fname fields with | None -> fail_type ("unknown field " ^ fname ^ " on struct " ^ sname) | Some t -> t) let rec infer_expr ctx env = function | IntLit _ -> TInt | BoolLit _ -> TBool | Var x -> lookup_var env x | Unop (Neg, e) -> expect_type (infer_expr ctx env e) TInt; TInt | Unop (Not, e) -> expect_type (infer_expr ctx env e) TBool; TBool | Binop (op, a, b) -> let ta = infer_expr ctx env a in let tb = infer_expr ctx env b in (match op with | Add | Sub | Mul | Div | Mod -> expect_type ta TInt; expect_type tb TInt; TInt | And | Or -> expect_type ta TBool; expect_type tb TBool; TBool | Lt | Le | Gt | Ge -> expect_type ta TInt; expect_type tb TInt; TBool | Eq | Ne -> if not (equal_typ ta tb) then fail_type "equality operands must have same type"; TBool) | Assign (lhs, rhs) -> (match lhs with | Var x -> if not (is_var_mutable env x) then fail_type ("cannot assign to immutable variable: " ^ x); let tl = lookup_var env x in let tr = infer_expr ctx env rhs in expect_type tr tl; tl | ArrayGet _ | StructGet _ -> let tl = infer_expr ctx env lhs in let tr = infer_expr ctx env rhs in expect_type tr tl; tl | _ -> fail_type "left side of assignment is not assignable") | ArrayGet (arr, idx) -> expect_type (infer_expr ctx env idx) TInt; (match infer_expr ctx env arr with | TArray t -> t | t -> fail_type ("indexing requires array, got " ^ string_of_typ t)) | StructGet (obj, fld) -> (match infer_expr ctx env obj with | TStruct sname -> lookup_struct_field ctx sname fld | t -> fail_type ("field access requires struct, got " ^ string_of_typ t)) | Call (callee, args) -> let fname = match callee with | Var n -> n | _ -> fail_type "only direct function calls are supported" in let sig_ = match StringMap.find_opt fname ctx.funcs with | Some s -> s | None -> fail_type ("unknown function: " ^ fname) in if List.length args <> List.length sig_.fparams then fail_type (Printf.sprintf "function %s expects %d arguments, got %d" fname (List.length sig_.fparams) (List.length args)); List.iter2 (fun arg pty -> expect_type (infer_expr ctx env arg) pty) args sig_.fparams; sig_.fret let collect_ctx (prog : program) : tc_ctx = let rec collect tops structs funcs globals = match tops with | [] -> { structs; funcs; globals } | TopStruct s :: tl -> if StringMap.mem s.sname structs then fail_type ("duplicate struct: " ^ s.sname); let fields = List.fold_left (fun acc (t, n) -> if StringMap.mem n acc then fail_type ("duplicate field " ^ n ^ " in struct " ^ s.sname); StringMap.add n t acc) StringMap.empty s.fields in collect tl (StringMap.add s.sname fields structs) funcs globals | TopFunc f :: tl -> if StringMap.mem f.name funcs then fail_type ("duplicate function: " ^ f.name); let sig_ = { fparams = List.map snd f.params; fret = f.ret } in collect tl structs (StringMap.add f.name sig_ funcs) globals | TopGlobalVar (is_mut, n, t_annot, init_e) :: tl -> if StringMap.mem n globals then fail_type ("duplicate global variable: " ^ n); let partial_ctx = { structs; funcs; globals } in let t = match t_annot with | Some t -> t | None -> (match init_e with | Some e -> infer_expr partial_ctx globals e | None -> fail_type ("global variable '" ^ n ^ "' requires a type annotation or initializer")) in collect tl structs funcs (StringMap.add n (t, is_mut) globals) in collect prog StringMap.empty StringMap.empty StringMap.empty let rec check_stmt ctx ret env = function | VarDecl (is_mut, n, t_annot, init) -> let t = match (t_annot, init) with | (Some t, Some e) -> validate_type ctx.structs false t; expect_type (infer_expr ctx env e) t; t | (Some t, None) -> validate_type ctx.structs false t; t | (None, Some e) -> let t = infer_expr ctx env e in validate_type ctx.structs false t; t | (None, None) -> fail_type ("cannot determine type of '" ^ n ^ "': no type annotation and no initializer") in StringMap.add n (t, is_mut) env | Expr e -> ignore (infer_expr ctx env e); env | If (cond, tbranch, ebranch) -> expect_type (infer_expr ctx env cond) TBool; ignore (check_block ctx ret env tbranch); ignore (check_block ctx ret env ebranch); env | ForEach (it_t, it_name, iterable, body) -> validate_type ctx.structs false it_t; (match infer_expr ctx env iterable with | TArray elem_t -> expect_type elem_t it_t | t -> fail_type ("foreach expects array iterable, got " ^ string_of_typ t)); let env' = StringMap.add it_name (it_t, false) env in ignore (check_block ctx ret env' body); env | Return None -> expect_type TVoid ret; env | Return (Some e) -> expect_type (infer_expr ctx env e) ret; env | Block stmts -> ignore (check_block ctx ret env stmts); env and check_block ctx ret env stmts = List.fold_left (check_stmt ctx ret) env stmts let rec has_return_stmt = function | Return _ -> true | If (_, t, e) -> List.exists has_return_stmt t || List.exists has_return_stmt e | ForEach (_, _, _, body) | Block body -> List.exists has_return_stmt body | VarDecl _ | Expr _ -> false let check_program (prog : program) = let ctx = collect_ctx prog in StringMap.iter (fun _ (t, _) -> validate_type ctx.structs false t) ctx.globals; StringMap.iter (fun _ sig_ -> List.iter (validate_type ctx.structs false) sig_.fparams; validate_type ctx.structs true sig_.fret) ctx.funcs; List.iter (function | TopStruct s -> List.iter (fun (t, _) -> validate_type ctx.structs false t) s.fields | TopGlobalVar (_, n, t_annot, init) -> let t = match t_annot with | Some t -> t | None -> fst (StringMap.find n ctx.globals) in validate_type ctx.structs false t; (match init with None -> () | Some e -> expect_type (infer_expr ctx ctx.globals e) t) | TopFunc f -> let env_with_globals = ctx.globals in let env = List.fold_left (fun acc (n, t) -> validate_type ctx.structs false t; if StringMap.mem n acc then fail_type ("duplicate parameter name: " ^ n); StringMap.add n (t, false) acc) env_with_globals f.params in ignore (check_block ctx f.ret env f.body); if (not (equal_typ f.ret TVoid)) && not (List.exists has_return_stmt f.body) then fail_type ("non-void function " ^ f.name ^ " must return a value")) prog (* Annotation pass: fills in inferred types so generators always get fully-typed ASTs. *) let rec annotate_stmts ctx env stmts = List.fold_left_map (annotate_stmt ctx) env stmts and annotate_stmt ctx env stmt = match stmt with | VarDecl (is_mut, n, t_annot, init) -> let t = match (t_annot, init) with | (Some t, _) -> t | (None, Some e) -> infer_expr ctx env e | (None, None) -> TVoid (* unreachable after successful check_program *) in let env' = StringMap.add n (t, is_mut) env in (env', VarDecl (is_mut, n, Some t, init)) | If (cond, tbranch, ebranch) -> let (_, tbranch') = annotate_stmts ctx env tbranch in let (_, ebranch') = annotate_stmts ctx env ebranch in (env, If (cond, tbranch', ebranch')) | ForEach (it_t, it_name, iterable, body) -> let env' = StringMap.add it_name (it_t, false) env in let (_, body') = annotate_stmts ctx env' body in (env, ForEach (it_t, it_name, iterable, body')) | Block stmts -> let (_, stmts') = annotate_stmts ctx env stmts in (env, Block stmts') | other -> (env, other) let annotate_top ctx = function | TopFunc f -> let env = List.fold_left (fun acc (n, t) -> StringMap.add n (t, false) acc) ctx.globals f.params in let (_, body') = annotate_stmts ctx env f.body in TopFunc { f with body = body' } | TopGlobalVar (is_mut, n, t_annot, init) -> let t = match t_annot with | Some t -> t | None -> (match init with | Some e -> infer_expr ctx ctx.globals e | None -> TVoid) in TopGlobalVar (is_mut, n, Some t, init) | other -> other let annotate_program prog = let ctx = collect_ctx prog in List.map (annotate_top ctx) prog let type_check (prog : program) : (program, string) result = try check_program prog; Ok (annotate_program prog) with Type_error msg -> Error msg