(*
 * This file is part of Barista.
 * Copyright (C) 2007-2014 Xavier Clerc.
 *
 * Barista is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * Barista is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *)

open Consts

let (++) = UTF8.(++)


(* Java types definition *)

type java_type =
  [ `Boolean
  | `Byte
  | `Char
  | `Double
  | `Float
  | `Int
  | `Long
  | `Short
  | `Void
  | `Class of Name.for_class
  | `Array of 'a ] constraint 'a = non_void_java_type
and non_void_java_type =
  [ `Boolean
  | `Byte
  | `Char
  | `Double
  | `Float
  | `Int
  | `Long
  | `Short
  | `Class of Name.for_class
  | `Array of 'a ] constraint 'a = non_void_java_type

type array_type =
  [ `Array of 'a ] constraint 'a = [ `Boolean
                                   | `Byte
                                   | `Char
                                   | `Double
                                   | `Float
                                   | `Int
                                   | `Long
                                   | `Short
                                   | `Class of Name.for_class
                                   | `Array of 'a ]


(* Exception *)

BARISTA_ERROR =
  | Invalid_class_name -> "invalid class name"
  | Invalid_array_element_type -> "invalid array element type (void)"
  | Array_with_too_many_dimensions -> "array with more than 255 dimensions"
  | Invalid_descriptor_string -> "invalid descriptor string"
  | Empty_descriptor_string -> "empty descriptor string"
  | Invalid_field_type -> "invalid field type (void)"
  | Invalid_local_variable_type -> "invalid local variable type (void)"
  | Invalid_method_descriptor -> "invalid method descriptor"
  | Invalid_method_parameter_type -> "invalid parameter type (void)"
  | Void_not_allowed -> "void is not allowed here"


(* Utility functions *)

let is_primitive = function
  | `Boolean
  | `Byte
  | `Char
  | `Double
  | `Float
  | `Int
  | `Long
  | `Short -> true
  | `Void
  | `Class _
  | `Array _ -> false

let filter_void err = function
  | `Boolean -> `Boolean
  | `Byte -> `Byte
  | `Char -> `Char
  | `Double -> `Double
  | `Float -> `Float
  | `Int -> `Int
  | `Long -> `Long
  | `Short -> `Short
  | `Void -> fail err
  | `Class c -> `Class c
  | `Array t -> `Array t

let filter_non_array err = function
  | `Boolean
  | `Byte
  | `Char
  | `Double
  | `Float
  | `Int
  | `Long
  | `Short
  | `Void
  | `Class _ -> fail err
  | `Array t -> `Array t

let java_type_of_partial_utf8 str i =
  let len = UTF8.length str in
  let rec jtopu n idx =
    if idx < len then begin
      let ch = UTF8.get str idx in
      Utils.switch
        UChar.equal
        [ capital_z,
          (fun _ -> `Boolean, succ idx);

          capital_b,
          (fun _ -> `Byte, succ idx);

          capital_c,
          (fun _ -> `Char, succ idx);

          capital_d,
          (fun _ -> `Double, succ idx);

          capital_f,
          (fun _ -> `Float, succ idx);

          capital_i,
          (fun _ -> `Int, succ idx);

          capital_j,
          (fun _ -> `Long, succ idx);

          capital_s,
          (fun _ -> `Short, succ idx);

          capital_v,
          (fun _ -> `Void, succ idx);

          capital_l,
          (fun _ ->
            try
              let index = UTF8.index_from str (succ idx) semi_colon in
              let name = (Name.make_for_class_from_internal (UTF8.substring str (succ idx) (pred index))) in
              (`Class name, index + 1)
            with
            | Not_found
            | Name.Exception _ -> fail Invalid_class_name);

          opening_square_bracket,
          (fun _ ->
            if n < 255 then
              let t, res = jtopu (succ n) (succ idx) in
              (`Array (filter_void Invalid_array_element_type t), res)
            else
              fail Array_with_too_many_dimensions) ]
        (fun _ -> fail Invalid_descriptor_string)
        ch
    end else
      fail Empty_descriptor_string in
  jtopu 0 i

let java_type_of_internal_utf8 s =
  let res, idx = java_type_of_partial_utf8 s 0 in
  if idx = UTF8.length s then
    res
  else
    fail Invalid_descriptor_string

let internal_utf8_of_java_type =
  let rec uojt n = function
    | `Boolean -> @"Z"
    | `Byte -> @"B"
    | `Char -> @"C"
    | `Double -> @"D"
    | `Float -> @"F"
    | `Int -> @"I"
    | `Long -> @"J"
    | `Short -> @"S"
    | `Void -> @"V"
    | `Class c ->
        @"L" ++ (Name.internal_utf8_for_class c) ++ @";"
    | `Array jt ->
        if n < 255 then
          @"[" ++ (uojt (succ n) (jt :> java_type))
        else
          fail Array_with_too_many_dimensions
  in uojt 0

let rec external_utf8_of_java_type = function
  | `Boolean -> @"boolean"
  | `Byte -> @"byte"
  | `Char -> @"char"
  | `Double -> @"double"
  | `Float -> @"float"
  | `Int -> @"int"
  | `Long -> @"long"
  | `Short -> @"short"
  | `Void -> @"void"
  | `Class n -> Name.external_utf8_for_class n
  | `Array jt -> (external_utf8_of_java_type (jt :> java_type)) ++ @"[]"

let external_utf8_of_java_type_varargs = function
  | `Boolean -> @"boolean"
  | `Byte -> @"byte"
  | `Char -> @"char"
  | `Double -> @"double"
  | `Float -> @"float"
  | `Int -> @"int"
  | `Long -> @"long"
  | `Short -> @"short"
  | `Void -> @"void"
  | `Class n -> Name.external_utf8_for_class n
  | `Array jt -> (external_utf8_of_java_type (jt :> java_type)) ++ @"..."

let rec short_utf8_of_java_type = function
  | `Boolean -> @"boolean"
  | `Byte -> @"byte"
  | `Char -> @"char"
  | `Double -> @"double"
  | `Float -> @"float"
  | `Int -> @"int"
  | `Long -> @"long"
  | `Short -> @"short"
  | `Void -> @"void"
  | `Class n -> Name.short_utf8_for_class n
  | `Array jt -> (short_utf8_of_java_type (jt :> java_type)) ++ @"[]"

let short_utf8_of_java_type_varargs = function
  | `Boolean -> @"boolean"
  | `Byte -> @"byte"
  | `Char -> @"char"
  | `Double -> @"double"
  | `Float -> @"float"
  | `Int -> @"int"
  | `Long -> @"long"
  | `Short -> @"short"
  | `Void -> @"void"
  | `Class n -> Name.external_utf8_for_class (snd (Name.split_class_name n))
  | `Array jt -> (short_utf8_of_java_type (jt :> java_type)) ++ @"..."

let rec shortest_utf8_of_java_type = function
  | `Boolean -> @"boolean"
  | `Byte -> @"byte"
  | `Char -> @"char"
  | `Double -> @"double"
  | `Float -> @"float"
  | `Int -> @"int"
  | `Long -> @"long"
  | `Short -> @"short"
  | `Void -> @"void"
  | `Class n -> Name.external_utf8_for_class (snd (Name.split_class_name n))
  | `Array jt -> (shortest_utf8_of_java_type (jt :> java_type)) ++ @"[]"

let shortest_utf8_of_java_type_varargs = function
  | `Boolean -> @"boolean"
  | `Byte -> @"byte"
  | `Char -> @"char"
  | `Double -> @"double"
  | `Float -> @"float"
  | `Int -> @"int"
  | `Long -> @"long"
  | `Short -> @"short"
  | `Void -> @"void"
  | `Class n -> Name.short_utf8_for_class n
  | `Array jt -> (shortest_utf8_of_java_type (jt :> java_type)) ++ @"..."

let java_type_of_external_utf8 s =
  let rec make_array n x =
    if n = 0 then
      x
    else
      `Array (make_array (pred n) x) in
  let l = UTF8.length s in
  let i = ref 0 in
  while !i < l && ((UChar.is_identifier_part (UTF8.get s !i))
                 || (UChar.equal dot (UTF8.get s !i))
                 || (UChar.equal dollar (UTF8.get s !i))
                 || (UChar.equal opening_square_bracket (UTF8.get s !i))
                 || (UChar.equal closing_square_bracket (UTF8.get s !i))) do
    incr i
  done;
  if !i = l && UChar.is_letter (UTF8.get s 0) then
    let j = ref (pred l) in
    let dims = ref 0 in
    while (!j - 1 >= 0)
        && (UChar.equal closing_square_bracket (UTF8.get s !j))
        && (UChar.equal opening_square_bracket (UTF8.get s (!j - 1))) do
      incr dims;
      decr j;
      decr j
    done;
    if !dims > 255 then fail Array_with_too_many_dimensions;
    let prefix = UTF8.substring s 0 !j in
    let base = match (try UTF8.to_string prefix with _ -> "") with
    | "boolean" -> `Boolean
    | "byte" -> `Byte
    | "char" -> `Char
    | "double" -> `Double
    | "float" -> `Float
    | "int" -> `Int
    | "long" -> `Long
    | "short" -> `Short
    | "void" -> `Void
    | _ -> `Class (Name.make_for_class_from_external prefix) in
    if !dims = 0 then
      base
    else
      let array = make_array !dims (filter_void Invalid_array_element_type base) in
      (array :> java_type)
  else
    fail Invalid_descriptor_string

let rec equal_java_type x y =
  match (x, y) with
  | `Boolean, `Boolean
  | `Byte, `Byte
  | `Char, `Char
  | `Double, `Double
  | `Float, `Float
  | `Int, `Int
  | `Long, `Long
  | `Short, `Short
  | `Void, `Void -> true
  | (`Class cn1), (`Class cn2) -> Name.equal_for_class cn1 cn2
  | (`Array a1), (`Array a2) -> equal_java_type (a1 :> java_type) (a2 :> java_type)
  | _ -> false

let rec compare_java_type x y =
  match (x, y) with
  | `Boolean, `Boolean
  | `Byte, `Byte
  | `Char, `Char
  | `Double, `Double
  | `Float, `Float
  | `Int, `Int
  | `Long, `Long
  | `Short, `Short
  | `Void, `Void -> 0
  | (`Class cn1), (`Class cn2) -> Name.compare_for_class cn1 cn2
  | (`Array a1), (`Array a2) -> compare_java_type (a1 :> java_type) (a2 :> java_type)
  | _ -> Pervasives.compare x y

let rec hash_java_type x =
  match x with
  | `Boolean -> 0
  | `Byte -> 1
  | `Char -> 2
  | `Double -> 3
  | `Float -> 4
  | `Int -> 5
  | `Long -> 6
  | `Short -> 7
  | `Void -> 8
  | `Class cn-> 9 + (Name.hash_for_class cn)
  | `Array a -> 10 + (hash_java_type (a :> java_type))


(* Field descriptors *)

type for_field = non_void_java_type

let field_of_utf8 str =
  let t = java_type_of_internal_utf8 str in
  filter_void Invalid_field_type t

let utf8_of_field fd =
  internal_utf8_of_java_type (fd :> java_type)

let java_type_of_external_utf8_no_void s =
  let res = java_type_of_external_utf8 s in
  filter_void Void_not_allowed res

let java_type_of_internal_utf8_no_void s =
  let res = java_type_of_internal_utf8 s in
  filter_void Void_not_allowed res

let equal_for_field x y =
  equal_java_type (x :> java_type) (y :> java_type)

let compare_for_field x y =
  compare_java_type (x :> java_type) (y :> java_type)

let hash_for_field x =
  hash_java_type (x :> java_type)


(* Method descriptors *)

type for_parameter = non_void_java_type

let parameter_of_utf8 = field_of_utf8

let utf8_of_parameter = utf8_of_field

let equal_for_parameter x y = equal_java_type (x :> java_type) (y :> java_type)

let compare_for_parameter x y = compare_java_type (x :> java_type) (y :> java_type)

let hash_for_parameter x = hash_java_type (x :> java_type)

type for_method = (for_parameter list) * java_type

let method_of_utf8 str =
  let len = UTF8.length str in
  if (len > 2) && (UChar.equal opening_parenthesis (UTF8.get str 0)) then
    let index = (try
      UTF8.index_from str 1 closing_parenthesis
    with Not_found -> fail Invalid_method_descriptor) in
    let (ret, last) = java_type_of_partial_utf8 str (index + 1) in
    if (last <> len) then fail Invalid_method_descriptor;
    let params = ref [] in
    let curr = ref 1 in
    while !curr < index do
      let t, i = java_type_of_partial_utf8 str !curr in
      params := (filter_void Invalid_method_parameter_type t) :: !params;
      curr := i
    done;
    if !curr = index then
      ((List.rev !params), ret)
    else
      fail Invalid_method_descriptor
  else
    fail Invalid_method_descriptor

let utf8_of_method (params, return) =
  @"("
    ++ (UTF8.concat (List.map utf8_of_parameter params))
    ++ @")"
    ++ (internal_utf8_of_java_type return)

let equal_for_method (xp, xr) (yp, yr) =
  let xp = (xp :> java_type list) in
  let yp = (yp :> java_type list) in
  (Utils.list_equal equal_java_type xp yp) && (equal_java_type xr yr)

let compare_for_method (xp, xr) (yp, yr) =
  let xp = (xp :> java_type list) in
  let yp = (yp :> java_type list) in
  let res = Utils.list_compare compare_java_type xp yp in
  if res = 0 then
    compare_java_type xr yr
  else
    res

let hash_for_method (xp, xr) =
  let xp = (xp :> java_type list) in
  (Utils.list_hash hash_java_type xp) + (hash_java_type xr)
