209 lines
7.3 KiB
OCaml
209 lines
7.3 KiB
OCaml
open Ast
|
|
|
|
module StringMap = Map.Make (String)
|
|
|
|
exception Type_error of string
|
|
|
|
type func_sig = {
|
|
fparams : typ list;
|
|
fret : typ;
|
|
}
|
|
|
|
type tc_ctx = {
|
|
structs : (typ StringMap.t) StringMap.t;
|
|
funcs : func_sig StringMap.t;
|
|
globals : typ StringMap.t;
|
|
}
|
|
|
|
let fail_type msg = raise (Type_error msg)
|
|
|
|
let expect_type got want =
|
|
if not (equal_typ got want) then
|
|
fail_type
|
|
(Printf.sprintf "type mismatch: got %s, expected %s" (string_of_typ got) (string_of_typ want))
|
|
|
|
let rec validate_type (structs : (typ StringMap.t) StringMap.t) (allow_void : bool) = function
|
|
| TVoid when allow_void -> ()
|
|
| TVoid -> fail_type "void is not a valid variable type"
|
|
| TStruct n ->
|
|
if not (StringMap.mem n structs) then fail_type ("unknown struct type: " ^ n)
|
|
| TArray t ->
|
|
if equal_typ t TVoid then fail_type "array element type cannot be void";
|
|
validate_type structs false t
|
|
| TInt | TBool -> ()
|
|
|
|
let collect_ctx (prog : program) : tc_ctx =
|
|
let rec collect tops structs funcs globals =
|
|
match tops with
|
|
| [] -> { structs; funcs; globals }
|
|
| TopStruct s :: tl ->
|
|
if StringMap.mem s.sname structs then fail_type ("duplicate struct: " ^ s.sname);
|
|
let fields =
|
|
List.fold_left
|
|
(fun acc (t, n) ->
|
|
if StringMap.mem n acc then fail_type ("duplicate field " ^ n ^ " in struct " ^ s.sname);
|
|
StringMap.add n t acc)
|
|
StringMap.empty s.fields
|
|
in
|
|
collect tl (StringMap.add s.sname fields structs) funcs globals
|
|
| TopFunc f :: tl ->
|
|
if StringMap.mem f.name funcs then fail_type ("duplicate function: " ^ f.name);
|
|
let sig_ = { fparams = List.map fst f.params; fret = f.ret } in
|
|
collect tl structs (StringMap.add f.name sig_ funcs) globals
|
|
| TopGlobalVar (t, n, _) :: tl ->
|
|
if StringMap.mem n globals then fail_type ("duplicate global variable: " ^ n);
|
|
collect tl structs funcs (StringMap.add n t globals)
|
|
in
|
|
collect prog StringMap.empty StringMap.empty StringMap.empty
|
|
|
|
let lookup_var env x =
|
|
match StringMap.find_opt x env with
|
|
| Some t -> t
|
|
| None -> fail_type ("unknown variable: " ^ x)
|
|
|
|
let lookup_struct_field ctx sname fname =
|
|
match StringMap.find_opt sname ctx.structs with
|
|
| None -> fail_type ("unknown struct: " ^ sname)
|
|
| Some fields ->
|
|
(match StringMap.find_opt fname fields with
|
|
| None -> fail_type ("unknown field " ^ fname ^ " on struct " ^ sname)
|
|
| Some t -> t)
|
|
|
|
let rec infer_expr ctx env = function
|
|
| IntLit _ -> TInt
|
|
| BoolLit _ -> TBool
|
|
| Var x -> lookup_var env x
|
|
| Unop (Neg, e) ->
|
|
expect_type (infer_expr ctx env e) TInt;
|
|
TInt
|
|
| Unop (Not, e) ->
|
|
expect_type (infer_expr ctx env e) TBool;
|
|
TBool
|
|
| Binop (op, a, b) ->
|
|
let ta = infer_expr ctx env a in
|
|
let tb = infer_expr ctx env b in
|
|
(match op with
|
|
| Add | Sub | Mul | Div | Mod ->
|
|
expect_type ta TInt;
|
|
expect_type tb TInt;
|
|
TInt
|
|
| And | Or ->
|
|
expect_type ta TBool;
|
|
expect_type tb TBool;
|
|
TBool
|
|
| Lt | Le | Gt | Ge ->
|
|
expect_type ta TInt;
|
|
expect_type tb TInt;
|
|
TBool
|
|
| Eq | Ne ->
|
|
if not (equal_typ ta tb) then fail_type "equality operands must have same type";
|
|
TBool)
|
|
| Assign (lhs, rhs) ->
|
|
(match lhs with Var _ | ArrayGet _ | StructGet _ -> () | _ -> fail_type "left side of assignment is not assignable");
|
|
let tl = infer_expr ctx env lhs in
|
|
let tr = infer_expr ctx env rhs in
|
|
expect_type tr tl;
|
|
tl
|
|
| ArrayGet (arr, idx) ->
|
|
expect_type (infer_expr ctx env idx) TInt;
|
|
(match infer_expr ctx env arr with
|
|
| TArray t -> t
|
|
| t -> fail_type ("indexing requires array, got " ^ string_of_typ t))
|
|
| StructGet (obj, fld) ->
|
|
(match infer_expr ctx env obj with
|
|
| TStruct sname -> lookup_struct_field ctx sname fld
|
|
| t -> fail_type ("field access requires struct, got " ^ string_of_typ t))
|
|
| Call (callee, args) ->
|
|
let fname =
|
|
match callee with
|
|
| Var n -> n
|
|
| _ -> fail_type "only direct function calls are supported"
|
|
in
|
|
let sig_ =
|
|
match StringMap.find_opt fname ctx.funcs with
|
|
| Some s -> s
|
|
| None -> fail_type ("unknown function: " ^ fname)
|
|
in
|
|
if List.length args <> List.length sig_.fparams then
|
|
fail_type
|
|
(Printf.sprintf "function %s expects %d arguments, got %d" fname (List.length sig_.fparams)
|
|
(List.length args));
|
|
List.iter2 (fun arg pty -> expect_type (infer_expr ctx env arg) pty) args sig_.fparams;
|
|
sig_.fret
|
|
|
|
let rec check_stmt ctx ret env = function
|
|
| VarDecl (t, n, init) ->
|
|
validate_type ctx.structs false t;
|
|
(match init with None -> () | Some e -> expect_type (infer_expr ctx env e) t);
|
|
StringMap.add n t env
|
|
| Expr e ->
|
|
ignore (infer_expr ctx env e);
|
|
env
|
|
| If (cond, tbranch, ebranch) ->
|
|
expect_type (infer_expr ctx env cond) TBool;
|
|
ignore (check_block ctx ret env tbranch);
|
|
ignore (check_block ctx ret env ebranch);
|
|
env
|
|
| ForEach (it_t, it_name, iterable, body) ->
|
|
validate_type ctx.structs false it_t;
|
|
(match infer_expr ctx env iterable with
|
|
| TArray elem_t -> expect_type elem_t it_t
|
|
| t -> fail_type ("foreach expects array iterable, got " ^ string_of_typ t));
|
|
let env' = StringMap.add it_name it_t env in
|
|
ignore (check_block ctx ret env' body);
|
|
env
|
|
| Return None ->
|
|
expect_type TVoid ret;
|
|
env
|
|
| Return (Some e) ->
|
|
expect_type (infer_expr ctx env e) ret;
|
|
env
|
|
| Block stmts ->
|
|
ignore (check_block ctx ret env stmts);
|
|
env
|
|
|
|
and check_block ctx ret env stmts = List.fold_left (check_stmt ctx ret) env stmts
|
|
|
|
let rec has_return_stmt = function
|
|
| Return _ -> true
|
|
| If (_, t, e) -> List.exists has_return_stmt t || List.exists has_return_stmt e
|
|
| ForEach (_, _, _, body) | Block body -> List.exists has_return_stmt body
|
|
| VarDecl _ | Expr _ -> false
|
|
|
|
let check_program (prog : program) =
|
|
let ctx = collect_ctx prog in
|
|
StringMap.iter (fun _ t -> validate_type ctx.structs false t) ctx.globals;
|
|
StringMap.iter
|
|
(fun _ sig_ ->
|
|
List.iter (validate_type ctx.structs false) sig_.fparams;
|
|
validate_type ctx.structs true sig_.fret)
|
|
ctx.funcs;
|
|
List.iter
|
|
(function
|
|
| TopStruct s ->
|
|
List.iter (fun (t, _) -> validate_type ctx.structs false t) s.fields
|
|
| TopGlobalVar (t, _, init) ->
|
|
validate_type ctx.structs false t;
|
|
let env = ctx.globals in
|
|
(match init with None -> () | Some e -> expect_type (infer_expr ctx env e) t)
|
|
| TopFunc f ->
|
|
let env_with_globals = ctx.globals in
|
|
let env =
|
|
List.fold_left
|
|
(fun acc (t, n) ->
|
|
validate_type ctx.structs false t;
|
|
if StringMap.mem n acc then fail_type ("duplicate parameter name: " ^ n);
|
|
StringMap.add n t acc)
|
|
env_with_globals f.params
|
|
in
|
|
ignore (check_block ctx f.ret env f.body);
|
|
if (not (equal_typ f.ret TVoid)) && not (List.exists has_return_stmt f.body) then
|
|
fail_type ("non-void function " ^ f.name ^ " must return a value"))
|
|
prog
|
|
|
|
let type_check (prog : program) =
|
|
try
|
|
check_program prog;
|
|
Ok ()
|
|
with Type_error msg -> Error msg
|