module StringMap = Map.Make (String) exception Parse_error of string exception Type_error of string type typ = | TInt | TBool | TVoid | TStruct of string | TArray of typ type binop = | Add | Sub | Mul | Div | Mod | And | Or | Eq | Ne | Lt | Le | Gt | Ge type unop = Neg | Not type expr = | IntLit of int | BoolLit of bool | Var of string | Binop of binop * expr * expr | Unop of unop * expr | Assign of expr * expr | Call of expr * expr list | ArrayGet of expr * expr | StructGet of expr * string type stmt = | VarDecl of typ * string * expr option | Expr of expr | If of expr * stmt list * stmt list | ForEach of typ * string * expr * stmt list | Return of expr option | Block of stmt list type func = { name : string; params : (typ * string) list; ret : typ; body : stmt list; } type struct_def = { sname : string; fields : (typ * string) list; } type top = | TopStruct of struct_def | TopFunc of func | TopGlobalVar of typ * string * expr option type program = top list let string_of_typ = let rec go = function | TInt -> "int" | TBool -> "bool" | TVoid -> "void" | TStruct n -> "struct " ^ n | TArray t -> go t ^ "[]" in go let rec equal_typ a b = match (a, b) with | TInt, TInt | TBool, TBool | TVoid, TVoid -> true | TStruct x, TStruct y -> String.equal x y | TArray x, TArray y -> equal_typ x y | _ -> false type token = | TIntKw | TBoolKw | TVoidKw | TStructKw | TIf | TElse | TFor | TEach | TForEach | TIn | TReturn | TTrue | TFalse | TIdent of string | TIntLit of int | TLParen | TRParen | TLBrace | TRBrace | TLBracket | TRBracket | TSemicolon | TComma | TDot | TAssign | TPlus | TMinus | TStar | TSlash | TPercent | TAndAnd | TOrOr | TBang | TEqEq | TNe | TLt | TLe | TGt | TGe | TEOF let is_space = function ' ' | '\t' | '\r' | '\n' -> true | _ -> false let is_digit c = c >= '0' && c <= '9' let is_ident_start c = (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c = '_' let is_ident_char c = is_ident_start c || is_digit c let keyword_or_ident s = match s with | "int" -> TIntKw | "bool" -> TBoolKw | "void" -> TVoidKw | "struct" -> TStructKw | "if" -> TIf | "else" -> TElse | "for" -> TFor | "each" -> TEach | "foreach" -> TForEach | "in" -> TIn | "return" -> TReturn | "true" -> TTrue | "false" -> TFalse | _ -> TIdent s let lex (src : string) : token list = let n = String.length src in let rec skip_line_comment i = if i >= n then i else if src.[i] = '\n' then i + 1 else skip_line_comment (i + 1) in let rec skip_block_comment i = if i + 1 >= n then raise (Parse_error "unterminated block comment") else if src.[i] = '*' && src.[i + 1] = '/' then i + 2 else skip_block_comment (i + 1) in let rec read_number i j = if j < n && is_digit src.[j] then read_number i (j + 1) else let s = String.sub src i (j - i) in (TIntLit (int_of_string s), j) in let rec read_ident i j = if j < n && is_ident_char src.[j] then read_ident i (j + 1) else let s = String.sub src i (j - i) in (keyword_or_ident s, j) in let rec loop i acc = if i >= n then List.rev (TEOF :: acc) else if is_space src.[i] then loop (i + 1) acc else match src.[i] with | '/' when i + 1 < n && src.[i + 1] = '/' -> loop (skip_line_comment (i + 2)) acc | '/' when i + 1 < n && src.[i + 1] = '*' -> loop (skip_block_comment (i + 2)) acc | '(' -> loop (i + 1) (TLParen :: acc) | ')' -> loop (i + 1) (TRParen :: acc) | '{' -> loop (i + 1) (TLBrace :: acc) | '}' -> loop (i + 1) (TRBrace :: acc) | '[' -> loop (i + 1) (TLBracket :: acc) | ']' -> loop (i + 1) (TRBracket :: acc) | ';' -> loop (i + 1) (TSemicolon :: acc) | ',' -> loop (i + 1) (TComma :: acc) | '.' -> loop (i + 1) (TDot :: acc) | '+' -> loop (i + 1) (TPlus :: acc) | '-' -> loop (i + 1) (TMinus :: acc) | '*' -> loop (i + 1) (TStar :: acc) | '%' -> loop (i + 1) (TPercent :: acc) | '/' -> loop (i + 1) (TSlash :: acc) | '!' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TNe :: acc) | '!' -> loop (i + 1) (TBang :: acc) | '=' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TEqEq :: acc) | '=' -> loop (i + 1) (TAssign :: acc) | '&' when i + 1 < n && src.[i + 1] = '&' -> loop (i + 2) (TAndAnd :: acc) | '|' when i + 1 < n && src.[i + 1] = '|' -> loop (i + 2) (TOrOr :: acc) | '<' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TLe :: acc) | '<' -> loop (i + 1) (TLt :: acc) | '>' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TGe :: acc) | '>' -> loop (i + 1) (TGt :: acc) | c when is_digit c -> let tok, j = read_number i (i + 1) in loop j (tok :: acc) | c when is_ident_start c -> let tok, j = read_ident i (i + 1) in loop j (tok :: acc) | c -> let msg = Printf.sprintf "unexpected character: %c" c in raise (Parse_error msg) in loop 0 [] type parser_state = { toks : token array; mutable i : int; } let mk_state toks = { toks = Array.of_list toks; i = 0 } let peek st = if st.i < Array.length st.toks then st.toks.(st.i) else TEOF let consume st = let t = peek st in st.i <- st.i + 1; t let expect st tok = match (tok, consume st) with | TLParen, TLParen | TRParen, TRParen | TLBrace, TLBrace | TRBrace, TRBrace | TLBracket, TLBracket | TRBracket, TRBracket | TSemicolon, TSemicolon | TComma, TComma | TDot, TDot | TAssign, TAssign | TPlus, TPlus | TMinus, TMinus | TStar, TStar | TSlash, TSlash | TPercent, TPercent | TAndAnd, TAndAnd | TOrOr, TOrOr | TBang, TBang | TEqEq, TEqEq | TNe, TNe | TLt, TLt | TLe, TLe | TGt, TGt | TGe, TGe | TIf, TIf | TElse, TElse | TForEach, TForEach | TFor, TFor | TEach, TEach | TIn, TIn | TReturn, TReturn | TIntKw, TIntKw | TBoolKw, TBoolKw | TVoidKw, TVoidKw | TStructKw, TStructKw | TEOF, TEOF -> () | _ -> raise (Parse_error "unexpected token") let expect_ident st = match consume st with | TIdent s -> s | _ -> raise (Parse_error "expected identifier") let starts_type = function TIntKw | TBoolKw | TVoidKw | TStructKw -> true | _ -> false let rec parse_type st = let base = match consume st with | TIntKw -> TInt | TBoolKw -> TBool | TVoidKw -> TVoid | TStructKw -> TStruct (expect_ident st) | _ -> raise (Parse_error "expected type") in let rec arrays t = match peek st with | TLBracket -> expect st TLBracket; expect st TRBracket; arrays (TArray t) | _ -> t in arrays base let rec parse_program st = let rec loop acc = match peek st with | TEOF -> List.rev acc | _ -> loop (parse_top st :: acc) in loop [] and parse_top st = match peek st with | TStructKw -> expect st TStructKw; let sname = expect_ident st in (match peek st with | TLBrace -> expect st TLBrace; let rec fields acc = match peek st with | TRBrace -> List.rev acc | _ -> let t = parse_type st in let n = expect_ident st in expect st TSemicolon; fields ((t, n) :: acc) in let fs = fields [] in expect st TRBrace; expect st TSemicolon; TopStruct { sname; fields = fs } | _ -> let ty = TStruct sname in parse_top_after_type st ty) | _ -> let ty = parse_type st in parse_top_after_type st ty and parse_top_after_type st ty = let name = expect_ident st in match peek st with | TLParen -> expect st TLParen; let params = parse_params st in expect st TRParen; let body = parse_stmt_as_block st in TopFunc { name; params; ret = ty; body } | _ -> let init = match peek st with | TAssign -> expect st TAssign; Some (parse_expr st) | _ -> None in expect st TSemicolon; TopGlobalVar (ty, name, init) and parse_params st = match peek st with | TRParen -> [] | _ -> let rec loop acc = let t = parse_type st in let n = expect_ident st in match peek st with | TComma -> expect st TComma; loop ((t, n) :: acc) | _ -> List.rev ((t, n) :: acc) in loop [] and parse_stmt_as_block st = match peek st with | TLBrace -> expect st TLBrace; let rec loop acc = match peek st with | TRBrace -> expect st TRBrace; List.rev acc | _ -> loop (parse_stmt st :: acc) in loop [] | _ -> [ parse_stmt st ] and parse_stmt st = match peek st with | TLBrace -> Block (parse_stmt_as_block st) | TIf -> expect st TIf; expect st TLParen; let cond = parse_expr st in expect st TRParen; let then_body = parse_stmt_as_block st in let else_body = match peek st with | TElse -> expect st TElse; parse_stmt_as_block st | _ -> [] in If (cond, then_body, else_body) | TForEach | TFor -> (match peek st with | TForEach -> expect st TForEach | TFor -> expect st TFor; expect st TEach | _ -> ()); expect st TLParen; let it_t = parse_type st in let it_name = expect_ident st in expect st TIn; let iterable = parse_expr st in expect st TRParen; let body = parse_stmt_as_block st in ForEach (it_t, it_name, iterable, body) | TReturn -> expect st TReturn; let v = match peek st with | TSemicolon -> None | _ -> Some (parse_expr st) in expect st TSemicolon; Return v | t when starts_type t -> let ty = parse_type st in let n = expect_ident st in let init = match peek st with | TAssign -> expect st TAssign; Some (parse_expr st) | _ -> None in expect st TSemicolon; VarDecl (ty, n, init) | _ -> let e = parse_expr st in expect st TSemicolon; Expr e and parse_expr st = parse_assignment st and parse_assignment st = let lhs = parse_or st in match peek st with | TAssign -> expect st TAssign; let rhs = parse_assignment st in Assign (lhs, rhs) | _ -> lhs and parse_or st = let rec loop left = match peek st with | TOrOr -> expect st TOrOr; loop (Binop (Or, left, parse_and st)) | _ -> left in loop (parse_and st) and parse_and st = let rec loop left = match peek st with | TAndAnd -> expect st TAndAnd; loop (Binop (And, left, parse_equality st)) | _ -> left in loop (parse_equality st) and parse_equality st = let rec loop left = match peek st with | TEqEq -> expect st TEqEq; loop (Binop (Eq, left, parse_rel st)) | TNe -> expect st TNe; loop (Binop (Ne, left, parse_rel st)) | _ -> left in loop (parse_rel st) and parse_rel st = let rec loop left = match peek st with | TLt -> expect st TLt; loop (Binop (Lt, left, parse_add st)) | TLe -> expect st TLe; loop (Binop (Le, left, parse_add st)) | TGt -> expect st TGt; loop (Binop (Gt, left, parse_add st)) | TGe -> expect st TGe; loop (Binop (Ge, left, parse_add st)) | _ -> left in loop (parse_add st) and parse_add st = let rec loop left = match peek st with | TPlus -> expect st TPlus; loop (Binop (Add, left, parse_mul st)) | TMinus -> expect st TMinus; loop (Binop (Sub, left, parse_mul st)) | _ -> left in loop (parse_mul st) and parse_mul st = let rec loop left = match peek st with | TStar -> expect st TStar; loop (Binop (Mul, left, parse_unary st)) | TSlash -> expect st TSlash; loop (Binop (Div, left, parse_unary st)) | TPercent -> expect st TPercent; loop (Binop (Mod, left, parse_unary st)) | _ -> left in loop (parse_unary st) and parse_unary st = match peek st with | TMinus -> expect st TMinus; Unop (Neg, parse_unary st) | TBang -> expect st TBang; Unop (Not, parse_unary st) | _ -> parse_postfix st and parse_postfix st = let rec loop e = match peek st with | TLParen -> expect st TLParen; let args = parse_args st in expect st TRParen; loop (Call (e, args)) | TLBracket -> expect st TLBracket; let idx = parse_expr st in expect st TRBracket; loop (ArrayGet (e, idx)) | TDot -> expect st TDot; let fld = expect_ident st in loop (StructGet (e, fld)) | _ -> e in loop (parse_primary st) and parse_args st = match peek st with | TRParen -> [] | _ -> let rec loop acc = let e = parse_expr st in match peek st with | TComma -> expect st TComma; loop (e :: acc) | _ -> List.rev (e :: acc) in loop [] and parse_primary st = match consume st with | TIntLit n -> IntLit n | TTrue -> BoolLit true | TFalse -> BoolLit false | TIdent s -> Var s | TLParen -> let e = parse_expr st in expect st TRParen; e | _ -> raise (Parse_error "expected expression") let parse_string src = try let st = mk_state (lex src) in Ok (parse_program st) with Parse_error msg -> Error msg type func_sig = { fparams : typ list; fret : typ; } type tc_ctx = { structs : (typ StringMap.t) StringMap.t; funcs : func_sig StringMap.t; globals : typ StringMap.t; } let fail_type msg = raise (Type_error msg) let expect_type got want = if not (equal_typ got want) then fail_type (Printf.sprintf "type mismatch: got %s, expected %s" (string_of_typ got) (string_of_typ want)) let rec validate_type (structs : (typ StringMap.t) StringMap.t) (allow_void : bool) = function | TVoid when allow_void -> () | TVoid -> fail_type "void is not a valid variable type" | TStruct n -> if not (StringMap.mem n structs) then fail_type ("unknown struct type: " ^ n) | TArray t -> if equal_typ t TVoid then fail_type "array element type cannot be void"; validate_type structs false t | TInt | TBool -> () let collect_ctx (prog : program) : tc_ctx = let rec collect tops structs funcs globals = match tops with | [] -> { structs; funcs; globals } | TopStruct s :: tl -> if StringMap.mem s.sname structs then fail_type ("duplicate struct: " ^ s.sname); let fields = List.fold_left (fun acc (t, n) -> if StringMap.mem n acc then fail_type ("duplicate field " ^ n ^ " in struct " ^ s.sname); StringMap.add n t acc) StringMap.empty s.fields in collect tl (StringMap.add s.sname fields structs) funcs globals | TopFunc f :: tl -> if StringMap.mem f.name funcs then fail_type ("duplicate function: " ^ f.name); let sig_ = { fparams = List.map fst f.params; fret = f.ret } in collect tl structs (StringMap.add f.name sig_ funcs) globals | TopGlobalVar (t, n, _) :: tl -> if StringMap.mem n globals then fail_type ("duplicate global variable: " ^ n); collect tl structs funcs (StringMap.add n t globals) in collect prog StringMap.empty StringMap.empty StringMap.empty let lookup_var env x = match StringMap.find_opt x env with | Some t -> t | None -> fail_type ("unknown variable: " ^ x) let lookup_struct_field ctx sname fname = match StringMap.find_opt sname ctx.structs with | None -> fail_type ("unknown struct: " ^ sname) | Some fields -> (match StringMap.find_opt fname fields with | None -> fail_type ("unknown field " ^ fname ^ " on struct " ^ sname) | Some t -> t) let rec infer_expr ctx env = function | IntLit _ -> TInt | BoolLit _ -> TBool | Var x -> lookup_var env x | Unop (Neg, e) -> expect_type (infer_expr ctx env e) TInt; TInt | Unop (Not, e) -> expect_type (infer_expr ctx env e) TBool; TBool | Binop (op, a, b) -> let ta = infer_expr ctx env a in let tb = infer_expr ctx env b in (match op with | Add | Sub | Mul | Div | Mod -> expect_type ta TInt; expect_type tb TInt; TInt | And | Or -> expect_type ta TBool; expect_type tb TBool; TBool | Lt | Le | Gt | Ge -> expect_type ta TInt; expect_type tb TInt; TBool | Eq | Ne -> if not (equal_typ ta tb) then fail_type "equality operands must have same type"; TBool) | Assign (lhs, rhs) -> (match lhs with Var _ | ArrayGet _ | StructGet _ -> () | _ -> fail_type "left side of assignment is not assignable"); let tl = infer_expr ctx env lhs in let tr = infer_expr ctx env rhs in expect_type tr tl; tl | ArrayGet (arr, idx) -> expect_type (infer_expr ctx env idx) TInt; (match infer_expr ctx env arr with | TArray t -> t | t -> fail_type ("indexing requires array, got " ^ string_of_typ t)) | StructGet (obj, fld) -> (match infer_expr ctx env obj with | TStruct sname -> lookup_struct_field ctx sname fld | t -> fail_type ("field access requires struct, got " ^ string_of_typ t)) | Call (callee, args) -> let fname = match callee with | Var n -> n | _ -> fail_type "only direct function calls are supported" in let sig_ = match StringMap.find_opt fname ctx.funcs with | Some s -> s | None -> fail_type ("unknown function: " ^ fname) in if List.length args <> List.length sig_.fparams then fail_type (Printf.sprintf "function %s expects %d arguments, got %d" fname (List.length sig_.fparams) (List.length args)); List.iter2 (fun arg pty -> expect_type (infer_expr ctx env arg) pty) args sig_.fparams; sig_.fret let rec check_stmt ctx ret env = function | VarDecl (t, n, init) -> validate_type ctx.structs false t; (match init with None -> () | Some e -> expect_type (infer_expr ctx env e) t); StringMap.add n t env | Expr e -> ignore (infer_expr ctx env e); env | If (cond, tbranch, ebranch) -> expect_type (infer_expr ctx env cond) TBool; ignore (check_block ctx ret env tbranch); ignore (check_block ctx ret env ebranch); env | ForEach (it_t, it_name, iterable, body) -> validate_type ctx.structs false it_t; (match infer_expr ctx env iterable with | TArray elem_t -> expect_type elem_t it_t | t -> fail_type ("foreach expects array iterable, got " ^ string_of_typ t)); let env' = StringMap.add it_name it_t env in ignore (check_block ctx ret env' body); env | Return None -> expect_type TVoid ret; env | Return (Some e) -> expect_type (infer_expr ctx env e) ret; env | Block stmts -> ignore (check_block ctx ret env stmts); env and check_block ctx ret env stmts = List.fold_left (check_stmt ctx ret) env stmts let rec has_return_stmt = function | Return _ -> true | If (_, t, e) -> List.exists has_return_stmt t || List.exists has_return_stmt e | ForEach (_, _, _, body) | Block body -> List.exists has_return_stmt body | VarDecl _ | Expr _ -> false let check_program (prog : program) = let ctx = collect_ctx prog in StringMap.iter (fun _ t -> validate_type ctx.structs false t) ctx.globals; StringMap.iter (fun _ sig_ -> List.iter (validate_type ctx.structs false) sig_.fparams; validate_type ctx.structs true sig_.fret) ctx.funcs; List.iter (function | TopStruct s -> List.iter (fun (t, _) -> validate_type ctx.structs false t) s.fields | TopGlobalVar (t, _, init) -> validate_type ctx.structs false t; let env = ctx.globals in (match init with None -> () | Some e -> expect_type (infer_expr ctx env e) t) | TopFunc f -> let env_with_globals = ctx.globals in let env = List.fold_left (fun acc (t, n) -> validate_type ctx.structs false t; if StringMap.mem n acc then fail_type ("duplicate parameter name: " ^ n); StringMap.add n t acc) env_with_globals f.params in ignore (check_block ctx f.ret env f.body); if (not (equal_typ f.ret TVoid)) && not (List.exists has_return_stmt f.body) then fail_type ("non-void function " ^ f.name ^ " must return a value")) prog let type_check (prog : program) = try check_program prog; Ok () with Type_error msg -> Error msg let parse_and_type_check src = match parse_string src with | Error e -> Error ("Parse error: " ^ e) | Ok prog -> (match type_check prog with | Error e -> Error ("Type error: " ^ e) | Ok () -> Ok prog)