/*
 * Decompiled with CFR 0.152.
 */
package ml.dmlc.xgboost4j.scala.spark;

import java.io.Serializable;
import java.util.Map;
import ml.dmlc.xgboost4j.java.Communicator;
import ml.dmlc.xgboost4j.java.ITracker;
import ml.dmlc.xgboost4j.java.RabitTracker;
import ml.dmlc.xgboost4j.java.XGBoostError;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.DMatrix;
import ml.dmlc.xgboost4j.scala.EvalTrait;
import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager;
import ml.dmlc.xgboost4j.scala.ExternalCheckpointParams;
import ml.dmlc.xgboost4j.scala.ObjectiveTrait;
import ml.dmlc.xgboost4j.scala.spark.TrackerConf;
import ml.dmlc.xgboost4j.scala.spark.Watches;
import ml.dmlc.xgboost4j.scala.spark.XGBoostExecutionParams;
import ml.dmlc.xgboost4j.scala.spark.XGBoostExecutionParamsFactory;
import ml.dmlc.xgboost4j.scala.spark.XGBoostStageLevel;
import ml.dmlc.xgboost4j.scala.spark.package$;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContext$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.rdd.RDDBarrier;
import org.apache.spark.resource.ResourceInformation;
import scala.;
import scala.$less$colon$less$;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.None$;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Some;
import scala.Tuple2;
import scala.collection.ArrayOps$;
import scala.collection.IterableOnce;
import scala.collection.IterableOnceOps;
import scala.collection.Iterator;
import scala.collection.StringOps$;
import scala.collection.immutable.Seq;
import scala.math.Ordering;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ModuleSerializationProxy;
import scala.runtime.ScalaRunTime$;

public final class XGBoost$
implements XGBoostStageLevel {
    public static final XGBoost$ MODULE$ = new XGBoost$();
    private static final Log logger;
    private static Log ml$dmlc$xgboost4j$scala$spark$XGBoostStageLevel$$logger;

    static {
        XGBoostStageLevel.$init$(MODULE$);
        logger = LogFactory.getLog((String)"XGBoostSpark");
    }

    @Override
    public boolean isStandaloneOrLocalCluster(SparkConf conf) {
        return XGBoostStageLevel.isStandaloneOrLocalCluster$(this, conf);
    }

    @Override
    public boolean skipStageLevelScheduling(String sparkVersion, boolean runOnGpu, SparkConf conf) {
        return XGBoostStageLevel.skipStageLevelScheduling$(this, sparkVersion, runOnGpu, conf);
    }

    @Override
    public RDD<Tuple2<Booster, scala.collection.immutable.Map<String, float[]>>> tryStageLevelScheduling(SparkContext sc, XGBoostExecutionParams xgbExecParams, RDD<Tuple2<Booster, scala.collection.immutable.Map<String, float[]>>> rdd) {
        return XGBoostStageLevel.tryStageLevelScheduling$(this, sc, xgbExecParams, rdd);
    }

    @Override
    public Log ml$dmlc$xgboost4j$scala$spark$XGBoostStageLevel$$logger() {
        return ml$dmlc$xgboost4j$scala$spark$XGBoostStageLevel$$logger;
    }

    @Override
    public final void ml$dmlc$xgboost4j$scala$spark$XGBoostStageLevel$_setter_$ml$dmlc$xgboost4j$scala$spark$XGBoostStageLevel$$logger_$eq(Log x$1) {
        ml$dmlc$xgboost4j$scala$spark$XGBoostStageLevel$$logger = x$1;
    }

    private Log logger() {
        return logger;
    }

    public int getGPUAddrFromResources() {
        TaskContext tc = TaskContext$.MODULE$.get();
        if (tc == null) {
            throw new RuntimeException("Something wrong for task context");
        }
        scala.collection.immutable.Map resources = tc.resources();
        if (resources.contains((Object)"gpu")) {
            String[] addrs = ((ResourceInformation)resources.apply((Object)"gpu")).addresses();
            if (ArrayOps$.MODULE$.size$extension(Predef$.MODULE$.refArrayOps((Object[])addrs)) > 1) {
                this.logger().warn((Object)"XGBoost only supports 1 gpu per worker");
            }
            return StringOps$.MODULE$.toInt$extension(Predef$.MODULE$.augmentString((String)ArrayOps$.MODULE$.head$extension(Predef$.MODULE$.refArrayOps((Object[])addrs))));
        }
        throw new RuntimeException("gpu is not allocated by spark, please check if gpu scheduling is enabled");
    }

    private Watches buildWatchesAndCheck(Function0<Watches> buildWatchesFun) {
        Watches watches = (Watches)buildWatchesFun.apply();
        if (!watches.toMap().contains((Object)"train")) {
            throw new XGBoostError(new StringBuilder(64).append("detected an empty partition in the training data, partition ID:").append(" ").append(TaskContext$.MODULE$.getPartitionId()).toString());
        }
        return watches;
    }

    private Iterator<Tuple2<Booster, scala.collection.immutable.Map<String, float[]>>> buildDistributedBooster(Function0<Watches> buildWatches, XGBoostExecutionParams xgbExecutionParam, Map<String, Object> rabitEnv, ObjectiveTrait obj, EvalTrait eval, Booster prevBooster) {
        Iterator iterator;
        Watches watches = null;
        String taskId = Integer.toString(TaskContext$.MODULE$.getPartitionId());
        String attempt = Integer.toString(TaskContext$.MODULE$.get().attemptNumber());
        rabitEnv.put("DMLC_TASK_ID", taskId);
        int numRounds = xgbExecutionParam.numRounds();
        boolean makeCheckpoint = xgbExecutionParam.checkpointParam().isDefined() && StringOps$.MODULE$.toInt$extension(Predef$.MODULE$.augmentString(taskId)) == 0;
        try {
            try {
                Booster booster;
                Communicator.init(rabitEnv);
                watches = this.buildWatchesAndCheck(buildWatches);
                int numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds();
                float[][] metrics = (float[][])Array$.MODULE$.tabulate(watches.size(), (Function1 & Serializable)x$8 -> XGBoost$.$anonfun$buildDistributedBooster$1(numRounds, BoxesRunTime.unboxToInt((Object)x$8)), ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(Float.TYPE)));
                Option<ExternalCheckpointParams> externalCheckpointParams = xgbExecutionParam.checkpointParam();
                scala.collection.immutable.Map params = xgbExecutionParam.toMap();
                if (xgbExecutionParam.runOnGpu()) {
                    int gpuId = xgbExecutionParam.isLocal() ? 0 : this.getGPUAddrFromResources();
                    this.logger().info((Object)new StringBuilder(31).append("Leveraging gpu device ").append(gpuId).append(" to train").toString());
                    params = (scala.collection.immutable.Map)params.$plus(Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"device"), (Object)new StringBuilder(5).append("cuda:").append(gpuId).toString()));
                }
                Booster booster2 = booster = makeCheckpoint ? ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.trainAndSaveCheckpoint((DMatrix)watches.toMap().apply((Object)"train"), params, numRounds, watches.toMap(), metrics, obj, eval, numEarlyStoppingRounds, prevBooster, externalCheckpointParams) : ml.dmlc.xgboost4j.scala.XGBoost$.MODULE$.train((DMatrix)watches.toMap().apply((Object)"train"), params, numRounds, watches.toMap(), metrics, obj, eval, numEarlyStoppingRounds, prevBooster);
                iterator = TaskContext$.MODULE$.get().partitionId() == 0 ? scala.package$.MODULE$.Iterator().apply((Seq)ScalaRunTime$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)booster), (Object)((IterableOnceOps)watches.toMap().keys().zip((IterableOnce)Predef$.MODULE$.wrapRefArray((Object[])metrics))).toMap((.less.colon.less)$less$colon$less$.MODULE$.refl()))})) : scala.package$.MODULE$.Iterator().empty();
            }
            catch (XGBoostError xgbException) {
                this.logger().error((Object)new StringBuilder(43).append("XGBooster worker ").append(taskId).append(" has failed ").append(attempt).append(" times due to ").toString(), (Throwable)xgbException);
                throw xgbException;
            }
        }
        finally {
            Communicator.shutdown();
            if (watches != null) {
                watches.delete();
            }
        }
        return iterator;
    }

    private <T> T withTracker(int nWorkers, TrackerConf conf, Function1<ITracker, T> block) {
        Object object;
        RabitTracker tracker = new RabitTracker(nWorkers, conf.hostIp(), conf.port(), conf.timeout());
        Predef$.MODULE$.require(tracker.start(), (Function0 & Serializable)() -> "FAULT: Failed to start tracker");
        try {
            object = block.apply((Object)tracker);
        }
        finally {
            tracker.stop();
        }
        return (T)object;
    }

    public Tuple2<Booster, scala.collection.immutable.Map<String, float[]>> trainDistributed(SparkContext sc, Function1<XGBoostExecutionParams, Tuple2<RDD<Function0<Watches>>, Option<RDD<?>>>> buildTrainingData, scala.collection.immutable.Map<String, Object> params) throws XGBoostError {
        Tuple2 tuple2;
        this.logger().info((Object)new StringBuilder(34).append("Running XGBoost ").append(package$.MODULE$.VERSION()).append(" with parameters:\n").append(params.mkString("\n")).toString());
        XGBoostExecutionParamsFactory xgbParamsFactory = new XGBoostExecutionParamsFactory(params, sc);
        XGBoostExecutionParams runtimeParams = xgbParamsFactory.buildXGBRuntimeParams();
        Booster prevBooster = (Booster)runtimeParams.checkpointParam().map((Function1 & Serializable)checkpointParam -> {
            ExternalCheckpointManager checkpointManager = new ExternalCheckpointManager(checkpointParam.checkpointPath(), FileSystem.get((Configuration)sc.hadoopConfiguration()));
            checkpointManager.cleanUpHigherVersions(runtimeParams.numRounds());
            return checkpointManager.loadCheckpointAsScalaBooster();
        }).orNull((.less.colon.less)$less$colon$less$.MODULE$.refl());
        Tuple2 tuple22 = (Tuple2)buildTrainingData.apply((Object)runtimeParams);
        if (tuple22 == null) {
            throw new MatchError((Object)tuple22);
        }
        RDD trainingRDD = (RDD)tuple22._1();
        Option optionalCachedRDD = (Option)tuple22._2();
        Tuple2 tuple23 = new Tuple2((Object)trainingRDD, (Object)optionalCachedRDD);
        RDD trainingRDD2 = (RDD)tuple23._1();
        Option optionalCachedRDD2 = (Option)tuple23._2();
        try {
            try {
                Tuple2 tuple24 = (Tuple2)this.withTracker(runtimeParams.numWorkers(), runtimeParams.trackerConf(), (Function1 & Serializable)tracker -> {
                    Map rabitEnv = tracker.getWorkerArgs();
                    RDDBarrier qual$1 = trainingRDD2.barrier();
                    Function1 & Serializable x$1 = (Function1 & Serializable)iter -> {
                        None$ optionWatches;
                        block0: {
                            optionWatches = None$.MODULE$;
                            if (!iter.hasNext()) break block0;
                            optionWatches = new Some(iter.next());
                        }
                        return (Iterator)optionWatches.map((Function1 & Serializable)buildWatches -> MODULE$.buildDistributedBooster((Function0<Watches>)buildWatches, runtimeParams, rabitEnv, runtimeParams.obj(), runtimeParams.eval(), prevBooster)).getOrElse((Function0 & Serializable)() -> {
                            throw new RuntimeException("No Watches to train");
                        });
                    };
                    boolean x$2 = qual$1.mapPartitions$default$2();
                    RDD boostersAndMetrics = qual$1.mapPartitions((Function1)x$1, x$2, ClassTag$.MODULE$.apply(Tuple2.class));
                    RDD<Tuple2<Booster, scala.collection.immutable.Map<String, float[]>>> boostersAndMetricsWithRes = MODULE$.tryStageLevelScheduling(sc, runtimeParams, (RDD<Tuple2<Booster, scala.collection.immutable.Map<String, float[]>>>)boostersAndMetrics);
                    boolean x$3 = true;
                    Ordering x$4 = boostersAndMetricsWithRes.repartition$default$2(1);
                    Tuple2 tuple2 = ((Tuple2[])boostersAndMetricsWithRes.repartition(1, x$4).collect())[0];
                    if (tuple2 == null) {
                        throw new MatchError((Object)tuple2);
                    }
                    Booster booster = (Booster)tuple2._1();
                    scala.collection.immutable.Map metrics = (scala.collection.immutable.Map)tuple2._2();
                    Tuple2 tuple22 = new Tuple2((Object)booster, (Object)metrics);
                    Booster booster2 = (Booster)tuple22._1();
                    scala.collection.immutable.Map metrics2 = (scala.collection.immutable.Map)tuple22._2();
                    return new Tuple2((Object)booster2, (Object)metrics2);
                });
                if (tuple24 == null) {
                    throw new MatchError((Object)tuple24);
                }
                Booster booster = (Booster)tuple24._1();
                scala.collection.immutable.Map metrics = (scala.collection.immutable.Map)tuple24._2();
                Tuple2 tuple25 = new Tuple2((Object)booster, (Object)metrics);
                Booster booster2 = (Booster)tuple25._1();
                scala.collection.immutable.Map metrics2 = (scala.collection.immutable.Map)tuple25._2();
                runtimeParams.checkpointParam().foreach((Function1 & Serializable)cpParam -> {
                    XGBoost$.$anonfun$trainDistributed$6(runtimeParams, sc, cpParam);
                    return BoxedUnit.UNIT;
                });
                tuple2 = new Tuple2((Object)booster2, (Object)metrics2);
            }
            catch (Throwable t) {
                this.logger().error((Object)"the job was aborted due to ", t);
                throw t;
            }
        }
        finally {
            optionalCachedRDD2.foreach((Function1 & Serializable)x$12 -> x$12.unpersist(x$12.unpersist$default$1()));
        }
        return tuple2;
    }

    private Object writeReplace() {
        return new ModuleSerializationProxy(XGBoost$.class);
    }

    public static final /* synthetic */ float[] $anonfun$buildDistributedBooster$1(int numRounds$1, int x$8) {
        return (float[])Array$.MODULE$.ofDim(numRounds$1, (ClassTag)ClassTag$.MODULE$.Float());
    }

    public static final /* synthetic */ void $anonfun$trainDistributed$6(XGBoostExecutionParams runtimeParams$1, SparkContext sc$1, ExternalCheckpointParams cpParam) {
        if (!((ExternalCheckpointParams)runtimeParams$1.checkpointParam().get()).skipCleanCheckpoint()) {
            ExternalCheckpointManager checkpointManager = new ExternalCheckpointManager(cpParam.checkpointPath(), FileSystem.get((Configuration)sc$1.hadoopConfiguration()));
            checkpointManager.cleanPath();
            return;
        }
    }

    private XGBoost$() {
    }
}

