#!/usr/bin/env python3
"""
Phase 2 v2: Parse Compose UI Trees using tree-sitter AST

Replaces regex+brace-depth parsing with proper AST analysis via tree-sitter-kotlin.
Handles all Kotlin patterns: lambda defaults, CompositionLocalProvider, nested control
flow, Modifier chains, and recursive custom composable expansion.

Usage:
    python3 parse_compose_tree_v2.py <project_root> <discovery_json> <output_dir>

Output:
    <output_dir>/compose_trees/<ScreenName>.json
"""

import json
import os
import re
import sys
from pathlib import Path

import tree_sitter
import tree_sitter_kotlin as tsk

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from salt_ui_mapping import get_component_info, is_non_ui_call, COMPOSE_NON_UI

KOTLIN = tree_sitter.Language(tsk.language())

# ---------------------------------------------------------------------------
# String resource resolver (same logic as v1, kept lightweight)
# ---------------------------------------------------------------------------
import xml.etree.ElementTree as ET

class StringResourceResolver:
    def __init__(self):
        self.strings: dict[str, str] = {}

    def load(self, project_root: str):
        res_base = os.path.join(project_root, "app", "src", "main", "res")
        if not os.path.isdir(res_base):
            return
        for d in sorted(os.listdir(res_base)):
            if d.startswith("values"):
                vdir = os.path.join(res_base, d)
                if os.path.isdir(vdir):
                    for f in os.listdir(vdir):
                        if f.endswith(".xml"):
                            self._parse(os.path.join(vdir, f), d == "values")

    def _parse(self, path, is_default):
        try:
            root = ET.parse(path).getroot()
        except ET.ParseError:
            return
        for e in root:
            if e.tag == "string":
                name = e.get("name", "")
                text = (e.text or "").strip()
                if name and (is_default or name not in self.strings):
                    self.strings[name] = text

    def resolve(self, text: str) -> str | None:
        m = re.search(r'R\.string\.(\w+)', text)
        if m and m.group(1) in self.strings:
            return self.strings[m.group(1)]
        return None

# ---------------------------------------------------------------------------
# AST helpers
# ---------------------------------------------------------------------------

def node_text(node) -> str:
    return node.text.decode("utf-8", errors="replace") if node else ""

def find_children(node, type_name: str):
    return [c for c in node.children if c.type == type_name]

def find_child(node, type_name: str):
    for c in node.children:
        if c.type == type_name:
            return c
    return None

def is_composable_name(name: str) -> bool:
    """A call is a composable if it starts with uppercase and isn't a known non-UI."""
    return bool(name) and name[0].isupper() and not is_non_ui_call(name)

def get_call_name(call_node) -> str:
    """Extract the function name from a call_expression node."""
    for c in call_node.children:
        if c.type == "identifier":
            return node_text(c)
        if c.type == "navigation_expression":
            # e.g. Modifier.padding — take the last identifier
            parts = []
            for sub in c.children:
                if sub.type == "identifier":
                    parts.append(node_text(sub))
            return ".".join(parts) if parts else node_text(c)
        if c.type == "call_expression":
            # chained call: something().trailing_lambda — recurse to innermost
            return get_call_name(c)
    return ""

def get_simple_call_name(call_node) -> str:
    """Get just the top-level function name (not chained)."""
    for c in call_node.children:
        if c.type == "identifier":
            return node_text(c)
    # For chained calls, find the innermost identifier
    inner = find_child(call_node, "call_expression")
    if inner:
        return get_simple_call_name(inner)
    nav = find_child(call_node, "navigation_expression")
    if nav:
        ids = [node_text(c) for c in nav.children if c.type == "identifier"]
        return ids[-1] if ids else ""
    return ""

# ---------------------------------------------------------------------------
# Modifier chain parser
# ---------------------------------------------------------------------------

def parse_modifier_chain(node) -> list[dict]:
    """Parse a Modifier chain (e.g., Modifier.fillMaxSize().padding(16.dp)) into a list."""
    mods = []
    _collect_modifier_calls(node, mods)
    return mods

def _collect_modifier_calls(node, result: list):
    """Recursively collect modifier calls from a chained expression."""
    if node is None:
        return

    if node.type == "call_expression":
        # Check if this is a modifier call (method on navigation_expression chain)
        nav = find_child(node, "navigation_expression")
        args_node = find_child(node, "value_arguments")

        if nav:
            # Recurse into the chain first (left to right ordering)
            inner_call = find_child(nav, "call_expression")
            if inner_call:
                _collect_modifier_calls(inner_call, result)

            # Now extract this modifier method
            ids = [node_text(c) for c in nav.children if c.type == "identifier"]
            method_name = ids[-1] if ids else ""
            if method_name and method_name != "Modifier":
                args = _extract_value_args(args_node) if args_node else {}
                entry = {"name": method_name, "args": args}
                # Check for trailing lambda (e.g., .clickable { })
                lambda_node = find_child(node, "annotated_lambda")
                if lambda_node:
                    entry["has_lambda"] = True
                    # Extract navigation from onClick lambdas
                    nav_targets = _find_navigate_calls(lambda_node)
                    if nav_targets:
                        entry["navigates_to"] = nav_targets
                result.append(entry)
        elif args_node:
            # Simple call like Modifier.then(...)
            name = get_simple_call_name(node)
            if name and name != "Modifier":
                args = _extract_value_args(args_node)
                result.append({"name": name, "args": args})

    elif node.type == "navigation_expression":
        inner_call = find_child(node, "call_expression")
        if inner_call:
            _collect_modifier_calls(inner_call, result)

def _find_navigate_calls(node) -> list[str]:
    """Find navController.navigate(ScreenRoute.XXX) calls in a subtree."""
    targets = []
    text = node_text(node)
    for m in re.finditer(r'(?:navigate|safeNavigate)\s*\(\s*(?:ScreenRoute\.)?(\w+)', text):
        targets.append(m.group(1))
    return targets

# ---------------------------------------------------------------------------
# Value argument extraction
# ---------------------------------------------------------------------------

def _extract_value_args(args_node) -> dict:
    """Extract named and positional arguments from a value_arguments node."""
    if args_node is None:
        return {}
    args = {}
    pos_idx = 0
    for child in args_node.children:
        if child.type == "value_argument":
            name_node = find_child(child, "identifier")
            eq = any(c.type == "=" for c in child.children)
            if name_node and eq:
                # Named arg: name = value
                name = node_text(name_node)
                # Value is everything after '='
                value_parts = []
                past_eq = False
                for c in child.children:
                    if c.type == "=":
                        past_eq = True
                        continue
                    if past_eq and c.type != "identifier":
                        value_parts.append(node_text(c))
                    elif past_eq and c.type == "identifier" and c != name_node:
                        value_parts.append(node_text(c))
                value = " ".join(value_parts).strip()
                if not value:
                    # Fallback: grab everything that's not the name and '='
                    full = node_text(child)
                    m = re.match(r'\w+\s*=\s*(.+)', full, re.DOTALL)
                    value = m.group(1).strip() if m else full
                args[name] = value
            else:
                # Positional arg
                value = node_text(child).strip()
                if value:
                    args[f"_pos_{pos_idx}"] = value
                    pos_idx += 1
    return args

# ---------------------------------------------------------------------------
# Core: AST → Compose tree conversion
# ---------------------------------------------------------------------------

class ComposeTreeBuilder:
    def __init__(self, source: bytes, resolver: StringResourceResolver | None = None):
        self.source = source
        self.resolver = resolver
        self.parser = tree_sitter.Parser(KOTLIN)
        self.tree = self.parser.parse(source)

    def find_composable_functions(self) -> list[tuple[str, any]]:
        """Find all @Composable function declarations."""
        results = []
        self._find_composable_fns(self.tree.root_node, results)
        return results

    def _find_composable_fns(self, node, results):
        if node.type == "function_declaration":
            mods = find_child(node, "modifiers")
            if mods and b"@Composable" in mods.text:
                name = find_child(node, "identifier")
                if name:
                    results.append((node_text(name), node))
        for child in node.children:
            self._find_composable_fns(child, results)

    def extract_function_params(self, func_node) -> list[dict]:
        """Extract function parameters."""
        params_node = find_child(func_node, "function_value_parameters")
        if not params_node:
            return []
        params = []
        for child in params_node.children:
            if child.type == "parameter":
                name = find_child(child, "identifier")
                type_node = find_child(child, "user_type") or find_child(child, "function_type")
                pname = node_text(name) if name else ""
                ptype = node_text(type_node) if type_node else ""
                # Find default value
                default = None
                eq_found = False
                for c in child.parent.children:
                    if c.type == "=" and c.start_point[0] == child.end_point[0]:
                        eq_found = True
                    elif eq_found and c != child:
                        default = node_text(c).strip()
                        break
                params.append({"name": pname, "type": ptype, "default": default})
        return params

    def parse_function_body(self, func_node) -> list[dict]:
        """Parse a function body into a list of UI tree nodes."""
        body = find_child(func_node, "function_body")
        if not body:
            return []
        block = find_child(body, "block")
        if not block:
            return []
        return self._parse_block(block)

    def _parse_block(self, block_node, depth=0) -> list[dict]:
        """Parse a block { ... } into a list of composable tree nodes."""
        if depth > 20:
            return []
        nodes = []
        for child in block_node.children:
            if child.type in ("{", "}"):
                continue
            extracted = self._extract_ui_node(child, depth)
            if extracted:
                if isinstance(extracted, list):
                    nodes.extend(extracted)
                else:
                    nodes.append(extracted)
        return nodes

    def _extract_ui_node(self, node, depth=0) -> dict | list[dict] | None:
        """Extract a UI tree node from an AST node."""
        if depth > 20:
            return None

        # --- call_expression: could be a composable call ---
        if node.type == "call_expression":
            return self._parse_call_expression(node, depth)

        # --- if_expression: conditional UI ---
        if node.type == "if_expression":
            return self._parse_if_expression(node, depth)

        # --- when_expression: multi-branch UI ---
        if node.type == "when_expression":
            return self._parse_when_expression(node, depth)

        # --- property_declaration: might assign state ---
        if node.type == "property_declaration":
            # Skip — state declarations are handled by Phase 3a
            return None

        # --- for/forEach loops ---
        if node.type == "for_statement":
            return self._parse_for_statement(node, depth)

        return None

    def _parse_call_expression(self, node, depth) -> dict | None:
        """Parse a composable call_expression into a tree node.

        tree-sitter represents `Foo(params) { content }` as:
            call_expression               ← OUTER
              call_expression             ← INNER: has identifier + value_arguments
                identifier = 'Foo'
                value_arguments(...)
              annotated_lambda            ← trailing lambda on OUTER

        Simple calls without trailing lambda:
            call_expression
              identifier = 'Foo'
              value_arguments(...)
        """
        inner_call = find_child(node, "call_expression")
        trailing_lambda = find_child(node, "annotated_lambda")

        if inner_call and trailing_lambda:
            # Chained pattern: params from inner, content from trailing lambda
            name = get_simple_call_name(inner_call)
            args_node = find_child(inner_call, "value_arguments")
        else:
            name = get_simple_call_name(node)
            args_node = find_child(node, "value_arguments")

        if not name:
            return None

        # Skip non-UI utility calls but handle list scope functions
        # For transparent containers (CompositionLocalProvider, etc.), extract children
        TRANSPARENT_CONTAINERS = {"CompositionLocalProvider", "ProvideTextStyle", "MaterialTheme"}
        if name in TRANSPARENT_CONTAINERS:
            if trailing_lambda:
                lit = find_child(trailing_lambda, "lambda_literal")
                if lit:
                    return self._parse_lambda_body(lit, depth + 1)  # returns list
            return None

        if is_non_ui_call(name) or (name[0].islower() and name not in ("items", "itemsIndexed", "item", "stickyHeader")):
            if trailing_lambda and name in ("items", "itemsIndexed", "item", "stickyHeader"):
                lit = find_child(trailing_lambda, "lambda_literal")
                if lit:
                    children = self._parse_lambda_body(lit, depth + 1)
                    if children:
                        return {
                            "type": f"LazyScope.{name}",
                            "params": _extract_value_args(args_node) if args_node else {},
                            "children": children,
                        }
            return None

        if not is_composable_name(name):
            return None

        # Build the node
        comp_info = get_component_info(name)
        result = {"type": name}

        if comp_info:
            result["salt_semantic"] = comp_info.get("type")
            result["library"] = comp_info.get("library")
            result["arkts_hint"] = comp_info.get("arkts_hint")
            if comp_info.get("clickable"):
                result["clickable"] = True
        else:
            result["custom"] = True

        # Extract params
        if args_node:
            params = _extract_value_args(args_node)
            if params:
                result["params"] = params
                if self.resolver:
                    resolved = {}
                    for k, v in params.items():
                        if isinstance(v, str) and "R.string." in v:
                            r = self.resolver.resolve(v)
                            if r:
                                resolved[k] = r
                    if resolved:
                        result["resolved_strings"] = resolved
                        for k, v in resolved.items():
                            result["params"][k] = v

        # Structured modifier extraction
        modifier_text = result.get("params", {}).get("modifier", "")
        if modifier_text and "Modifier" in modifier_text and args_node:
            for va in find_children(args_node, "value_argument"):
                va_name = find_child(va, "identifier")
                if va_name and node_text(va_name) == "modifier":
                    for c in va.children:
                        if c.type in ("call_expression", "navigation_expression"):
                            mods = parse_modifier_chain(c)
                            if mods:
                                result["modifiers"] = mods
                            break

        # Extract actions (onClick etc.) from args
        actions = self._extract_actions_from(args_node)
        if actions:
            result["actions"] = actions

        # Extract composable content params (e.g., titleContent = { ... })
        if args_node:
            for va in find_children(args_node, "value_argument"):
                va_name = find_child(va, "identifier")
                if not va_name:
                    continue
                pname = node_text(va_name)
                if pname in ("onClick", "onLongClick", "onChange", "onValueChange",
                             "onConfirm", "onDismissRequest", "onYes", "onNo", "onRefresh",
                             "modifier"):
                    continue
                lit = find_child(va, "lambda_literal")
                if lit:
                    children = self._parse_lambda_body(lit, depth + 1)
                    if children:
                        if "content_params" not in result:
                            result["content_params"] = {}
                        result["content_params"][pname] = children

        # Parse trailing lambda (main content children)
        if trailing_lambda:
            lit = find_child(trailing_lambda, "lambda_literal")
            if lit:
                children = self._parse_lambda_body(lit, depth + 1)
                if children:
                    result["children"] = children

        return result

    def _extract_actions_from(self, args_node) -> list[dict]:
        """Extract event handler actions from value_arguments."""
        if not args_node:
            return []
        actions = []
        event_names = {"onClick", "onLongClick", "onChange", "onValueChange",
                       "onConfirm", "onDismissRequest", "onYes", "onNo", "onRefresh"}

        for va in find_children(args_node, "value_argument"):
            name_id = find_child(va, "identifier")
            if not name_id:
                continue
            pname = node_text(name_id)
            if pname not in event_names:
                continue

            action = {"event": pname}
            handler_text = node_text(va)

            nav_targets = _find_navigate_calls(va)
            if nav_targets:
                action["navigates_to"] = nav_targets[0]

            activity_match = re.search(r'(\w+Activity)::class', handler_text)
            if activity_match:
                action["starts_activity"] = activity_match.group(1)

            m = re.search(r'=\s*\{(.*)\}', handler_text, re.DOTALL)
            if m:
                body = m.group(1).strip()
                if len(body) < 150:
                    action["handler_body"] = body

            actions.append(action)

        return actions

    def _parse_lambda_body(self, lambda_lit, depth) -> list[dict]:
        """Parse a lambda_literal's body for composable children."""
        nodes = []
        for child in lambda_lit.children:
            if child.type in ("{", "}", "lambda_parameters", "->"):
                continue
            extracted = self._extract_ui_node(child, depth)
            if extracted:
                if isinstance(extracted, list):
                    nodes.extend(extracted)
                else:
                    nodes.append(extracted)
        return nodes

    def _parse_if_expression(self, node, depth) -> dict | None:
        """Parse if/else into a ConditionalBlock node."""
        # Extract condition
        condition = ""
        blocks = []
        for child in node.children:
            if child.type in ("(", ")"):
                continue
            if child.type == "if":
                continue
            if child.type == "else":
                continue
            # The condition is between ( and )
            if child.type not in ("block", "call_expression", "if_expression",
                                   "when_expression") and child.type != "{":
                ctext = node_text(child).strip("()")
                if ctext and not condition:
                    condition = ctext

        # Collect then and else blocks
        then_children = []
        else_children = []
        found_else = False
        for child in node.children:
            if child.type == "else":
                found_else = True
                continue
            if child.type == "block":
                parsed = self._parse_block(child, depth + 1)
                if not found_else:
                    then_children = parsed
                else:
                    else_children = parsed
            elif child.type == "if_expression" and found_else:
                # else if
                inner = self._parse_if_expression(child, depth + 1)
                if inner:
                    else_children = [inner]
            elif child.type == "call_expression":
                # Single-line if: if (cond) SomeComposable()
                parsed = self._parse_call_expression(child, depth + 1)
                if parsed:
                    if not found_else:
                        then_children = [parsed]
                    else:
                        else_children = [parsed]

        if not then_children and not else_children:
            return None

        result = {
            "type": "ConditionalBlock",
            "condition": condition,
            "then_children": then_children,
        }
        if else_children:
            result["else_children"] = else_children
        return result

    def _parse_when_expression(self, node, depth) -> dict | None:
        """Parse when expression into a WhenBlock node."""
        # Extract subject
        subject = ""
        subject_node = find_child(node, "when_subject")
        if subject_node:
            subject = node_text(subject_node).strip("()")

        branches = []
        for entry in find_children(node, "when_entry"):
            condition = ""
            children = []
            for child in entry.children:
                if child.type == "->":
                    continue
                if child.type in ("type_test", "range_test", "expression"):
                    condition = node_text(child)
                elif child.type == "block":
                    children = self._parse_block(child, depth + 1)
                elif child.type == "call_expression":
                    parsed = self._parse_call_expression(child, depth + 1)
                    if parsed:
                        children = [parsed]
                elif child.type == "if_expression":
                    parsed = self._parse_if_expression(child, depth + 1)
                    if parsed:
                        children = [parsed]
                elif child.type == "when_expression":
                    parsed = self._parse_when_expression(child, depth + 1)
                    if parsed:
                        children = [parsed]
                elif child.type == "when_condition":
                    condition = node_text(child)
            if condition or children:
                branches.append({"condition": condition, "children": children})

        if not branches:
            return None

        return {
            "type": "WhenBlock",
            "expression": subject,
            "branches": branches,
        }

    def _parse_for_statement(self, node, depth) -> dict | None:
        """Parse for loop (rare in Compose but used in some screens)."""
        children = []
        block = find_child(node, "block")
        if block:
            children = self._parse_block(block, depth + 1)
        if not children:
            return None
        return {
            "type": "ForLoop",
            "expression": node_text(node).split("{")[0].strip(),
            "children": children,
        }



# ---------------------------------------------------------------------------
# Custom composable expander (recursive via tree-sitter)
# ---------------------------------------------------------------------------

class RecursiveExpander:
    """Recursively expand custom composables by parsing their source definitions."""

    def __init__(self, project_root: str, resolver: StringResourceResolver | None):
        self.project_root = project_root
        self.resolver = resolver
        self._source_cache: dict[str, bytes] = {}
        self._tree_cache: dict[str, ComposeTreeBuilder] = {}
        self._kt_index: dict[str, str] | None = None
        self._expanding: set[str] = set()  # cycle guard

    def _build_index(self):
        if self._kt_index is not None:
            return
        self._kt_index = {}
        for root, _, files in os.walk(self.project_root):
            for f in files:
                if f.endswith(".kt"):
                    self._kt_index[f[:-3]] = os.path.join(root, f)

    def _get_builder(self, path: str) -> ComposeTreeBuilder:
        if path not in self._tree_cache:
            if path not in self._source_cache:
                with open(path, "rb") as f:
                    self._source_cache[path] = f.read()
            self._tree_cache[path] = ComposeTreeBuilder(self._source_cache[path], self.resolver)
        return self._tree_cache[path]

    def _find_composable_def(self, name: str) -> tuple[ComposeTreeBuilder, any] | None:
        """Find a @Composable fun definition by name."""
        self._build_index()

        # Search in files — prioritize likely locations
        candidates = []

        # 1. File named after the composable (e.g., HomeScreenContent.kt)
        if name in self._kt_index:
            candidates.append(self._kt_index[name])

        # 2. File named after the base screen (e.g., HomeScreenContent → HomeScreen.kt)
        # Strip common suffixes to find parent file
        for suffix in ("Content", "Body", "View", "UI", "Impl", "Internal"):
            base = name.replace(suffix, "")
            if base != name and base in self._kt_index:
                candidates.append(self._kt_index[base])

        # 3. Search ALL Kotlin source files (not just /ui/)
        # This is important for projects where composables are defined in the same file
        # as the screen that calls them
        for fname, fpath in self._kt_index.items():
            if fpath not in candidates:
                candidates.append(fpath)

        for fpath in candidates:
            builder = self._get_builder(fpath)
            for fn_name, fn_node in builder.find_composable_functions():
                if fn_name == name:
                    return builder, fn_node

        return None

    def expand_tree(self, nodes: list[dict], max_depth: int = 3, current_depth: int = 0):
        """Recursively expand custom composable nodes."""
        if current_depth >= max_depth:
            return
        for node in nodes:
            if not isinstance(node, dict):
                continue

            # Expand custom composables
            if node.get("custom") and "children" not in node and "expanded_children" not in node:
                comp_name = node.get("type", "")
                if comp_name and comp_name not in self._expanding:
                    self._expanding.add(comp_name)
                    result = self._find_composable_def(comp_name)
                    if result:
                        builder, fn_node = result
                        children = builder.parse_function_body(fn_node)
                        if children:
                            node["expanded_children"] = children
                            node.pop("custom", None)
                            node["expanded_from"] = comp_name
                            # Recurse into expanded children
                            self.expand_tree(children, max_depth, current_depth + 1)
                    self._expanding.discard(comp_name)

            # Recurse into existing children (increment depth!)
            for key in ("children", "expanded_children", "then_children", "else_children"):
                if key in node:
                    self.expand_tree(node[key], max_depth, current_depth + 1)
            for branch in node.get("branches", []):
                self.expand_tree(branch.get("children", []), max_depth, current_depth)


# ---------------------------------------------------------------------------
# Main pipeline
# ---------------------------------------------------------------------------

def process_screen(project_root: str, screen_name: str, source_file: str,
                   resolver: StringResourceResolver | None = None) -> dict:
    """Process a single screen composable into a tree JSON."""
    full_path = os.path.join(project_root, source_file)
    if not os.path.exists(full_path):
        return {"error": f"Source file not found: {source_file}"}

    with open(full_path, "rb") as f:
        source = f.read()

    builder = ComposeTreeBuilder(source, resolver)
    composables = builder.find_composable_functions()

    # Find the target function
    target_node = None
    for fn_name, fn_node in composables:
        if fn_name == screen_name:
            target_node = fn_node
            break

    if not target_node:
        # Try all composable functions in the file
        all_trees = {}
        for fn_name, fn_node in composables:
            tree = builder.parse_function_body(fn_node)
            if tree:
                all_trees[fn_name] = tree
        if all_trees:
            return {
                "composable_name": screen_name,
                "source_file": source_file,
                "note": f"Target function '{screen_name}' not found, extracted {len(all_trees)} composables",
                "all_composables": {k: v for k, v in all_trees.items()},
                "compose_tree": list(all_trees.values())[0] if len(all_trees) == 1 else [],
            }
        return {
            "composable_name": screen_name,
            "source_file": source_file,
            "compose_tree": [],
            "note": "No composable functions found",
        }

    params = builder.extract_function_params(target_node)
    tree = builder.parse_function_body(target_node)

    return {
        "composable_name": screen_name,
        "source_file": source_file,
        "function_parameters": params,
        "compose_tree": tree if tree else [],
    }


def main():
    if len(sys.argv) < 4:
        print("Usage: python3 parse_compose_tree_v2.py <project_root> <discovery_json> <output_dir>")
        sys.exit(1)

    project_root = os.path.abspath(sys.argv[1])
    discovery_path = os.path.abspath(sys.argv[2])
    output_dir = os.path.abspath(sys.argv[3])

    trees_dir = os.path.join(output_dir, "compose_trees")
    os.makedirs(trees_dir, exist_ok=True)

    with open(discovery_path, "r", encoding="utf-8") as f:
        discovery = json.load(f)

    print("[Phase 2 v2] Parsing Compose UI trees with tree-sitter AST")

    # Initialize resolver
    print("  Loading string resources...")
    resolver = StringResourceResolver()
    resolver.load(project_root)
    print(f"  Loaded {len(resolver.strings)} string resources")

    # Pre-scan project widgets so they get recognized in the tree
    from salt_ui_mapping import scan_project_widgets
    project_widgets = scan_project_widgets(project_root)
    if project_widgets:
        print(f"  Loaded {len(project_widgets)} project-specific widget definitions")

    processed = 0
    errors = 0
    total_nodes = 0

    def count_nodes(nodes):
        c = 0
        for n in nodes:
            if isinstance(n, dict):
                c += 1
                for key in ("children", "expanded_children", "then_children", "else_children"):
                    c += count_nodes(n.get(key, []))
                for branch in n.get("branches", []):
                    c += count_nodes(branch.get("children", []))
        return c

    # Process composable registrations
    for reg in discovery.get("compose_registrations", []):
        screen_name = reg.get("screen_composable")
        source_file = reg.get("source_file")
        if not screen_name or not source_file:
            continue

        result = process_screen(project_root, screen_name, source_file, resolver)
        nc = count_nodes(result.get("compose_tree", []))
        total_nodes += nc

        if "error" in result:
            print(f"  ERROR {screen_name}: {result['error']}")
            errors += 1
        else:
            print(f"  {screen_name}: {nc} nodes")
            processed += 1

        result["route_constant"] = reg.get("route_constant")
        result["route_pattern"] = reg.get("full_route_pattern")
        result["arguments"] = reg.get("arguments", [])

        with open(os.path.join(trees_dir, f"{screen_name}.json"), "w", encoding="utf-8") as f:
            json.dump(result, f, indent=2, ensure_ascii=False)

    # Process shell components
    for comp in discovery.get("shell_components", []):
        name = comp.get("name")
        source_file = comp.get("source_file")
        if not name or not source_file:
            continue

        result = process_screen(project_root, name, source_file, resolver)
        nc = count_nodes(result.get("compose_tree", []))
        total_nodes += nc
        result["shell_role"] = comp.get("role")
        print(f"  Shell_{name}: {nc} nodes")
        processed += 1

        with open(os.path.join(trees_dir, f"Shell_{name}.json"), "w", encoding="utf-8") as f:
            json.dump(result, f, indent=2, ensure_ascii=False)

    # Process Activities
    for activity in discovery.get("activities", []):
        if activity.get("ui_type") != "compose":
            continue
        source_file = activity.get("source_file")
        simple_name = activity.get("simple_name")
        if not source_file:
            continue

        full_path = os.path.join(project_root, source_file)
        if not os.path.exists(full_path):
            continue

        with open(full_path, "rb") as f:
            source = f.read()

        builder = ComposeTreeBuilder(source, resolver)
        for fn_name, fn_node in builder.find_composable_functions():
            tree = builder.parse_function_body(fn_node)
            nc = count_nodes(tree)
            total_nodes += nc
            if tree:
                out = {
                    "composable_name": fn_name,
                    "source_file": source_file,
                    "activity": activity.get("class"),
                    "compose_tree": tree,
                }
                with open(os.path.join(trees_dir, f"Activity_{simple_name}_{fn_name}.json"), "w", encoding="utf-8") as f:
                    json.dump(out, f, indent=2, ensure_ascii=False)
                processed += 1
                print(f"  Activity_{simple_name}_{fn_name}: {nc} nodes")

    # Recursive expansion
    print("\n  Expanding custom composables...")
    expander = RecursiveExpander(project_root, resolver)
    expanded_count = 0

    def _count_expanded(nodes) -> int:
        """Count nodes that have been expanded (have expanded_from field)."""
        c = 0
        for n in nodes:
            if isinstance(n, dict):
                if n.get("expanded_from"):
                    c += 1
                for key in ("children", "expanded_children", "then_children", "else_children"):
                    c += _count_expanded(n.get(key, []))
                for branch in n.get("branches", []):
                    c += _count_expanded(branch.get("children", []))
        return c

    for fname in os.listdir(trees_dir):
        if not fname.endswith(".json"):
            continue
        fpath = os.path.join(trees_dir, fname)
        with open(fpath, "r", encoding="utf-8") as f:
            data = json.load(f)
        tree = data.get("compose_tree", [])
        before_expanded = _count_expanded(tree)
        expander.expand_tree(tree, max_depth=5)
        after_expanded = _count_expanded(tree)
        newly_expanded = after_expanded - before_expanded
        if newly_expanded > 0:
            expanded_count += newly_expanded
            data["compose_tree"] = tree
            with open(fpath, "w", encoding="utf-8") as f:
                json.dump(data, f, indent=2, ensure_ascii=False)

    print(f"  Expanded {expanded_count} custom composables")

    print(f"\n[Phase 2 v2 Complete]")
    print(f"  Processed: {processed}")
    print(f"  Errors: {errors}")
    print(f"  Total nodes: {total_nodes}")
    print(f"  Custom expanded: {expanded_count}")
    print(f"  Output: {trees_dir}/")


def _count_custom(nodes) -> int:
    c = 0
    for n in nodes:
        if isinstance(n, dict):
            if n.get("custom"):
                c += 1
            for key in ("children", "expanded_children", "then_children", "else_children"):
                c += _count_custom(n.get(key, []))
            for branch in n.get("branches", []):
                c += _count_custom(branch.get("children", []))
    return c


if __name__ == "__main__":
    main()