diff --git a/.vscode/settings.json b/.vscode/settings.json index b47cb4e..8d83c8c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,5 +2,23 @@ "chat.tools.terminal.autoApprove": { "dune": true, "ocamlobjinfo": true + }, + "python.autoComplete.extraPaths": [ + "${workspaceFolder}/sources/poky/bitbake/lib", + "${workspaceFolder}/sources/poky/meta/lib" + ], + "python.analysis.extraPaths": [ + "${workspaceFolder}/sources/poky/bitbake/lib", + "${workspaceFolder}/sources/poky/meta/lib" + ], + "files.associations": { + "*.smr": "scheme", + "*.fnl": "scheme", + "*.urn": "scheme", + "*.prf": "ini", + "*.dhall": "haskell", + "bsconfig.json": "jsonc", + "*.conf": "bitbake", + "*.inc": "bitbake" } } \ No newline at end of file diff --git a/bin/main.ml b/bin/main.ml index 0dd4af1..4ebdf11 100644 --- a/bin/main.ml +++ b/bin/main.ml @@ -37,15 +37,17 @@ let () = exit 1 | Ok ast -> ast in - (match Spooky.type_check ast with - | Error msg -> - prerr_endline ("type error: " ^ msg); - exit 1 - | Ok () -> ()); + let typed_ast = + match Spooky.type_check ast with + | Error msg -> + prerr_endline ("type error: " ^ msg); + exit 1 + | Ok prog -> prog + in let generated = match !generator with - | Json -> Spooky.generate_json ast - | C -> Spooky.generate_c ast + | Json -> Spooky.generate_json typed_ast + | C -> Spooky.generate_c typed_ast in match !output_path with | Some path -> diff --git a/examples/modules/math.spooky b/examples/modules/math.spooky index 39a290b..823f133 100644 --- a/examples/modules/math.spooky +++ b/examples/modules/math.spooky @@ -1,3 +1,3 @@ -int add(int a, int b) { +fn add(a: int, b: int) -> int { return a + b; } diff --git a/examples/sample.spooky b/examples/sample.spooky index b1e7ede..0246851 100644 --- a/examples/sample.spooky +++ b/examples/sample.spooky @@ -3,18 +3,18 @@ struct Point { int y; }; -int sum_array(int[] arr) { - int total = 0; +fn sum_array(arr: int[]) -> int { + let mut total = 0; foreach (int n in arr) { total = total + n; } return total; } -int main() { - Point p; - int[] nums; - int x = 1 + 2 * 3; +fn main() -> int { + let p: Point; + let nums: int[]; + let mut x = 1 + 2 * 3; p.x = x; if (x > 0) { x = x + p.x; diff --git a/examples/with_imports.spooky b/examples/with_imports.spooky index d38db5a..236d002 100644 --- a/examples/with_imports.spooky +++ b/examples/with_imports.spooky @@ -1,8 +1,8 @@ import "modules/types.spooky"; import "modules/math.spooky"; -int main() { - Point p; +fn main() -> int { + let p: Point; p.x = 1; p.y = 2; return add(p.x, p.y); diff --git a/lib/ast.ml b/lib/ast.ml index 0e64882..68d1d3a 100644 --- a/lib/ast.ml +++ b/lib/ast.ml @@ -4,6 +4,8 @@ type typ = | TVoid | TStruct of string | TArray of typ + | TComptime of typ (* #type — comptime-only *) + | TCombo of typ * typ (* type#type — linked runtime+comptime pair *) type binop = | Add @@ -34,7 +36,8 @@ type expr = | StructGet of expr * string type stmt = - | VarDecl of typ * string * expr option + | VarDecl of bool * string * typ option * expr option + (* is_mut, name, type_annotation, init *) | Expr of expr | If of expr * stmt list * stmt list | ForEach of typ * string * expr * stmt list @@ -43,7 +46,7 @@ type stmt = type func = { name : string; - params : (typ * string) list; + params : (string * typ) list; (* name, type *) ret : typ; body : stmt list; } @@ -56,7 +59,8 @@ type struct_def = { type top = | TopStruct of struct_def | TopFunc of func - | TopGlobalVar of typ * string * expr option + | TopGlobalVar of bool * string * typ option * expr option + (* is_mut, name, type_annotation, init *) type program = top list @@ -67,6 +71,8 @@ let string_of_typ = | TVoid -> "void" | TStruct n -> "struct " ^ n | TArray t -> go t ^ "[]" + | TComptime t -> "#" ^ go t + | TCombo (t1, t2) -> go t1 ^ "#" ^ go t2 in go @@ -106,9 +112,11 @@ let indent n = String.make (2 * n) ' ' let rec string_of_stmt ?(level = 0) st = let i = indent level in match st with - | VarDecl (t, n, None) -> Printf.sprintf "%sVarDecl(%s %s)" i (string_of_typ t) n - | VarDecl (t, n, Some e) -> - Printf.sprintf "%sVarDecl(%s %s = %s)" i (string_of_typ t) n (string_of_expr e) + | VarDecl (mut, n, t_opt, init_opt) -> + let mut_s = if mut then "mut " else "" in + let t_s = match t_opt with Some t -> ": " ^ string_of_typ t | None -> "" in + let init_s = match init_opt with Some e -> " = " ^ string_of_expr e | None -> "" in + Printf.sprintf "%sVarDecl(let %s%s%s%s)" i mut_s n t_s init_s | Expr e -> Printf.sprintf "%sExpr(%s)" i (string_of_expr e) | Return None -> Printf.sprintf "%sReturn" i | Return (Some e) -> Printf.sprintf "%sReturn(%s)" i (string_of_expr e) @@ -134,14 +142,15 @@ let string_of_top = function in if String.equal fields "" then Printf.sprintf "Struct %s" s.sname else Printf.sprintf "Struct %s\n%s" s.sname fields - | TopGlobalVar (t, n, init) -> - (match init with - | None -> Printf.sprintf "GlobalVar(%s %s)" (string_of_typ t) n - | Some e -> Printf.sprintf "GlobalVar(%s %s = %s)" (string_of_typ t) n (string_of_expr e)) + | TopGlobalVar (mut, n, t_opt, init_opt) -> + let mut_s = if mut then "mut " else "" in + let t_s = match t_opt with Some t -> ": " ^ string_of_typ t | None -> "" in + let init_s = match init_opt with Some e -> " = " ^ string_of_expr e | None -> "" in + Printf.sprintf "GlobalVar(let %s%s%s%s)" mut_s n t_s init_s | TopFunc f -> let params = f.params - |> List.map (fun (t, n) -> Printf.sprintf "%s %s" (string_of_typ t) n) + |> List.map (fun (n, t) -> Printf.sprintf "%s: %s" n (string_of_typ t)) |> String.concat ", " in let body = String.concat "\n" (List.map (string_of_stmt ~level:1) f.body) in @@ -155,4 +164,6 @@ let rec equal_typ a b = | TInt, TInt | TBool, TBool | TVoid, TVoid -> true | TStruct x, TStruct y -> String.equal x y | TArray x, TArray y -> equal_typ x y + | TComptime x, TComptime y -> equal_typ x y + | TCombo (a1, a2), TCombo (b1, b2) -> equal_typ a1 b1 && equal_typ a2 b2 | _ -> false diff --git a/lib/ast.mli b/lib/ast.mli index fe717da..fbcc376 100644 --- a/lib/ast.mli +++ b/lib/ast.mli @@ -4,6 +4,8 @@ type typ = | TVoid | TStruct of string | TArray of typ + | TComptime of typ + | TCombo of typ * typ type binop = | Add @@ -34,7 +36,7 @@ type expr = | StructGet of expr * string type stmt = - | VarDecl of typ * string * expr option + | VarDecl of bool * string * typ option * expr option | Expr of expr | If of expr * stmt list * stmt list | ForEach of typ * string * expr * stmt list @@ -43,7 +45,7 @@ type stmt = type func = { name : string; - params : (typ * string) list; + params : (string * typ) list; ret : typ; body : stmt list; } @@ -56,7 +58,7 @@ type struct_def = { type top = | TopStruct of struct_def | TopFunc of func - | TopGlobalVar of typ * string * expr option + | TopGlobalVar of bool * string * typ option * expr option type program = top list diff --git a/lib/generator_c.ml b/lib/generator_c.ml index 3547f86..4414780 100644 --- a/lib/generator_c.ml +++ b/lib/generator_c.ml @@ -6,6 +6,8 @@ let rec c_type = function | TVoid -> "void" | TStruct n -> "struct " ^ n | TArray t -> c_type t ^ "*" + | TComptime t -> c_type t (* comptime types lower to their runtime equivalent *) + | TCombo (t, _) -> c_type t (* combo types use the runtime (left) type in C *) let rec expr_to_c = function | IntLit n -> string_of_int n @@ -24,8 +26,9 @@ let rec expr_to_c = function let indent n = String.make (2 * n) ' ' let rec stmt_to_c ?(level = 1) = function - | VarDecl (t, n, None) -> Printf.sprintf "%s%s %s;" (indent level) (c_type t) n - | VarDecl (t, n, Some e) -> Printf.sprintf "%s%s %s = %s;" (indent level) (c_type t) n (expr_to_c e) + | VarDecl (_, n, Some t, None) -> Printf.sprintf "%s%s %s;" (indent level) (c_type t) n + | VarDecl (_, n, Some t, Some e) -> Printf.sprintf "%s%s %s = %s;" (indent level) (c_type t) n (expr_to_c e) + | VarDecl (_, n, None, _) -> Printf.sprintf "%s/* unresolved type for %s */" (indent level) n | Expr e -> Printf.sprintf "%s%s;" (indent level) (expr_to_c e) | Return None -> Printf.sprintf "%sreturn;" (indent level) | Return (Some e) -> Printf.sprintf "%sreturn %s;" (indent level) (expr_to_c e) @@ -64,7 +67,7 @@ let struct_to_c s = let func_to_c f = let params = f.params - |> List.map (fun (t, n) -> Printf.sprintf "%s %s" (c_type t) n) + |> List.map (fun (n, t) -> Printf.sprintf "%s %s" (c_type t) n) |> String.concat ", " in let body = f.body |> List.map (stmt_to_c ~level:1) |> String.concat "\n" in @@ -73,10 +76,12 @@ let func_to_c f = let top_to_c = function | TopStruct s -> struct_to_c s | TopFunc f -> func_to_c f - | TopGlobalVar (t, n, init) -> + | TopGlobalVar (_, n, Some t, init) -> (match init with | None -> Printf.sprintf "%s %s;" (c_type t) n | Some e -> Printf.sprintf "%s %s = %s;" (c_type t) n (expr_to_c e)) + | TopGlobalVar (_, n, None, _) -> + Printf.sprintf "/* unresolved type for global %s */" n let generate (prog : program) = let header = "#include \n\n" in diff --git a/lib/generator_json.ml b/lib/generator_json.ml index d527b20..d18f1a7 100644 --- a/lib/generator_json.ml +++ b/lib/generator_json.ml @@ -24,6 +24,8 @@ let rec typ_to_json = function | TVoid -> obj [ kv "kind" (q "void") ] | TStruct name -> obj [ kv "kind" (q "struct"); kv "name" (q name) ] | TArray t -> obj [ kv "kind" (q "array"); kv "elem" (typ_to_json t) ] + | TComptime t -> obj [ kv "kind" (q "comptime"); kv "inner" (typ_to_json t) ] + | TCombo (t1, t2) -> obj [ kv "kind" (q "combo"); kv "runtime" (typ_to_json t1); kv "comptime" (typ_to_json t2) ] let rec expr_to_json = function | IntLit n -> obj [ kv "node" (q "IntLit"); kv "value" (string_of_int n) ] @@ -45,9 +47,11 @@ let rec expr_to_json = function obj [ kv "node" (q "StructGet"); kv "target" (expr_to_json target); kv "field" (q field) ] let rec stmt_to_json = function - | VarDecl (t, n, init) -> + | VarDecl (is_mut, n, t_opt, init) -> obj - [ kv "node" (q "VarDecl"); kv "type" (typ_to_json t); kv "name" (q n); + [ kv "node" (q "VarDecl"); kv "mut" (if is_mut then "true" else "false"); + kv "name" (q n); + kv "type" (match t_opt with None -> "null" | Some t -> typ_to_json t); kv "init" (match init with None -> "null" | Some e -> expr_to_json e) ] | Expr e -> obj [ kv "node" (q "Expr"); kv "expr" (expr_to_json e) ] | If (cond, tbranch, ebranch) -> @@ -77,12 +81,14 @@ let top_to_json = function kv "params" (arr (List.map - (fun (t, n) -> obj [ kv "type" (typ_to_json t); kv "name" (q n) ]) + (fun (n, t) -> obj [ kv "name" (q n); kv "type" (typ_to_json t) ]) f.params)); kv "body" (arr (List.map stmt_to_json f.body)) ] - | TopGlobalVar (t, n, init) -> + | TopGlobalVar (is_mut, n, t_opt, init) -> obj - [ kv "node" (q "GlobalVar"); kv "type" (typ_to_json t); kv "name" (q n); + [ kv "node" (q "GlobalVar"); kv "mut" (if is_mut then "true" else "false"); + kv "name" (q n); + kv "type" (match t_opt with None -> "null" | Some t -> typ_to_json t); kv "init" (match init with None -> "null" | Some e -> expr_to_json e) ] let generate (prog : program) = diff --git a/lib/lexer.ml b/lib/lexer.ml index ab79660..386a8ef 100644 --- a/lib/lexer.ml +++ b/lib/lexer.ml @@ -12,6 +12,9 @@ type token = | TReturn | TTrue | TFalse + | TFn + | TLet + | TMut | TIdent of string | TIntLit of int | TLParen @@ -39,6 +42,9 @@ type token = | TGt | TGe | TEOF + | TArrow + | TColon + | THash exception Lex_error of string @@ -65,6 +71,9 @@ let keyword_or_ident s = | "return" -> TReturn | "true" -> TTrue | "false" -> TFalse + | "fn" -> TFn + | "let" -> TLet + | "mut" -> TMut | _ -> TIdent s let lex (src : string) : token list = @@ -108,6 +117,7 @@ let lex (src : string) : token list = | ',' -> loop (i + 1) (TComma :: acc) | '.' -> loop (i + 1) (TDot :: acc) | '+' -> loop (i + 1) (TPlus :: acc) + | '-' when i + 1 < n && src.[i + 1] = '>' -> loop (i + 2) (TArrow :: acc) | '-' -> loop (i + 1) (TMinus :: acc) | '*' -> loop (i + 1) (TStar :: acc) | '%' -> loop (i + 1) (TPercent :: acc) @@ -122,6 +132,8 @@ let lex (src : string) : token list = | '<' -> loop (i + 1) (TLt :: acc) | '>' when i + 1 < n && src.[i + 1] = '=' -> loop (i + 2) (TGe :: acc) | '>' -> loop (i + 1) (TGt :: acc) + | ':' -> loop (i + 1) (TColon :: acc) + | '#' -> loop (i + 1) (THash :: acc) | c when is_digit c -> let tok, j = read_number i (i + 1) in loop j (tok :: acc) diff --git a/lib/lexer.mli b/lib/lexer.mli index d5e3c5b..3e88e82 100644 --- a/lib/lexer.mli +++ b/lib/lexer.mli @@ -12,6 +12,9 @@ type token = | TReturn | TTrue | TFalse + | TFn + | TLet + | TMut | TIdent of string | TIntLit of int | TLParen @@ -39,6 +42,9 @@ type token = | TGt | TGe | TEOF + | TArrow + | TColon + | THash exception Lex_error of string diff --git a/lib/parser.ml b/lib/parser.ml index 0d6d75e..bb17718 100644 --- a/lib/parser.ml +++ b/lib/parser.ml @@ -56,6 +56,12 @@ let expect st tok = | TBoolKw, TBoolKw | TVoidKw, TVoidKw | TStructKw, TStructKw + | TFn, TFn + | TLet, TLet + | TMut, TMut + | TArrow, TArrow + | TColon, TColon + | THash, THash | TEOF, TEOF -> () | _ -> raise (Parse_error "unexpected token") @@ -64,25 +70,9 @@ let expect_ident st = | TIdent s -> s | _ -> raise (Parse_error "expected identifier") -let starts_builtin_type = function TIntKw | TBoolKw | TVoidKw | TStructKw -> true | _ -> false - -let rec skip_array_suffixes st j = - match peek_n st j with - | TLBracket -> - (match peek_n st (j + 1) with - | TRBracket -> skip_array_suffixes st (j + 2) - | _ -> j) - | _ -> j - -let looks_like_type_start st = - match peek st with - | t when starts_builtin_type t -> true - | TIdent _ -> - let j = skip_array_suffixes st 1 in - (match peek_n st j with TIdent _ -> true | _ -> false) - | _ -> false - -let rec parse_type st = +(* Parse a base type (no # prefix or # suffix). + Handles: int, bool, void, struct Name, TypeName, and [] suffixes. *) +let rec parse_base_type st = let base = match consume st with | TIntKw -> TInt @@ -102,6 +92,23 @@ let rec parse_type st = in arrays base +(* Parse a full type, including comptime forms: + #type -> TComptime type + type -> type + type#type -> TCombo (runtime_type, comptime_type) *) +and parse_type st = + match peek st with + | THash -> + expect st THash; + TComptime (parse_base_type st) + | _ -> + let t = parse_base_type st in + (match peek st with + | THash -> + expect st THash; + TCombo (t, parse_base_type st) + | _ -> t) + let rec parse_program st = let rec loop acc = match peek st with @@ -115,62 +122,66 @@ and parse_top st = | TStructKw -> expect st TStructKw; let sname = expect_ident st in - (match peek st with - | TLBrace -> - expect st TLBrace; - let rec fields acc = - match peek st with - | TRBrace -> List.rev acc - | _ -> - let t = parse_type st in - let n = expect_ident st in - expect st TSemicolon; - fields ((t, n) :: acc) - in - let fs = fields [] in - expect st TRBrace; - expect st TSemicolon; - TopStruct { sname; fields = fs } - | _ -> - let ty = TStruct sname in - parse_top_after_type st ty) - | _ -> - if not (looks_like_type_start st) then raise (Parse_error "expected top-level declaration"); - let ty = parse_type st in - parse_top_after_type st ty - -and parse_top_after_type st ty = - let name = expect_ident st in - match peek st with - | TLParen -> + expect st TLBrace; + let rec fields acc = + match peek st with + | TRBrace -> List.rev acc + | _ -> + let t = parse_base_type st in + let n = expect_ident st in + expect st TSemicolon; + fields ((t, n) :: acc) + in + let fs = fields [] in + expect st TRBrace; + expect st TSemicolon; + TopStruct { sname; fields = fs } + | TFn -> + expect st TFn; + let name = expect_ident st in expect st TLParen; let params = parse_params st in expect st TRParen; + expect st TArrow; + let ret = parse_type st in let body = parse_stmt_as_block st in - TopFunc { name; params; ret = ty; body } - | _ -> - let init = + TopFunc { name; params; ret; body } + | TLet -> + expect st TLet; + let is_mut = match peek st with - | TAssign -> - expect st TAssign; - Some (parse_expr st) + | TMut -> + expect st TMut; + true + | _ -> false + in + let name = expect_ident st in + let t_annot = + match peek st with + | TColon -> + expect st TColon; + Some (parse_type st) | _ -> None in + expect st TAssign; + let e = parse_expr st in expect st TSemicolon; - TopGlobalVar (ty, name, init) + TopGlobalVar (is_mut, name, t_annot, Some e) + | _ -> raise (Parse_error "expected top-level declaration (struct, fn, or let)") and parse_params st = match peek st with | TRParen -> [] | _ -> let rec loop acc = - let t = parse_type st in let n = expect_ident st in + expect st TColon; + let t = parse_type st in match peek st with | TComma -> expect st TComma; - loop ((t, n) :: acc) - | _ -> List.rev ((t, n) :: acc) + loop ((n, t) :: acc) + | _ -> List.rev ((n, t) :: acc) in loop [] @@ -229,9 +240,23 @@ and parse_stmt st = in expect st TSemicolon; Return v - | _ when looks_like_type_start st -> - let ty = parse_type st in + | TLet -> + expect st TLet; + let is_mut = + match peek st with + | TMut -> + expect st TMut; + true + | _ -> false + in let n = expect_ident st in + let t_annot = + match peek st with + | TColon -> + expect st TColon; + Some (parse_type st) + | _ -> None + in let init = match peek st with | TAssign -> @@ -240,7 +265,7 @@ and parse_stmt st = | _ -> None in expect st TSemicolon; - VarDecl (ty, n, init) + VarDecl (is_mut, n, t_annot, init) | _ -> let e = parse_expr st in expect st TSemicolon; diff --git a/lib/spooky.ml b/lib/spooky.ml index 7b62157..c67b8ba 100644 --- a/lib/spooky.ml +++ b/lib/spooky.ml @@ -23,7 +23,7 @@ let parse_and_type_check src = | Error e -> Error e | Ok prog -> (match type_check prog with - | Ok () -> Ok prog + | Ok annotated -> Ok annotated | Error e -> Error ("type error: " ^ e)) let parse_and_type_check_file path = @@ -31,7 +31,7 @@ let parse_and_type_check_file path = | Error e -> Error e | Ok prog -> (match type_check prog with - | Ok () -> Ok prog + | Ok annotated -> Ok annotated | Error e -> Error ("type error: " ^ e)) let generate_json = Generator_json.generate diff --git a/lib/spooky.mli b/lib/spooky.mli index 2b3092f..119c4bb 100644 --- a/lib/spooky.mli +++ b/lib/spooky.mli @@ -12,7 +12,7 @@ val parse_string : string -> (program, string) result val load_source_with_imports : string -> (string, string) result val parse_file : string -> (program, string) result val string_of_program : program -> string -val type_check : program -> (unit, string) result +val type_check : program -> (program, string) result val parse_and_type_check : string -> (program, string) result val parse_and_type_check_file : string -> (program, string) result val generate_json : program -> string diff --git a/lib/typechecker.ml b/lib/typechecker.ml index 8af198b..f4bd86f 100644 --- a/lib/typechecker.ml +++ b/lib/typechecker.ml @@ -12,7 +12,7 @@ type func_sig = { type tc_ctx = { structs : (typ StringMap.t) StringMap.t; funcs : func_sig StringMap.t; - globals : typ StringMap.t; + globals : (typ * bool) StringMap.t; } let fail_type msg = raise (Type_error msg) @@ -31,36 +31,21 @@ let rec validate_type (structs : (typ StringMap.t) StringMap.t) (allow_void : bo if equal_typ t TVoid then fail_type "array element type cannot be void"; validate_type structs false t | TInt | TBool -> () - -let collect_ctx (prog : program) : tc_ctx = - let rec collect tops structs funcs globals = - match tops with - | [] -> { structs; funcs; globals } - | TopStruct s :: tl -> - if StringMap.mem s.sname structs then fail_type ("duplicate struct: " ^ s.sname); - let fields = - List.fold_left - (fun acc (t, n) -> - if StringMap.mem n acc then fail_type ("duplicate field " ^ n ^ " in struct " ^ s.sname); - StringMap.add n t acc) - StringMap.empty s.fields - in - collect tl (StringMap.add s.sname fields structs) funcs globals - | TopFunc f :: tl -> - if StringMap.mem f.name funcs then fail_type ("duplicate function: " ^ f.name); - let sig_ = { fparams = List.map fst f.params; fret = f.ret } in - collect tl structs (StringMap.add f.name sig_ funcs) globals - | TopGlobalVar (t, n, _) :: tl -> - if StringMap.mem n globals then fail_type ("duplicate global variable: " ^ n); - collect tl structs funcs (StringMap.add n t globals) - in - collect prog StringMap.empty StringMap.empty StringMap.empty + | TComptime t -> validate_type structs false t + | TCombo (t1, t2) -> + validate_type structs false t1; + validate_type structs false t2 let lookup_var env x = match StringMap.find_opt x env with - | Some t -> t + | Some (t, _) -> t | None -> fail_type ("unknown variable: " ^ x) +let is_var_mutable env x = + match StringMap.find_opt x env with + | Some (_, m) -> m + | None -> false + let lookup_struct_field ctx sname fname = match StringMap.find_opt sname ctx.structs with | None -> fail_type ("unknown struct: " ^ sname) @@ -99,11 +84,20 @@ let rec infer_expr ctx env = function if not (equal_typ ta tb) then fail_type "equality operands must have same type"; TBool) | Assign (lhs, rhs) -> - (match lhs with Var _ | ArrayGet _ | StructGet _ -> () | _ -> fail_type "left side of assignment is not assignable"); - let tl = infer_expr ctx env lhs in - let tr = infer_expr ctx env rhs in - expect_type tr tl; - tl + (match lhs with + | Var x -> + if not (is_var_mutable env x) then + fail_type ("cannot assign to immutable variable: " ^ x); + let tl = lookup_var env x in + let tr = infer_expr ctx env rhs in + expect_type tr tl; + tl + | ArrayGet _ | StructGet _ -> + let tl = infer_expr ctx env lhs in + let tr = infer_expr ctx env rhs in + expect_type tr tl; + tl + | _ -> fail_type "left side of assignment is not assignable") | ArrayGet (arr, idx) -> expect_type (infer_expr ctx env idx) TInt; (match infer_expr ctx env arr with @@ -131,11 +125,57 @@ let rec infer_expr ctx env = function List.iter2 (fun arg pty -> expect_type (infer_expr ctx env arg) pty) args sig_.fparams; sig_.fret +let collect_ctx (prog : program) : tc_ctx = + let rec collect tops structs funcs globals = + match tops with + | [] -> { structs; funcs; globals } + | TopStruct s :: tl -> + if StringMap.mem s.sname structs then fail_type ("duplicate struct: " ^ s.sname); + let fields = + List.fold_left + (fun acc (t, n) -> + if StringMap.mem n acc then fail_type ("duplicate field " ^ n ^ " in struct " ^ s.sname); + StringMap.add n t acc) + StringMap.empty s.fields + in + collect tl (StringMap.add s.sname fields structs) funcs globals + | TopFunc f :: tl -> + if StringMap.mem f.name funcs then fail_type ("duplicate function: " ^ f.name); + let sig_ = { fparams = List.map snd f.params; fret = f.ret } in + collect tl structs (StringMap.add f.name sig_ funcs) globals + | TopGlobalVar (is_mut, n, t_annot, init_e) :: tl -> + if StringMap.mem n globals then fail_type ("duplicate global variable: " ^ n); + let partial_ctx = { structs; funcs; globals } in + let t = match t_annot with + | Some t -> t + | None -> + (match init_e with + | Some e -> infer_expr partial_ctx globals e + | None -> + fail_type ("global variable '" ^ n ^ "' requires a type annotation or initializer")) + in + collect tl structs funcs (StringMap.add n (t, is_mut) globals) + in + collect prog StringMap.empty StringMap.empty StringMap.empty + let rec check_stmt ctx ret env = function - | VarDecl (t, n, init) -> - validate_type ctx.structs false t; - (match init with None -> () | Some e -> expect_type (infer_expr ctx env e) t); - StringMap.add n t env + | VarDecl (is_mut, n, t_annot, init) -> + let t = match (t_annot, init) with + | (Some t, Some e) -> + validate_type ctx.structs false t; + expect_type (infer_expr ctx env e) t; + t + | (Some t, None) -> + validate_type ctx.structs false t; + t + | (None, Some e) -> + let t = infer_expr ctx env e in + validate_type ctx.structs false t; + t + | (None, None) -> + fail_type ("cannot determine type of '" ^ n ^ "': no type annotation and no initializer") + in + StringMap.add n (t, is_mut) env | Expr e -> ignore (infer_expr ctx env e); env @@ -149,7 +189,7 @@ let rec check_stmt ctx ret env = function (match infer_expr ctx env iterable with | TArray elem_t -> expect_type elem_t it_t | t -> fail_type ("foreach expects array iterable, got " ^ string_of_typ t)); - let env' = StringMap.add it_name it_t env in + let env' = StringMap.add it_name (it_t, false) env in ignore (check_block ctx ret env' body); env | Return None -> @@ -172,7 +212,7 @@ let rec has_return_stmt = function let check_program (prog : program) = let ctx = collect_ctx prog in - StringMap.iter (fun _ t -> validate_type ctx.structs false t) ctx.globals; + StringMap.iter (fun _ (t, _) -> validate_type ctx.structs false t) ctx.globals; StringMap.iter (fun _ sig_ -> List.iter (validate_type ctx.structs false) sig_.fparams; @@ -182,18 +222,21 @@ let check_program (prog : program) = (function | TopStruct s -> List.iter (fun (t, _) -> validate_type ctx.structs false t) s.fields - | TopGlobalVar (t, _, init) -> + | TopGlobalVar (_, n, t_annot, init) -> + let t = match t_annot with + | Some t -> t + | None -> fst (StringMap.find n ctx.globals) + in validate_type ctx.structs false t; - let env = ctx.globals in - (match init with None -> () | Some e -> expect_type (infer_expr ctx env e) t) + (match init with None -> () | Some e -> expect_type (infer_expr ctx ctx.globals e) t) | TopFunc f -> let env_with_globals = ctx.globals in let env = List.fold_left - (fun acc (t, n) -> + (fun acc (n, t) -> validate_type ctx.structs false t; if StringMap.mem n acc then fail_type ("duplicate parameter name: " ^ n); - StringMap.add n t acc) + StringMap.add n (t, false) acc) env_with_globals f.params in ignore (check_block ctx f.ret env f.body); @@ -201,8 +244,59 @@ let check_program (prog : program) = fail_type ("non-void function " ^ f.name ^ " must return a value")) prog -let type_check (prog : program) = +(* Annotation pass: fills in inferred types so generators always get fully-typed ASTs. *) +let rec annotate_stmts ctx env stmts = + List.fold_left_map (annotate_stmt ctx) env stmts + +and annotate_stmt ctx env stmt = + match stmt with + | VarDecl (is_mut, n, t_annot, init) -> + let t = match (t_annot, init) with + | (Some t, _) -> t + | (None, Some e) -> infer_expr ctx env e + | (None, None) -> TVoid (* unreachable after successful check_program *) + in + let env' = StringMap.add n (t, is_mut) env in + (env', VarDecl (is_mut, n, Some t, init)) + | If (cond, tbranch, ebranch) -> + let (_, tbranch') = annotate_stmts ctx env tbranch in + let (_, ebranch') = annotate_stmts ctx env ebranch in + (env, If (cond, tbranch', ebranch')) + | ForEach (it_t, it_name, iterable, body) -> + let env' = StringMap.add it_name (it_t, false) env in + let (_, body') = annotate_stmts ctx env' body in + (env, ForEach (it_t, it_name, iterable, body')) + | Block stmts -> + let (_, stmts') = annotate_stmts ctx env stmts in + (env, Block stmts') + | other -> (env, other) + +let annotate_top ctx = function + | TopFunc f -> + let env = + List.fold_left + (fun acc (n, t) -> StringMap.add n (t, false) acc) + ctx.globals f.params + in + let (_, body') = annotate_stmts ctx env f.body in + TopFunc { f with body = body' } + | TopGlobalVar (is_mut, n, t_annot, init) -> + let t = match t_annot with + | Some t -> t + | None -> + (match init with + | Some e -> infer_expr ctx ctx.globals e + | None -> TVoid) + in + TopGlobalVar (is_mut, n, Some t, init) + | other -> other + +let annotate_program prog = + let ctx = collect_ctx prog in + List.map (annotate_top ctx) prog + +let type_check (prog : program) : (program, string) result = try check_program prog; - Ok () + Ok (annotate_program prog) with Type_error msg -> Error msg diff --git a/lib/typechecker.mli b/lib/typechecker.mli index d88d573..cbf8d8b 100644 --- a/lib/typechecker.mli +++ b/lib/typechecker.mli @@ -1 +1 @@ -val type_check : Ast.program -> (unit, string) result +val type_check : Ast.program -> (Ast.program, string) result diff --git a/test/test_spooky.ml b/test/test_spooky.ml index f2bc81e..79f9079 100644 --- a/test/test_spooky.ml +++ b/test/test_spooky.ml @@ -4,18 +4,18 @@ struct Item { int value; }; -int fold(int[] xs) { - int total = 0; +fn fold(xs: int[]) -> int { + let mut total: int = 0; foreach (int x in xs) { total = total + x; } return total; } -int main() { - int[] xs; - Item it; - int y = 2 + 3 * 4; +fn main() -> int { + let xs: int[]; + let it: Item; + let mut y = 2 + 3 * 4; it.value = y; if (y >= 0) { y = fold(xs); @@ -28,14 +28,52 @@ int main() { let invalid_program = {| -int main() { - bool flag = true; - int x = 1; +fn main() -> int { + let flag = true; + let mut x = 1; x = flag; return x; } |} +let combo_valid_program = + {| +fn get_value() -> int#int { + let x = 1; + return x; +} + +fn do_work(v: int#int) -> int { + return 0; +} + +fn main() -> int { + let x = get_value(); + return do_work(x); +} +|} + +let combo_invalid_program = + {| +fn get_value() -> int { + return 1; +} + +fn get_comptime() -> #int { + return 2; +} + +fn do_work(v: int#int) -> int { + return 0; +} + +fn main() -> int { + let x = get_value(); + let y = get_comptime(); + return do_work(x); +} +|} + let test_valid_program () = match Spooky.parse_and_type_check valid_program with | Ok _ -> () @@ -46,6 +84,16 @@ let test_invalid_program () = | Ok _ -> failwith "expected type error, but got success" | Error _ -> () +let test_combo_valid_program () = + match Spooky.parse_and_type_check combo_valid_program with + | Ok _ -> () + | Error msg -> failwith ("expected valid combo program, got: " ^ msg) + +let test_combo_invalid_program () = + match Spooky.parse_and_type_check combo_invalid_program with + | Ok _ -> failwith "expected combo type error, but got success" + | Error _ -> () + let write_file path content = let oc = open_out_bin path in Fun.protect ~finally:(fun () -> close_out oc) (fun () -> output_string oc content) @@ -65,9 +113,9 @@ let test_imports () = ~finally:cleanup (fun () -> write_file (Filename.concat modules_dir "math.spooky") - "int add(int a, int b) { return a + b; }\n"; + "fn add(a: int, b: int) -> int { return a + b; }\n"; write_file (Filename.concat base "main.spooky") - "import \"modules/math.spooky\";\nint main() { return add(1, 2); }\n"; + "import \"modules/math.spooky\";\nfn main() -> int { return add(1, 2); }\n"; match Spooky.parse_and_type_check_file (Filename.concat base "main.spooky") with | Ok _ -> () | Error msg -> failwith ("expected valid import program, got: " ^ msg)) @@ -75,5 +123,7 @@ let test_imports () = let () = test_valid_program (); test_invalid_program (); + test_combo_valid_program (); + test_combo_invalid_program (); test_imports (); print_endline "All parser/type-check tests passed."