(*
 * This file is part of Barista.
 * Copyright (C) 2007-2014 Xavier Clerc.
 *
 * Barista is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * Barista is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *)

open Consts


(* Types *)

type t =
  [ `Int of int32
  | `Float of float
  | `String of UTF8.t
  | `Class_or_interface of Name.for_class
  | `Array_type of Descriptor.array_type
  | `Method_type of Descriptor.for_method
  | `Method_handle of Bootstrap.method_handle ]

type set = t list ref

type index = First of Utils.u2 | Second of Utils.u2

type element =
  | Class of index
  | Fieldref of index * index
  | Methodref of index * index
  | InterfaceMethodref of index * index
  | String of index
  | Integer of int32
  | Float of int32
(* Long and Double cannot appear in LDC
  | Long of int32 * int32
  | Double of int32 * int32
*)
  | NameAndType of index * index
  | UTF8 of UTF8.t
  | MethodHandle of Reference.kind * index
  | MethodType of index
(* InvokeDynamic and ModuleId cannot appear in LDC
  | InvokeDynamic of index * index
  | ModuleId of index * index
*)


(* Exception *)

BARISTA_ERROR =
  | Too_large of (x : int) ->
      Printf.sprintf "ldc constant pool is too large (%d)" x
  | Too_many_constraints of (x : int) ->
      Printf.sprintf "too many ldc constraints (%d)" x


(* Utilities *)

let equal x y =
  match (x, y) with
  | (`Int z1), (`Int z2) -> z1 = z2
  | (`Float z1), (`Float z2) -> z1 = z2
  | (`String z1), (`String z2) -> UTF8.equal z1 z2
  | (`Class_or_interface z1), (`Class_or_interface z2) -> Name.equal_for_class z1 z2
  | (`Array_type z1), (`Array_type z2) -> Descriptor.equal_java_type (z1 :> Descriptor.java_type) (z2 :> Descriptor.java_type)
  | (`Method_type z1), (`Method_type z2) -> Descriptor.equal_for_method z1 z2
  | (`Method_handle z1), (`Method_handle z2) ->  Bootstrap.equal_method_handle z1 z2
  | _ -> false

let compare x y =
  match (x, y) with
  | (`Int z1), (`Int z2) -> Pervasives.compare z1 z2
  | (`Float z1), (`Float z2) -> Pervasives.compare z1 z2
  | (`String z1), (`String z2) -> UTF8.compare z1 z2
  | (`Class_or_interface z1), (`Class_or_interface z2) -> Name.compare_for_class z1 z2
  | (`Array_type z1), (`Array_type z2) -> Descriptor.compare_java_type (z1 :> Descriptor.java_type) (z2 :> Descriptor.java_type)
  | (`Method_type z1), (`Method_type z2) -> Descriptor.compare_for_method z1 z2
  | (`Method_handle z1), (`Method_handle z2) ->  Bootstrap.compare_method_handle z1 z2
  | _ -> Pervasives.compare x y

let hash x =
  match x with
  | `Int _
  | `Float _
  | `String _ -> Utils.universal_hash x
  | `Class_or_interface z -> Name.hash_for_class z
  | `Array_type z -> Descriptor.hash_java_type (z :> Descriptor.java_type)
  | `Method_type z -> Descriptor.hash_for_method z
  | `Method_handle z ->  Bootstrap.hash_method_handle z

let convert base e =
  let conv = function
    | First x -> x
    | Second x -> Utils.u2 ((x :> int) + base) in
  match e with
  | Class idx -> ConstantPool.Class (conv idx)
  | Fieldref (idx1, idx2) -> ConstantPool.Fieldref (conv idx1, conv idx2)
  | Methodref (idx1, idx2) -> ConstantPool.Methodref (conv idx1, conv idx2)
  | InterfaceMethodref (idx1, idx2) -> ConstantPool.InterfaceMethodref (conv idx1, conv idx2)
  | String idx -> ConstantPool.String (conv idx)
  | Integer x -> ConstantPool.Integer x
  | Float x -> ConstantPool.Float x
(* Long and Double cannot appear in LDC
  | Long (x, y) -> ConstantPool.Long (x, y)
  | Double (x, y) -> ConstantPool.Double (x, y)
*)
  | NameAndType (idx1, idx2) -> ConstantPool.NameAndType (conv idx1, conv idx2)
  | UTF8 x -> ConstantPool.UTF8 x
  | MethodHandle (rk, idx) -> ConstantPool.MethodHandle (rk, conv idx)
  | MethodType idx -> ConstantPool.MethodType (conv idx)
(* InvokeDynamic and ModuleId cannot appear in LDC
  | InvokeDynamic (idx1, idx2) -> ConstantPool.InvokeDynamic (conv idx1, conv idx2)
  | ModuleId (idx1, idx2) -> ConstantPool.ModuleId (conv idx1, conv idx2)
*)

let dummy_element =
  UTF8 @"Dummy-Constant-Pool-Entry"

let pool_equal x y =
  (x != dummy_element)
    && (match x, y with
    | (UTF8 u), (UTF8 u') -> UTF8.equal u u'
    | _ -> x = y)

let add_if_not_found ext elem =
  ExtendableArray.add_if_not_found
    (Exception (Too_large (ExtendableArray.length ext)))
    (fun x -> pool_equal x elem)
    ext
    elem
    dummy_element
    false


(* Functions *)

let make () =
  ref []

let add c s =
  let l = !s in
  if not (List.exists (equal c) l) then
    s := c :: l

let encode s =
  let fst_pool = ExtendableArray.make 1 128 dummy_element in
  let snd_pool = ExtendableArray.make 0 128 dummy_element in
  (* those are added to the second pool *)
  let add_utf8 u =
    let elem = UTF8 u in
    Second (add_if_not_found snd_pool elem) in
  (* those are added preferably to the first pool *)
  let add_class_best n =
    let name_index = add_utf8 (Name.internal_utf8_for_class n) in
    let elem = Class name_index in
    try
      First (ExtendableArray.find (fun x -> pool_equal x elem) fst_pool)
    with Not_found ->
      Second (add_if_not_found snd_pool elem) in
  (* those are also added to the second pool *)
  let add_name_and_type n t =
    let name_index = add_utf8 n in
    let type_index = add_utf8 t in
    let elem = NameAndType (name_index, type_index) in
    Second (add_if_not_found snd_pool elem) in
  let add_field cn n t =
    let class_index = add_class_best cn in
    let d = Descriptor.utf8_of_field t in
    let name_and_type_index = add_name_and_type (Name.utf8_for_field n) d in
    let elem = Fieldref (class_index, name_and_type_index) in
    Second (add_if_not_found snd_pool elem) in
  let add_method cn n t =
    let class_index = add_class_best cn in
    let d = Descriptor.utf8_of_method t in
    let name_and_type_index = add_name_and_type (Name.utf8_for_method n) d in
    let elem = Methodref (class_index, name_and_type_index) in
    Second (add_if_not_found snd_pool elem) in
  let add_interface_method cn n t =
    let class_index = add_class_best cn in
    let d = Descriptor.utf8_of_method t in
    let name_and_type_index = add_name_and_type (Name.utf8_for_method n) d in
    let elem = InterfaceMethodref (class_index, name_and_type_index) in
    Second (add_if_not_found snd_pool elem) in
  (* those are added to the first pool *)
  let add_integer i =
    let elem = Integer i in
    ignore (add_if_not_found fst_pool elem) in
  let add_float f =
    let elem = Float (Int32.bits_of_float f) in
    ignore (add_if_not_found fst_pool elem) in
  let add_string s =
    let v = add_utf8 s in
    let elem = String v in
    ignore (add_if_not_found fst_pool elem) in
  let add_class n =
    let name_index = add_utf8 (Name.internal_utf8_for_class n) in
    let elem = Class name_index in
    ignore (add_if_not_found fst_pool elem) in
  let add_array_class a =
    let d = Descriptor.internal_utf8_of_java_type (a :> Descriptor.java_type) in
    let name_index = add_utf8 d in
    let elem = Class name_index in
    ignore (add_if_not_found fst_pool elem) in
  let add_method_type t =
    let d = Descriptor.utf8_of_method t in
    let type_index = add_utf8 d in
    let elem = MethodType type_index in
    ignore (add_if_not_found fst_pool elem) in
  let add_method_handle r =
    let elem = match r with
    | `getField (cn, fn, fd) ->
        let reference_index = add_field cn fn fd in
        MethodHandle (Reference.REF_getField, reference_index)
    | `getStatic (cn, fn, fd) ->
        let reference_index = add_field cn fn fd in
        MethodHandle (Reference.REF_getStatic, reference_index)
    | `putField (cn, fn, fd) ->
        let reference_index = add_field cn fn fd in
        MethodHandle (Reference.REF_putField, reference_index)
    | `putStatic (cn, fn, fd) ->
        let reference_index = add_field cn fn fd in
        MethodHandle (Reference.REF_putStatic, reference_index)
    | `invokeVirtual (cn, mn, mt) ->
        let reference_index = add_method cn mn mt in
        MethodHandle (Reference.REF_invokeVirtual, reference_index)
    | `invokeStatic (cn, mn, mt) ->
        let reference_index = add_method cn mn mt in
        MethodHandle (Reference.REF_invokeStatic, reference_index)
    | `invokeSpecial (cn, mn, mt) ->
        let reference_index = add_method cn mn mt in
        MethodHandle (Reference.REF_invokeSpecial, reference_index)
    | `newInvokeSpecial (cn, pl) ->
        let mn = Name.make_for_method class_constructor in
        let mt = pl, (`Class cn) in
        let reference_index = add_method cn mn mt in
        MethodHandle (Reference.REF_newInvokeSpecial, reference_index)
    | `invokeInterface (cn, mn, mt) ->
        let reference_index = add_interface_method cn mn mt in
        MethodHandle (Reference.REF_invokeInterface, reference_index) in
    ignore (add_if_not_found fst_pool elem) in
  (* classes are added first to ensure they have the lowest indexes *)
  List.iter
    (function
      | `Class_or_interface x -> add_class x
      | _ -> ())
    !s;
  List.iter
    (function
      | `Int x -> add_integer x
      | `Float x -> add_float x
      | `String x -> add_string x
      | `Class_or_interface _ -> ()
      | `Array_type x -> add_array_class x
      | `Method_type x -> add_method_type x
      | `Method_handle x -> add_method_handle x)
    !s;
  (* concatenation of both pools can now be performed *)
  let fst_len = ExtendableArray.length fst_pool in
  let snd_len = ExtendableArray.length snd_pool in
  let tot_len = fst_len + snd_len in
  if fst_len > 256 then fail (Too_many_constraints fst_len);
  if tot_len > 65536 then fail (Too_large tot_len);
  let res = ExtendableArray.make tot_len tot_len ConstantPool.dummy_element in
  for i = 1 to pred fst_len do
    let j = Utils.u2 i in
    ExtendableArray.set res j (convert fst_len (ExtendableArray.get fst_pool j));
  done;
  for i = 0 to pred snd_len do
    let j = Utils.u2 (fst_len + i) in
    let k = Utils.u2 i in
    ExtendableArray.set res j (convert fst_len (ExtendableArray.get snd_pool k));
  done;
  let res = ExtendableArray.to_array res in
  ConstantPool.make_extendable_from_array res
