Newer
Older
vavr-effect / src / main / java / com / github / tonivade / vavr / effect / IO.java
/*
 * Copyright (c) 2024, Antonio Gabriel Muñoz Conejo <antoniogmc at gmail dot com>
 * Distributed under the terms of the MIT License
 */
package com.github.tonivade.vavr.effect;

import static io.vavr.Function1.identity;

import java.time.Duration;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.function.Supplier;

import io.vavr.CheckedConsumer;
import io.vavr.CheckedFunction0;
import io.vavr.CheckedRunnable;
import io.vavr.Function1;
import io.vavr.Function2;
import io.vavr.PartialFunction;
import io.vavr.Tuple;
import io.vavr.Tuple2;
import io.vavr.collection.HashMap;
import io.vavr.collection.List;
import io.vavr.collection.Seq;
import io.vavr.concurrent.Future;
import io.vavr.concurrent.Promise;
import io.vavr.control.Either;
import io.vavr.control.Option;
import io.vavr.control.Try;

public sealed interface IO<T> {

  IO<Unit> UNIT = pure(Unit.unit());

  default Future<T> runAsync() {
    return runAsync(this, IOConnection.UNCANCELLABLE).future();
  }

  default Future<T> runAsync(Executor executor) {
    return forked(executor).andThen(this).runAsync();
  }

  default T unsafeRunSync() {
    return safeRunSync().get();
  }

  default Try<T> safeRunSync() {
    return runAsync().toTry();
  }

  default void safeRunAsync(Consumer<? super Try<? extends T>> callback) {
    safeRunAsync(Future.DEFAULT_EXECUTOR, callback);
  }

  default void safeRunAsync(Executor executor, Consumer<? super Try<? extends T>> callback) {
    runAsync(executor).onComplete(callback);
  }

  default <R> IO<R> map(Function1<? super T, ? extends R> map) {
    return flatMap(map.andThen(IO::pure));
  }

  default <R> IO<R> flatMap(Function1<? super T, ? extends IO<? extends R>> map) {
    return new FlatMapped<>(this, map);
  }

  default <R> IO<R> andThen(IO<? extends R> after) {
    return flatMap(ignore -> after);
  }

  default <R> IO<R> ap(IO<Function1<? super T, ? extends R>> apply) {
    return parMap2(Future.DEFAULT_EXECUTOR, this, apply, (v, a) -> a.apply(v));
  }

  default IO<Try<T>> attempt() {
    return map(Try::success).recover(Try::failure);
  }

  default IO<Either<Throwable, T>> either() {
    return attempt().map(Try::toEither);
  }

  default <L, R> IO<Either<L, R>> either(Function1<? super Throwable, ? extends L> mapError,
                                         Function1<? super T, ? extends R> mapper) {
    return either().map(either -> either.bimap(mapError, mapper));
  }

  default <R> IO<R> redeem(Function1<? super Throwable, ? extends R> mapError,
                           Function1<? super T, ? extends R> mapper) {
    return attempt().map(result -> result.fold(mapError, mapper));
  }

  default <R> IO<R> redeemWith(Function1<? super Throwable, IO<? extends R>> mapError,
                               Function1<? super T, IO<? extends R>> mapper) {
    return attempt().flatMap(result -> result.fold(mapError, mapper));
  }

  default IO<T> recover(Function1<? super Throwable, ? extends T> mapError) {
    return recoverWith(partialFunction(x -> true, mapError.andThen(IO::pure)));
  }

  @SuppressWarnings("unchecked")
  default <X extends Throwable> IO<T> recover(Class<X> type, Function1<? super X, ? extends T> function) {
    return recoverWith(partialFunction(error -> error.getClass().equals(type), t -> function.andThen(IO::pure).apply((X) t)));
  }

  default IO<T> recoverWith(PartialFunction<? super Throwable, IO<? extends T>> mapper) {
    return new Recover<>(this, mapper);
  }

  default IO<Tuple2<Duration, T>> timed() {
    return task(System::nanoTime).flatMap(
      start -> map(result -> Tuple.of(Duration.ofNanos(System.nanoTime() - start), result)));
  }

  default IO<Fiber<T>> fork() {
    return async(callback -> {
      IOConnection connection = IOConnection.cancellable();
      Promise<T> promise = runAsync(this, connection);

      IO<T> join = fromFuture(promise.future());
      IO<Unit> cancel = exec(connection::cancel);

      callback.accept(Try.success(new Fiber<>(join, cancel)));
    });
  }

  default IO<T> timeout(Duration duration) {
    return timeout(Future.DEFAULT_EXECUTOR, duration);
  }

  default IO<T> timeout(Executor executor, Duration duration) {
    return racePair(executor, this, sleep(duration)).flatMap(either -> either.fold(
        ta -> ta._2().cancel().map(x -> ta._1()),
        tb -> tb._1().cancel().flatMap(x -> raiseError(new TimeoutException()))));
  }

  default IO<T> repeat() {
    return repeat(1);
  }

  default IO<T> repeat(int times) {
    return repeat(Schedule.<T>recurs(times).zipRight(Schedule.identity()));
  }

  default IO<T> repeat(Duration delay) {
    return repeat(delay, 1);
  }

  default IO<T> repeat(Duration delay, int times) {
    return repeat(Schedule.<T>recursSpaced(delay, times).zipRight(Schedule.identity()));
  }

  default <B> IO<B> repeat(Schedule<T, B> schedule) {
    return repeatOrElse(schedule, (e, b) -> raiseError(e));
  }

  default <B> IO<B> repeatOrElse(
      Schedule<T, B> schedule,
      Function2<Throwable, Option<B>, IO<B>> orElse) {
    return repeatOrElseEither(schedule, orElse).map(IO::merge);
  }

  default <B, C> IO<Either<C, B>> repeatOrElseEither(
      Schedule<T, B> schedule,
      Function2<Throwable, Option<B>, IO<C>> orElse) {
    return new Repeat<>(this, schedule, orElse).run();
  }

  default IO<T> retry() {
    return retry(1);
  }

  default IO<T> retry(int maxRetries) {
    return retry(Schedule.recurs(maxRetries));
  }

  default IO<T> retry(Duration delay) {
    return retry(delay, 1);
  }

  default IO<T> retry(Duration delay, int maxRetries) {
    return retry(Schedule.<Throwable>recursSpaced(delay, maxRetries));
  }

  default <B> IO<T> retry(Schedule<Throwable, B> schedule) {
    return retryOrElse(schedule, (e, b) -> raiseError(e));
  }

  default <B> IO<T> retryOrElse(
      Schedule<Throwable, B> schedule,
      Function2<Throwable, B, IO<T>> orElse) {
    return retryOrElseEither(schedule, orElse).map(IO::merge);
  }

  default <B, C> IO<Either<B, T>> retryOrElseEither(
      Schedule<Throwable, C> schedule,
      Function2<Throwable, C, IO<B>> orElse) {
    return new Retry<>(this, schedule, orElse).run();
  }

  @SuppressWarnings("unchecked")
  static <T> IO<T> narrowK(IO<? extends T> value) {
    return (IO<T>) value;
  }

  static <T> IO<T> pure(T value) {
    return new Pure<>(value);
  }

  static <A, B> IO<Either<A, B>> race(IO<? extends A> fa, IO<? extends B> fb) {
    return race(Future.DEFAULT_EXECUTOR, fa, fb);
  }

  static <A, B> IO<Either<A, B>> race(Executor executor, IO<? extends A> fa, IO<? extends B> fb) {
    return racePair(executor, fa, fb).flatMap(either -> either.fold(
        ta -> ta._2().cancel().map(x -> Either.left(ta._1())),
        tb -> tb._1().cancel().map(x -> Either.right(tb._2()))));
  }

  static <A, B> IO<Either<Tuple2<A, Fiber<B>>, Tuple2<Fiber<A>, B>>> racePair(
      Executor executor, IO<? extends A> fa, IO<? extends B> fb) {
    return cancellable(callback -> {

      IOConnection connection1 = IOConnection.cancellable();
      IOConnection connection2 = IOConnection.cancellable();

      Promise<A> promiseA = runAsync(forked(executor).andThen(fa), connection1);
      Promise<B> promiseB = runAsync(forked(executor).andThen(fb), connection2);

      promiseA.future().onComplete(result -> callback.accept(
          result.map(a -> Either.left(
              Tuple.of(a, new Fiber<>(fromFuture(promiseB.future()), exec(connection2::cancel)))))));
      promiseB.future().onComplete(result -> callback.accept(
          result.map(b -> Either.right(
              Tuple.of(new Fiber<>(fromFuture(promiseA.future()), exec(connection2::cancel)), b)))));

      return exec(() -> {
        try {
          connection1.cancel();
        } finally {
          connection2.cancel();
        }
      });
    });
  }

  static <T> IO<T> raiseError(Throwable error) {
    return new Failure<>(error);
  }

  static <T> IO<T> delay(Duration delay, CheckedFunction0<? extends T> lazy) {
    return sleep(delay).andThen(task(lazy));
  }

  static <T> IO<T> suspend(Supplier<IO<? extends T>> lazy) {
    return new Suspend<>(lazy);
  }

  static <T, R> Function1<T, IO<R>> lift(Function1<T, R> task) {
    return task.andThen(IO::pure);
  }

  public static <A, B> Function1<A, IO<B>> liftOption(Function1<? super A, ? extends Option<? extends B>> function) {
    return value -> fromOption(function.apply(value));
  }

  public static <A, B> Function1<A, IO<B>> liftTry(Function1<? super A, ? extends Try<? extends B>> function) {
    return value -> fromTry(function.apply(value));
  }

  public static <A, B> Function1<A, IO<B>> liftEither(Function1<? super A, ? extends Either<Throwable, ? extends B>> function) {
    return value -> fromEither(function.apply(value));
  }

  static <T> IO<T> fromOption(Option<? extends T> task) {
    return fromEither(toEither(task));
  }

  static <T> IO<T> fromTry(Try<? extends T> task) {
    return fromEither(task.toEither());
  }

  static <T> IO<T> fromEither(Either<Throwable, ? extends T> task) {
    return task.fold(IO::raiseError, IO::pure);
  }

  static <T> IO<T> fromFuture(Future<? extends T> promise) {
    CheckedConsumer<Consumer<? super Try<? extends T>>> callback = promise::onComplete;
    return async(callback);
  }

  static <T> IO<T> fromCompletableFuture(CompletableFuture<? extends T> promise) {
    return fromFuture(Future.fromCompletableFuture(promise));
  }

  static IO<Unit> sleep(Duration duration) {
    return sleep(Future.DEFAULT_EXECUTOR, duration);
  }

  static IO<Unit> sleep(Executor executor, Duration duration) {
    return cancellable(callback -> {
      Future<Unit> sleep = FutureModule.sleep(executor,duration)
        .onComplete(result -> callback.accept(Try.success(Unit.unit())));
      return exec(() -> sleep.cancel(true));
    });
  }

  static IO<Unit> exec(CheckedRunnable task) {
    return task(asSupplier(task));
  }

  static <T> IO<T> task(CheckedFunction0<? extends T> producer) {
    return new Delay<>(producer);
  }

  static <T> IO<T> never() {
    return async(callback -> {});
  }

  static IO<Unit> forked() {
    return forked(Future.DEFAULT_EXECUTOR);
  }

  static IO<Unit> forked(Executor executor) {
    return async(callback -> executor.execute(() -> callback.accept(Try.success(Unit.unit()))));
  }

  static <T> IO<T> async(CheckedConsumer<Consumer<? super Try<? extends T>>> callback) {
    return cancellable(asFunction(callback));
  }

  static <T> IO<T> cancellable(Function1<Consumer<? super Try<? extends T>>, IO<Unit>> callback) {
    return new Async<>(callback);
  }

  static <A, T> IO<Function1<A, IO<T>>> memoize(Function1<A, IO<T>> function) {
    return memoize(Future.DEFAULT_EXECUTOR, function);
  }

  static <A, T> IO<Function1<A, IO<T>>> memoize(Executor executor, Function1<A, IO<T>> function) {
    var ref = Reference.make(HashMap.<A, Promise<T>>empty());
    return ref.map(r -> {
      Function1<A, IO<IO<T>>> result = a -> r.modify(map -> map.get(a).fold(() -> {
        Promise<T> promise = Promise.make();
        function.apply(a).safeRunAsync(executor, promise::tryComplete);
        return Tuple.of(fromFuture(promise.future()), map.put(a, promise));
      }, promise -> Tuple.of(fromFuture(promise.future()), map)));
      return result.andThen(io -> io.flatMap(identity()));
    });
  }

  static IO<Unit> unit() {
    return UNIT;
  }

  static <T, R> IO<R> bracket(IO<? extends T> acquire,
      Function1<? super T, ? extends IO<? extends R>> use, Function1<? super T, IO<Unit>> release) {
    return cancellable(callback -> {

      IOConnection cancellable = IOConnection.cancellable();

      Promise<? extends T> promise = runAsync(acquire, cancellable);

      promise.future()
        .onFailure(error -> callback.accept(Try.failure(error)))
        .onSuccess(resource -> runAsync(use.apply(resource), cancellable).future()
          .onComplete(result -> runAsync(release.apply(resource), cancellable).future()
            .onComplete(ignore -> callback.accept(result))
        ));

      return exec(cancellable::cancel);
    });
  }

  static <T, R> IO<R> bracket(IO<? extends T> acquire,
      Function1<? super T, ? extends IO<? extends R>> use, CheckedConsumer<? super T> release) {
    return bracket(acquire, use, asFunction(release));
  }

  static <T extends AutoCloseable, R> IO<R> bracket(IO<? extends T> acquire,
      Function1<? super T, ? extends IO<? extends R>> use) {
    return bracket(acquire, use, AutoCloseable::close);
  }

  static IO<Unit> sequence(Seq<IO<?>> sequence) {
    IO<?> initial = unit();
    return sequence.foldLeft(initial, (IO<?> a, IO<?> b) -> a.andThen(b)).andThen(unit());
  }

  static <A> IO<Seq<A>> traverse(Seq<IO<A>> sequence) {
    return traverse(Future.DEFAULT_EXECUTOR, sequence);
  }

  static <A> IO<Seq<A>> traverse(Executor executor, Seq<IO<A>> sequence) {
    return sequence.foldLeft(pure(List.empty()),
        (IO<Seq<A>> xs, IO<A> a) -> parMap2(executor, xs, a, Seq::append));
  }

  static <A, B, C> IO<C> parMap2(IO<? extends A> fa, IO<? extends B> fb,
                              Function2<? super A, ? super B, ? extends C> mapper) {
    return parMap2(Future.DEFAULT_EXECUTOR, fa, fb, mapper);
  }

  static <A, B, C> IO<C> parMap2(Executor executor, IO<? extends A> fa, IO<? extends B> fb,
                              Function2<? super A, ? super B, ? extends C> mapper) {
    return cancellable(callback -> {

      IOConnection connection1 = IOConnection.cancellable();
      IOConnection connection2 = IOConnection.cancellable();

      Promise<A> promiseA = runAsync(forked(executor).andThen(fa), connection1);
      Promise<B> promiseB = runAsync(forked(executor).andThen(fb), connection2);

      promiseA.future().onComplete(a -> promiseB.future().onComplete(b -> callback.accept(map2(a, b, mapper))));

      return exec(() -> {
        try {
          connection1.cancel();
        } finally {
          connection2.cancel();
        }
      });
    });
  }

  static <A, B> IO<Tuple2<A, B>> tuple(IO<? extends A> fa, IO<? extends B> fb) {
    return tuple(Future.DEFAULT_EXECUTOR, fa, fb);
  }

  static <A, B> IO<Tuple2<A, B>> tuple(Executor executor, IO<? extends A> fa, IO<? extends B> fb) {
    return parMap2(executor, fa, fb, Tuple::of);
  }

  private static <T> Promise<T> runAsync(IO<T> current, IOConnection connection) {
    return runAsync(current, connection, new CallStack<>(), Promise.make());
  }

  private static <A, B, C> Try<? extends C> map2(Try<A> a, Try<B> b, Function2<? super A, ? super B, ? extends C> mapper) {
    return a.flatMap(x -> b.map(y -> mapper.apply(x, y)));
  }

  private static <T> Either<Throwable, T> toEither(Option<? extends T> task) {
    return task.fold(() -> Either.left(new NoSuchElementException()), Either::right);
  }

  private static <A> A merge(Either<A, A> either) {
    return either.fold(Function1.identity(), Function1.identity());
  }

  @SuppressWarnings("serial")
  private static <A, B> PartialFunction<A, B> partialFunction(Predicate<? super A> matcher, Function1<? super A, ? extends B> function) {
    return new PartialFunction<>() {

      @Override
      public boolean isDefinedAt(A value) {
        return matcher.test(value);
      }

      @Override
      public B apply(A t) {
        return function.apply(t);
      }
    };
  }

  private static CheckedFunction0<Unit> asSupplier(CheckedRunnable task) {
    return () -> {
      task.unchecked().run();
      return Unit.unit();
    };
  }

  private static <T> Function1<T, IO<Unit>> asFunction(CheckedConsumer<? super T> release) {
    return t -> {
      release.unchecked().accept(t);
      return unit();
    };
  }

  @SuppressWarnings("unchecked")
  private static <T, U, V> Promise<T> runAsync(IO<T> current, IOConnection connection, CallStack<T> stack, Promise<T> promise) {
    while (true) {
      try {
        current = unwrap(current, stack, identity());

        if (current instanceof Pure<T> pure) {
          return promise.success(pure.value);
        }

        if (current instanceof Async<T> async) {
          return executeAsync(async, connection, promise);
        }

        if (current instanceof FlatMapped) {
          stack.push();

          var flatMapped = (FlatMapped<U, T>) current;
          IO<U> source = narrowK(unwrap(flatMapped.current, stack, u -> u.flatMap(flatMapped.next)));

          if (source instanceof Async<U> async) {
            Promise<U> nextPromise = Promise.make();

            nextPromise.future().andThen(tryU -> tryU.onFailure(promise::failure)
                .onSuccess(u -> runAsync(narrowK(flatMapped.next.apply(u)), connection, stack, promise)));

            executeAsync(async, connection, nextPromise);

            return promise;
          }

          if (source instanceof Pure<U> pure) {
            Function1<? super U, IO<T>> andThen = flatMapped.next.andThen(IO::narrowK);
            current = andThen.apply(pure.value);
          } else if (source instanceof FlatMapped) {
            FlatMapped<V, U> flatMapped2 = (FlatMapped<V, U>) source;
            current = flatMapped2.current.flatMap(a -> flatMapped2.next.apply(a).flatMap(flatMapped.next));
          }
        } else {
          stack.pop();
        }
      } catch (Throwable error) {
        Option<IO<T>> result = stack.tryHandle(error);

        if (result.isDefined()) {
          current = result.get();
        } else {
          return promise.failure(error);
        }
      }
    }
  }

  private static <T, U> IO<T> unwrap(IO<T> current, CallStack<U> stack, Function1<IO<? extends T>, IO<? extends U>> next) {
    while (true) {
      if (current instanceof Failure<T> failure) {
        return stack.sneakyThrow(failure.error);
      } else if (current instanceof Recover<T> recover) {
        stack.add(partialFunction(recover.mapper::isDefinedAt, recover.mapper.andThen(next)));
        current = recover.current;
      } else if (current instanceof Suspend<T> suspend) {
        current = narrowK(suspend.lazy.get());
      } else if (current instanceof Delay<T> delay) {
        return pure(delay.task.unchecked().get());
      } else if (current instanceof Pure) {
        return current;
      } else if (current instanceof FlatMapped) {
        return current;
      } else if (current instanceof Async) {
        return current;
      } else {
        throw new IllegalStateException();
      }
    }
  }

  private static <T> Promise<T> executeAsync(Async<T> current, IOConnection connection, Promise<T> promise) {
    if (connection.isCancellable() && !connection.updateState(StateIO::startingNow).isRunnable()) {
      return promise.complete(Try.failure(new CancellationException()));
    }

    connection.setCancelToken(current.callback.apply(promise::tryComplete));

    promise.future().andThen(x -> connection.setCancelToken(UNIT));

    if (connection.isCancellable() && connection.updateState(StateIO::notStartingNow).isCancellingNow()) {
      connection.cancelNow();
    }

    return promise;
  }

  record Pure<T>(T value) implements IO<T> {

    public Pure {
      Objects.requireNonNull(value);
    }

    @Override
    public String toString() {
      return "Pure(" + value + ")";
    }
  }

  record Failure<T>(Throwable error) implements IO<T> {

    public Failure {
      Objects.requireNonNull(error);
    }

    @Override
    public String toString() {
      return "Failure(" + error + ")";
    }
  }

  record FlatMapped<T, R>(IO<? extends T> current, Function1<? super T, ? extends IO<? extends R>> next)  implements IO<R> {

    public FlatMapped {
      Objects.requireNonNull(current);
      Objects.requireNonNull(next);
    }

    @Override
    public String toString() {
      return "FlatMapped(" + current + ", ?)";
    }
  }

  record Delay<T>(CheckedFunction0<? extends T> task) implements IO<T> {

    public Delay {
      Objects.requireNonNull(task);
    }

    @Override
    public String toString() {
      return "Delay(?)";
    }
  }

  record Async<T>(Function1<Consumer<? super Try<? extends T>>, IO<Unit>> callback) implements IO<T> {

    public Async {
      Objects.requireNonNull(callback);
    }

    @Override
    public String toString() {
      return "Async(?)";
    }
  }

  record Suspend<T>(Supplier<? extends IO<? extends T>> lazy) implements IO<T> {

    public Suspend {
      Objects.requireNonNull(lazy);
    }

    @Override
    public String toString() {
      return "Suspend(?)";
    }
  }

  record Recover<T>(IO<T> current, PartialFunction<? super Throwable, ? extends IO<? extends T>> mapper) implements IO<T> {

    public Recover {
      Objects.requireNonNull(current);
      Objects.requireNonNull(mapper);
    }

    @Override
    public String toString() {
      return "Recover(" + current + ", ?)";
    }
  }
}

interface FutureModule {

  ScheduledExecutorService SCHEDULER = Executors.newScheduledThreadPool(0);

  static Future<Unit> sleep(Executor executor, Duration delay) {
    return Future.fromCompletableFuture(executor, CompletableFuture.supplyAsync(Unit::unit, delayedExecutor(delay, executor)));
  }

  static Executor delayedExecutor(Duration delay, Executor executor) {
    return task -> SCHEDULER.schedule(() -> executor.execute(task), delay.toMillis(), TimeUnit.MILLISECONDS);
  }
}