spooky/_build/default/lib/spooky.ml

807 lines
21 KiB
OCaml
Raw Normal View History

2026-04-29 15:25:15 +00:00
module StringMap = Map.Make (String)
exception Parse_error of string
exception Type_error of string
type typ =
| TInt
| TBool
| TVoid
| TStruct of string
| TArray of typ
type binop =
| Add
| Sub
| Mul
| Div
| Mod
| And
| Or
| Eq
| Ne
| Lt
| Le
| Gt
| Ge
type unop = Neg | Not
type expr =
| IntLit of int
| BoolLit of bool
| Var of string
| Binop of binop * expr * expr
| Unop of unop * expr
| Assign of expr * expr
| Call of expr * expr list
| ArrayGet of expr * expr
| StructGet of expr * string
type stmt =
| VarDecl of typ * string * expr option
| Expr of expr
| If of expr * stmt list * stmt list
| ForEach of typ * string * expr * stmt list
| Return of expr option
| Block of stmt list
type func = {
name : string;
params : (typ * string) list;
ret : typ;
body : stmt list;
}
type struct_def = {
sname : string;
fields : (typ * string) list;
}
type top =
| TopStruct of struct_def
| TopFunc of func
| TopGlobalVar of typ * string * expr option
type program = top list
let string_of_typ =
let rec go = function
| TInt -> "int"
| TBool -> "bool"
| TVoid -> "void"
| TStruct n -> "struct " ^ n
| TArray t -> go t ^ "[]"
in
go
let rec equal_typ a b =
match (a, b) with
| TInt, TInt | TBool, TBool | TVoid, TVoid -> true
| TStruct x, TStruct y -> String.equal x y
| TArray x, TArray y -> equal_typ x y
| _ -> false
type token =
| TIntKw
| TBoolKw
| TVoidKw
| TStructKw
| TIf
| TElse
| TFor
| TEach
| TForEach
| TIn
| TReturn
| TTrue
| TFalse
| TIdent of string
| TIntLit of int
| TLParen
| TRParen
| TLBrace
| TRBrace
| TLBracket
| TRBracket
| TSemicolon
| TComma
| TDot
| TAssign
| TPlus
| TMinus
| TStar
| TSlash
| TPercent
| TAndAnd
| TOrOr
| TBang
| TEqEq
| TNe
| TLt
| TLe
| TGt
| TGe
| TEOF
let is_space = function ' ' | '\t' | '\r' | '\n' -> true | _ -> false
let is_digit c = c >= '0' && c <= '9'
let is_ident_start c =
(c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c = '_'
let is_ident_char c = is_ident_start c || is_digit c
let keyword_or_ident s =
match s with
| "int" -> TIntKw
| "bool" -> TBoolKw
| "void" -> TVoidKw
| "struct" -> TStructKw
| "if" -> TIf
| "else" -> TElse
| "for" -> TFor
| "each" -> TEach
| "foreach" -> TForEach
| "in" -> TIn
| "return" -> TReturn
| "true" -> TTrue
| "false" -> TFalse
| _ -> TIdent s
let lex (src : string) : token list =
let n = String.length src in
let rec skip_line_comment i =
if i >= n then i
else if src.[i] = '\n' then i + 1
else skip_line_comment (i + 1)
in
let rec skip_block_comment i =
if i + 1 >= n then raise (Parse_error "unterminated block comment")
else if src.[i] = '*' && src.[i + 1] = '/' then i + 2
else skip_block_comment (i + 1)
in
let rec read_number i j =
if j < n && is_digit src.[j] then read_number i (j + 1)
else
let s = String.sub src i (j - i) in
(TIntLit (int_of_string s), j)
in
let rec read_ident i j =
if j < n && is_ident_char src.[j] then read_ident i (j + 1)
else
let s = String.sub src i (j - i) in
(keyword_or_ident s, j)
in
let rec loop i acc =
if i >= n then List.rev (TEOF :: acc)
else if is_space src.[i] then loop (i + 1) acc
else
match src.[i] with
| '/' when i + 1 < n && src.[i + 1] = '/' -> loop (skip_line_comment (i + 2)) acc
| '/' when i + 1 < n && src.[i + 1] = '*' -> loop (skip_block_comment (i + 2)) acc
| '(' -> loop (i + 1) (TLParen :: acc)
| ')' -> loop (i + 1) (TRParen :: acc)
| '{' -> loop (i + 1) (TLBrace :: acc)
| '}' -> loop (i + 1) (TRBrace :: acc)
| '[' -> loop (i + 1) (TLBracket :: acc)
| ']' -> loop (i + 1) (TRBracket :: acc)
| ';' -> loop (i + 1) (TSemicolon :: acc)
| ',' -> loop (i + 1) (TComma :: acc)
| '.' -> loop (i + 1) (TDot :: acc)
| '+' -> loop (i + 1) (TPlus :: acc)
| '-' -> loop (i + 1) (TMinus :: acc)
| '*' -> loop (i + 1) (TStar :: acc)
| '%' -> loop (i + 1) (TPercent :: acc)
| '/' -> loop (i + 1) (TSlash :: acc)
| '!' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TNe :: acc)
| '!' -> loop (i + 1) (TBang :: acc)
| '=' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TEqEq :: acc)
| '=' -> loop (i + 1) (TAssign :: acc)
| '&' when i + 1 < n && src.[i + 1] = '&' -> loop (i + 2) (TAndAnd :: acc)
| '|' when i + 1 < n && src.[i + 1] = '|' -> loop (i + 2) (TOrOr :: acc)
| '<' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TLe :: acc)
| '<' -> loop (i + 1) (TLt :: acc)
| '>' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TGe :: acc)
| '>' -> loop (i + 1) (TGt :: acc)
| c when is_digit c ->
let tok, j = read_number i (i + 1) in
loop j (tok :: acc)
| c when is_ident_start c ->
let tok, j = read_ident i (i + 1) in
loop j (tok :: acc)
| c ->
let msg = Printf.sprintf "unexpected character: %c" c in
raise (Parse_error msg)
in
loop 0 []
type parser_state = {
toks : token array;
mutable i : int;
}
let mk_state toks = { toks = Array.of_list toks; i = 0 }
let peek st = if st.i < Array.length st.toks then st.toks.(st.i) else TEOF
let consume st =
let t = peek st in
st.i <- st.i + 1;
t
let expect st tok =
match (tok, consume st) with
| TLParen, TLParen
| TRParen, TRParen
| TLBrace, TLBrace
| TRBrace, TRBrace
| TLBracket, TLBracket
| TRBracket, TRBracket
| TSemicolon, TSemicolon
| TComma, TComma
| TDot, TDot
| TAssign, TAssign
| TPlus, TPlus
| TMinus, TMinus
| TStar, TStar
| TSlash, TSlash
| TPercent, TPercent
| TAndAnd, TAndAnd
| TOrOr, TOrOr
| TBang, TBang
| TEqEq, TEqEq
| TNe, TNe
| TLt, TLt
| TLe, TLe
| TGt, TGt
| TGe, TGe
| TIf, TIf
| TElse, TElse
| TForEach, TForEach
| TFor, TFor
| TEach, TEach
| TIn, TIn
| TReturn, TReturn
| TIntKw, TIntKw
| TBoolKw, TBoolKw
| TVoidKw, TVoidKw
| TStructKw, TStructKw
| TEOF, TEOF -> ()
| _ -> raise (Parse_error "unexpected token")
let expect_ident st =
match consume st with
| TIdent s -> s
| _ -> raise (Parse_error "expected identifier")
let starts_type = function TIntKw | TBoolKw | TVoidKw | TStructKw -> true | _ -> false
let rec parse_type st =
let base =
match consume st with
| TIntKw -> TInt
| TBoolKw -> TBool
| TVoidKw -> TVoid
| TStructKw -> TStruct (expect_ident st)
| _ -> raise (Parse_error "expected type")
in
let rec arrays t =
match peek st with
| TLBracket ->
expect st TLBracket;
expect st TRBracket;
arrays (TArray t)
| _ -> t
in
arrays base
let rec parse_program st =
let rec loop acc =
match peek st with
| TEOF -> List.rev acc
| _ -> loop (parse_top st :: acc)
in
loop []
and parse_top st =
match peek st with
| TStructKw ->
expect st TStructKw;
let sname = expect_ident st in
(match peek st with
| TLBrace ->
expect st TLBrace;
let rec fields acc =
match peek st with
| TRBrace -> List.rev acc
| _ ->
let t = parse_type st in
let n = expect_ident st in
expect st TSemicolon;
fields ((t, n) :: acc)
in
let fs = fields [] in
expect st TRBrace;
expect st TSemicolon;
TopStruct { sname; fields = fs }
| _ ->
let ty = TStruct sname in
parse_top_after_type st ty)
| _ ->
let ty = parse_type st in
parse_top_after_type st ty
and parse_top_after_type st ty =
let name = expect_ident st in
match peek st with
| TLParen ->
expect st TLParen;
let params = parse_params st in
expect st TRParen;
let body = parse_stmt_as_block st in
TopFunc { name; params; ret = ty; body }
| _ ->
let init =
match peek st with
| TAssign ->
expect st TAssign;
Some (parse_expr st)
| _ -> None
in
expect st TSemicolon;
TopGlobalVar (ty, name, init)
and parse_params st =
match peek st with
| TRParen -> []
| _ ->
let rec loop acc =
let t = parse_type st in
let n = expect_ident st in
match peek st with
| TComma ->
expect st TComma;
loop ((t, n) :: acc)
| _ -> List.rev ((t, n) :: acc)
in
loop []
and parse_stmt_as_block st =
match peek st with
| TLBrace ->
expect st TLBrace;
let rec loop acc =
match peek st with
| TRBrace ->
expect st TRBrace;
List.rev acc
| _ -> loop (parse_stmt st :: acc)
in
loop []
| _ -> [ parse_stmt st ]
and parse_stmt st =
match peek st with
| TLBrace -> Block (parse_stmt_as_block st)
| TIf ->
expect st TIf;
expect st TLParen;
let cond = parse_expr st in
expect st TRParen;
let then_body = parse_stmt_as_block st in
let else_body =
match peek st with
| TElse ->
expect st TElse;
parse_stmt_as_block st
| _ -> []
in
If (cond, then_body, else_body)
| TForEach | TFor ->
(match peek st with
| TForEach -> expect st TForEach
| TFor ->
expect st TFor;
expect st TEach
| _ -> ());
expect st TLParen;
let it_t = parse_type st in
let it_name = expect_ident st in
expect st TIn;
let iterable = parse_expr st in
expect st TRParen;
let body = parse_stmt_as_block st in
ForEach (it_t, it_name, iterable, body)
| TReturn ->
expect st TReturn;
let v =
match peek st with
| TSemicolon -> None
| _ -> Some (parse_expr st)
in
expect st TSemicolon;
Return v
| t when starts_type t ->
let ty = parse_type st in
let n = expect_ident st in
let init =
match peek st with
| TAssign ->
expect st TAssign;
Some (parse_expr st)
| _ -> None
in
expect st TSemicolon;
VarDecl (ty, n, init)
| _ ->
let e = parse_expr st in
expect st TSemicolon;
Expr e
and parse_expr st = parse_assignment st
and parse_assignment st =
let lhs = parse_or st in
match peek st with
| TAssign ->
expect st TAssign;
let rhs = parse_assignment st in
Assign (lhs, rhs)
| _ -> lhs
and parse_or st =
let rec loop left =
match peek st with
| TOrOr ->
expect st TOrOr;
loop (Binop (Or, left, parse_and st))
| _ -> left
in
loop (parse_and st)
and parse_and st =
let rec loop left =
match peek st with
| TAndAnd ->
expect st TAndAnd;
loop (Binop (And, left, parse_equality st))
| _ -> left
in
loop (parse_equality st)
and parse_equality st =
let rec loop left =
match peek st with
| TEqEq ->
expect st TEqEq;
loop (Binop (Eq, left, parse_rel st))
| TNe ->
expect st TNe;
loop (Binop (Ne, left, parse_rel st))
| _ -> left
in
loop (parse_rel st)
and parse_rel st =
let rec loop left =
match peek st with
| TLt ->
expect st TLt;
loop (Binop (Lt, left, parse_add st))
| TLe ->
expect st TLe;
loop (Binop (Le, left, parse_add st))
| TGt ->
expect st TGt;
loop (Binop (Gt, left, parse_add st))
| TGe ->
expect st TGe;
loop (Binop (Ge, left, parse_add st))
| _ -> left
in
loop (parse_add st)
and parse_add st =
let rec loop left =
match peek st with
| TPlus ->
expect st TPlus;
loop (Binop (Add, left, parse_mul st))
| TMinus ->
expect st TMinus;
loop (Binop (Sub, left, parse_mul st))
| _ -> left
in
loop (parse_mul st)
and parse_mul st =
let rec loop left =
match peek st with
| TStar ->
expect st TStar;
loop (Binop (Mul, left, parse_unary st))
| TSlash ->
expect st TSlash;
loop (Binop (Div, left, parse_unary st))
| TPercent ->
expect st TPercent;
loop (Binop (Mod, left, parse_unary st))
| _ -> left
in
loop (parse_unary st)
and parse_unary st =
match peek st with
| TMinus ->
expect st TMinus;
Unop (Neg, parse_unary st)
| TBang ->
expect st TBang;
Unop (Not, parse_unary st)
| _ -> parse_postfix st
and parse_postfix st =
let rec loop e =
match peek st with
| TLParen ->
expect st TLParen;
let args = parse_args st in
expect st TRParen;
loop (Call (e, args))
| TLBracket ->
expect st TLBracket;
let idx = parse_expr st in
expect st TRBracket;
loop (ArrayGet (e, idx))
| TDot ->
expect st TDot;
let fld = expect_ident st in
loop (StructGet (e, fld))
| _ -> e
in
loop (parse_primary st)
and parse_args st =
match peek st with
| TRParen -> []
| _ ->
let rec loop acc =
let e = parse_expr st in
match peek st with
| TComma ->
expect st TComma;
loop (e :: acc)
| _ -> List.rev (e :: acc)
in
loop []
and parse_primary st =
match consume st with
| TIntLit n -> IntLit n
| TTrue -> BoolLit true
| TFalse -> BoolLit false
| TIdent s -> Var s
| TLParen ->
let e = parse_expr st in
expect st TRParen;
e
| _ -> raise (Parse_error "expected expression")
let parse_string src =
try
let st = mk_state (lex src) in
Ok (parse_program st)
with Parse_error msg -> Error msg
type func_sig = {
fparams : typ list;
fret : typ;
}
type tc_ctx = {
structs : (typ StringMap.t) StringMap.t;
funcs : func_sig StringMap.t;
globals : typ StringMap.t;
}
let fail_type msg = raise (Type_error msg)
let expect_type got want =
if not (equal_typ got want) then
fail_type (Printf.sprintf "type mismatch: got %s, expected %s" (string_of_typ got) (string_of_typ want))
let rec validate_type (structs : (typ StringMap.t) StringMap.t) (allow_void : bool) = function
| TVoid when allow_void -> ()
| TVoid -> fail_type "void is not a valid variable type"
| TStruct n ->
if not (StringMap.mem n structs) then fail_type ("unknown struct type: " ^ n)
| TArray t ->
if equal_typ t TVoid then fail_type "array element type cannot be void";
validate_type structs false t
| TInt | TBool -> ()
let collect_ctx (prog : program) : tc_ctx =
let rec collect tops structs funcs globals =
match tops with
| [] -> { structs; funcs; globals }
| TopStruct s :: tl ->
if StringMap.mem s.sname structs then fail_type ("duplicate struct: " ^ s.sname);
let fields =
List.fold_left
(fun acc (t, n) ->
if StringMap.mem n acc then fail_type ("duplicate field " ^ n ^ " in struct " ^ s.sname);
StringMap.add n t acc)
StringMap.empty s.fields
in
collect tl (StringMap.add s.sname fields structs) funcs globals
| TopFunc f :: tl ->
if StringMap.mem f.name funcs then fail_type ("duplicate function: " ^ f.name);
let sig_ = { fparams = List.map fst f.params; fret = f.ret } in
collect tl structs (StringMap.add f.name sig_ funcs) globals
| TopGlobalVar (t, n, _) :: tl ->
if StringMap.mem n globals then fail_type ("duplicate global variable: " ^ n);
collect tl structs funcs (StringMap.add n t globals)
in
collect prog StringMap.empty StringMap.empty StringMap.empty
let lookup_var env x =
match StringMap.find_opt x env with
| Some t -> t
| None -> fail_type ("unknown variable: " ^ x)
let lookup_struct_field ctx sname fname =
match StringMap.find_opt sname ctx.structs with
| None -> fail_type ("unknown struct: " ^ sname)
| Some fields ->
(match StringMap.find_opt fname fields with
| None -> fail_type ("unknown field " ^ fname ^ " on struct " ^ sname)
| Some t -> t)
let rec infer_expr ctx env = function
| IntLit _ -> TInt
| BoolLit _ -> TBool
| Var x -> lookup_var env x
| Unop (Neg, e) ->
expect_type (infer_expr ctx env e) TInt;
TInt
| Unop (Not, e) ->
expect_type (infer_expr ctx env e) TBool;
TBool
| Binop (op, a, b) ->
let ta = infer_expr ctx env a in
let tb = infer_expr ctx env b in
(match op with
| Add | Sub | Mul | Div | Mod ->
expect_type ta TInt;
expect_type tb TInt;
TInt
| And | Or ->
expect_type ta TBool;
expect_type tb TBool;
TBool
| Lt | Le | Gt | Ge ->
expect_type ta TInt;
expect_type tb TInt;
TBool
| Eq | Ne ->
if not (equal_typ ta tb) then fail_type "equality operands must have same type";
TBool)
| Assign (lhs, rhs) ->
(match lhs with Var _ | ArrayGet _ | StructGet _ -> () | _ -> fail_type "left side of assignment is not assignable");
let tl = infer_expr ctx env lhs in
let tr = infer_expr ctx env rhs in
expect_type tr tl;
tl
| ArrayGet (arr, idx) ->
expect_type (infer_expr ctx env idx) TInt;
(match infer_expr ctx env arr with
| TArray t -> t
| t -> fail_type ("indexing requires array, got " ^ string_of_typ t))
| StructGet (obj, fld) ->
(match infer_expr ctx env obj with
| TStruct sname -> lookup_struct_field ctx sname fld
| t -> fail_type ("field access requires struct, got " ^ string_of_typ t))
| Call (callee, args) ->
let fname =
match callee with
| Var n -> n
| _ -> fail_type "only direct function calls are supported"
in
let sig_ =
match StringMap.find_opt fname ctx.funcs with
| Some s -> s
| None -> fail_type ("unknown function: " ^ fname)
in
if List.length args <> List.length sig_.fparams then
fail_type
(Printf.sprintf "function %s expects %d arguments, got %d" fname (List.length sig_.fparams)
(List.length args));
List.iter2 (fun arg pty -> expect_type (infer_expr ctx env arg) pty) args sig_.fparams;
sig_.fret
let rec check_stmt ctx ret env = function
| VarDecl (t, n, init) ->
validate_type ctx.structs false t;
(match init with None -> () | Some e -> expect_type (infer_expr ctx env e) t);
StringMap.add n t env
| Expr e ->
ignore (infer_expr ctx env e);
env
| If (cond, tbranch, ebranch) ->
expect_type (infer_expr ctx env cond) TBool;
ignore (check_block ctx ret env tbranch);
ignore (check_block ctx ret env ebranch);
env
| ForEach (it_t, it_name, iterable, body) ->
validate_type ctx.structs false it_t;
(match infer_expr ctx env iterable with
| TArray elem_t -> expect_type elem_t it_t
| t -> fail_type ("foreach expects array iterable, got " ^ string_of_typ t));
let env' = StringMap.add it_name it_t env in
ignore (check_block ctx ret env' body);
env
| Return None ->
expect_type TVoid ret;
env
| Return (Some e) ->
expect_type (infer_expr ctx env e) ret;
env
| Block stmts ->
ignore (check_block ctx ret env stmts);
env
and check_block ctx ret env stmts = List.fold_left (check_stmt ctx ret) env stmts
let rec has_return_stmt = function
| Return _ -> true
| If (_, t, e) -> List.exists has_return_stmt t || List.exists has_return_stmt e
| ForEach (_, _, _, body) | Block body -> List.exists has_return_stmt body
| VarDecl _ | Expr _ -> false
let check_program (prog : program) =
let ctx = collect_ctx prog in
StringMap.iter (fun _ t -> validate_type ctx.structs false t) ctx.globals;
StringMap.iter
(fun _ sig_ ->
List.iter (validate_type ctx.structs false) sig_.fparams;
validate_type ctx.structs true sig_.fret)
ctx.funcs;
List.iter
(function
| TopStruct s ->
List.iter (fun (t, _) -> validate_type ctx.structs false t) s.fields
| TopGlobalVar (t, _, init) ->
validate_type ctx.structs false t;
let env = ctx.globals in
(match init with None -> () | Some e -> expect_type (infer_expr ctx env e) t)
| TopFunc f ->
let env_with_globals = ctx.globals in
let env =
List.fold_left
(fun acc (t, n) ->
validate_type ctx.structs false t;
if StringMap.mem n acc then fail_type ("duplicate parameter name: " ^ n);
StringMap.add n t acc)
env_with_globals f.params
in
ignore (check_block ctx f.ret env f.body);
if (not (equal_typ f.ret TVoid)) && not (List.exists has_return_stmt f.body) then
fail_type ("non-void function " ^ f.name ^ " must return a value"))
prog
let type_check (prog : program) =
try
check_program prog;
Ok ()
with Type_error msg -> Error msg
let parse_and_type_check src =
match parse_string src with
| Error e -> Error ("Parse error: " ^ e)
| Ok prog ->
(match type_check prog with
| Error e -> Error ("Type error: " ^ e)
| Ok () -> Ok prog)