(* Postgres: OCaml bindings for PostgreSQL
   Copyright (C) 2001  Alain Frisch         <Alain.Frisch@ens.fr>

   This library is free software; you can redistribute it and/or
   modify it under the terms of the GNU Lesser General Public
   License; see the file LGPL.
*)

external init: unit -> unit = "init_PQstub"
let _ = init ()

module Result =
struct
  type t

  type status = 
    | Empty_query 
    | Command_ok
    | Tuples_ok
    | Copy_out
    | Copy_in
    | Bad_response
    | Nonfatal_error
    | Fatal_error

  type oid = int

  external is_null: t -> bool = "res_isnull"
				  
  external status: t -> status = "stub_PQresultStatus"
  external string_of_status: status -> string = "stub_PQresStatus" 
  external error: t -> string = "stub_PQresultErrorMessage"
  external ntuples: t -> int = "stub_PQntuples"
  external nfields: t -> int = "stub_PQnfields"
  external binary_tuples: t -> bool = "stub_PQbinaryTuples"
  external fname: t -> int -> string = "stub_PQfname"
  external fnumber: t -> string -> int ="stub_PQfnumber"
  external ftype: t -> int -> oid = "stub_PQftype"
  external fsize: t -> int -> int = "stub_PQfsize"
  external fmod: t -> int -> int = "stub_PQfmod"
  external getvalue: t -> int -> int -> string = "stub_PQgetvalue"
  external getlength: t -> int -> int -> int = "stub_PQgetlength"
  external getisnull: t -> int -> int -> bool = "stub_PQgetisnull"
  external cmd_status: t -> string = "stub_PQcmdStatus"
  external cmd_tuples: t -> string = "stub_PQcmdTuples"
  external oid_value: t -> oid = "stub_PQoidValue"
end

type error =
  | Field_out_of_range of int*int
  | Tuple_out_of_range of int*int
  | Binary
  | ConnectionFailure of string
  | UnexpectedStatus of Result.status * string * (Result.status list)

exception Error of error

let string_of_error = function
  | Field_out_of_range (f,max) ->
      Printf.sprintf "field number %i is out of range 0..%i" f (max-1)
  | Tuple_out_of_range (f,max) ->
      Printf.sprintf "tuple number %i is out of range 0..%i" f (max-1)
  | Binary ->
      Printf.sprintf "this method doesn't accept binary tuples"
  | ConnectionFailure s ->
      s
  | UnexpectedStatus (s,msg,sl) ->
      Printf.sprintf "Result status %s unexpected (expected status:%s); %s"
	(Result.string_of_status s)
	(String.concat "," (List.map Result.string_of_status sl))
	msg

class type result =
object
  method internal: Result.t

  method status: Result.status
  method error: string

  method ntuples: int
  method nfields: int
  method binary: bool

  method fname: int -> string
  method fnumber: string -> int
  method ftype: int -> Result.oid
  method fsize: int -> int
  method fmod: int -> int
  method getvalue: int -> int -> string
  method getlength: int -> int -> int
  method getisnull: int -> int -> bool

  method cmd_status: string
  method cmd_tuples: string
  method oid_value: int

  method get_fields_list: string list

  method get_tuple_list: int -> string list
  method get_tuple_array: int -> string array

  method get_list: string list list
  method get_array: string array array
end


class make_result res : result = 
  let nfields =  Result.nfields res
  and ntuples = Result.ntuples res
  and binary  = Result.binary_tuples res
  in
  let check_field field =
    if (field < 0) || (field >= nfields) then raise (Error (Field_out_of_range (field,nfields)))
  and check_tuple tuple =
    if (tuple < 0) || (tuple >= ntuples) then raise (Error (Tuple_out_of_range (tuple,ntuples)))
  in
object
  method internal = res

  method status = Result.status res
  method error  = Result.error res
  method ntuples = ntuples
  method nfields = nfields
  method binary = binary

  method fname field   = check_field field; Result.fname res field
  method fnumber s     = 
    match (Result.fnumber res s) with
      | -1 -> raise Not_found
      | n ->n
  method ftype field   = check_field field; Result.ftype res field
  method fsize field   = check_field field; Result.fsize res field
  method fmod field    = check_field field; Result.fmod res field

  method getvalue tuple field =
    check_field field; check_tuple tuple;
    if (binary) then raise (Error Binary);
    Result.getvalue res tuple field

  method getlength tuple field =
    check_field field; check_tuple tuple;
    Result.getlength res tuple field

  method getisnull tuple field =
    check_field field; check_tuple tuple;
    Result.getisnull res tuple field

  method cmd_status = Result.cmd_status res
  method cmd_tuples = Result.cmd_tuples res
  method oid_value  = Result.oid_value res

  val mutable cache_fields_list = None
  method get_fields_list =
    match cache_fields_list with
      | Some f -> f
      | None ->
	  let f = ref [] in
	  for field = nfields - 1 downto 0 do 
	    f := (Result.fname res field) :: !f
	  done;
	  cache_fields_list <- Some !f;
	  !f
	    
  method get_tuple_list tuple =
    check_tuple tuple;
    let t = ref [] in
    for field = nfields - 1 downto 0 do 
      t := (Result.getvalue res tuple field) :: !t
    done;
    !t
    
  method get_tuple_array tuple =
    check_tuple tuple;
    let t = Array.make nfields "" in
    for field = 0 to nfields - 1 do 
      t.(field) <- Result.getvalue res tuple field
    done;
    t

  val mutable cache_list = None
  method get_list =
(* it has been so long since I coded so imperatively ! *)
    match cache_list with
      | Some l -> l
      | None ->
	  let r = ref [] in
	  let t = ref [] in
	  for tuple = ntuples - 1 downto 0 do 
	    t:=[];
	    for field = nfields - 1 downto 0 do 
	      t := (Result.getvalue res tuple field) :: !t
	    done;
	    r := !t :: !r
	  done;
	  cache_list <- Some !r;
	  !r

  val mutable cache_array = None
  method get_array =
    match cache_array with
      | Some a -> a
      | None ->
	  let a =  Array.make_matrix ntuples nfields "" in
	  for tuple = 0 to ntuples - 1 do 
	    for field = 0 to nfields - 1 do 
	      a.(tuple).(field) <- Result.getvalue res tuple field
	    done;
	  done;
	  cache_array <- Some a;
	  a
 
end

module Connection =
struct
  type t
  type status = Ok | Bad

  external connect: string -> t = "stub_PQconnectdb"
  external is_null: t -> bool = "conn_isnull"
  external finish: t -> unit = "stub_PQfinish"
  external reset: t -> unit = "stub_PQreset"
  external db: t -> string = "stub_PQdb"
  external user: t -> string = "stub_PQuser"
  external pass: t -> string = "stub_PQpass"
  external host: t -> string = "stub_PQhost"
  external port: t -> string = "stub_PQport"
  external tty: t -> string = "stub_PQtty"
  external options: t -> string = "stub_PQoptions"
  external status: t -> status = "stub_PQstatus"
  external error_message: t -> string = "stub_PQerrorMessage"
  external backend_pid: t -> int = "stub_PQbackendPID"

  external notifies: t -> (string * int) option = "stub_PQnotifies"
  external set_notice_processor: t -> (string -> unit) -> unit 
		       = "stub_PQsetNoticeProcessor"

  external set_nonblocking: t -> bool -> int = "stub_PQsetnonblocking"
  external is_nonblocking: t -> bool = "stub_PQisnonblocking"
  external consume_input: t -> int = "stub_PQconsumeInput"

  external is_busy: t -> bool = "stub_PQisBusy"
  external flush: t -> int = "stub_PQflush"
  external socket: t -> int = "stub_PQsocket"
  external request_cancel: t -> int = "stub_PQrequestCancel"

  external getline: t -> string -> int -> int -> int = "stub_PQgetline"
  external getline_async: t -> string -> int -> int -> int = "stub_PQgetlineAsync"
  external putline: t -> string -> int = "stub_PQputline"
  external putnbytes: t -> string -> int -> int -> int = "stub_PQputnbytes"
  external endcopy: t -> int = "stub_PQendcopy"

  external exec: t -> string -> Result.t = "stub_PQexec"
  external send_query: t -> string -> int = "stub_PQsendQuery"
  external get_result: t -> Result.t = "stub_PQgetResult"
  external make_empty: t -> Result.status -> Result.t = "stub_PQmakeEmptyPGresult"
end

module LargeObjects =
struct
  type t = int
  type conn = Connection.t
  type oid = Result.oid
  external lo_open: conn -> oid -> t = "stub_lo_open"
  external close: conn -> t -> int = "stub_lo_close"
  external read: conn -> t -> string -> int -> int -> int = "stub_lo_read"
  external write: conn -> t -> string -> int -> int -> int = "stub_lo_write"
  external seek: conn -> t -> int -> int = "stub_lo_lseek"
  external tell: conn -> t -> int = "stub_lo_tell"
  external create: conn -> oid = "stub_lo_creat"
  external unlink: conn -> oid -> oid = "stub_lo_unlink"
  external import: conn -> string -> oid = "stub_lo_import"
  external export: conn -> oid -> string -> int = "stub_lo_export"
end

class connection conninfo =
  let conn = Connection.connect conninfo in
  let signal_error () = 
    raise (Error (ConnectionFailure (Connection.error_message conn))) in
  let check_null () =
    if (Connection.is_null conn) then signal_error () in
  let _ =
    if (Connection.status conn <> Connection.Ok) then 
       (
	 let s = Connection.error_message conn in
	 Connection.finish conn; 
	 raise (Error (ConnectionFailure s))
       )
  in
object(self)
  method internal = conn

  method may_reset =
    check_null ();
    if (Connection.status conn = Connection.Bad) then
      ( Connection.reset conn;
	if (Connection.status conn <> Connection.Ok) then 
	  (signal_error ()) )

  method close    = check_null (); Connection.finish conn
  method reset    = check_null (); Connection.reset conn

  (* Accessors *)

  method db       = check_null (); Connection.db conn
  method user     = check_null (); Connection.user conn
  method pass     = check_null (); Connection.pass conn
  method host     = check_null (); Connection.host conn
  method port     = check_null (); Connection.port conn
  method tty      = check_null (); Connection.tty conn
  method options  = check_null (); Connection.options conn
  method backend  = check_null (); Connection.backend_pid conn

  (* Notification *)
  method notification =
    check_null ();
    Connection.notifies conn

  (* Notice processor *)
  method set_notice_callback f =
    check_null ();
    Connection.set_notice_processor conn f

  (* Non blocking mode *)
  method set_nonblocking b = 
    check_null (); 
    if (Connection.set_nonblocking conn b <> 0) then signal_error ()

  method is_nonblocking = 
    check_null (); 
    Connection.is_nonblocking conn

  method consume_input = 
    check_null (); 
    if (Connection.consume_input conn <> 1) then signal_error ()

  method is_busy = 
    check_null (); 
    Connection.is_busy conn

  method flush = 
    check_null (); 
    if (Connection.flush conn <> 0) then signal_error ()

  method socket = 
    check_null (); 
    let sock = Connection.socket conn in
    if (sock = -1) then signal_error () else sock

  method request_cancel = 
    check_null (); 
    if (Connection.request_cancel conn = 0) then signal_error ()

  (* Copy operations *)

  method getline buf pos len =
    check_null ();
    if (len<0) || (pos<0) || (pos+len>String.length buf) then
      invalid_arg "Postgres.connection#getline";
    Connection.getline conn buf pos len

  method getline_async buf pos len =
    check_null ();
    if (len<0) || (pos<0) || (pos+len>String.length buf) then
      invalid_arg "Postgres.connection#getline";
    Connection.getline conn buf pos len

  method putline buf =
    check_null ();
    if (Connection.putline conn buf <>0) then signal_error ()
 
  method putnbytes buf pos len =
    check_null ();
    if (len<0) || (pos<0) || (pos+len>String.length buf) then
      invalid_arg "Postgres.connection#putnbytes";
    if (Connection.putnbytes conn buf pos len <>0) then signal_error ()
 
  method endcopy =
    check_null ();
    if (Connection.endcopy conn <>0) then signal_error ()

  method copy_out f =
    check_null ();
    let buf = Buffer.create 1024 in
    let s = String.create 512 in

    let rec aux r =
      let zero = String.index s '\000' in
      Buffer.add_substring buf s 0 zero;
      (match r with
	 | 0 -> f (Buffer.contents buf); Buffer.clear buf; ligne ()
	 | 1 -> aux (Connection.getline conn s 0 (String.length s))
	 | _ -> 
	     self#endcopy;
	     f (Buffer.contents buf)
      )
    and ligne () =
      let r = Connection.getline conn s 0 (String.length s) in
      if (String.sub s 0 3 <> "\\.\000") 
      then aux r
      else self#endcopy
    in
    ligne ()

  method copy_out_channel chan =
    self#copy_out (fun s -> output_string chan s; output_string chan "\n")

  method copy_in_channel chan =
    try while true do self#putline (input_line chan); self#putline "\n" done;
    with End_of_file ->
      self#putline "\\.\n";
      self#endcopy

  (* Request *)
  method empty_result status =
    check_null ();
    new make_result (Connection.make_empty conn status)

  method exec query =
    check_null ();
    let res = Connection.exec conn query in
    if (Result.is_null res) then signal_error () else new make_result res

  method send query =
    check_null ();
    if (Connection.send_query conn query <> 1) then signal_error ()

  method get_result =
    check_null ();
    let res = Connection.get_result conn in
    if (Result.is_null res) then None else Some (new make_result res)

  method exec_expect query status =
    let res = self#exec query in
    if not (List.mem res#status status) then
      raise (Error (UnexpectedStatus (res#status,res#error,status)));
    res

(* Large objects *)

  method lo_open oid =
    check_null ();
    let lo = LargeObjects.lo_open conn oid in
    if (lo = -1) then signal_error ();
    lo

  method lo_close oid =
    check_null ();
    if (LargeObjects.close conn oid = -1) then signal_error ()

  method lo_read lo buf pos len  =
    check_null ();
    if (len<0) || (pos<0) || (pos+len>String.length buf) then
      invalid_arg "Postgres.connection#lo_read";
    let read = LargeObjects.read conn lo buf pos len in
    if (read = -1) then signal_error ();
    read

  method lo_write lo buf pos len  =
    check_null ();
    if (len<0) || (pos<0) || (pos+len>String.length buf) then
      invalid_arg "Postgres.connection#lo_write";
    let w = LargeObjects.write conn lo buf pos len in
    if (w < len) then signal_error ()

  method lo_write_string lo buf =
    self#lo_write lo buf 0 (String.length buf)

  method lo_seek lo pos =
    check_null ();
    if (LargeObjects.seek conn lo pos < 0) then signal_error ()

  method lo_create =
    check_null ();
    let lo =  LargeObjects.create conn in
    if (lo <= 0) then signal_error ();
    lo
   
  method lo_tell lo  =
    check_null ();
    let pos = LargeObjects.tell conn lo in
    if (pos = -1) then signal_error ();
    pos
      
  method lo_unlink oid =
    check_null ();
    let oid = LargeObjects.unlink conn oid in
    if (oid = -1) then signal_error ()

  method lo_import filename =
    check_null ();
    let oid = LargeObjects.import conn filename in
    if (oid = 0) then signal_error ();
    oid
      
  method lo_export oid filename =
    check_null ();
    if (LargeObjects.export conn oid filename <= 0) then signal_error ()
      

end


let conninfo ?host ?hostaddr ?port ?dbname ?user ?password ?options ?tty ?requiressl () =
  let b = Buffer.create 512 in
  let field name = function
    | None -> ()
    | Some x -> 
	Printf.bprintf b "%s='" name;
	for i = 0 to String.length x - 1 do
	  if x.[i]='\'' 
	  then Buffer.add_string b "\\'" 
	  else Buffer.add_char b x.[i]
	done;
	Buffer.add_string b "' "
  in
  field "host" host;
  field "hostaddr" hostaddr;
  field "port" port;
  field "dbname" dbname;
  field "user" user;
  field "password" password;
  field "options" options;
  field "tty" tty;
  field "requiressl" requiressl;
  Buffer.contents b


external escapeString: dest:string -> int -> src:string -> int -> int -> int = "stub_PQescapeString"

let escape_substring s pos len =
  if (pos < 0) || (len < 0) || (pos + len > String.length s) then
    invalid_arg "escape_substring";
  let buf = String.create (len * 2 + 1) in
  let n = escapeString ~dest:buf 0 ~src:s pos len in
  String.sub buf 0 n

let escape_string s =
  escape_substring s 0 (String.length s)
