(*
 * This file is part of Barista.
 * Copyright (C) 2007-2014 Xavier Clerc.
 *
 * Barista is free software; you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * Barista is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *)


BARISTA_ERROR =
  | Cannot_process_archive of (str : UTF8.t) ->
      Printf.sprintf "Cannot process archive %S"
        (UTF8.to_string_noerr str)

module ClassNameSet = Set.Make (struct
  type t = Name.for_class
  let compare = Name.compare_for_class
end)

module ClassNameMap = Map.Make (struct
  type t = Name.for_class
  let compare = Name.compare_for_class
end)

type summary = {
    full_name : Name.for_class;
    package : Name.for_package;
    extends : Name.for_class option;
    implements : Name.for_class list;
    is_interface : bool;
    is_abstract : bool;
    fields : int;
    methods : int;
    code_size : int;
    references : int ClassNameMap.t;
  }

let summary cd =
  let code_size = ref 0 in
  let references = ref ClassNameMap.empty in
  let add_reference x =
    let old = try ClassNameMap.find x !references with Not_found -> 0 in
    references := ClassNameMap.add x (succ old) !references in
  let iterator =
    object (self)
      inherit ClassTraversal.default_class_definition_iterator
      method private descriptor (d : Descriptor.java_type) =
          match d with
          | `Class cn -> add_reference cn
          | `Array a -> self#descriptor (a :> Descriptor.java_type)
          | _ -> ()
      method private descriptor_method (d : Descriptor.for_method) =
        List.iter
          (fun elem ->
            self#descriptor (elem :> Descriptor.java_type))
          (fst d);
        self#descriptor (snd d)
      method private instruction (i : Instruction.t) =
        let open Instruction in
        match i with
        | ANEWARRAY (`Class_or_interface cn) -> add_reference cn
        | ANEWARRAY (`Array_type d) -> self#descriptor (d :> Descriptor.java_type)
        | CHECKCAST (`Class_or_interface cn) -> add_reference cn
        | CHECKCAST (`Array_type d) -> self#descriptor (d :> Descriptor.java_type)
        | GETFIELD (cn, _, d) ->
            add_reference cn;
            self#descriptor (d :> Descriptor.java_type)
        | GETSTATIC (cn, _, d) ->
            add_reference cn;
            self#descriptor (d :> Descriptor.java_type)
        | INSTANCEOF (`Class_or_interface cn) -> add_reference cn
        | INSTANCEOF (`Array_type d) -> self#descriptor (d :> Descriptor.java_type)
        | INVOKEINTERFACE (cn, _, d) ->
            add_reference cn;
            self#descriptor_method d
        | INVOKESPECIAL (cn, _, d) ->
            add_reference cn;
            self#descriptor_method d
        | INVOKESTATIC (cn, _, d) ->
            add_reference cn;
            self#descriptor_method d
        | INVOKEVIRTUAL (`Class_or_interface cn, _, d) ->
            add_reference cn;
            self#descriptor_method d
        | INVOKEVIRTUAL (`Array_type d, _, d') ->
            self#descriptor (d :> Descriptor.java_type);
            self#descriptor_method d'
        | LDC (`Class_or_interface cn) -> add_reference cn
        | LDC (`Array_type d) -> self#descriptor (d :> Descriptor.java_type)
        | LDC (`Method_type d) -> self#descriptor_method d
        | LDC_W (`Class_or_interface cn) -> add_reference cn
        | LDC_W (`Array_type d) -> self#descriptor (d :> Descriptor.java_type)
        | LDC_W (`Method_type d) -> self#descriptor_method d
        | MULTIANEWARRAY (`Class_or_interface cn, _) -> add_reference cn
        | MULTIANEWARRAY (`Array_type d, _) -> self#descriptor (d :> Descriptor.java_type)
        | NEW cn -> add_reference cn
        | NEWARRAY d -> self#descriptor d
        | PUTFIELD (cn, _, d) ->
            add_reference cn;
            self#descriptor (d :> Descriptor.java_type)
        | PUTSTATIC (cn, _, d) ->
            add_reference cn;
            self#descriptor (d :> Descriptor.java_type)
        | _ -> ()
      method private code (c : Attribute.code_value) =
        let open Attribute in
        code_size := !code_size + (Instruction.size_of_list 0 c.code);
        List.iter self#instruction c.code;
        List.iter
          (fun { caught; _ } ->
            match caught with
            | Some cn -> add_reference cn
            | None -> ())
          c.exception_table;
        List.iter
          (fun a -> self#attribute (a :> Attribute.t))
          c.attributes
      method private annotation (a : Annotation.t) =
        add_reference (fst a);
        List.iter
          (fun (_, a) -> self#annotation_element_value a)
          (snd a)
      method private annotation_list l = List.iter (self#annotation) l
      method private annotation_list_list l = List.iter (self#annotation_list) l
      method private annotation_extended (a : Annotation.extended) =
        let x, y, _, _ = a in
        self#annotation (x, y)
      method private annotation_extended_list l = List.iter (self#annotation_extended) l
      method private annotation_element_value (a : Annotation.element_value) =
        let open Annotation in
        match a with
        | Enum_value (cn, _) -> add_reference cn
        | Class_value cn -> add_reference cn
        | Annotation_value a -> self#annotation a
        | Array_value l -> List.iter self#annotation_element_value l
        | _ -> ()
      method private attribute (a : Attribute.t) =
        match a with
        | `ConstantValue _ -> ()
        | `Code c -> self#code c
        | `Exceptions l -> List.iter add_reference l
        | `InnerClasses _ -> ()
        | `EnclosingMethod _ -> ()
        | `Synthetic -> ()
        | `Signature _ -> ()
        | `SourceFile _ -> ()
        | `SourceDebugExtension _ -> ()
        | `LineNumberTable _ -> ()
        | `LocalVariableTable _ -> ()
        | `LocalVariableTypeTable _ -> ()
        | `Deprecated -> ()
        | `RuntimeVisibleAnnotations l -> self#annotation_list l
        | `RuntimeInvisibleAnnotations l -> self#annotation_list l
        | `RuntimeVisibleParameterAnnotations l -> self#annotation_list_list l
        | `RuntimeInvisibleParameterAnnotations l -> self#annotation_list_list l
        | `AnnotationDefault a -> self#annotation_element_value a
        | `StackMapTable _ -> ()
        | `BootstrapMethods _ -> ()
        | `MethodParameters _ -> ()
        | `RuntimeVisibleTypeAnnotations l -> self#annotation_extended_list l
        | `RuntimeInvisibleTypeAnnotations l -> self#annotation_extended_list l
        | `Module _ -> ()
        | `ModuleRequires _ -> ()
        | `ModulePermits _ -> ()
        | `ModuleProvides _ -> ()
        | `Unknown _ -> ()
      method! class_attribute a = self#attribute (a :> Attribute.t)
      method! field_attribute a = self#attribute (a :> Attribute.t)
      method! regular_method_attribute a = self#attribute (a :> Attribute.t)
      method! constructor_attribute a = self#attribute (a :> Attribute.t)
      method! initializer_attribute a = self#attribute (a :> Attribute.t)
    end in
  ClassDefinition.iter iterator cd;
  { full_name = cd.ClassDefinition.name;
    package = fst (Name.split_class_name cd.ClassDefinition.name);
    extends = cd.ClassDefinition.extends;
    implements = cd.ClassDefinition.implements;
    is_interface = AccessFlag.mem_class `Interface cd.ClassDefinition.access_flags;
    is_abstract = AccessFlag.mem_class `Abstract cd.ClassDefinition.access_flags;
    fields = List.length cd.ClassDefinition.fields;
    methods = List.length cd.ClassDefinition.methods;
    code_size = !code_size;
    references = !references; }
  
let print_to_buffer buffer fmt l =
  let summaries = ref [] in
  let iterator zip =
    object
      inherit ArchiveTraversal.default_archive_iterator zip
      method! class_definition cd =
        summaries := (summary cd) :: !summaries
    end in
  List.iter
    (fun elem ->
      try
        elem
        |> Path.make_of_utf8
        |> ArchiveFile.make_of_path
        |> (fun arch -> ArchiveTraversal.iter (iterator arch) arch; arch)
        |> ArchiveFile.close_noerr
      with _ ->
        fail (Cannot_process_archive elem))
    l;
  let builder = Graph.make () in
  let utf8_of_bool = function
    | true -> @"true"
    | false -> @"false" in
  let utf8_of_int x =
    UTF8.of_string (string_of_int x) in
  (* internal vertices, i. e. classes appearing in archives *)
  let internal_vertices =
    List.fold_left
      (fun acc elem ->
        let name = Name.external_utf8_for_class elem.full_name in
        let package = fst (Name.split_class_name elem.full_name) in
        let package = Name.external_utf8_for_package package in
        let properties = [
          @"external", @"false";
          @"package", package;
          @"interface", utf8_of_bool elem.is_interface;
          @"abstract", utf8_of_bool elem.is_abstract;
          @"fields", utf8_of_int elem.fields;
          @"methods", utf8_of_int elem.methods;
          @"code_size", utf8_of_int elem.code_size;
        ] in
        Graph.add_vertex
          ~id:name
          ~label:name
          ~cluster:package
          ~properties
          builder;
        ClassNameSet.add elem.full_name acc)
      ClassNameSet.empty
      !summaries in
  (* external vertices, i. e. classes NOT appearing in archives *)
  let external_vertices = ref ClassNameSet.empty in
  let add_if_not_present id =
    if not (ClassNameSet.mem id internal_vertices
          || ClassNameSet.mem id !external_vertices) then begin
      external_vertices := ClassNameSet.add id !external_vertices;
      let name = Name.external_utf8_for_class id in
      let package = fst (Name.split_class_name id) in
      let package = Name.external_utf8_for_package package in
      let properties = [
        @"external", @"true";
        @"package", package;
      ] in
      Graph.add_vertex
        ~id:name
        ~label:name
        ~cluster:package
        ~properties
        builder;
    end in
  List.iter
    (fun elem ->
      ClassNameMap.iter
        (fun k _ -> add_if_not_present k)
        elem.references;
      (match elem.extends with
      | Some e -> add_if_not_present e
      | None -> ());
      List.iter
        add_if_not_present
        elem.implements)
    !summaries;
  (* edges for extends / implements *)
  let next = ref 0 in
  let next_id () =
    incr next;
    UTF8.of_string ("edge" ^ (string_of_int !next)) in
  List.iter
    (fun elem ->
      (match elem.extends with
      | Some parent ->
          Graph.add_edge
            ~id:(next_id ())
            ~directed:true
            ~label:@"extends"
            ~properties:[]
            ~vertices:[Name.external_utf8_for_class elem.full_name, @"";
                       Name.external_utf8_for_class parent, @""]
            builder
      | None -> ());
      List.iter
        (fun itf ->
          Graph.add_edge
            ~id:(next_id ())
            ~directed:true
            ~label:@"implements"
            ~properties:[]
            ~vertices:[Name.external_utf8_for_class elem.full_name, @"";
                       Name.external_utf8_for_class itf, @""]
            builder)
        elem.implements)
    !summaries;
  (* edges for references *)
  List.iter
    (fun elem ->
      ClassNameMap.iter
        (fun k v ->
          let properties = [
            @"weight", utf8_of_int v
          ] in
          Graph.add_edge
            ~id:(next_id ())
            ~directed:true
            ~label:@"references"
            ~properties
            ~vertices:[Name.external_utf8_for_class elem.full_name, @"";
                       Name.external_utf8_for_class k, @""]
            builder)
        elem.references)
    !summaries;
  let graph = Graph.to_graph builder in
  let vertice_types = [
    @"external", @"boolean";
    @"package", @"string";
    @"interface", @"boolean";
    @"abstract", @"boolean";
    @"fields", @"integer";
    @"methods", @"integer";
    @"code_size", @"integer";
  ] in
  let edge_types = [
    @"weight", @"integer";
  ] in
  let graph = Graph.({ graph with
                       vertice_types = map_of_list vertice_types;
                       edge_types = map_of_list edge_types }) in
  Graph.dump fmt graph buffer

let print_to_stream chan fmt l =
  let buffer = UTF8Buffer.make () in
  print_to_buffer buffer fmt l;
  let writer = UTF8LineWriter.make_of_stream chan in
  UTF8LineWriter.write_line
    writer
    (UTF8Buffer.contents buffer);
  UTF8LineWriter.flush writer

let print fmt l =
  print_to_stream OutputStream.stdout fmt l
