diff --git a/dune b/dune index 89ee5a3..0f7b651 100644 --- a/dune +++ b/dune @@ -68,7 +68,7 @@ (executable (name test) - (libraries lwt lwt.unix devkit extlib extunix libevent ocamlnet_lite ounit2 unix yojson) + (libraries lwt lwt.unix devkit extlib extunix libevent ocamlnet_lite ounit2 threads unix yojson) (modules test test_httpev)) ; uses 8GB+ RAM, so do not run as part of test suite diff --git a/files.ml b/files.ml index 659b145..3f26994 100644 --- a/files.ml +++ b/files.ml @@ -67,7 +67,7 @@ let mkdir_p ?(perm=0o755) path = in aux path -let save_as name ?(mode=0o644) f = +let save_as_regular name ?(mode=0o644) f = (* not using make_temp_file cause same dir is needed for atomic rename *) let temp = Printf.sprintf "%s.save.%d.tmp" name (U.gettid ()) in bracket (Unix.openfile temp [Unix.O_WRONLY;Unix.O_CREAT] mode) Unix.close begin fun fd -> @@ -81,3 +81,9 @@ let save_as name ?(mode=0o644) f = with exn -> Exn.suppress Unix.unlink temp; raise exn end + +let rec save_as name ?mode f = + match (Unix.lstat name).st_kind with + | Unix.S_LNK -> save_as (Unix.realpath name) ?mode f + | Unix.S_REG | (exception Unix.Unix_error (Unix.ENOENT, _, _)) -> save_as_regular name ?mode f + | _ -> Out_channel.with_open_gen [ Open_wronly ] 0 name f diff --git a/files.mli b/files.mli index 319ff47..b5ab3e6 100644 --- a/files.mli +++ b/files.mli @@ -22,11 +22,16 @@ val open_out_append_text : string -> out_channel val mkdir_p : ?perm:Unix.file_perm -> string -> unit (** [save_as filename ?mode f] is similar to - [Control.with_open_file_bin] except that writing is done to a - temporary file that will be renamed to [filename] after [f] has - succesfully terminated. Therefore this guarantee that either - [filename] will not be modified or will contain whatever [f] was - writing to it as a side-effect. + [Control.with_open_file_bin] for regular files, except that + writing is done to a temporary file that will be renamed to + [filename] after [f] has succesfully terminated. Therefore this + guarantee that either [filename] will not be modified or will + contain whatever [f] was writing to it as a side-effect. + + There is no such special treatment for special files (Unix.stat + kind not S_REG, e.g. devices, pipes, etc), instead they are + written to directly. Symlinks are followed (not overwritten in + place). Throws {!Unix.Unix_error} on broken symlinks. FIXME windows *) val save_as : string -> ?mode:Unix.file_perm -> (out_channel -> unit) -> unit diff --git a/test.ml b/test.ml index c9af39f..15fccb2 100644 --- a/test.ml +++ b/test.ml @@ -595,6 +595,80 @@ let () = assert_equal !accumulator 4; () +let with_temp_path name f = + let path = Filename.concat (Filename.get_temp_dir_name ()) name in + (try Sys.remove path with _ -> ()); + Fun.protect ~finally:(fun () -> try Sys.remove path with _ -> ()) (fun () -> f path) + +let () = test "Files.save_as writes to new regular file" @@ fun () -> + with_temp_path "test_save_as_new.txt" @@ fun path -> + Files.save_as path (fun oc -> output_string oc "hello\n"); + let content = In_channel.with_open_text path In_channel.input_all in + assert_equal ~printer:id "hello\n" content + +let () = test "Files.save_as overwrites existing regular file" @@ fun () -> + with_temp_path "test_save_as_overwrite.txt" @@ fun path -> + Out_channel.with_open_text path (fun oc -> output_string oc "old\n"); + Files.save_as path (fun oc -> output_string oc "new\n"); + let content = In_channel.with_open_text path In_channel.input_all in + assert_equal ~printer:id "new\n" content + +let () = test "Files.save_as no temp file left on success" @@ fun () -> + with_temp_path "test_save_as_no_temp.txt" @@ fun path -> + Files.save_as path (fun oc -> output_string oc "data\n"); + let temp = Printf.sprintf "%s.save.%d.tmp" path (U.gettid ()) in + assert_bool "temp file should not exist" (not (Sys.file_exists temp)) + +let () = test "Files.save_as no temp file left on failure" @@ fun () -> + with_temp_path "test_save_as_fail.txt" @@ fun path -> + (try Files.save_as path (fun _oc -> failwith "boom") with Failure _ -> ()); + let temp = Printf.sprintf "%s.save.%d.tmp" path (U.gettid ()) in + assert_bool "temp file should not exist" (not (Sys.file_exists temp)); + assert_bool "target file should not exist" (not (Sys.file_exists path)) + +let () = test "Files.save_as writes to /dev/null without error" @@ fun () -> + Files.save_as "/dev/null" (fun oc -> output_string oc "discarded\n"); + let temp = Printf.sprintf "/dev/null.save.%d.tmp" (U.gettid ()) in + assert_bool "temp file should not exist" (not (Sys.file_exists temp)); + assert_bool "/dev/null should be a char device" ((Unix.stat "/dev/null").st_kind = Unix.S_CHR) + +let () = test "Files.save_as writes to named pipe (FIFO)" @@ fun () -> + with_temp_path "test_save_as_fifo" @@ fun fifo_path -> + Unix.mkfifo fifo_path 0o644; + let received = ref "" in + let reader = Thread.create (fun () -> received := In_channel.with_open_text fifo_path In_channel.input_all) () in + Files.save_as fifo_path (fun oc -> output_string oc "fifo data\n"); + Thread.join reader; + assert_equal ~printer:id "fifo data\n" !received + +let () = test "Files.save_as no temp file created for FIFO" @@ fun () -> + with_temp_path "test_save_as_fifo2" @@ fun fifo_path -> + Unix.mkfifo fifo_path 0o644; + let reader = Thread.create (fun () -> ignore (In_channel.with_open_text fifo_path In_channel.input_all)) () in + Files.save_as fifo_path (fun oc -> output_string oc "data\n"); + Thread.join reader; + let temp = Printf.sprintf "%s.save.%d.tmp" fifo_path (U.gettid ()) in + assert_bool "temp file should not exist" (not (Sys.file_exists temp)) + +let () = test "Files.save_as writes through symlink without clobbering it" @@ fun () -> + with_temp_path "test_save_as_symlink_target.txt" @@ fun target -> + with_temp_path "test_save_as_symlink_link.txt" @@ fun link -> + Out_channel.with_open_text target (fun oc -> output_string oc "old\n"); + Unix.symlink target link; + Files.save_as link (fun oc -> output_string oc "new\n"); + assert_bool "symlink should still be a symlink" ((Unix.lstat link).st_kind = Unix.S_LNK); + let content = In_channel.with_open_text target In_channel.input_all in + assert_equal ~printer:id "new\n" content + +let () = test "Files.save_as fails on broken symlink" @@ fun () -> + with_temp_path "test_save_as_symlink_target.txt" @@ fun target -> + with_temp_path "test_save_as_symlink_link.txt" @@ fun link -> + Unix.symlink target link; + try + Files.save_as link (fun oc -> output_string oc "new\n"); + assert_failure "should fail on broken symlink" + with Unix.Unix_error(Unix.ENOENT, "realpath", _) -> () + let () = test "Logfmt" begin fun () -> let eq name expected got = assert_equal ~msg:name ~printer:(fun s -> sprintf "%S" s) expected got