package org.testcontainers.ollama;

import com.github.dockerjava.api.DockerClient;
import com.github.dockerjava.api.model.DeviceRequest;
import com.github.dockerjava.api.model.Info;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.testcontainers.DockerClientFactory;
import org.testcontainers.containers.GenericContainer;
import org.testcontainers.utility.DockerImageName;

/* loaded from: input_file:org/testcontainers/ollama/OllamaContainer.class */
public class OllamaContainer extends GenericContainer<OllamaContainer> {
    private static final DockerImageName DOCKER_IMAGE_NAME = DockerImageName.parse("ollama/ollama");
    private static final int OLLAMA_PORT = 11434;

    public OllamaContainer(String str) {
        this(DockerImageName.parse(str));
    }

    public OllamaContainer(DockerImageName dockerImageName) {
        super(dockerImageName);
        dockerImageName.assertCompatibleWith(new DockerImageName[]{DOCKER_IMAGE_NAME});
        Map runtimes = ((Info) this.dockerClient.infoCmd().exec()).getRuntimes();
        if (runtimes != null && runtimes.containsKey("nvidia")) {
            withCreateContainerCmdModifier(createContainerCmd -> {
                createContainerCmd.getHostConfig().withDeviceRequests(Collections.singletonList(new DeviceRequest().withCapabilities(Collections.singletonList(Collections.singletonList("gpu"))).withCount(-1)));
            });
        }
        withExposedPorts(new Integer[]{Integer.valueOf(OLLAMA_PORT)});
    }

    public void commitToImage(String str) {
        if (DockerImageName.parse(getDockerImageName()).equals(DockerImageName.parse(str))) {
            return;
        }
        DockerClient client = DockerClientFactory.instance().client();
        if (((List) client.listImagesCmd().withReferenceFilter(str).exec()).isEmpty()) {
            DockerImageName parse = DockerImageName.parse(str);
            client.commitCmd(getContainerId()).withRepository(parse.getUnversionedPart()).withLabels(Collections.singletonMap("org.testcontainers.sessionId", "")).withTag(parse.getVersionPart()).exec();
        }
    }

    public int getPort() {
        return getMappedPort(OLLAMA_PORT).intValue();
    }

    public String getEndpoint() {
        return "http://" + getHost() + ":" + getPort();
    }
}
