package org.deeplearning4j.optimize.solvers.accumulation;

import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.optimize.api.StepFunction;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/accumulation/BasicGradientsAccumulator.class */
public class BasicGradientsAccumulator implements GradientsAccumulator {
    private static final Logger log = LoggerFactory.getLogger(BasicGradientsAccumulator.class);
    protected MessageHandler handler;
    protected transient IndexedTail gradients;
    protected transient INDArray storage;
    protected transient INDArray updates;
    protected transient AtomicLong ownCounter;
    protected transient AtomicLong extCounter;
    protected long[] shape;
    protected char ordering;
    protected int parties;
    protected CyclicBarrier barrier;
    protected AtomicLong firstOne;
    protected List<INDArray> candidates;
    protected ReentrantReadWriteLock updatesLock;
    protected AtomicBoolean hasSomething;

    public BasicGradientsAccumulator(int i) {
        this(i, new LocalHandler());
    }

    public BasicGradientsAccumulator(int i, @NonNull MessageHandler messageHandler) {
        this.ownCounter = new AtomicLong(0L);
        this.extCounter = new AtomicLong(0L);
        this.parties = 0;
        this.firstOne = new AtomicLong(-1L);
        this.candidates = new CopyOnWriteArrayList();
        this.updatesLock = new ReentrantReadWriteLock();
        this.hasSomething = new AtomicBoolean(false);
        if (messageHandler == null) {
            throw new NullPointerException("handler is marked non-null but is null");
        }
        this.gradients = new IndexedTail(i);
        this.handler = messageHandler;
        this.handler.initialize(this);
        this.parties = i;
        this.barrier = new CyclicBarrier(i);
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public IndexedTail getExternalSource() {
        return this.gradients;
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void applyUpdate(StepFunction stepFunction, INDArray iNDArray, INDArray iNDArray2, boolean z) {
        try {
            this.updatesLock.readLock().lock();
            this.firstOne.compareAndSet(-1L, Thread.currentThread().getId());
            if (this.hasSomething.get()) {
                stepFunction.step(iNDArray, this.updates);
            }
            this.barrier.await();
            if (this.firstOne.get() == Thread.currentThread().getId()) {
                this.updates.assign(Double.valueOf(EvaluationBinary.DEFAULT_EDGE_VALUE));
                this.hasSomething.set(false);
                this.firstOne.set(-1L);
            }
            this.updatesLock.readLock().unlock();
            this.barrier.await();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        } catch (BrokenBarrierException e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void markExternalUpdates(boolean z) {
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void applyUpdate(StepFunction stepFunction, INDArray iNDArray, INDArray iNDArray2, double d) {
        try {
            this.updatesLock.readLock().lock();
            this.firstOne.compareAndSet(-1L, Thread.currentThread().getId());
            if (this.hasSomething.get()) {
                stepFunction.step(iNDArray, this.updates, d);
            }
            this.barrier.await();
            if (this.firstOne.get() == Thread.currentThread().getId()) {
                this.updates.assign(Double.valueOf(EvaluationBinary.DEFAULT_EDGE_VALUE));
                this.hasSomething.set(false);
                this.firstOne.set(-1L);
            }
            this.updatesLock.readLock().unlock();
            this.barrier.await();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        } catch (BrokenBarrierException e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void storeUpdate(INDArray iNDArray, int i, int i2) {
        try {
            Nd4j.getExecutioner().commit();
            this.firstOne.compareAndSet(-1L, Thread.currentThread().getId());
            this.candidates.add(iNDArray);
            this.barrier.await();
            if (this.firstOne.get() == Thread.currentThread().getId()) {
                if (this.storage == null) {
                    this.shape = iNDArray.shape();
                    this.ordering = iNDArray.ordering();
                    MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
                    try {
                        this.storage = Nd4j.create(this.shape, this.ordering);
                        if (scopeOutOfWorkspaces != null) {
                            scopeOutOfWorkspaces.close();
                        }
                    } catch (Throwable th) {
                        if (scopeOutOfWorkspaces != null) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                Nd4j.accumulate(this.storage, this.candidates);
                Nd4j.getExecutioner().commit();
                if (this.handler.broadcastUpdates(this.storage, i, i2)) {
                    this.ownCounter.getAndIncrement();
                }
                this.firstOne.set(-1L);
                this.candidates.clear();
            }
            this.barrier.await();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        } catch (BrokenBarrierException e2) {
            throw new RuntimeException(e2);
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void receiveUpdate(INDArray iNDArray) {
        this.extCounter.getAndIncrement();
        this.updatesLock.writeLock().lock();
        if (this.updates == null) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            try {
                this.updates = Nd4j.create(iNDArray.shape(), iNDArray.ordering());
                if (scopeOutOfWorkspaces != null) {
                    scopeOutOfWorkspaces.close();
                }
            } catch (Throwable th) {
                if (scopeOutOfWorkspaces != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        this.hasSomething.compareAndSet(false, true);
        this.updates.addi(iNDArray);
        Nd4j.getExecutioner().commit();
        this.updatesLock.writeLock().unlock();
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void reset() {
        this.updatesLock.writeLock().lock();
        if (this.storage != null) {
            this.storage.assign(Float.valueOf(0.0f));
        }
        if (this.updates != null) {
            this.updates.assign(Float.valueOf(0.0f));
        }
        this.updatesLock.writeLock().unlock();
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void touch() {
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public void setExternalSource(IndexedTail indexedTail) {
        this.gradients = indexedTail;
    }

    @Override // org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator
    public boolean hasAnything() {
        return false;
    }
}
