package org.apache.kafka.streams.processor.internals.assignment;

import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.UUID;
import java.util.function.BiConsumer;
import java.util.function.BiPredicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.processor.TaskId;
import org.apache.kafka.streams.processor.internals.TopologyMetadata;
import org.apache.kafka.streams.processor.internals.assignment.RackAwareTaskAssignor;

/* loaded from: input_file:org/apache/kafka/streams/processor/internals/assignment/BalanceSubtopologyGraphConstructor.class */
public class BalanceSubtopologyGraphConstructor implements RackAwareGraphConstructor {
    private final Map<TopologyMetadata.Subtopology, Set<TaskId>> tasksForTopicGroup;

    public BalanceSubtopologyGraphConstructor(Map<TopologyMetadata.Subtopology, Set<TaskId>> map) {
        this.tasksForTopicGroup = map;
    }

    @Override // org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor
    public int getSinkNodeID(List<TaskId> list, List<UUID> list2, Map<TopologyMetadata.Subtopology, Set<TaskId>> map) {
        return list2.size() + list.size() + (list2.size() * map.size());
    }

    @Override // org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor
    public int getClientNodeId(int i, List<TaskId> list, List<UUID> list2, int i2) {
        return list.size() + (list2.size() * i2) + i;
    }

    @Override // org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor
    public int getClientIndex(int i, List<TaskId> list, List<UUID> list2, int i2) {
        return (i - list.size()) - (list2.size() * i2);
    }

    private static int getSecondStageClientNodeId(List<TaskId> list, List<UUID> list2, Map<TopologyMetadata.Subtopology, Set<TaskId>> map, int i) {
        return list.size() + (list2.size() * map.size()) + i;
    }

    @Override // org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor
    public Graph<Integer> constructTaskGraph(List<UUID> list, List<TaskId> list2, Map<UUID, ClientState> map, Map<TaskId, UUID> map2, Map<UUID, Integer> map3, BiPredicate<ClientState, TaskId> biPredicate, RackAwareTaskAssignor.CostFunction costFunction, int i, int i2, boolean z, boolean z2) {
        validateTasks(list2);
        Graph<Integer> graph = new Graph<>();
        for (TaskId taskId : list2) {
            for (Map.Entry<UUID, ClientState> entry : map.entrySet()) {
                if (biPredicate.test(entry.getValue(), taskId)) {
                    map3.merge(entry.getKey(), 1, (v0, v1) -> {
                        return Integer.sum(v0, v1);
                    });
                }
            }
        }
        constructEdges(graph, list2, list, map, map2, map3, biPredicate, costFunction, i, i2, z, z2);
        long calculateMaxFlow = graph.calculateMaxFlow();
        if (calculateMaxFlow != list2.size()) {
            throw new IllegalStateException("max flow calculated: " + calculateMaxFlow + " doesn't match taskSize: " + list2.size());
        }
        return graph;
    }

    @Override // org.apache.kafka.streams.processor.internals.assignment.RackAwareGraphConstructor
    public boolean assignTaskFromMinCostFlow(Graph<Integer> graph, List<UUID> list, List<TaskId> list2, Map<UUID, ClientState> map, Map<UUID, Integer> map2, Map<TaskId, UUID> map3, BiConsumer<ClientState, TaskId> biConsumer, BiConsumer<ClientState, TaskId> biConsumer2, BiPredicate<ClientState, TaskId> biPredicate) {
        TreeMap treeMap = new TreeMap(this.tasksForTopicGroup);
        HashSet hashSet = new HashSet(list2);
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        boolean z = false;
        Iterator it = treeMap.entrySet().iterator();
        while (it.hasNext()) {
            for (TaskId taskId : new TreeSet((Collection) ((Map.Entry) it.next()).getValue())) {
                if (hashSet.contains(taskId)) {
                    KeyValue<Boolean, Integer> assignTaskToClient = assignTaskToClient(graph, taskId, i, i2, map, list, list2, map3, biConsumer, biConsumer2);
                    z |= assignTaskToClient.key.booleanValue();
                    i3 += assignTaskToClient.value.intValue();
                    i++;
                }
            }
            i2++;
        }
        validateAssignedTask(list2, i3, map, map2, biPredicate);
        return z;
    }

    private void validateTasks(List<TaskId> list) {
        Set set = (Set) this.tasksForTopicGroup.values().stream().flatMap((v0) -> {
            return v0.stream();
        }).collect(Collectors.toSet());
        for (TaskId taskId : list) {
            if (!set.contains(taskId)) {
                throw new IllegalStateException("Task " + taskId + " not in tasksForTopicGroup");
            }
        }
    }

    private void constructEdges(Graph<Integer> graph, List<TaskId> list, List<UUID> list2, Map<UUID, ClientState> map, Map<TaskId, UUID> map2, Map<UUID, Integer> map3, BiPredicate<ClientState, TaskId> biPredicate, RackAwareTaskAssignor.CostFunction costFunction, int i, int i2, boolean z, boolean z2) {
        HashSet hashSet = new HashSet(list);
        TreeMap treeMap = new TreeMap(this.tasksForTopicGroup);
        int sinkNodeID = getSinkNodeID(list, list2, this.tasksForTopicGroup);
        int i3 = 0;
        int i4 = 0;
        Iterator it = treeMap.entrySet().iterator();
        while (it.hasNext()) {
            TreeSet<TaskId> treeSet = new TreeSet((Collection) ((Map.Entry) it.next()).getValue());
            for (int i5 = 0; i5 < list2.size(); i5++) {
                UUID uuid = list2.get(i5);
                int clientNodeId = getClientNodeId(i5, list, list2, i4);
                int i6 = i3;
                int i7 = 0;
                for (TaskId taskId : treeSet) {
                    if (hashSet.contains(taskId)) {
                        i7++;
                        boolean test = biPredicate.test(map.get(uuid), taskId);
                        graph.addEdge(Integer.valueOf(i6), Integer.valueOf(clientNodeId), 1, costFunction.getCost(taskId, uuid, test, i, i2, z2), 0);
                        i6++;
                        if (!test) {
                            continue;
                        } else {
                            if (!z && map2.containsKey(taskId)) {
                                throw new IllegalArgumentException("Task " + taskId + " assigned to multiple clients " + uuid + ", " + map2.get(taskId));
                            }
                            map2.put(taskId, uuid);
                        }
                    }
                }
                if (i7 > 0) {
                    graph.addEdge(Integer.valueOf(clientNodeId), Integer.valueOf(getSecondStageClientNodeId(list, list2, this.tasksForTopicGroup, i5)), map3.containsKey(uuid) ? (int) Math.ceil(((map3.get(uuid).intValue() * 1.0d) / list.size()) * i7) : 0, 0, 0);
                }
            }
            Stream stream = treeSet.stream();
            hashSet.getClass();
            i3 += (int) stream.filter((v1) -> {
                return r2.contains(v1);
            }).count();
            i4++;
        }
        int i8 = 0;
        Iterator it2 = treeMap.entrySet().iterator();
        while (it2.hasNext()) {
            Iterator it3 = new TreeSet((Collection) ((Map.Entry) it2.next()).getValue()).iterator();
            while (it3.hasNext()) {
                if (hashSet.contains((TaskId) it3.next())) {
                    graph.addEdge(-1, Integer.valueOf(i8), 1, 0, 0);
                    i8++;
                }
            }
        }
        for (int i9 = 0; i9 < list2.size(); i9++) {
            graph.addEdge(Integer.valueOf(getSecondStageClientNodeId(list, list2, this.tasksForTopicGroup, i9)), Integer.valueOf(sinkNodeID), map3.getOrDefault(list2.get(i9), 0).intValue(), 0, 0);
        }
        graph.setSourceNode(-1);
        graph.setSinkNode(Integer.valueOf(sinkNodeID));
    }
}
