(*
 * This file is part of Barista.
 * Copyright (C) 2007-2014 Xavier Clerc.
 *
 * Barista is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * Barista is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *)

open Consts

let (++) = UTF8.(++)


(* Constants *)

let tab = @"  "

let class_constructor_name = Name.make_for_method class_constructor

let class_initializer_name = Name.make_for_method class_initializer


(* Functions *)

let utf8_of_method_desc name desc =
  let params, return = desc in
  let params = UTF8.concat_sep_map @"," Descriptor.external_utf8_of_java_type (params :> Descriptor.java_type list) in
  let return = Descriptor.external_utf8_of_java_type return in
  SPRINTF ("%s %s(%s)" return (Name.utf8_for_method name) params)

let utf8_of_method_call name desc =
  let params, return = desc in
  let params = UTF8.concat_sep_map @"," Descriptor.external_utf8_of_java_type (params :> Descriptor.java_type list) in
  let return = Descriptor.external_utf8_of_java_type return in
  let name = Name.utf8_for_method name in
  SPRINTF ("%s(%s):%s" name params return)

let extract_code_attributes l =
  let rec extract lnt lvt lvtt st rvta rita unk l =
    match l with
    | hd :: tl ->
        (match hd with
        | `LineNumberTable l -> extract (lnt @ l) lvt lvtt st rvta rita unk tl
        | `LocalVariableTable l -> extract lnt (lvt @ l) lvtt st rvta rita unk tl
        | `LocalVariableTypeTable l -> extract lnt lvt (lvtt @ l) st rvta rita unk tl
        | `StackMapTable l -> extract lnt lvt lvtt l rvta rita unk tl
        | `RuntimeVisibleTypeAnnotations l -> extract lnt lvt lvtt st (rvta @ l) rita unk tl
        | `RuntimeInvisibleTypeAnnotations l -> extract lnt lvt lvtt st rvta (rita @ l) unk tl
        | `Unknown (n, s) -> extract lnt lvt lvtt st rvta rita ((n, s) :: unk) tl)
    | [] ->
        ((List.sort (fun (x, _) (y, _) -> compare x y) lnt),
         (List.sort (fun { Attribute.local_start = x; _ } { Attribute.local_start = y; _ } -> compare x y) lvt),
         (List.sort (fun { Attribute.local_type_start = x; _ } { Attribute.local_type_start = y; _ } -> compare x y) lvtt),
         st,
         rvta,
         rita,
         (List.rev unk)) in
  extract [] [] [] [] [] [] [] l

let rec annotation_value_list prefix = function
  | Annotation.Boolean_value b ->
      if b then
        [prefix ++ @" boolean 1"]
      else
        [prefix ++ @" boolean 0"]
  | Annotation.Byte_value b ->
      [SPRINTF ("%s byte %d" prefix b)]
  | Annotation.Char_value c ->
      let c = UChar.to_code c in
      [SPRINTF ("%s char %d" prefix c)]
  | Annotation.Double_value d ->
      [SPRINTF ("%s double %f" prefix d)]
  | Annotation.Float_value f ->
      [SPRINTF ("%s float %f" prefix f)]
  | Annotation.Int_value i ->
      [SPRINTF ("%s int %ld" prefix i)]
  | Annotation.Long_value l ->
      [SPRINTF ("%s long %Ld" prefix l)]
  | Annotation.Short_value s ->
      [SPRINTF ("%s short %d" prefix s)]
  | Annotation.String_value s ->
      [SPRINTF ("%s string %S" prefix s)]
  | Annotation.Enum_value (cn, fn) ->
      [SPRINTF ("%s enum %s %s"
                  prefix
                  (Name.external_utf8_for_class cn)
                  (Name.utf8_for_field fn))]
  | Annotation.Class_value cn ->
      let cn = Name.external_utf8_for_class cn in
      [SPRINTF ("%s class %s" prefix cn)]
  | Annotation.Annotation_value a ->
      let prefix' = prefix ++ @" annotation" in
      annotation_list prefix' a
  | Annotation.Array_value l ->
      let idx = ref 0 in
      List.flatten
        (List.map
           (fun x ->
             let prefix' = SPRINTF ("%s %d" prefix !idx) in
             incr idx;
             annotation_value_list prefix' x)
           l)
and annotation_list prefix (name, pairs) =
  let prefix' = SPRINTF ("%s%s " prefix (Name.external_utf8_for_class name)) in
  if pairs = [] then
    [prefix']
  else
    List.flatten
      (List.map
         (fun (id, vl) ->
           annotation_value_list (prefix' ++ id) vl)
         pairs)

let add_attribute buffer a =
  UTF8Buffer.add_string buffer tab;
  match a with
  | `ConstantValue cv ->
      UTF8Buffer.add_string buffer @"@ConstantValue ";
      (match cv with
      | Attribute.Long_value lv ->
          BPRINTF ("%Ld\n" buffer lv)
      | Attribute.Float_value fv ->
          BPRINTF ("%f\n" buffer fv)
      | Attribute.Double_value dv ->
          BPRINTF ("%f\n" buffer dv)
      | Attribute.Boolean_value bv ->
          BPRINTF ("%d\n" buffer (if bv then 1 else 0))
      | Attribute.Byte_value bv ->
          BPRINTF ("%d\n" buffer bv)
      | Attribute.Character_value cv ->
          BPRINTF ("%d\n" buffer cv)
      | Attribute.Short_value sv ->
          BPRINTF ("%d\n" buffer sv)
      | Attribute.Integer_value iv ->
          BPRINTF ("%ld\n" buffer iv)
      | Attribute.String_value sv ->
          BPRINTF ("%S\n" buffer sv))
  | `Code cv ->
      let line_number_table, local_variable_table, local_variable_type_table, stack_map_table, _, _, unknowns =
        extract_code_attributes cv.Attribute.attributes in
      let prev = ref (-1) in
      let ofs = ref 0 in
      BPRINTF (".max_stack %d\n" buffer (cv.Attribute.max_stack :> int));
      BPRINTF ("  .max_locals %d\n" buffer (cv.Attribute.max_locals :> int));
      List.iter
        (fun i ->
          (try
            let line = snd (List.find (fun ((x : Utils.u2), _) -> !prev < (x :> int) && !ofs >= (x :> int)) line_number_table) in
            BPRINTF ("  @LineNumberTable %d\n" buffer (line :> int));
          with Not_found -> ());
          List.iter
            (fun { Attribute.local_start; local_length; local_name; local_descriptor; local_index } ->
              let start = (local_start : Utils.u2 :> int) in
              let length = (local_length : Utils.u2 :> int) in
              let index = (local_index : Utils.u2 :> int) in
              let finish = start + length in
              if (!ofs >= start && !ofs < finish)
                  && not (!prev >= start && !prev < finish) then
                begin
                  BPRINTF ("  @LocalVariableTable code%08d: code%08d: %s %s %d\n"
                             buffer
                             start
                             finish
                             local_name
                             (Descriptor.external_utf8_of_java_type (local_descriptor :> Descriptor.java_type))
                             index)
                end)
            local_variable_table;
          List.iter
            (fun { Attribute.local_type_start; local_type_length; local_type_name; local_type_signature; local_type_index } ->
              let start = (local_type_start : Utils.u2 :> int) in
              let length = (local_type_length : Utils.u2 :> int) in
              let index = (local_type_index : Utils.u2 :> int) in
              let finish = start + length in
              if (!ofs >= start && !ofs < finish)
                  && not (!prev >= start && !prev < finish) then
                begin
                  BPRINTF ("  @LocalVariableTypeTable code%08d: code%08d: %s %S %d\n"
                             buffer
                             start
                             finish
                             local_type_name
                             (Signature.utf8_of_field_type_signature local_type_signature)
                             index)
                end)
            local_variable_type_table;
          let sz, is_wide, mnemo, p, t = Instruction.decompile !ofs i in
          BPRINTF ("  code%08d: %s%s"
                     buffer
                     !ofs
                     (if is_wide then @"wide " else @"")
                     (UTF8.of_string mnemo));
          let add_field (cn, fn, d) =
            BPRINTF ("%s.%s:%s"
                       buffer
                       (Name.external_utf8_for_class cn)
                       (Name.utf8_for_field fn)
                       (Descriptor.external_utf8_of_java_type (d :> Descriptor.java_type))) in
          let add_method (cn, mn, d) =
            BPRINTF ("%s.%s"
                       buffer
                       (Name.external_utf8_for_class cn)
                       (utf8_of_method_call mn d)) in
          let add_method_handle = function
            | `getField (cn, fn, d) ->
                UTF8Buffer.add_string buffer @"getField%";
                add_field (cn, fn, d)
            | `getStatic (cn, fn, d) ->
                UTF8Buffer.add_string buffer @"getStatic%";
                add_field (cn, fn, d)
            | `putField (cn, fn, d) ->
                UTF8Buffer.add_string buffer @"putField%";
                add_field (cn, fn, d)
            | `putStatic (cn, fn, d) ->
                UTF8Buffer.add_string buffer @"putStatic%";
                add_field (cn, fn, d)
            | `invokeVirtual (cn, mn, d) ->
                UTF8Buffer.add_string buffer @"invokeVirtual%";
                add_method (cn, mn, d)
            | `invokeStatic (cn, mn, d) ->
                UTF8Buffer.add_string buffer @"invokeStatic%";
                add_method (cn, mn, d)
            | `invokeSpecial (cn, mn, d) ->
                UTF8Buffer.add_string buffer @"invokeSpecial%";
                add_method (cn, mn, d)
            | `newInvokeSpecial (cn, d) ->
                UTF8Buffer.add_string buffer @"newInvokeSpecial%";
                add_method (cn, (Name.make_for_method class_constructor), (d, `Class cn))
            | `invokeInterface (cn, mn, d) ->
                UTF8Buffer.add_string buffer @"invokeInterface%";
                add_method (cn, mn, d) in
          let add_method_type (params, return) =
            BPRINTF ("(%s):%s"
                       buffer
                       (UTF8.concat_sep_map @"," Descriptor.external_utf8_of_java_type (params :> Descriptor.java_type list))
                       (Descriptor.external_utf8_of_java_type return)) in
          List.iter (fun x ->
            UTF8Buffer.add_string buffer @" ";
            match x with
            | Instruction.Int_constant ic ->
                BPRINTF ("%Ld" buffer ic)
            | Instruction.Offset o ->
                BPRINTF ("code%08d:" buffer (!ofs + (Int32.to_int o)))
            | Instruction.Float_constant fc ->
                BPRINTF ("%f" buffer fc)
            | Instruction.String_constant sc ->
                BPRINTF ("%S" buffer sc)
            | Instruction.Class_name cn ->
                UTF8Buffer.add_string buffer (Name.external_utf8_for_class cn)
            | Instruction.Array_type at ->
                UTF8Buffer.add_string buffer at
            | Instruction.Primitive_type pt ->
                UTF8Buffer.add_string buffer (Descriptor.external_utf8_of_java_type pt)
            | Instruction.Field (cn, fn, d) ->
                add_field (cn, fn, d)
            | Instruction.Dynamic_method ((mh, args), mn, d) ->
                add_method_handle mh;
                List.iter
                  (fun arg ->
                    UTF8Buffer.add_string buffer @" ";
                    match arg with
                    | `String s ->
                        BPRINTF ("%S" buffer s)
                    | `Class cn ->
                        UTF8Buffer.add_string buffer (Name.external_utf8_for_class cn)
                    | `Integer i ->
                        BPRINTF ("int %ld" buffer i)
                    | `Long l ->
                        BPRINTF ("long %Ld" buffer l)
                    | `Float f ->
                        BPRINTF ("float %f" buffer f)
                    | `Double d ->
                        BPRINTF ("double %f" buffer d)
                    | `MethodHandle mh ->
                        add_method_handle mh
                    | `MethodType mt ->
                        add_method_type mt)
                  args;
                BPRINTF (" %s" buffer (utf8_of_method_call mn d))
            | Instruction.Method (cn, mn, d) ->
                add_method (cn, mn, d)
            | Instruction.Array_method (at, mn, d) ->
                BPRINTF ("%s.%s"
                           buffer
                           (Descriptor.external_utf8_of_java_type (at :> Descriptor.java_type))
                           (utf8_of_method_call mn d))
            | Instruction.Method_type_constant (params, return) ->
                add_method_type (params, return)
            | Instruction.Method_handle_constant mh ->
                add_method_handle mh)
            p;
          (match t with
          | Instruction.No_tail -> ()
          | Instruction.Match_offset_pairs l ->
              List.iter
                (fun (m, o) ->
                  let m = (m : Utils.s4 :> int32) in
                  let o = (o : Instruction.long_offset :> int32) in
                  UTF8Buffer.add_newline buffer;
                  BPRINTF ("    %ld => code%08d:" buffer m (!ofs + (Int32.to_int o))))
                l
          | Instruction.Long_offsets l ->
              List.iter
                (fun o ->
                  let o = (o : Instruction.long_offset :> int32) in
                  UTF8Buffer.add_newline buffer;
                  BPRINTF ("    => code%08d:" buffer (!ofs + (Int32.to_int o))))
                l);
          UTF8Buffer.add_endline buffer empty_utf8;
          prev := !ofs;
          ofs := !ofs + sz)
        cv.Attribute.code;
      List.iter
        (fun elem ->
          let start_pc = (elem.Attribute.try_start : Utils.u2 :> int) in
          let end_pc = (elem.Attribute.try_end : Utils.u2 :> int) in
          let handler_pc = (elem.Attribute.catch : Utils.u2 :> int) in
          BPRINTF ("  .catch code%08d: code%08d: code%08d:" buffer start_pc end_pc handler_pc);
          (match elem.Attribute.caught with
          | Some n -> BPRINTF (" %s\n" buffer (Name.external_utf8_for_class n))
          | None -> BPRINTF ("\n" buffer)))
        cv.Attribute.exception_table;
      let utf8_of_type = function
        | Attribute.Top_variable_info -> @"top"
        | Attribute.Integer_variable_info -> @"int"
        | Attribute.Float_variable_info -> @"float"
        | Attribute.Long_variable_info -> @"long"
        | Attribute.Double_variable_info -> @"double"
        | Attribute.Null_variable_info -> @"null"
        | Attribute.Uninitialized_this_variable_info -> @"uninit_this"
        | Attribute.Object_variable_info (`Class_or_interface n) ->
            Name.external_utf8_for_class n
        | Attribute.Object_variable_info (`Array_type at) ->
            Descriptor.external_utf8_of_java_type (at :> Descriptor.java_type)
        | Attribute.Uninitialized_variable_info ofs ->
            let ofs = (ofs :> int) in
            SPRINTF ("uninit code%08d:" ofs) in
      List.iter
        (function
          | Attribute.Same_frame ofs ->
              BPRINTF ("  .frame code%08d: same\n" buffer (ofs :> int))
          | Attribute.Same_locals_1_stack_item_frame (ofs, t) ->
              BPRINTF ("  .frame code%08d: same_locals %s\n" buffer (ofs :> int) (utf8_of_type t))
          | Attribute.Chop_1_frame ofs ->
              BPRINTF ("  .frame code%08d: chop 1\n" buffer (ofs :> int))
          | Attribute.Chop_2_frame ofs ->
              BPRINTF ("  .frame code%08d: chop 2\n" buffer (ofs :> int))
          | Attribute.Chop_3_frame ofs ->
              BPRINTF ("  .frame code%08d: chop 3\n" buffer (ofs :> int))
          | Attribute.Append_1_frame (ofs, t1) ->
              BPRINTF ("  .frame code%08d: append %s\n" buffer (ofs :> int) (utf8_of_type t1))
          | Attribute.Append_2_frame (ofs, t1, t2) ->
              BPRINTF ("  .frame code%08d: append %s %s\n" buffer (ofs :> int) (utf8_of_type t1) (utf8_of_type t2))
          | Attribute.Append_3_frame (ofs, t1, t2, t3) ->
              BPRINTF ("  .frame code%08d: append %s %s %s\n" buffer (ofs :> int) (utf8_of_type t1) (utf8_of_type t2) (utf8_of_type t3))
          | Attribute.Full_frame (ofs, l1, l2) ->
              let utf8_of_type_list l = UTF8.concat_sep_map @" " utf8_of_type l in
              BPRINTF ("  .frame code%08d: full %s ~ %s\n" buffer (ofs :> int) (utf8_of_type_list l1) (utf8_of_type_list l2)))
        stack_map_table;
      List.iter
        (fun (n, b) ->
          BPRINTF ("@Unknown %S %S\n" buffer (UTF8.escape n) (UTF8.of_string (Bytes.as_string b))))
        unknowns
  | `Exceptions e ->
      let e = UTF8.concat_sep_map @" " Name.external_utf8_for_class e in
      BPRINTF ("@Exceptions %s\n" buffer e)
  | `InnerClasses l ->
      let l' =
        List.map
          (function { Attribute.inner_class; outer_class; inner_name; inner_flags } ->
            SPRINTF ("@InnerClass %s %s %s %s"
                       (match inner_class with
                       | Some cn -> Name.external_utf8_for_class cn
                       | None -> @"0")
                       (match outer_class with
                       | Some cn -> Name.external_utf8_for_class cn
                       | None -> @"0")
                       (match inner_name with
                       | Some n -> n
                       | None -> @"0")
                       (AccessFlag.list_to_utf8 (inner_flags :> AccessFlag.t list))))
          l in
      UTF8Buffer.add_endline
        buffer
        (UTF8.concat_sep @"\n  " l')
  | `EnclosingMethod { Attribute.innermost_class; enclosing_method } ->
      BPRINTF ("@EnclosingMethod %s%s\n"
                 buffer
                 (Name.external_utf8_for_class innermost_class)
                 (match enclosing_method with
                 | Some (n, d) ->
                     @" " ++ (utf8_of_method_call n d)
                 | None -> @""))
  | `Synthetic ->
      BPRINTF ("@Synthetic\n" buffer)
  | `Signature s ->
      let s' = (match s with
      | `Class cs -> Signature.utf8_of_class_signature cs
      | `Method ms -> Signature.utf8_of_method_signature ms
      | `Field fs -> Signature.utf8_of_field_type_signature fs) in
      BPRINTF ("@Signature %S\n" buffer s')
  | `SourceFile sf ->
      BPRINTF ("@SourceFile %S\n" buffer sf)
  | `SourceDebugExtension sde ->
      BPRINTF ("@SourceDebugExtension %S\n" buffer sde)
  | `LineNumberTable _ -> ()
  | `LocalVariableTable _ -> ()
  | `LocalVariableTypeTable _ -> ()
  | `Deprecated ->
      BPRINTF ("@Deprecated\n" buffer)
  | `RuntimeVisibleAnnotations l ->
      let l' = List.flatten
          (List.map
             (annotation_list @"@RuntimeVisibleAnnotations ")
             l) in
      UTF8Buffer.add_endline buffer (UTF8.concat_sep @"\n  " l')
  | `RuntimeInvisibleAnnotations l ->
      let l' = List.flatten
          (List.map
             (annotation_list @"@RuntimeInvisibleAnnotations ")
             l) in
      UTF8Buffer.add_endline buffer (UTF8.concat_sep @"\n  " l')
  | `RuntimeVisibleParameterAnnotations l ->
      let no = ref 0 in
      List.iter
        (fun x ->
          let l' = List.flatten
              (List.map
                 (annotation_list
                    (SPRINTF ("%s@RuntimeVisibleParameterAnnotations %d"
                                (if !no = 0 then @"" else @"  ")
                                !no)))
                 x) in
          UTF8Buffer.add_endline buffer (UTF8.concat_sep @"\n  " l');
          incr no)
        l
  | `RuntimeInvisibleParameterAnnotations l ->
      let no = ref 0 in
      List.iter
        (fun x ->
          let l' = List.flatten
              (List.map
                 (annotation_list
                    (SPRINTF ("%s@RuntimeInvisibleParameterAnnotations %d"
                                (if !no = 0 then @"" else @"  ")
                                !no)))
                 x) in
          UTF8Buffer.add_endline buffer (UTF8.concat_sep @"\n  " l');
          incr no)
        l
  | `RuntimeVisibleTypeAnnotations _ -> ()
  | `RuntimeInvisibleTypeAnnotations _ -> ()
  | `AnnotationDefault ad ->
      let l = annotation_value_list @"@AnnotationDefault" ad in
      UTF8Buffer.add_endline buffer (UTF8.concat_sep @"\n  " l)
  | `StackMapTable _ -> ()
  | `BootstrapMethods _ -> ()
  | `MethodParameters _ -> ()
  | `Module (mn, mv) ->
      BPRINTF ("@Module %S %S\n" buffer mn mv)
  | `ModuleRequires _ -> ()
  | `ModulePermits _ -> ()
  | `ModuleProvides _ -> ()
  | `Unknown (n, s) ->
      BPRINTF ("@Unknown %S %S\n" buffer n (UTF8.of_string (Bytes.as_string s)))

let disassemble_to_buffer buffer cp s =
  let cl = ClassLoader.make_of_class_path cp in
  let cd = Lookup.for_class false ~open_packages:[ @"java.lang" ] cl s in
  let cd = cd.Lookup.value in
  BPRINTF (".class %s%s\n"
             buffer
             (AccessFlag.list_to_utf8 (cd.ClassDefinition.access_flags :> AccessFlag.t list))
             (Name.external_utf8_for_class cd.ClassDefinition.name));
  (match cd.ClassDefinition.extends with
  | Some c -> BPRINTF (".extends %s\n" buffer (Name.external_utf8_for_class c))
  | None -> ());
  let interfaces = List.map Name.external_utf8_for_class cd.ClassDefinition.implements in
  List.iter
    (fun x ->
      BPRINTF (".implements %s\n" buffer x))
    (List.sort UTF8.compare interfaces);
  List.iter
    (add_attribute buffer)
    (List.sort
       Attribute.compare_according_to_significance
       (cd.ClassDefinition.attributes :> Attribute.t list));
  if cd.ClassDefinition.fields <> [] then
    UTF8Buffer.add_endline buffer @"";
  List.iter
    (fun f ->
      BPRINTF ("\n.field %s%s %s\n"
                 buffer
                 (AccessFlag.list_to_utf8 (List.sort AccessFlag.compare (f.Field.flags :> AccessFlag.t list)))
                 (Descriptor.external_utf8_of_java_type (f.Field.descriptor :> Descriptor.java_type))
                 (Name.utf8_for_field f.Field.name));
      List.iter
        (add_attribute buffer)
        (List.sort
           Attribute.compare_according_to_significance
           (f.Field.attributes :> Attribute.t list)))
    (List.sort Field.compare_according_to_visibility cd.ClassDefinition.fields);
  if cd.ClassDefinition.methods <> [] then
    UTF8Buffer.add_endline buffer empty_utf8;
  List.iter
    (fun x ->
      let flags, name, desc, attrs =
        match x with
        | Method.Regular mr ->
            mr.Method.flags,
            mr.Method.name,
            mr.Method.descriptor,
            mr.Method.attributes
        | Method.Constructor mc ->
            (mc.Method.cstr_flags :> AccessFlag.for_method list),
            class_constructor_name,
            (mc.Method.cstr_descriptor, `Void),
            mc.Method.cstr_attributes
        | Method.Initializer mi ->
            (mi.Method.init_flags :> AccessFlag.for_method list),
            class_initializer_name,
            ([], `Void),
            mi.Method.init_attributes in
      BPRINTF ("\n.method %s%s\n"
                 buffer
                 (AccessFlag.list_to_utf8 (flags :> AccessFlag.t list))
                 (utf8_of_method_desc name desc));
      List.iter
        (add_attribute buffer)
        (List.sort
           Attribute.compare_according_to_significance
           (attrs :> Attribute.t list)))
    (List.sort Method.compare_according_to_visibility cd.ClassDefinition.methods)

let disassemble_to_stream chan cp s =
  let buffer = UTF8Buffer.make () in
  disassemble_to_buffer buffer cp s;
  let writer = UTF8LineWriter.make_of_stream chan in
  UTF8LineWriter.write_line
    writer
    (UTF8Buffer.contents buffer);
  UTF8LineWriter.flush writer

let disassemble cp s =
  disassemble_to_stream OutputStream.stdout cp s
