package org.nd4j.jita.constant;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.ShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/jita/constant/ConstantProtector.class */
public class ConstantProtector {
    private static ConstantProtector ourInstance = new ConstantProtector();
    private List<DataBuffer> protectorLegacy = new CopyOnWriteArrayList();
    private List<Pair<DataBuffer, long[]>> protector = new CopyOnWriteArrayList();
    private List<Map<LongShapeDescriptor, Pair<DataBuffer, long[]>>> deviceCache = new ArrayList();

    public static ConstantProtector getInstance() {
        return ourInstance;
    }

    private ConstantProtector() {
        purgeProtector();
    }

    public void purgeProtector() {
        this.protector = new CopyOnWriteArrayList();
        this.deviceCache = new ArrayList();
        int numberOfDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        for (int i = 0; i < numberOfDevices; i++) {
            this.deviceCache.add(i, new ConcurrentHashMap());
        }
    }

    public void persistDataBuffer(DataBuffer dataBuffer) {
        this.protectorLegacy.add(dataBuffer);
    }

    public void persistDataBuffer(Pair<DataBuffer, long[]> pair) {
        this.protector.add(pair);
    }

    public void persistDataBuffer(int i, ShapeDescriptor shapeDescriptor, Pair<DataBuffer, long[]> pair) {
        this.deviceCache.get(i).put(LongShapeDescriptor.fromShapeDescriptor(shapeDescriptor), pair);
    }

    public void persistDataBuffer(int i, LongShapeDescriptor longShapeDescriptor, Pair<DataBuffer, long[]> pair) {
        this.deviceCache.get(i).put(longShapeDescriptor, pair);
    }

    public Pair<DataBuffer, long[]> getDataBuffer(int i, ShapeDescriptor shapeDescriptor) {
        return this.deviceCache.get(i).get(LongShapeDescriptor.fromShapeDescriptor(shapeDescriptor));
    }

    public Pair<DataBuffer, long[]> getDataBuffer(int i, LongShapeDescriptor longShapeDescriptor) {
        return this.deviceCache.get(i).get(longShapeDescriptor);
    }

    public boolean containsDataBuffer(int i, ShapeDescriptor shapeDescriptor) {
        return this.deviceCache.get(i).containsKey(LongShapeDescriptor.fromShapeDescriptor(shapeDescriptor));
    }

    public boolean containsDataBuffer(int i, LongShapeDescriptor longShapeDescriptor) {
        return this.deviceCache.get(i).containsKey(longShapeDescriptor);
    }
}
