/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
 */
macro package magic.dsl

import std.ast.*
import std.collection.{ArrayList, map, collectArrayList}
import std.fs.File

/**
 * For each `@template` function, we transform string literals in its function body,
 * and replace it as an object that implements `operator func()`.
 * For example,
 * ```
 * @template
 * func foo(a: String, b: String) { ... }
 * ```
 * will be transformed as
 * ```
 * func foo_template__(...) { ... }
 * ```
 */
func transformTemplate(input: Tokens): Tokens {
    // Step 1: transform the function body
    let tokens = transformStringLiteral(input, isDecl: true)
    let decl: Decl = parseDecl(tokens)
    if (decl.isFuncDecl()) {
        let funcDecl = decl.asFuncDecl()
        // Step 2: transform the template function
        return transformTemplateFunc(funcDecl)
    } else {
        eprintln("@template should be called on function declarations")
        return input
    }
}

/**
 * 1. Find all independent string literals, which are top-level STRING_LITERAL tokens
 *    Since we cannot mutate AST in-place, we collect positions of these tokens
 * 2. Replace all these string literals (tokens) with a call of `prompt_builder`
 * 3. There is a corner case where we treat top-level function calling with dollar-identifiers like `$id(...)`
 *    as a syntax sugar of `"${id(...)}"` where `id` is a prompt function.
 */
func transformStringLiteral(input: Tokens, isDecl!: Bool): Tokens {
    // Desugar `$id(...)` as "id(...)"
    let tokens = transformTemplateFuncCall(input)
    let node: Node = if (isDecl) {
        parseDecl(tokens)
    } else {
        parseExpr(tokens)
    }
    let collector = collectIndependentStringLiteral(node)
    let transformedTokens = Tokens()
    for (token in tokens) {
        if ((token.kind == TokenKind.STRING_LITERAL ||
            token.kind == TokenKind.MULTILINE_STRING) &&
            collector.isIndependentStringLiteral(token)) {
            transformedTokens.append(quote(
                $PROMPT_BUILDER.append($token)
                // $token
            ))
        } else {
            transformedTokens.append(token)
        }
    }
    return transformedTokens
}

/**
 * We treat the function call with a dollar-identifier like `$id(...)`
 * as a syntax sugar of `"${id(...)}"` where `id` is a prompt function.
 */
func transformTemplateFuncCall(input: Tokens): Tokens {
    let transformedTokens = Tokens()
    var currPos = 0
    while (currPos < input.size) {
        let token = input[currPos]
        if (token.kind == TokenKind.DOLLAR_IDENTIFIER) {
            let (argTokens, endPos) = extractFuncArguments(input, currPos)
            // The new function identifier should remove the `$` prefix
            let newToken = Token(TokenKind.IDENTIFIER, token.value[1..token.value.size])
            // Build the function call
            let callTokens = quote($newToken($argTokens))
            // Wrap it in a string literal
            let strLit = "\${" + callTokens.toString() + "}"
            transformedTokens.append(Token(TokenKind.STRING_LITERAL, strLit))
            currPos = endPos
        } else {
            transformedTokens.append(token)
        }
        currPos += 1
    }
    return transformedTokens
}

func extractFuncArguments(input: Tokens, begin: Int64): (Tokens, Int64) {
    var currPos = begin
    // For simplicity, we assume currPos is always valid
    let nextToken = input[currPos + 1]
    currPos += 1
    if (nextToken.kind != TokenKind.LPAREN) {
        throw DslException("A LPAREN is expected")
    }
    currPos += 1

    let result = Tokens()
    while (currPos < input.size && input[currPos].kind != TokenKind.RPAREN) {
        result.append(input[currPos])
        currPos += 1
    }
    return (result, currPos)
}

/**
 * Rename the template function
 * and insert a `StringBuilder` to the function body.
 */
func transformTemplateFunc(funcDecl: FuncDecl): Tokens {
    // Make sure that there is no `return`s in the template function.
    if (hasReturnExpr(funcDecl)) {
        eprintln("Template functions cannot contain return expressions.")
        return Tokens()
    }
    let modifier = funcDecl.modifiers
    let name = funcDecl.identifier
    let newName = newIdentifierToken("${name.value}${TEMPLATE_FUNC_NAME_POSTFIX}")
    let paramList = funcDecl.funcParams
    let retType = checkFuncRetString(funcDecl)
    let funcBody = funcDecl.block.nodes
    let genericTypeParam = getGenericTypeParamsTokens(funcDecl)
    let genericConstraint = getGenericConstraintsTokens(funcDecl)

    return quote (
        $modifier func $newName $genericTypeParam ( $paramList ): $retType $genericConstraint {
            let $PROMPT_BUILDER = StringBuilder()
            $funcBody
            return $PROMPT_BUILDER.toString()
        }
    )
}

/**
 * Check whether the return type of a functions is String.
 */
func checkFuncRetString(funcDecl: FuncDecl): Token {
    try {
        let ty = funcDecl.declType
        if (!ty.isRefType()) {
            eprintln("Function ${funcDecl.identifier.value} should return String")
            return Token(TokenKind.IDENTIFIER, "String")
        }
        let refTy: RefType = ty.asRefType()
        let id = refTy.identifier
        if (!(id.kind == TokenKind.IDENTIFIER && id.value == "String")) {
            eprintln("Prompt function should return String")
        }
        return Token(TokenKind.IDENTIFIER, "String")
    } catch (_: ASTException) {
        return Token(TokenKind.IDENTIFIER, "String")
    }
}

/**
 * Concatenate tokens of the included file and tokens of the function body
 */
func concatPromptTokens(input: Tokens, includeTokens: Tokens): Tokens {
    let funcDecl: FuncDecl = parseFuncDecl(input)
    let modifierTokens = funcDecl.modifiers
    let nameTokens = funcDecl.identifier
    let paramListTokens = funcDecl.funcParams
    let retTypeTokens = getRetTypeTokens(funcDecl)
    let funcBodyTokens = funcDecl.block
    let genericTypeParamTokens = getGenericTypeParamsTokens(funcDecl)
    let genericConstraintTokens = getGenericConstraintsTokens(funcDecl)

    return quote (
        $modifierTokens func $nameTokens $genericTypeParamTokens $paramListTokens: $retTypeTokens $genericConstraintTokens {
            $includeTokens
            $funcBodyTokens
        }
    )
}

func checkAllParamOfStringType(funcDecl: FuncDecl): Bool {
    let funcParamList: ArrayList<FuncParam> = funcDecl.funcParams
    // Iterate all parameters.
    for (param in funcParamList) {
        let paramTy = param.paramType
        if (!isString(paramTy)) {
            return false
        }
    }
    return true
}

func isString(ty: TypeNode): Bool {
    if (!ty.isRefType()) {
        return false
    }
    let refTy = ty.asRefType()
    let name = refTy.identifier
    return name.kind == TokenKind.IDENTIFIER && name.value == "String"
}

func getFuncParamNameList(funcDecl: FuncDecl): ArrayList<String> {
    let funcParamList: ArrayList<FuncParam> = funcDecl.funcParams
    let paramNames = ArrayList<String>()
    // Iterate all parameters.
    for (param in funcParamList) {
        let token = param.identifier
        paramNames.append(token.value)
    }
    return paramNames
}

func getFuncParamSize(funcDecl: FuncDecl): Int64 {
    return funcDecl.funcParams.size
}