/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.agents.integrations.chatmodels.ollama;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.github.ollama4j.OllamaAPI;
import io.github.ollama4j.exceptions.RoleNotFoundException;
import io.github.ollama4j.models.chat.OllamaChatMessage;
import io.github.ollama4j.models.chat.OllamaChatMessageRole;
import io.github.ollama4j.models.chat.OllamaChatResult;
import io.github.ollama4j.tools.Tools;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.flink.agents.api.chat.messages.ChatMessage;
import org.apache.flink.agents.api.chat.messages.MessageRole;
import org.apache.flink.agents.api.chat.model.BaseChatModelConnection;
import org.apache.flink.agents.api.resource.Resource;
import org.apache.flink.agents.api.resource.ResourceDescriptor;
import org.apache.flink.agents.api.resource.ResourceType;
import org.apache.flink.agents.api.tools.Tool;
import org.apache.flink.agents.api.tools.ToolParameters;

public class OllamaChatModelConnection
extends BaseChatModelConnection {
    private final OllamaAPI client;
    private final Pattern pattern;

    public OllamaChatModelConnection(ResourceDescriptor descriptor, BiFunction<String, ResourceType, Resource> getResource) {
        super(descriptor, getResource);
        String endpoint = (String)descriptor.getArgument("endpoint");
        if (endpoint == null || endpoint.isEmpty()) {
            throw new IllegalArgumentException("endpoint should not be null or empty.");
        }
        this.client = new OllamaAPI(endpoint);
        Integer maxChatToolCallRetries = (Integer)descriptor.getArgument("maxChatToolCallRetries");
        this.client.setMaxChatToolCallRetries(maxChatToolCallRetries != null ? maxChatToolCallRetries : 10);
        Integer requestTimeout = (Integer)descriptor.getArgument("requestTimeout");
        this.client.setRequestTimeoutSeconds(requestTimeout != null ? (long)requestTimeout.intValue() : 10L);
        this.pattern = Pattern.compile("<think>(.*?)</think>", 32);
    }

    public OllamaChatModelConnection(String endpoint, BiFunction<String, ResourceType, Resource> getResource) {
        this(new ResourceDescriptor(OllamaChatModelConnection.class.getName(), Map.of("endpoint", endpoint)), getResource);
    }

    private void registerTools(List<Tool> tools) {
        ObjectMapper mapper = new ObjectMapper();
        try {
            for (Tool tool : tools) {
                Map<String, Object> schema = mapper.readValue(tool.getMetadata().getInputSchema(), new TypeReference<Map<String, Object>>(){});
                Map properties = (Map)schema.get("properties");
                List required = (List)schema.get("required");
                HashMap<String, Tools.PromptFuncDefinition.Property> propertiesMap = new HashMap<String, Tools.PromptFuncDefinition.Property>();
                for (Map.Entry entry : properties.entrySet()) {
                    String paramName = (String)entry.getKey();
                    Map paramSchema = (Map)entry.getValue();
                    String type = (String)paramSchema.get("type");
                    String description = (String)paramSchema.get("description");
                    propertiesMap.put(paramName, Tools.PromptFuncDefinition.Property.builder().type(type).description(description).required(required.contains(paramName)).build());
                }
                Tools.ToolSpecification toolSpec = Tools.ToolSpecification.builder().functionName(tool.getName()).functionDescription(tool.getDescription()).toolPrompt(Tools.PromptFuncDefinition.builder().type("prompt").function(Tools.PromptFuncDefinition.PromptFuncSpec.builder().name(tool.getName()).description(tool.getDescription()).parameters(Tools.PromptFuncDefinition.Parameters.builder().type("object").properties(propertiesMap).build()).build()).build()).toolFunction(arguments -> tool.call(new ToolParameters(arguments))).build();
                this.client.registerTool(toolSpec);
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private OllamaChatMessage convertToOllamaChatMessages(ChatMessage message) {
        MessageRole role = message.getRole();
        try {
            OllamaChatMessageRole ollamaRole = OllamaChatMessageRole.getRole(role.name().toLowerCase());
            return new OllamaChatMessage(ollamaRole, message.getContent());
        }
        catch (RoleNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public ChatMessage chat(List<ChatMessage> messages, List<Tool> tools, Map<String, Object> arguments) {
        try {
            this.registerTools(tools);
            List<OllamaChatMessage> ollamaChatMessages = messages.stream().map(this::convertToOllamaChatMessages).collect(Collectors.toList());
            OllamaChatResult ollamaChatResult = this.client.chat((String)arguments.get("model"), ollamaChatMessages);
            return this.extraReasoning(ollamaChatResult.getResponse());
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private ChatMessage extraReasoning(String response) {
        Matcher matcher = this.pattern.matcher(response);
        StringBuilder reasoning = new StringBuilder();
        while (matcher.find()) {
            reasoning.append(matcher.group(1));
        }
        response = matcher.replaceAll("").strip();
        ChatMessage responseMessage = ChatMessage.assistant(response);
        HashMap<String, Object> extraArgs = new HashMap<String, Object>();
        extraArgs.put("reasoning", reasoning.toString().strip());
        responseMessage.setExtraArgs(extraArgs);
        return responseMessage;
    }
}

