/* * Copyright (c) 2022, Antonio Gabriel Muñoz Conejo <antoniogmc at gmail dot com> * Distributed under the terms of the MIT License */ package com.github.tonivade.vavr.effect; import static io.vavr.Function1.identity; import java.time.Duration; import java.util.ArrayDeque; import java.util.Deque; import java.util.NoSuchElementException; import java.util.Objects; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.function.Supplier; import java.util.function.UnaryOperator; import io.vavr.CheckedConsumer; import io.vavr.CheckedRunnable; import io.vavr.Function1; import io.vavr.Function2; import io.vavr.PartialFunction; import io.vavr.Tuple; import io.vavr.Tuple2; import io.vavr.collection.HashMap; import io.vavr.collection.List; import io.vavr.collection.Seq; import io.vavr.concurrent.Future; import io.vavr.concurrent.Promise; import io.vavr.control.Either; import io.vavr.control.Option; import io.vavr.control.Try; public sealed interface IO<T> { IO<Unit> UNIT = pure(Unit.unit()); default Future<T> runAsync() { return runAsync(this, IOConnection.UNCANCELLABLE).future(); } default Future<T> runAsync(Executor executor) { return forked(executor).andThen(this).runAsync(); } default T unsafeRunSync() { return safeRunSync().get(); } default Try<T> safeRunSync() { return runAsync().toTry(); } default void safeRunAsync(Consumer<? super Try<? extends T>> callback) { safeRunAsync(Future.DEFAULT_EXECUTOR, callback); } default void safeRunAsync(Executor executor, Consumer<? super Try<? extends T>> callback) { runAsync(executor).onComplete(callback); } default <R> IO<R> map(Function1<? super T, ? extends R> map) { return flatMap(map.andThen(IO::pure)); } default <R> IO<R> flatMap(Function1<? super T, ? extends IO<? extends R>> map) { return new FlatMapped<>(this, map.andThen(IO::narrowK)); } default <R> IO<R> andThen(IO<? extends R> after) { return flatMap(ignore -> after); } default <R> IO<R> ap(IO<Function1<? super T, ? extends R>> apply) { return parMap2(Future.DEFAULT_EXECUTOR, this, apply, (v, a) -> a.apply(v)); } default IO<Try<T>> attempt() { return map(Try::success).recover(Try::failure); } default IO<Either<Throwable, T>> either() { return attempt().map(Try::toEither); } default <L, R> IO<Either<L, R>> either(Function1<? super Throwable, ? extends L> mapError, Function1<? super T, ? extends R> mapper) { return either().map(either -> either.bimap(mapError, mapper)); } default <R> IO<R> redeem(Function1<? super Throwable, ? extends R> mapError, Function1<? super T, ? extends R> mapper) { return attempt().map(result -> result.fold(mapError, mapper)); } default <R> IO<R> redeemWith(Function1<? super Throwable, IO<? extends R>> mapError, Function1<? super T, IO<? extends R>> mapper) { return attempt().flatMap(result -> result.fold(mapError, mapper)); } default IO<T> recover(Function1<? super Throwable, ? extends T> mapError) { return recoverWith(partialFunction(x -> true, mapError.andThen(IO::pure))); } @SuppressWarnings("unchecked") default <X extends Throwable> IO<T> recover(Class<X> type, Function1<? super X, ? extends T> function) { return recoverWith(partialFunction(error -> error.getClass().equals(type), t -> function.andThen(IO::pure).apply((X) t))); } @SuppressWarnings("serial") static <A, B> PartialFunction<A, B> partialFunction(Predicate<? super A> matcher, Function1<? super A, ? extends B> function) { return new PartialFunction<>() { @Override public boolean isDefinedAt(A value) { return matcher.test(value); } @Override public B apply(A t) { return function.apply(t); } }; } default IO<T> recoverWith(PartialFunction<? super Throwable, IO<? extends T>> mapper) { return new Recover<>(this, mapper); } default IO<Tuple2<Duration, T>> timed() { return IO.task(System::nanoTime).flatMap( start -> map(result -> Tuple.of(Duration.ofNanos(System.nanoTime() - start), result))); } default IO<Fiber<T>> fork() { return async(callback -> { IOConnection connection = IOConnection.cancellable(); Promise<T> promise = runAsync(this, connection); IO<T> join = fromFuture(promise.future()); IO<Unit> cancel = exec(connection::cancel); callback.accept(Try.success(Fiber.of(join, cancel))); }); } default IO<T> timeout(Duration duration) { return timeout(Future.DEFAULT_EXECUTOR, duration); } default IO<T> timeout(Executor executor, Duration duration) { return racePair(executor, this, sleep(duration)).flatMap(either -> either.fold( ta -> ta._2().cancel().map(x -> ta._1()), tb -> tb._1().cancel().flatMap(x -> IO.raiseError(new TimeoutException())))); } default IO<T> repeat() { return repeat(1); } default IO<T> repeat(int times) { return repeat(this, unit(), times); } default IO<T> repeat(Duration delay) { return repeat(delay, 1); } default IO<T> repeat(Duration delay, int times) { return repeat(this, sleep(delay), times); } default IO<T> retry() { return retry(1); } default IO<T> retry(int maxRetries) { return retry(this, unit(), maxRetries); } default IO<T> retry(Duration delay) { return retry(delay, 1); } default IO<T> retry(Duration delay, int maxRetries) { return retry(this, sleep(delay), maxRetries); } @SuppressWarnings("unchecked") static <T> IO<T> narrowK(IO<? extends T> value) { return (IO<T>) value; } static <T> IO<T> pure(T value) { return new Pure<>(value); } static <A, B> IO<Either<A, B>> race(IO<? extends A> fa, IO<? extends B> fb) { return race(Future.DEFAULT_EXECUTOR, fa, fb); } static <A, B> IO<Either<A, B>> race(Executor executor, IO<? extends A> fa, IO<? extends B> fb) { return racePair(executor, fa, fb).flatMap(either -> either.fold( ta -> ta._2().cancel().map(x -> Either.left(ta._1())), tb -> tb._1().cancel().map(x -> Either.right(tb._2())))); } static <A, B> IO<Either<Tuple2<A, Fiber<B>>, Tuple2<Fiber<A>, B>>> racePair(Executor executor, IO<? extends A> fa, IO<? extends B> fb) { return cancellable(callback -> { IOConnection connection1 = IOConnection.cancellable(); IOConnection connection2 = IOConnection.cancellable(); Promise<A> promiseA = runAsync(IO.forked(executor).andThen(fa), connection1); Promise<B> promiseB = runAsync(IO.forked(executor).andThen(fb), connection2); promiseA.future().onComplete(result -> callback.accept( result.map(a -> Either.left(Tuple.of(a, Fiber.of(IO.fromFuture(promiseB.future()), IO.exec(connection2::cancel))))))); promiseB.future().onComplete(result -> callback.accept( result.map(b -> Either.right(Tuple.of(Fiber.of(IO.fromFuture(promiseA.future()), IO.exec(connection2::cancel)), b))))); return IO.exec(() -> { try { connection1.cancel(); } finally { connection2.cancel(); } }); }); } static <T> IO<T> raiseError(Throwable error) { return new Failure<>(error); } static <T> IO<T> delay(Duration delay, Supplier<? extends T> lazy) { return sleep(delay).andThen(task(lazy)); } static <T> IO<T> suspend(Supplier<IO<? extends T>> lazy) { return new Suspend<>(lazy); } static <T, R> Function1<T, IO<R>> lift(Function1<T, R> task) { return task.andThen(IO::pure); } public static <A, B> Function1<A, IO<B>> liftOption(Function1<? super A, ? extends Option<? extends B>> function) { return value -> fromOption(function.apply(value)); } public static <A, B> Function1<A, IO<B>> liftTry(Function1<? super A, ? extends Try<? extends B>> function) { return value -> fromTry(function.apply(value)); } public static <A, B> Function1<A, IO<B>> liftEither(Function1<? super A, ? extends Either<Throwable, ? extends B>> function) { return value -> fromEither(function.apply(value)); } static <T> IO<T> fromOption(Option<? extends T> task) { return fromEither(toEither(task)); } static <T> IO<T> fromTry(Try<? extends T> task) { return fromEither(task.toEither()); } static <T> IO<T> fromEither(Either<Throwable, ? extends T> task) { return task.fold(IO::raiseError, IO::pure); } static <T> IO<T> fromFuture(Future<? extends T> promise) { CheckedConsumer<Consumer<? super Try<? extends T>>> callback = promise::onComplete; return async(callback); } static <T> IO<T> fromCompletableFuture(CompletableFuture<? extends T> promise) { return fromFuture(Future.fromCompletableFuture(promise)); } static IO<Unit> sleep(Duration duration) { return sleep(Future.DEFAULT_EXECUTOR, duration); } static IO<Unit> sleep(Executor executor, Duration duration) { return cancellable(callback -> { Future<Unit> sleep = FutureModule.sleep(executor,duration) .onComplete(result -> callback.accept(Try.success(Unit.unit()))); return IO.exec(() -> sleep.cancel(true)); }); } static IO<Unit> exec(CheckedRunnable task) { return task(asSupplier(task)); } static <T> IO<T> task(Supplier<? extends T> producer) { return new Delay<>(producer); } static <T> IO<T> never() { return async(callback -> {}); } static IO<Unit> forked() { return forked(Future.DEFAULT_EXECUTOR); } static IO<Unit> forked(Executor executor) { return async(callback -> executor.execute(() -> callback.accept(Try.success(Unit.unit())))); } static <T> IO<T> async(CheckedConsumer<Consumer<? super Try<? extends T>>> callback) { return cancellable(asFunction(callback)); } static <T> IO<T> cancellable(Function1<Consumer<? super Try<? extends T>>, IO<Unit>> callback) { return new Async<>(callback); } static <A, T> IO<Function1<A, IO<T>>> memoize(Function1<A, IO<T>> function) { return memoize(Future.DEFAULT_EXECUTOR, function); } static <A, T> IO<Function1<A, IO<T>>> memoize(Executor executor, Function1<A, IO<T>> function) { var ref = Ref.make(HashMap.<A, Promise<T>>empty()); return ref.map(r -> { Function1<A, IO<IO<T>>> result = a -> r.modify(map -> map.get(a).fold(() -> { Promise<T> promise = Promise.make(); function.apply(a).safeRunAsync(executor, promise::tryComplete); return Tuple.of(IO.fromFuture(promise.future()), map.put(a, promise)); }, promise -> Tuple.of(IO.fromFuture(promise.future()), map))); return result.andThen(io -> io.flatMap(identity())); }); } static IO<Unit> unit() { return UNIT; } static <T, R> IO<R> bracket(IO<? extends T> acquire, Function1<? super T, ? extends IO<? extends R>> use, Function1<? super T, IO<Unit>> release) { return cancellable(callback -> { IOConnection cancellable = IOConnection.cancellable(); Promise<? extends T> promise = runAsync(acquire, cancellable); promise.future() .onFailure(error -> callback.accept(Try.failure(error))) .onSuccess(resource -> runAsync(use.andThen(IO::narrowK).apply(resource), cancellable).future() .onComplete(result -> runAsync(release.andThen(IO::narrowK).apply(resource), cancellable).future() .onComplete(ignore -> callback.accept(result)) )); return IO.exec(cancellable::cancel); }); } static <T, R> IO<R> bracket(IO<? extends T> acquire, Function1<? super T, ? extends IO<? extends R>> use, CheckedConsumer<? super T> release) { return bracket(acquire, use, asFunction(release)); } static <T extends AutoCloseable, R> IO<R> bracket(IO<? extends T> acquire, Function1<? super T, ? extends IO<? extends R>> use) { return bracket(acquire, use, AutoCloseable::close); } static IO<Unit> sequence(Seq<IO<?>> sequence) { IO<?> initial = IO.unit(); return sequence.foldLeft(initial, (IO<?> a, IO<?> b) -> a.andThen(b)).andThen(IO.unit()); } static <A> IO<Seq<A>> traverse(Seq<IO<A>> sequence) { return traverse(Future.DEFAULT_EXECUTOR, sequence); } static <A> IO<Seq<A>> traverse(Executor executor, Seq<IO<A>> sequence) { return sequence.foldLeft(pure(List.empty()), (IO<Seq<A>> xs, IO<A> a) -> parMap2(executor, xs, a, Seq::append)); } static <A, B, C> IO<C> parMap2(IO<? extends A> fa, IO<? extends B> fb, Function2<? super A, ? super B, ? extends C> mapper) { return parMap2(Future.DEFAULT_EXECUTOR, fa, fb, mapper); } static <A, B, C> IO<C> parMap2(Executor executor, IO<? extends A> fa, IO<? extends B> fb, Function2<? super A, ? super B, ? extends C> mapper) { return cancellable(callback -> { IOConnection connection1 = IOConnection.cancellable(); IOConnection connection2 = IOConnection.cancellable(); Promise<A> promiseA = runAsync(IO.forked(executor).andThen(fa), connection1); Promise<B> promiseB = runAsync(IO.forked(executor).andThen(fb), connection2); promiseA.future().onComplete(a -> promiseB.future().onComplete(b -> callback.accept(map2(a, b, mapper)))); return IO.exec(() -> { try { connection1.cancel(); } finally { connection2.cancel(); } }); }); } static <A, B> IO<Tuple2<A, B>> tuple(IO<? extends A> fa, IO<? extends B> fb) { return tuple(Future.DEFAULT_EXECUTOR, fa, fb); } static <A, B> IO<Tuple2<A, B>> tuple(Executor executor, IO<? extends A> fa, IO<? extends B> fb) { return parMap2(executor, fa, fb, Tuple::of); } private static <T> Promise<T> runAsync(IO<T> current, IOConnection connection) { return runAsync(current, connection, new CallStack<>(), Promise.make()); } private static <A, B, C> Try<? extends C> map2(Try<A> a, Try<B> b, Function2<? super A, ? super B, ? extends C> mapper) { return a.flatMap(x -> b.map(y -> mapper.apply(x, y))); } private static <T> Either<Throwable, T> toEither(Option<? extends T> task) { return task.fold(() -> Either.left(new NoSuchElementException()), Either::right); } private static Supplier<Unit> asSupplier(CheckedRunnable task) { return () -> { task.unchecked().run(); return Unit.unit(); }; } private static <T> Function1<T, IO<Unit>> asFunction(CheckedConsumer<? super T> release) { return t -> { release.unchecked().accept(t); return unit(); }; } @SuppressWarnings("unchecked") private static <T, U, V> Promise<T> runAsync(IO<T> current, IOConnection connection, CallStack<T> stack, Promise<T> promise) { while (true) { try { current = unwrap(current, stack, identity()); if (current instanceof Pure<T> pure) { return promise.success(pure.value); } if (current instanceof Async<T> async) { return executeAsync(async, connection, promise); } if (current instanceof FlatMapped) { stack.push(); var flatMapped = (FlatMapped<U, T>) current; IO<U> source = IO.narrowK(unwrap(flatMapped.current, stack, u -> u.flatMap(flatMapped.next))); if (source instanceof Async<U> async) { Promise<U> nextPromise = Promise.make(); nextPromise.future().andThen(tryU -> tryU.onFailure(promise::failure) .onSuccess(u -> runAsync(IO.narrowK(flatMapped.next.apply(u)), connection, stack, promise))); executeAsync(async, connection, nextPromise); return promise; } if (source instanceof Pure<U> pure) { Function1<? super U, IO<T>> andThen = flatMapped.next.andThen(IO::narrowK); current = andThen.apply(pure.value); } else if (source instanceof FlatMapped) { FlatMapped<V, U> flatMapped2 = (FlatMapped<V, U>) source; current = flatMapped2.current.flatMap(a -> flatMapped2.next.apply(a).flatMap(flatMapped.next)); } } else { stack.pop(); } } catch (Throwable error) { Option<IO<T>> result = stack.tryHandle(error); if (result.isDefined()) { current = result.get(); } else { return promise.failure(error); } } } } private static <T, U> IO<T> unwrap(IO<T> current, CallStack<U> stack, Function1<IO<? extends T>, IO<? extends U>> next) { while (true) { if (current instanceof Failure<T> failure) { return stack.sneakyThrow(failure.error); } else if (current instanceof Recover<T> recover) { stack.add(partialFunction(recover.mapper::isDefinedAt, recover.mapper.andThen(next))); current = recover.current; } else if (current instanceof Suspend<T> suspend) { Supplier<IO<T>> andThen = () -> IO.narrowK(suspend.lazy.get()); current = andThen.get(); } else if (current instanceof Delay<T> delay) { return IO.pure(delay.task.get()); } else if (current instanceof Pure) { return current; } else if (current instanceof FlatMapped) { return current; } else if (current instanceof Async) { return current; } else { throw new IllegalStateException(); } } } private static <T> Promise<T> executeAsync(Async<T> current, IOConnection connection, Promise<T> promise) { if (connection.isCancellable() && !connection.updateState(StateIO::startingNow).isRunnable()) { return promise.complete(Try.failure(new CancellationException())); } connection.setCancelToken(current.callback.apply(promise::tryComplete)); promise.future().andThen(x -> connection.setCancelToken(UNIT)); if (connection.isCancellable() && connection.updateState(StateIO::notStartingNow).isCancellingNow()) { connection.cancelNow(); } return promise; } private static <T> IO<T> repeat(IO<T> self, IO<Unit> pause, int times) { return self.redeemWith(IO::raiseError, value -> { if (times > 0) { return pause.andThen(repeat(self, pause, times - 1)); } else return IO.pure(value); }); } private static <T> IO<T> retry(IO<T> self, IO<Unit> pause, int maxRetries) { return self.redeemWith(error -> { if (maxRetries > 0) { return pause.andThen(retry(self, pause.repeat(), maxRetries - 1)); } else return IO.raiseError(error); }, IO::pure); } final class Pure<T> implements IO<T> { private final T value; private Pure(T value) { this.value = Objects.requireNonNull(value); } @Override public String toString() { return "Pure(" + value + ")"; } } final class Failure<T> implements IO<T> { private final Throwable error; private Failure(Throwable error) { this.error = Objects.requireNonNull(error); } @Override public String toString() { return "Failure(" + error + ")"; } } final class FlatMapped<T, R> implements IO<R> { private final IO<? extends T> current; private final Function1<? super T, ? extends IO<? extends R>> next; private FlatMapped(IO<? extends T> current, Function1<? super T, ? extends IO<? extends R>> next) { this.current = Objects.requireNonNull(current); this.next = Objects.requireNonNull(next); } @Override public String toString() { return "FlatMapped(" + current + ", ?)"; } } final class Delay<T> implements IO<T> { private final Supplier<? extends T> task; private Delay(Supplier<? extends T> task) { this.task = Objects.requireNonNull(task); } @Override public String toString() { return "Delay(?)"; } } final class Async<T> implements IO<T> { private final Function1<Consumer<? super Try<? extends T>>, IO<Unit>> callback; private Async(Function1<Consumer<? super Try<? extends T>>, IO<Unit>> callback) { this.callback = Objects.requireNonNull(callback); } @Override public String toString() { return "Async(?)"; } } final class Suspend<T> implements IO<T> { private final Supplier<? extends IO<? extends T>> lazy; private Suspend(Supplier<? extends IO<? extends T>> lazy) { this.lazy = Objects.requireNonNull(lazy); } @Override public String toString() { return "Suspend(?)"; } } final class Recover<T> implements IO<T> { private final IO<T> current; private final PartialFunction<? super Throwable, ? extends IO<? extends T>> mapper; private Recover(IO<T> current, PartialFunction<? super Throwable, ? extends IO<? extends T>> mapper) { this.current = Objects.requireNonNull(current); this.mapper = Objects.requireNonNull(mapper); } @Override public String toString() { return "Recover(" + current + ", ?)"; } } } sealed interface IOConnection { IOConnection UNCANCELLABLE = new Uncancellable(); boolean isCancellable(); void setCancelToken(IO<Unit> cancel); void cancelNow(); void cancel(); StateIO updateState(UnaryOperator<StateIO> update); static IOConnection cancellable() { return new Cancellable(); } static final class Uncancellable implements IOConnection { private Uncancellable() { } @Override public boolean isCancellable() { return false; } @Override public void setCancelToken(IO<Unit> cancel) { // uncancellable } @Override public void cancelNow() { // uncancellable } @Override public void cancel() { // uncancellable } @Override public StateIO updateState(UnaryOperator<StateIO> update) { return StateIO.INITIAL; } } static final class Cancellable implements IOConnection { private IO<Unit> cancelToken; private final AtomicReference<StateIO> state = new AtomicReference<>(StateIO.INITIAL); private Cancellable() { } @Override public boolean isCancellable() { return true; } @Override public void setCancelToken(IO<Unit> cancel) { this.cancelToken = Objects.requireNonNull(cancel); } @Override public void cancelNow() { cancelToken.runAsync(); } @Override public void cancel() { if (state.getAndUpdate(StateIO::cancellingNow).isCancelable()) { cancelNow(); state.set(StateIO.CANCELLED); } } @Override public StateIO updateState(UnaryOperator<StateIO> update) { return state.updateAndGet(update::apply); } } } final class StateIO { static final StateIO INITIAL = new StateIO(false, false, false); static final StateIO CANCELLED = new StateIO(true, false, false); private final boolean cancelled; private final boolean cancellingNow; private final boolean startingNow; StateIO(boolean cancelled, boolean cancellingNow, boolean startingNow) { this.cancelled = cancelled; this.cancellingNow = cancellingNow; this.startingNow = startingNow; } boolean isCancelled() { return cancelled; } boolean isCancellingNow() { return cancellingNow; } boolean isStartingNow() { return startingNow; } StateIO cancellingNow() { return new StateIO(cancelled, true, startingNow); } StateIO startingNow() { return new StateIO(cancelled, cancellingNow, true); } StateIO notStartingNow() { return new StateIO(cancelled, cancellingNow, false); } boolean isCancelable() { return !cancelled && !cancellingNow && !startingNow; } boolean isRunnable() { return !cancelled && !cancellingNow; } } final class CallStack<T> { private StackItem<T> top = new StackItem<>(); void push() { top.push(); } void pop() { if (top.count() > 0) { top.pop(); } else { top = top.prev(); } } void add(PartialFunction<? super Throwable, ? extends IO<? extends T>> mapError) { if (top.count() > 0) { top.pop(); top = new StackItem<>(top); } top.add(mapError); } Option<IO<T>> tryHandle(Throwable error) { while (top != null) { top.reset(); Option<IO<T>> result = top.tryHandle(error); if (result.isDefined()) { return result; } else { top = top.prev(); } } return Option.none(); } // XXX: https://www.baeldung.com/java-sneaky-throws @SuppressWarnings("unchecked") <X extends Throwable, R> R sneakyThrow(Throwable t) throws X { throw (X) t; } } final class StackItem<T> { private int count = 0; private final Deque<PartialFunction<? super Throwable, ? extends IO<? extends T>>> recover = new ArrayDeque<>(); private final StackItem<T> prev; StackItem() { this(null); } StackItem(StackItem<T> prev) { this.prev = prev; } StackItem<T> prev() { return prev; } int count() { return count; } void push() { count++; } void pop() { count--; } void reset() { count = 0; } void add(PartialFunction<? super Throwable, ? extends IO<? extends T>> mapError) { recover.addFirst(mapError); } Option<IO<T>> tryHandle(Throwable error) { while (!recover.isEmpty()) { var mapError = recover.removeFirst(); if (mapError.isDefinedAt(error)) { return Option.some(mapError.andThen(IO::<T>narrowK).apply(error)); } } return Option.none(); } } interface FutureModule { ScheduledExecutorService SCHEDULER = Executors.newScheduledThreadPool(0); static Future<Unit> sleep(Executor executor, Duration delay) { return Future.fromCompletableFuture(executor, CompletableFuture.supplyAsync(Unit::unit, delayedExecutor(delay, executor))); } static Executor delayedExecutor(Duration delay, Executor executor) { return task -> SCHEDULER.schedule(() -> executor.execute(task), delay.toMillis(), TimeUnit.MILLISECONDS); } }