/* * Copyright (c) 2024, Antonio Gabriel Muñoz Conejo <antoniogmc at gmail dot com> * Distributed under the terms of the MIT License */ package com.github.tonivade.vavr.effect; import static io.vavr.Function1.identity; import java.time.Duration; import java.util.NoSuchElementException; import java.util.Objects; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.function.Supplier; import io.vavr.CheckedConsumer; import io.vavr.CheckedFunction0; import io.vavr.CheckedRunnable; import io.vavr.Function1; import io.vavr.Function2; import io.vavr.PartialFunction; import io.vavr.Tuple; import io.vavr.Tuple2; import io.vavr.collection.HashMap; import io.vavr.collection.List; import io.vavr.collection.Seq; import io.vavr.concurrent.Future; import io.vavr.concurrent.Promise; import io.vavr.control.Either; import io.vavr.control.Option; import io.vavr.control.Try; public sealed interface IO<T> { IO<Unit> UNIT = pure(Unit.unit()); default Future<T> runAsync() { return runAsync(this, IOConnection.UNCANCELLABLE).future(); } default Future<T> runAsync(Executor executor) { return forked(executor).andThen(this).runAsync(); } default T unsafeRunSync() { return safeRunSync().get(); } default Try<T> safeRunSync() { return runAsync().toTry(); } default void safeRunAsync(Consumer<? super Try<? extends T>> callback) { safeRunAsync(Future.DEFAULT_EXECUTOR, callback); } default void safeRunAsync(Executor executor, Consumer<? super Try<? extends T>> callback) { runAsync(executor).onComplete(callback); } default <R> IO<R> map(Function1<? super T, ? extends R> map) { return flatMap(map.andThen(IO::pure)); } default <R> IO<R> flatMap(Function1<? super T, ? extends IO<? extends R>> map) { return new FlatMapped<>(this, map); } default <R> IO<R> andThen(IO<? extends R> after) { return flatMap(ignore -> after); } default <R> IO<R> ap(IO<Function1<? super T, ? extends R>> apply) { return parMap2(Future.DEFAULT_EXECUTOR, this, apply, (v, a) -> a.apply(v)); } default IO<Try<T>> attempt() { return map(Try::success).recover(Try::failure); } default IO<Either<Throwable, T>> either() { return attempt().map(Try::toEither); } default <L, R> IO<Either<L, R>> either(Function1<? super Throwable, ? extends L> mapError, Function1<? super T, ? extends R> mapper) { return either().map(either -> either.bimap(mapError, mapper)); } default <R> IO<R> redeem(Function1<? super Throwable, ? extends R> mapError, Function1<? super T, ? extends R> mapper) { return attempt().map(result -> result.fold(mapError, mapper)); } default <R> IO<R> redeemWith(Function1<? super Throwable, IO<? extends R>> mapError, Function1<? super T, IO<? extends R>> mapper) { return attempt().flatMap(result -> result.fold(mapError, mapper)); } default IO<T> recover(Function1<? super Throwable, ? extends T> mapError) { return recoverWith(partialFunction(x -> true, mapError.andThen(IO::pure))); } @SuppressWarnings("unchecked") default <X extends Throwable> IO<T> recover(Class<X> type, Function1<? super X, ? extends T> function) { return recoverWith(partialFunction(error -> error.getClass().equals(type), t -> function.andThen(IO::pure).apply((X) t))); } default IO<T> recoverWith(PartialFunction<? super Throwable, IO<? extends T>> mapper) { return new Recover<>(this, mapper); } default IO<Tuple2<Duration, T>> timed() { return task(System::nanoTime).flatMap( start -> map(result -> Tuple.of(Duration.ofNanos(System.nanoTime() - start), result))); } default IO<Fiber<T>> fork() { return async(callback -> { IOConnection connection = IOConnection.cancellable(); Promise<T> promise = runAsync(this, connection); IO<T> join = fromFuture(promise.future()); IO<Unit> cancel = exec(connection::cancel); callback.accept(Try.success(new Fiber<>(join, cancel))); }); } default IO<T> timeout(Duration duration) { return timeout(Future.DEFAULT_EXECUTOR, duration); } default IO<T> timeout(Executor executor, Duration duration) { return racePair(executor, this, sleep(duration)).flatMap(either -> either.fold( ta -> ta._2().cancel().map(x -> ta._1()), tb -> tb._1().cancel().flatMap(x -> raiseError(new TimeoutException())))); } default IO<T> repeat() { return repeat(1); } default IO<T> repeat(int times) { return repeat(Schedule.<T>recurs(times).zipRight(Schedule.identity())); } default IO<T> repeat(Duration delay) { return repeat(delay, 1); } default IO<T> repeat(Duration delay, int times) { return repeat(Schedule.<T>recursSpaced(delay, times).zipRight(Schedule.identity())); } default <B> IO<B> repeat(Schedule<T, B> schedule) { return repeatOrElse(schedule, (e, b) -> raiseError(e)); } default <B> IO<B> repeatOrElse( Schedule<T, B> schedule, Function2<Throwable, Option<B>, IO<B>> orElse) { return repeatOrElseEither(schedule, orElse).map(IO::merge); } default <B, C> IO<Either<C, B>> repeatOrElseEither( Schedule<T, B> schedule, Function2<Throwable, Option<B>, IO<C>> orElse) { return new Repeat<>(this, schedule, orElse).run(); } default IO<T> retry() { return retry(1); } default IO<T> retry(int maxRetries) { return retry(Schedule.recurs(maxRetries)); } default IO<T> retry(Duration delay) { return retry(delay, 1); } default IO<T> retry(Duration delay, int maxRetries) { return retry(Schedule.<Throwable>recursSpaced(delay, maxRetries)); } default <B> IO<T> retry(Schedule<Throwable, B> schedule) { return retryOrElse(schedule, (e, b) -> raiseError(e)); } default <B> IO<T> retryOrElse( Schedule<Throwable, B> schedule, Function2<Throwable, B, IO<T>> orElse) { return retryOrElseEither(schedule, orElse).map(IO::merge); } default <B, C> IO<Either<B, T>> retryOrElseEither( Schedule<Throwable, C> schedule, Function2<Throwable, C, IO<B>> orElse) { return new Retry<>(this, schedule, orElse).run(); } @SuppressWarnings("unchecked") static <T> IO<T> narrowK(IO<? extends T> value) { return (IO<T>) value; } static <T> IO<T> pure(T value) { return new Pure<>(value); } static <A, B> IO<Either<A, B>> race(IO<? extends A> fa, IO<? extends B> fb) { return race(Future.DEFAULT_EXECUTOR, fa, fb); } static <A, B> IO<Either<A, B>> race(Executor executor, IO<? extends A> fa, IO<? extends B> fb) { return racePair(executor, fa, fb).flatMap(either -> either.fold( ta -> ta._2().cancel().map(x -> Either.left(ta._1())), tb -> tb._1().cancel().map(x -> Either.right(tb._2())))); } static <A, B> IO<Either<Tuple2<A, Fiber<B>>, Tuple2<Fiber<A>, B>>> racePair( Executor executor, IO<? extends A> fa, IO<? extends B> fb) { return cancellable(callback -> { IOConnection connection1 = IOConnection.cancellable(); IOConnection connection2 = IOConnection.cancellable(); Promise<A> promiseA = runAsync(forked(executor).andThen(fa), connection1); Promise<B> promiseB = runAsync(forked(executor).andThen(fb), connection2); promiseA.future().onComplete(result -> callback.accept( result.map(a -> Either.left( Tuple.of(a, new Fiber<>(fromFuture(promiseB.future()), exec(connection2::cancel))))))); promiseB.future().onComplete(result -> callback.accept( result.map(b -> Either.right( Tuple.of(new Fiber<>(fromFuture(promiseA.future()), exec(connection2::cancel)), b))))); return exec(() -> { try { connection1.cancel(); } finally { connection2.cancel(); } }); }); } static <T> IO<T> raiseError(Throwable error) { return new Failure<>(error); } static <T> IO<T> delay(Duration delay, CheckedFunction0<? extends T> lazy) { return sleep(delay).andThen(task(lazy)); } static <T> IO<T> suspend(Supplier<IO<? extends T>> lazy) { return new Suspend<>(lazy); } static <T, R> Function1<T, IO<R>> lift(Function1<T, R> task) { return task.andThen(IO::pure); } public static <A, B> Function1<A, IO<B>> liftOption(Function1<? super A, ? extends Option<? extends B>> function) { return value -> fromOption(function.apply(value)); } public static <A, B> Function1<A, IO<B>> liftTry(Function1<? super A, ? extends Try<? extends B>> function) { return value -> fromTry(function.apply(value)); } public static <A, B> Function1<A, IO<B>> liftEither(Function1<? super A, ? extends Either<Throwable, ? extends B>> function) { return value -> fromEither(function.apply(value)); } static <T> IO<T> fromOption(Option<? extends T> task) { return fromEither(toEither(task)); } static <T> IO<T> fromTry(Try<? extends T> task) { return fromEither(task.toEither()); } static <T> IO<T> fromEither(Either<Throwable, ? extends T> task) { return task.fold(IO::raiseError, IO::pure); } static <T> IO<T> fromFuture(Future<? extends T> promise) { CheckedConsumer<Consumer<? super Try<? extends T>>> callback = promise::onComplete; return async(callback); } static <T> IO<T> fromCompletableFuture(CompletableFuture<? extends T> promise) { return fromFuture(Future.fromCompletableFuture(promise)); } static IO<Unit> sleep(Duration duration) { return sleep(Future.DEFAULT_EXECUTOR, duration); } static IO<Unit> sleep(Executor executor, Duration duration) { return cancellable(callback -> { Future<Unit> sleep = FutureModule.sleep(executor,duration) .onComplete(result -> callback.accept(Try.success(Unit.unit()))); return exec(() -> sleep.cancel(true)); }); } static IO<Unit> exec(CheckedRunnable task) { return task(asSupplier(task)); } static <T> IO<T> task(CheckedFunction0<? extends T> producer) { return new Delay<>(producer); } static <T> IO<T> never() { return async(callback -> {}); } static IO<Unit> forked() { return forked(Future.DEFAULT_EXECUTOR); } static IO<Unit> forked(Executor executor) { return async(callback -> executor.execute(() -> callback.accept(Try.success(Unit.unit())))); } static <T> IO<T> async(CheckedConsumer<Consumer<? super Try<? extends T>>> callback) { return cancellable(asFunction(callback)); } static <T> IO<T> cancellable(Function1<Consumer<? super Try<? extends T>>, IO<Unit>> callback) { return new Async<>(callback); } static <A, T> IO<Function1<A, IO<T>>> memoize(Function1<A, IO<T>> function) { return memoize(Future.DEFAULT_EXECUTOR, function); } static <A, T> IO<Function1<A, IO<T>>> memoize(Executor executor, Function1<A, IO<T>> function) { var ref = Reference.make(HashMap.<A, Promise<T>>empty()); return ref.map(r -> { Function1<A, IO<IO<T>>> result = a -> r.modify(map -> map.get(a).fold(() -> { Promise<T> promise = Promise.make(); function.apply(a).safeRunAsync(executor, promise::tryComplete); return Tuple.of(fromFuture(promise.future()), map.put(a, promise)); }, promise -> Tuple.of(fromFuture(promise.future()), map))); return result.andThen(io -> io.flatMap(identity())); }); } static IO<Unit> unit() { return UNIT; } static <T, R> IO<R> bracket(IO<? extends T> acquire, Function1<? super T, ? extends IO<? extends R>> use, Function1<? super T, IO<Unit>> release) { return cancellable(callback -> { IOConnection cancellable = IOConnection.cancellable(); Promise<? extends T> promise = runAsync(acquire, cancellable); promise.future() .onFailure(error -> callback.accept(Try.failure(error))) .onSuccess(resource -> runAsync(use.apply(resource), cancellable).future() .onComplete(result -> runAsync(release.apply(resource), cancellable).future() .onComplete(ignore -> callback.accept(result)) )); return exec(cancellable::cancel); }); } static <T, R> IO<R> bracket(IO<? extends T> acquire, Function1<? super T, ? extends IO<? extends R>> use, CheckedConsumer<? super T> release) { return bracket(acquire, use, asFunction(release)); } static <T extends AutoCloseable, R> IO<R> bracket(IO<? extends T> acquire, Function1<? super T, ? extends IO<? extends R>> use) { return bracket(acquire, use, AutoCloseable::close); } static IO<Unit> sequence(Seq<IO<?>> sequence) { IO<?> initial = unit(); return sequence.foldLeft(initial, (IO<?> a, IO<?> b) -> a.andThen(b)).andThen(unit()); } static <A> IO<Seq<A>> traverse(Seq<IO<A>> sequence) { return traverse(Future.DEFAULT_EXECUTOR, sequence); } static <A> IO<Seq<A>> traverse(Executor executor, Seq<IO<A>> sequence) { return sequence.foldLeft(pure(List.empty()), (IO<Seq<A>> xs, IO<A> a) -> parMap2(executor, xs, a, Seq::append)); } static <A, B, C> IO<C> parMap2(IO<? extends A> fa, IO<? extends B> fb, Function2<? super A, ? super B, ? extends C> mapper) { return parMap2(Future.DEFAULT_EXECUTOR, fa, fb, mapper); } static <A, B, C> IO<C> parMap2(Executor executor, IO<? extends A> fa, IO<? extends B> fb, Function2<? super A, ? super B, ? extends C> mapper) { return cancellable(callback -> { IOConnection connection1 = IOConnection.cancellable(); IOConnection connection2 = IOConnection.cancellable(); Promise<A> promiseA = runAsync(forked(executor).andThen(fa), connection1); Promise<B> promiseB = runAsync(forked(executor).andThen(fb), connection2); promiseA.future().onComplete(a -> promiseB.future().onComplete(b -> callback.accept(map2(a, b, mapper)))); return exec(() -> { try { connection1.cancel(); } finally { connection2.cancel(); } }); }); } static <A, B> IO<Tuple2<A, B>> tuple(IO<? extends A> fa, IO<? extends B> fb) { return tuple(Future.DEFAULT_EXECUTOR, fa, fb); } static <A, B> IO<Tuple2<A, B>> tuple(Executor executor, IO<? extends A> fa, IO<? extends B> fb) { return parMap2(executor, fa, fb, Tuple::of); } private static <T> Promise<T> runAsync(IO<T> current, IOConnection connection) { return runAsync(current, connection, new CallStack<>(), Promise.make()); } private static <A, B, C> Try<? extends C> map2(Try<A> a, Try<B> b, Function2<? super A, ? super B, ? extends C> mapper) { return a.flatMap(x -> b.map(y -> mapper.apply(x, y))); } private static <T> Either<Throwable, T> toEither(Option<? extends T> task) { return task.fold(() -> Either.left(new NoSuchElementException()), Either::right); } private static <A> A merge(Either<A, A> either) { return either.fold(Function1.identity(), Function1.identity()); } @SuppressWarnings("serial") private static <A, B> PartialFunction<A, B> partialFunction(Predicate<? super A> matcher, Function1<? super A, ? extends B> function) { return new PartialFunction<>() { @Override public boolean isDefinedAt(A value) { return matcher.test(value); } @Override public B apply(A t) { return function.apply(t); } }; } private static CheckedFunction0<Unit> asSupplier(CheckedRunnable task) { return () -> { task.unchecked().run(); return Unit.unit(); }; } private static <T> Function1<T, IO<Unit>> asFunction(CheckedConsumer<? super T> release) { return t -> { release.unchecked().accept(t); return unit(); }; } @SuppressWarnings("unchecked") private static <T, U, V> Promise<T> runAsync(IO<T> current, IOConnection connection, CallStack<T> stack, Promise<T> promise) { while (true) { try { current = unwrap(current, stack, identity()); if (current instanceof Pure<T> pure) { return promise.success(pure.value); } if (current instanceof Async<T> async) { return executeAsync(async, connection, promise); } if (current instanceof FlatMapped) { stack.push(); var flatMapped = (FlatMapped<U, T>) current; IO<U> source = narrowK(unwrap(flatMapped.current, stack, u -> u.flatMap(flatMapped.next))); if (source instanceof Async<U> async) { Promise<U> nextPromise = Promise.make(); nextPromise.future().andThen(tryU -> tryU.onFailure(promise::failure) .onSuccess(u -> runAsync(narrowK(flatMapped.next.apply(u)), connection, stack, promise))); executeAsync(async, connection, nextPromise); return promise; } if (source instanceof Pure<U> pure) { Function1<? super U, IO<T>> andThen = flatMapped.next.andThen(IO::narrowK); current = andThen.apply(pure.value); } else if (source instanceof FlatMapped) { FlatMapped<V, U> flatMapped2 = (FlatMapped<V, U>) source; current = flatMapped2.current.flatMap(a -> flatMapped2.next.apply(a).flatMap(flatMapped.next)); } } else { stack.pop(); } } catch (Throwable error) { Option<IO<T>> result = stack.tryHandle(error); if (result.isDefined()) { current = result.get(); } else { return promise.failure(error); } } } } private static <T, U> IO<T> unwrap(IO<T> current, CallStack<U> stack, Function1<IO<? extends T>, IO<? extends U>> next) { while (true) { if (current instanceof Failure<T> failure) { return stack.sneakyThrow(failure.error); } else if (current instanceof Recover<T> recover) { stack.add(partialFunction(recover.mapper::isDefinedAt, recover.mapper.andThen(next))); current = recover.current; } else if (current instanceof Suspend<T> suspend) { current = narrowK(suspend.lazy.get()); } else if (current instanceof Delay<T> delay) { return pure(delay.task.unchecked().get()); } else if (current instanceof Pure) { return current; } else if (current instanceof FlatMapped) { return current; } else if (current instanceof Async) { return current; } else { throw new IllegalStateException(); } } } private static <T> Promise<T> executeAsync(Async<T> current, IOConnection connection, Promise<T> promise) { if (connection.isCancellable() && !connection.updateState(StateIO::startingNow).isRunnable()) { return promise.complete(Try.failure(new CancellationException())); } connection.setCancelToken(current.callback.apply(promise::tryComplete)); promise.future().andThen(x -> connection.setCancelToken(UNIT)); if (connection.isCancellable() && connection.updateState(StateIO::notStartingNow).isCancellingNow()) { connection.cancelNow(); } return promise; } record Pure<T>(T value) implements IO<T> { public Pure { Objects.requireNonNull(value); } @Override public String toString() { return "Pure(" + value + ")"; } } record Failure<T>(Throwable error) implements IO<T> { public Failure { Objects.requireNonNull(error); } @Override public String toString() { return "Failure(" + error + ")"; } } record FlatMapped<T, R>(IO<? extends T> current, Function1<? super T, ? extends IO<? extends R>> next) implements IO<R> { public FlatMapped { Objects.requireNonNull(current); Objects.requireNonNull(next); } @Override public String toString() { return "FlatMapped(" + current + ", ?)"; } } record Delay<T>(CheckedFunction0<? extends T> task) implements IO<T> { public Delay { Objects.requireNonNull(task); } @Override public String toString() { return "Delay(?)"; } } record Async<T>(Function1<Consumer<? super Try<? extends T>>, IO<Unit>> callback) implements IO<T> { public Async { Objects.requireNonNull(callback); } @Override public String toString() { return "Async(?)"; } } record Suspend<T>(Supplier<? extends IO<? extends T>> lazy) implements IO<T> { public Suspend { Objects.requireNonNull(lazy); } @Override public String toString() { return "Suspend(?)"; } } record Recover<T>(IO<T> current, PartialFunction<? super Throwable, ? extends IO<? extends T>> mapper) implements IO<T> { public Recover { Objects.requireNonNull(current); Objects.requireNonNull(mapper); } @Override public String toString() { return "Recover(" + current + ", ?)"; } } } interface FutureModule { ScheduledExecutorService SCHEDULER = Executors.newScheduledThreadPool(0); static Future<Unit> sleep(Executor executor, Duration delay) { return Future.fromCompletableFuture(executor, CompletableFuture.supplyAsync(Unit::unit, delayedExecutor(delay, executor))); } static Executor delayedExecutor(Duration delay, Executor executor) { return task -> SCHEDULER.schedule(() -> executor.execute(task), delay.toMillis(), TimeUnit.MILLISECONDS); } }