(*
 * This file is part of Barista.
 * Copyright (C) 2007-2014 Xavier Clerc.
 *
 * Barista is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * Barista is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *)

let register = Camlp4.PreCast.AstFilters.register_str_item_filter

module Make (Syntax : Camlp4.Sig.Camlp4Syntax) = struct
  open Camlp4.Sig
  include Syntax

  (* Shorthand for the 'unknown' location. *)
  let ghost = Loc.ghost


  (* Symbolic constants for the UTF8-related functions. *)
  let utf8_of_string = <:expr@ghost< UTF8.of_string >>

  let utf8_equal = <:expr@ghost< UTF8.equal >>

  let uchar_of_char = <:expr@ghost< UChar.of_char >>

  let uchar_equal = <:expr@ghost< UChar.equal >>


  (* Construction of containers based on hash tables. *)
  let create p f =
    let prefix = String.copy p in
    let prefix x = prefix ^ (string_of_int x) in
    let table = Hashtbl.create 17 in
    let add x =
      let id =
        try
          Hashtbl.find table x
        with Not_found ->
          let len = Hashtbl.length table in
          Hashtbl.add table x len;
          len in
      prefix id in
    let prepend x =
      Hashtbl.fold
        (fun k v acc ->
          let _loc = Loc.ghost in
          let id = prefix v in
          let v = f k in 
          let p = Ast.(PaId (_loc, (IdLid (_loc, id)))) in
          let e = Ast.(StVal (_loc, ReNil, BiEq (_loc, p, v))) in
          Ast.StSem (_loc, e, acc))
        table
        x in
    add, prepend


  (* Containers for UTF8 strings and characters. *)
  let add_string, prepend_strings =
    create
      "__bartista__utf8__string__"
      (fun x ->
        let e = Ast.ExStr (ghost, x) in
        <:expr@ghost< $utf8_of_string$ $e$ >>)

  let add_char, prepend_chars =
    create
      "__bartista__uchar__"
      (fun x ->
        let e = Ast.ExChr (ghost, x) in
        <:expr@ghost< $uchar_of_char$ $e$ >>)


  (* Utility functions for construction of type/pattern tuples. *)
  type ty_sep = TyStar | TyOr

  let ty_tuple sep = function
    | [] -> Ast.TyNil ghost
    | hd :: [] -> hd
    | hd :: tl ->
        List.fold_left
          (fun acc elem ->
            match sep with
            | TyStar -> Ast.TySta (ghost, acc, elem)
            | TyOr -> Ast.TyOr (ghost, acc, elem))
          hd
          tl

  let pa_tuple cstr = function
    | [] -> <:patt@ghost< $uid:cstr$ >>
    | hd :: [] -> <:patt@ghost< $uid:cstr$ $hd$ >>
    | hd :: tl ->
        let res =
          List.fold_left
            (fun acc elem -> Ast.PaCom (ghost, acc, elem))
            hd
            tl in
        let res = Ast.PaTup (ghost, res) in
        <:patt@ghost< $uid:cstr$ $res$ >>


  (* Instrument to add UTF8 constants declaration at file head. *)
  let instrumented = ref []
  let () =
    let instrument =
      object
        inherit Ast.map
        method! str_item si =
          let loc = Ast.loc_of_str_item si in
          let file = Loc.file_name loc in
          if not (List.mem file !instrumented) then begin
            instrumented := file :: !instrumented;
            prepend_strings (prepend_chars si)
          end else
            si
      end in
    register (Obj.magic (instrument#str_item))


  (* Format *)
  type align =
    | Left
    | Right

  type base =
    | Decimal
    | Hexadecimal

  type output =
    | Escaped
    | Normal

  type kind =
    | Int
    | Int32
    | Int64
    | Nativeint

  type padding_character =
    | Pad_with_spaces
    | Pad_with_zeroes

  type format_element =
    | Fragment of string
    | Integer of kind * (align option) * (int option) * (padding_character option) * base
    | Float of (align option) * (int option)
    | String of (align option) * (int option) * output
    | Character of (align option) * (int option) * output
    | Boolean of (align option) * (int option)

  let rev_simplify l =
    let rec rs s acc = function
      | Fragment f :: tl ->
	  if s = "" then
	    rs f acc tl
	  else
	    rs (f ^ s) acc tl
      | hd :: tl ->
	  if s = "" then
	    rs "" (hd :: acc) tl
	  else
	    rs "" (hd :: (Fragment s) :: acc) tl
      | [] ->
	  if s = "" then
	    acc
	  else
	    (Fragment s) :: acc in
    rs "" [] l

  let (+=) r k = r := !r + k

  let parse_format_string s =
    let len = String.length s in
    let buff = Buffer.create len in
    let i = ref 0 in
    let res = ref [] in
    let fail () =
      failwith "invalid format string" in
    let check_none = function
      | Some _ -> fail ()
      | None -> () in
    let check_index idx =
      if idx >= len then fail () in
    let rec parse_directive align size pad_char =
      check_index !i;
      match s.[!i] with
      | '%' -> incr i; Fragment "%"
      | '-' -> check_none align; incr i; parse_directive (Some Left) size pad_char
      | '0'..'9' ->
	  check_none size;
	  let sz = Buffer.create 4 in
	  Buffer.add_char sz s.[!i];
	  incr i;
	  while (!i < len) && (s.[!i] >= '0') && (s.[!i] <= '9') do
	    Buffer.add_char sz s.[!i];
	    incr i
	  done;
	  let align = match align with Some _ -> align | None -> Some Right in
	  let size, char =
            try
              let sz = Buffer.contents sz in
              int_of_string sz,
              if sz.[0] = '0' then Pad_with_zeroes else Pad_with_spaces
            with _ -> fail () in
	  parse_directive align (Some size) (Some char)
      | 'd' -> incr i; Integer (Int, align, size, pad_char, Decimal)
      | 'x' -> incr i; Integer (Int, align, size, pad_char, Hexadecimal)
      | 'l' | 'L' | 'n' ->
	  let kind =
	    match s.[!i] with
	    | 'l' -> Int32
	    | 'L' -> Int64
	    | 'n' -> Nativeint
	    | _ -> fail () in
	  check_index (succ !i);
	  (match s.[succ !i] with
	  | 'd' -> i += 2; Integer (kind, align, size, pad_char, Decimal)
	  | 'x' -> i += 2; Integer (kind, align, size, pad_char, Hexadecimal)
	  | _ -> fail ())
      | 'f' -> incr i; Float (align, size)
      | 's' -> incr i; String (align, size, Normal)
      | 'S' -> incr i; String (align, size, Escaped)
      | 'c' -> incr i; Character (align, size, Normal)
      | 'C' -> incr i; Character (align, size, Escaped)
      | 'b' | 'B' -> incr i; Boolean (align, size)
      | ch ->
          let msg = Printf.sprintf "invalid format character %C" ch in
          failwith msg in
    while !i < len do
      let curr = s.[!i] in
      (match curr with
      | '%' ->
	  res := (Fragment (Buffer.contents buff)) :: !res;
	  Buffer.clear buff;
	  incr i;
	  res := (parse_directive None None None) :: !res
      | _ ->
	  Buffer.add_char buff curr;
	  incr i)
    done;
    let last = Buffer.contents buff in
    if last <> "" then res := (Fragment last) :: !res;
    rev_simplify !res

  let rec list_of_expression = function
    | Ast.ExApp (_, e1, e2) -> e2 :: (list_of_expression e1)
    | e -> [e]

  let buff_id = "__bartista__buffer"

  let xprintf format buff expressions =
    let format_elements = parse_format_string format in
    let expressions = ref expressions in
    let expr_counter = ref 0 in
    let get_expression () =
      match !expressions with
      | hd :: tl ->
	  incr expr_counter;
	  expressions := tl; hd
      | [] ->
	  let msg =
	    Printf.sprintf "cannot get expression for index %d"
	      !expr_counter in
	  failwith msg in
    let _loc = Loc.ghost in
    let true_ = <:expr< true >> in
    let false_ = <:expr< false >> in
    let int x = let x = string_of_int x in <:expr< $int:x$ >> in
    let format_instructions =
      List.fold_left
        (fun acc elem ->
          match elem with
	  | Fragment "\\n" ->
	      <:expr< UTF8Buffer.add_newline2 $acc$ >>
	  | Fragment str ->
              let str = add_string str in
	      <:expr< UTF8Buffer.add_string2 $acc$ $lid:str$ >>
	  | Integer (Int, None, None, None, Decimal) ->
	      <:expr< UTF8Buffer.add_int $acc$ $get_expression ()$ >>
	  | Integer (Int, None, None, None, Hexadecimal) ->
	      <:expr< UTF8Buffer.add_int_hexa $acc$ $get_expression ()$ >>
	  | Integer (Int, align, pad, pad_char, base) ->
              let left = match align with Some Left -> true_ | _ -> false_ in
              let pad = match pad with Some x -> int x | None -> int (-1) in
              let zeroes = match pad_char with None | Some Pad_with_spaces -> false_ | Some Pad_with_zeroes -> true_ in
              let hexa = match base with Hexadecimal -> true_ | Decimal -> false_ in
	      <:expr< UTF8Buffer.print_int $acc$ $left$ $pad$ $zeroes$ $hexa$ $get_expression ()$ >>
	  | Integer (Int32, None, None, None, Decimal) ->
	      <:expr< UTF8Buffer.add_int32 $acc$ $get_expression ()$ >>
	  | Integer (Int32, None, None, None, Hexadecimal) ->
	      <:expr< UTF8Buffer.add_int32_hexa $acc$ $get_expression ()$ >>
	  | Integer (Int32, align, pad, pad_char, base) ->
              let left = match align with Some Left -> true_ | _ -> false_ in
              let pad = match pad with Some x -> int x | None -> int (-1) in
              let zeroes = match pad_char with None | Some Pad_with_spaces -> false_ | Some Pad_with_zeroes -> true_ in
              let hexa = match base with Hexadecimal -> true_ | Decimal -> false_ in
	      <:expr< UTF8Buffer.print_int32 $acc$ $left$ $pad$ $zeroes$ $hexa$ $get_expression ()$ >>
	  | Integer (Int64, None, None, None, Decimal) ->
	      <:expr< UTF8Buffer.add_int64 $acc$ $get_expression ()$ >>
	  | Integer (Int64, None, None, None, Hexadecimal) ->
	      <:expr< UTF8Buffer.add_int64_hexa $acc$ $get_expression ()$ >>
	  | Integer (Int64, align, pad, pad_char, base) ->
              let left = match align with Some Left -> true_ | _ -> false_ in
              let pad = match pad with Some x -> int x | None -> int (-1) in
              let zeroes = match pad_char with None | Some Pad_with_spaces -> false_ | Some Pad_with_zeroes -> true_ in
              let hexa = match base with Hexadecimal -> true_ | Decimal -> false_ in
	      <:expr< UTF8Buffer.print_int64 $acc$ $left$ $pad$ $zeroes$ $hexa$ $get_expression ()$ >>
	  | Integer (Nativeint, None, None, None, Decimal) ->
	      <:expr< UTF8Buffer.add_nativeint $acc$ $get_expression ()$ >>
	  | Integer (Nativeint, None, None, None, Hexadecimal) ->
	      <:expr< UTF8Buffer.add_nativeint_hexa $acc$ $get_expression ()$ >>
	  | Integer (Nativeint, align, pad, pad_char, base) ->
              let left = match align with Some Left -> true_ | _ -> false_ in
              let pad = match pad with Some x -> int x | None -> int (-1) in
              let zeroes = match pad_char with None | Some Pad_with_spaces -> false_ | Some Pad_with_zeroes -> true_ in
              let hexa = match base with Hexadecimal -> true_ | Decimal -> false_ in
	      <:expr< UTF8Buffer.print_nativeint $acc$ $left$ $pad$ $zeroes$ $hexa$ $get_expression ()$ >>
	  | Float (None, None) ->
	      <:expr< UTF8Buffer.add_float $acc$ $get_expression ()$ >>
	  | Float (align, pad) ->
              let left = match align with Some Left -> true_ | _ -> false_ in
              let pad = match pad with Some x -> int x | None -> int (-1) in
	      <:expr< UTF8Buffer.print_float $acc$ $left$ $pad$ $get_expression ()$ >>
	  | String (None, None, Normal) ->
	      <:expr< UTF8Buffer.add_string2 $acc$ $get_expression ()$ >>
	  | String (None, None, Escaped) ->
	      <:expr< UTF8Buffer.add_string2 $acc$ (UTF8.escape $get_expression ()$) >>
	  | String (align, pad, output) ->
              let escape = match output with Escaped -> true_ | Normal -> false_ in
              let left = match align with Some Left -> true_ | _ -> false_ in
              let pad = match pad with Some x -> int x | None -> int (-1) in
	      <:expr< UTF8Buffer.print_string $acc$ $escape$ $left$ $pad$ $get_expression ()$ >>
	  | Character (None, None, Normal) ->
	      <:expr< UTF8Buffer.add_char2 $acc$ $get_expression ()$ >>
	  | Character (align, pad, output) ->
              let escape = match output with Escaped -> true_ | Normal -> false_ in
              let left = match align with Some Left -> true_ | _ -> false_ in
              let pad = match pad with Some x -> int x | None -> int (-1) in
	      <:expr< UTF8Buffer.print_char $acc$ $escape$ $left$ $pad$ $get_expression ()$ >>
	  | Boolean (None, None) ->
              let true_ = add_string "true" in
              let false_ = add_string "false" in
	      <:expr< UTF8Buffer.add_string2 $acc$
                      (if $get_expression ()$ then $lid:true_$ else $lid:false_$) >>
	  | Boolean (align, pad) ->
              let left = match align with Some Left -> true_ | _ -> false_ in
              let pad = match pad with Some x -> int x | None -> int (-1) in
	      <:expr< UTF8Buffer.print_char $acc$ $left$ $pad$ $get_expression ()$ >>)
        (match buff with
        | Some x -> x
        | None -> <:expr< UTF8Buffer.make () >>)
	format_elements in
    if !expressions <> [] then
      failwith "too many parameters";
    let format_instructions =
      match buff with
      | Some _ -> <:expr< ignore $format_instructions$ >>
      | None -> <:expr< UTF8Buffer.contents $format_instructions$ >> in
    format_instructions

  (* Actual grammar extension. *)
  EXTEND Gram
    GLOBAL: sig_item str_item expr match_case0;
    (* "Exception" pattern. *)
    sig_error_kind: [[
      id = UIDENT; l = OPT [ "of"; t = ctyp -> t] ->
        match l with
        | Some (Ast.TyTup (_, l)) ->
            let rec unroll = function
              | Ast.TySta (_, left, right) -> (unroll left) @ (unroll right)
              | x -> [x] in
            (id, unroll l)
        | Some x -> (id, [x])
        | None -> (id, [])
    ]];
    sig_item: [[
      "BARISTA_ERROR"; "="; OPT "|"; l = LIST1 sig_error_kind SEP "|" ->
        let kinds_for_typ =
          List.map
            (fun (cstr, params) ->
              let params = ty_tuple TyStar params in
              <:ctyp< $uid:cstr$ of $params$ >>)
            l in
        let typ = ty_tuple TyOr kinds_for_typ in
        let typ = Ast.TySum (ghost, typ) in
        <:sig_item< type error = $typ$;;
                    exception Exception of error;;
                    val string_of_error : error -> string;; >>
    ]];
    str_error_elem: [[
      "("; id = LIDENT; ":"; typ = ctyp; ")" ->
        (id, typ)
    ]];
    str_error_kind: [[
      id = UIDENT; l = OPT [ "of"; l = LIST1 str_error_elem SEP "*" -> l]; "->"; e = expr ->
        match l with
        | Some l -> (id, l, e)
        | None -> (id, [], e)
    ]];
    str_item: [[
      "BARISTA_ERROR"; "="; OPT "|"; l = LIST1 str_error_kind SEP "|" ->
        let kinds_for_typ =
          List.map
            (fun (cstr, params, _) ->
              let params = ty_tuple TyStar (List.map snd params) in
              <:ctyp< $uid:cstr$ of $params$ >>)
            l in
        let typ = ty_tuple TyOr kinds_for_typ in
        let typ = Ast.TySum (ghost, typ) in
        let kinds_for_pat =
          List.map
            (fun (cstr, params, exp) ->
              let patt =
                pa_tuple cstr
                  (List.map (fun (x, _) -> <:patt< $lid:x$ >>) params) in
              Ast.McArr (Loc.ghost, patt, <:expr< >>, exp))
            l in
        let cases = <:match_case< $list:kinds_for_pat$ >> in
        <:str_item< type error = $typ$;;
                    exception Exception of error;;
                    let fail e = raise (Exception e);;
                    let string_of_error e = match e with $cases$;;
                    let () =
                      Printexc.register_printer
                        (function
                          | Exception e -> Some (string_of_error e)
                          | _ -> None) >>
    ]];
    (* UTF8 extensions (expressions) & printf. *)
    expr: LEVEL "simple" [
      [ "@"; s = STRING ->
        let id = add_string s in
        <:expr< $lid:id$ >> ]
      | [ "@"; c = CHAR ->
        let id = add_char c in
        <:expr< $lid:id$ >> ]
      | [ "BPRINTF"; "("; fmt = STRING; "("; hd = expr; ")"; tl = OPT expr; ")" ->
          match tl with
          | Some x -> xprintf fmt (Some hd) (List.rev (list_of_expression x))
          | None -> xprintf fmt (Some hd) [] ]
      | [ "BPRINTF"; "("; fmt = STRING; l = expr; ")" ->
	  match List.rev (list_of_expression l) with
	  | hd :: tl -> xprintf fmt (Some hd) tl
	  | [] -> failwith "missing parameter" ]
      | [ "SPRINTF"; "("; fmt = STRING; "("; hd = expr; ")"; tl = OPT expr; ")" ->
          let expressions =
            match tl with
            | Some e -> [hd] @ (List.rev (list_of_expression e))
            | None -> [hd] in
          xprintf fmt None expressions ]
      | [ "SPRINTF"; "("; fmt = STRING; l = OPT expr; ")" ->
	  let expressions =
	    match l with
	    | Some e -> List.rev (list_of_expression e)
	    | None -> [] in
	  xprintf fmt None expressions ]
    ];
    (* UTF8 extensions (patterns). *)
    match_case0: [
      [ "@"; s = STRING; "->"; e = expr ->
        let id = add_string s in
        let id' = id ^ "'" in
        <:match_case< $lid:id'$ when $utf8_equal$ $lid:id$ $lid:id'$ -> $e$ >> ]
      | [ "@"; s = STRING; "when"; w = expr; "->"; e = expr ->
          let id = add_string s in
          let id' = id ^ "'" in
          <:match_case< $lid:id'$ when ($utf8_equal$ $lid:id$ $lid:id'$) && $w$ -> $e$ >> ]
      | [ "@"; c = CHAR; "->"; e = expr ->
          let id = add_char c in
          let id' = id ^ "'" in
          <:match_case< $lid:id'$ when $uchar_equal$ $lid:id$ $lid:id'$ -> $e$ >> ]
      | [ "@"; c = CHAR; "when"; w = expr; "->"; e = expr ->
          let id = add_char c in
          let id' = id ^ "'" in
          <:match_case< $lid:id'$ when ($uchar_equal$ $lid:id$ $lid:id'$) && $w$ -> $e$ >> ]
    ];
  END
end


(* Actual registration of syntax extension and instrument. *)
let () =
  let module Id =
    struct
      let name = "Barista-syntax-extension"
      let version = ""
    end in
  let module M = Camlp4.Register.OCamlSyntaxExtension (Id) (Make) in
  ()
