#!/usr/bin/env python3
"""
Phase 2: Parse Compose UI Trees

Static analysis of @Composable function bodies into hierarchical JSON component trees.
Tracks brace-depth nesting, detects composable calls, extracts parameters, handles
conditional blocks, and maps Salt UI components.

Usage:
    python3 parse_compose_tree.py <project_root> <discovery_json> <output_dir>

Output:
    <output_dir>/compose_trees/<ScreenName>.json (one per screen composable)
"""

import json
import os
import re
import sys
import xml.etree.ElementTree as ET
from pathlib import Path

# Add parent dir so we can import salt_ui_mapping
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from salt_ui_mapping import get_component_info, is_non_ui_call, COMPOSE_NON_UI


class StringResourceResolver:
    """Resolves R.string.xxx references to actual string values from res/values/strings.xml."""

    def __init__(self):
        self.strings = {}

    def load(self, project_root: str):
        """Load all string resources from res/values/ directories."""
        res_base = os.path.join(project_root, "app", "src", "main", "res")
        if not os.path.isdir(res_base):
            return

        # Load from all values directories (values, values-zh, etc.)
        # Prefer default values/ for base strings
        for values_dir_name in sorted(os.listdir(res_base)):
            if not values_dir_name.startswith("values"):
                continue
            values_dir = os.path.join(res_base, values_dir_name)
            if not os.path.isdir(values_dir):
                continue
            for fname in os.listdir(values_dir):
                if fname.endswith(".xml"):
                    self._parse_xml(os.path.join(values_dir, fname),
                                    is_default=(values_dir_name == "values"))

    def _parse_xml(self, path: str, is_default: bool = False):
        try:
            tree = ET.parse(path)
        except ET.ParseError:
            return
        root = tree.getroot()
        for elem in root:
            if elem.tag == "string":
                name = elem.get("name", "")
                text = (elem.text or "").strip()
                if name and (is_default or name not in self.strings):
                    self.strings[name] = text

    def resolve(self, param_value: str) -> str:
        """Resolve stringResource(R.string.xxx) or stringResource(id = R.string.xxx) to actual text."""
        # Pattern: stringResource(R.string.xxx) or stringResource(id = R.string.xxx)
        m = re.search(r'stringResource\s*\(\s*(?:id\s*=\s*)?R\.string\.(\w+)', param_value)
        if m:
            key = m.group(1)
            if key in self.strings:
                return self.strings[key]
        return param_value

    def resolve_tree(self, nodes: list):
        """Recursively resolve string resources in tree node params."""
        if not nodes:
            return
        for node in nodes:
            if not isinstance(node, dict):
                continue
            # Resolve params
            if "params" in node:
                for key, val in node["params"].items():
                    if isinstance(val, str) and "R.string." in val:
                        resolved = self.resolve(val)
                        if resolved != val:
                            node["params"][key] = resolved
                            if "resolved_strings" not in node:
                                node["resolved_strings"] = {}
                            node["resolved_strings"][key] = resolved
            # Recurse
            self.resolve_tree(node.get("children", []))
            self.resolve_tree(node.get("then_children", []))
            self.resolve_tree(node.get("else_children", []))
            for branch in node.get("branches", []):
                self.resolve_tree(branch.get("children", []))


class CustomComposableExpander:
    """One-level expansion of custom composable definitions."""

    def __init__(self, project_root: str):
        self.project_root = project_root
        self._source_cache: dict[str, str] = {}
        self._kt_index: dict[str, str] | None = None

    def _build_kt_index(self):
        """Build index of all .kt files for name lookup."""
        if self._kt_index is not None:
            return
        self._kt_index = {}
        for root, _, files in os.walk(self.project_root):
            for f in files:
                if f.endswith(".kt"):
                    self._kt_index[f[:-3]] = os.path.join(root, f)

    def _read_source(self, path: str) -> str:
        if path not in self._source_cache:
            with open(path, "r", encoding="utf-8", errors="ignore") as f:
                self._source_cache[path] = f.read()
        return self._source_cache[path]

    def _find_composable_source(self, name: str, context_source: str = "") -> str | None:
        """Find source file containing @Composable fun <name>(...)."""
        # First check in the context source (same file)
        if context_source and re.search(rf'\bfun\s+{re.escape(name)}\s*\(', context_source):
            return context_source

        # Search in kt index
        self._build_kt_index()

        # Direct file name match
        if name in self._kt_index:
            return self._read_source(self._kt_index[name])

        # Search in all files (expensive, limit to ui/ directories)
        for fname, fpath in self._kt_index.items():
            if "/ui/" not in fpath and "\\ui\\" not in fpath:
                continue
            src = self._read_source(fpath)
            if re.search(rf'@Composable\s[^{{]*\bfun\s+{re.escape(name)}\s*\(', src, re.DOTALL):
                return src

        return None

    def expand_custom_nodes(self, nodes: list, context_source: str = "", depth: int = 0):
        """Expand custom composable nodes one level deep."""
        if not nodes or depth > 1:
            return
        for node in nodes:
            if not isinstance(node, dict):
                continue

            if node.get("custom") and node.get("note") == "needs_resolution":
                comp_name = node.get("type", "")
                src = self._find_composable_source(comp_name, context_source)
                if src:
                    # Parse the composable function to get its direct children
                    parser = ComposeTreeParser(src, comp_name)
                    children = parser.parse_tree()
                    if children:
                        node["expanded_children"] = children
                        node["note"] = "auto_expanded_one_level"

            # Recurse into existing children (but don't expand their custom nodes further)
            self.expand_custom_nodes(node.get("children", []), context_source, depth + 1)
            self.expand_custom_nodes(node.get("then_children", []), context_source, depth + 1)
            self.expand_custom_nodes(node.get("else_children", []), context_source, depth + 1)
            for branch in node.get("branches", []):
                self.expand_custom_nodes(branch.get("children", []), context_source, depth + 1)


class ComposeTreeParser:
    """Parses a @Composable function body into a hierarchical component tree."""

    def __init__(self, source: str, function_name: str):
        self.source = source
        self.function_name = function_name
        self.lines = source.split("\n")

    def find_function_body(self) -> tuple[int, int] | None:
        """Find the start and end line indices of the @Composable function body."""
        # Find function declaration
        func_pattern = re.compile(rf'\bfun\s+{re.escape(self.function_name)}\s*\(')

        start_line = None
        for i, line in enumerate(self.lines):
            if func_pattern.search(line):
                start_line = i
                break

        if start_line is None:
            return None

        # Find the opening brace of the function body
        brace_depth = 0
        body_start = None
        for i in range(start_line, len(self.lines)):
            for ch in self.lines[i]:
                if ch == "{":
                    brace_depth += 1
                    if body_start is None:
                        body_start = i
                elif ch == "}":
                    brace_depth -= 1
                    if brace_depth == 0 and body_start is not None:
                        return (body_start, i)

        return None

    def extract_function_params(self) -> list[dict]:
        """Extract the function's parameters from its signature."""
        func_pattern = re.compile(rf'\bfun\s+{re.escape(self.function_name)}\s*\(')

        start_line = None
        for i, line in enumerate(self.lines):
            if func_pattern.search(line):
                start_line = i
                break

        if start_line is None:
            return []

        # Collect the full signature (may span multiple lines)
        sig_lines = []
        paren_depth = 0
        for i in range(start_line, min(start_line + 50, len(self.lines))):
            line = self.lines[i]
            sig_lines.append(line)
            paren_depth += line.count("(") - line.count(")")
            if paren_depth <= 0 and ")" in line:
                break

        signature = " ".join(sig_lines)
        # Extract parameter section
        m = re.search(r'\(\s*(.*)\s*\)', signature, re.DOTALL)
        if not m:
            return []

        param_text = m.group(1)
        params = []
        for param in self._split_params(param_text):
            param = param.strip()
            if not param:
                continue
            # Parse: name: Type = default
            pm = re.match(r'(\w+)\s*:\s*([^=]+?)(?:\s*=\s*(.+))?$', param.strip())
            if pm:
                params.append({
                    "name": pm.group(1).strip(),
                    "type": pm.group(2).strip(),
                    "default": pm.group(3).strip() if pm.group(3) else None,
                })

        return params

    def _split_params(self, text: str) -> list[str]:
        """Split parameter text by commas, respecting nested parens/angles/braces."""
        parts = []
        depth = 0
        current = []
        for ch in text:
            if ch in "(<{":
                depth += 1
                current.append(ch)
            elif ch in ")>}":
                depth -= 1
                current.append(ch)
            elif ch == "," and depth == 0:
                parts.append("".join(current))
                current = []
            else:
                current.append(ch)
        if current:
            parts.append("".join(current))
        return parts

    def parse_tree(self) -> dict | None:
        """Parse the function body into a hierarchical component tree."""
        bounds = self.find_function_body()
        if not bounds:
            return None

        body_start, body_end = bounds
        body_lines = self.lines[body_start:body_end + 1]
        body_text = "\n".join(body_lines)

        # Remove the outer function braces
        # Find first { and last }
        first_brace = body_text.index("{")
        last_brace = body_text.rindex("}")
        inner_text = body_text[first_brace + 1:last_brace]

        children = self._parse_block(inner_text, body_start + 1)
        return children

    def _parse_block(self, text: str, base_line: int = 0, _depth: int = 0) -> list[dict]:
        """Parse a block of Compose code into a list of component nodes."""
        if _depth > 20 or not text.strip():
            return []

        nodes = []
        lines = text.split("\n")
        i = 0

        while i < len(lines):
            line = lines[i].strip()

            # Skip empty lines and comments
            if not line or line.startswith("//") or line.startswith("*") or line.startswith("/*"):
                i += 1
                continue

            # Skip variable declarations that aren't UI
            if re.match(r'(val|var|fun )\s', line) and not re.match(r'(val|var)\s+\w+\s*=\s*[A-Z]', line):
                # Skip multi-line val/var blocks
                brace_d = line.count("{") - line.count("}")
                paren_d = line.count("(") - line.count(")")
                while (brace_d > 0 or paren_d > 0) and i + 1 < len(lines):
                    i += 1
                    line = lines[i].strip()
                    brace_d += line.count("{") - line.count("}")
                    paren_d += line.count("(") - line.count(")")
                i += 1
                continue

            # Check for conditional blocks: if (...) {
            if_match = re.match(r'if\s*\((.+)\)\s*\{?\s*$', line)
            if not if_match:
                # Multi-line if condition
                if re.match(r'if\s*\(', line):
                    collected = line
                    paren_d = line.count("(") - line.count(")")
                    while paren_d > 0 and i + 1 < len(lines):
                        i += 1
                        collected += " " + lines[i].strip()
                        paren_d += lines[i].count("(") - lines[i].count(")")
                    if_match = re.match(r'if\s*\((.+)\)\s*\{?\s*$', collected)

            if if_match:
                condition = if_match.group(1).strip()
                # Find the if-block
                if_block, end_i = self._extract_brace_block(lines, i)
                i = end_i + 1

                node = {
                    "type": "ConditionalBlock",
                    "condition": condition,
                    "then_children": self._parse_block(if_block, base_line + i, _depth + 1),
                }

                # Check for else
                if i < len(lines) and lines[i].strip().startswith("else"):
                    else_line = lines[i].strip()
                    if "else if" in else_line or "else if" in else_line:
                        # else if — treat as another conditional in else_children
                        pass
                    else_block, end_i = self._extract_brace_block(lines, i)
                    i = end_i + 1
                    node["else_children"] = self._parse_block(else_block, base_line + i, _depth + 1)

                # Only add if it has UI children
                if node.get("then_children") or node.get("else_children"):
                    nodes.append(node)
                continue

            # Check for when block: when (expr) {
            when_match = re.match(r'when\s*\((.+)\)\s*\{', line)
            if when_match:
                expr = when_match.group(1).strip()
                when_block, end_i = self._extract_brace_block(lines, i)
                i = end_i + 1
                branches = self._parse_when_branches(when_block, _depth)
                if branches:
                    nodes.append({
                        "type": "WhenBlock",
                        "expression": expr,
                        "branches": branches,
                    })
                continue

            # Check for composable call: CapitalizedName( or CapitalizedName {
            comp_match = re.match(r'([A-Z][a-zA-Z0-9]*)\s*(\(|\{)', line)
            if comp_match:
                comp_name = comp_match.group(1)

                # Skip non-UI calls
                if is_non_ui_call(comp_name):
                    brace_d = line.count("{") - line.count("}")
                    paren_d = line.count("(") - line.count(")")
                    while (brace_d > 0 or paren_d > 0) and i + 1 < len(lines):
                        i += 1
                        brace_d += lines[i].count("{") - lines[i].count("}")
                        paren_d += lines[i].count("(") - lines[i].count(")")
                    i += 1
                    continue

                # Extract the full call including params and trailing lambda
                call_text, end_i = self._extract_composable_call(lines, i)
                i = end_i + 1

                node = self._parse_composable_call(comp_name, call_text, _depth)
                if node:
                    nodes.append(node)
                continue

            # Check for function calls that might be composable: someFunction() where it returns composable
            # or method chains like items(...) { ... }
            items_match = re.match(r'(items|itemsIndexed|item|stickyHeader)\s*\(', line)
            if items_match:
                call_name = items_match.group(1)
                call_text, end_i = self._extract_composable_call(lines, i)
                i = end_i + 1
                # Parse children in the trailing lambda
                trailing_lambda = self._extract_trailing_lambda(call_text)
                children = []
                if trailing_lambda:
                    children = self._parse_block(trailing_lambda, base_line + i, _depth + 1)

                if children:
                    nodes.append({
                        "type": f"LazyListScope.{call_name}",
                        "params": self._extract_params_text(call_text),
                        "children": children,
                    })
                continue

            # Skip other lines
            i += 1

        return nodes

    def _extract_brace_block(self, lines: list[str], start_i: int) -> tuple[str, int]:
        """Extract a {...} block starting from line start_i. Returns (block_content, end_line_index)."""
        brace_depth = 0
        collecting = False
        block_lines = []
        i = start_i
        found_opening = False

        while i < len(lines):
            line = lines[i]
            line_content_parts = []
            for j, ch in enumerate(line):
                if ch == "{":
                    brace_depth += 1
                    if not found_opening:
                        found_opening = True
                        collecting = True
                        continue  # Don't include the opening brace char
                elif ch == "}":
                    brace_depth -= 1
                    if brace_depth == 0 and collecting:
                        # Add content before closing brace on this line
                        remaining = "".join(line_content_parts)
                        if remaining.strip():
                            block_lines.append(remaining)
                        return ("\n".join(block_lines), i)

                if collecting:
                    line_content_parts.append(ch)

            if collecting and found_opening:
                block_lines.append("".join(line_content_parts))
            i += 1

        # If no brace found, skip the line and return empty block
        if not found_opening:
            return ("", start_i)

        return ("\n".join(block_lines), max(i - 1, start_i))

    def _extract_composable_call(self, lines: list[str], start_i: int) -> tuple[str, int]:
        """Extract a full composable call including parens and trailing lambda."""
        call_lines = []
        brace_depth = 0
        paren_depth = 0
        i = start_i

        while i < len(lines):
            line = lines[i]
            call_lines.append(line)

            for j, ch in enumerate(line):
                if ch == "(":
                    paren_depth += 1
                elif ch == ")":
                    paren_depth -= 1
                elif ch == "{":
                    brace_depth += 1
                elif ch == "}":
                    brace_depth -= 1

            # Done when all parens and braces are closed
            if brace_depth <= 0 and paren_depth <= 0 and i > start_i:
                return ("\n".join(call_lines), i)
            if brace_depth == 0 and paren_depth == 0:
                return ("\n".join(call_lines), i)

            i += 1

        return ("\n".join(call_lines), i - 1)

    def _parse_composable_call(self, name: str, call_text: str, _depth: int = 0) -> dict | None:
        """Parse a composable call into a node with params and children."""
        comp_info = get_component_info(name)

        node = {
            "type": name,
        }

        if comp_info:
            node["salt_semantic"] = comp_info.get("type")
            node["library"] = comp_info.get("library")
            node["arkts_hint"] = comp_info.get("arkts_hint")
            if comp_info.get("clickable"):
                node["clickable"] = True
        else:
            node["custom"] = True
            node["note"] = "needs_resolution"

        # Extract parameters
        params = self._extract_params_text(call_text)
        if params:
            node["params"] = params

        # Extract actions (onClick, onValueChange, etc.)
        actions = self._extract_actions(call_text)
        if actions:
            node["actions"] = actions

        # Parse trailing lambda for children
        trailing_lambda = self._extract_trailing_lambda(call_text)
        if trailing_lambda:
            children = self._parse_block(trailing_lambda, 0, _depth + 1)
            if children:
                node["children"] = children

        return node

    def _extract_params_text(self, call_text: str) -> dict:
        """Extract named parameters from a composable call as raw strings."""
        params = {}

        # Find the parenthesized params section
        paren_start = call_text.find("(")
        if paren_start == -1:
            return params

        # Find matching close paren
        depth = 0
        paren_end = -1
        for i in range(paren_start, len(call_text)):
            if call_text[i] == "(":
                depth += 1
            elif call_text[i] == ")":
                depth -= 1
                if depth == 0:
                    paren_end = i
                    break

        if paren_end == -1:
            return params

        param_text = call_text[paren_start + 1:paren_end]

        # Split by commas at depth 0
        parts = self._split_at_depth_zero(param_text, ",")

        for part in parts:
            part = part.strip()
            if not part:
                continue
            # Named parameter: name = value
            eq_match = re.match(r'(\w+)\s*=\s*(.+)', part, re.DOTALL)
            if eq_match:
                pname = eq_match.group(1).strip()
                pvalue = eq_match.group(2).strip()
                # Skip lambda params (they'll be parsed as children)
                if pvalue.startswith("{"):
                    continue
                # Clean up the value
                pvalue = " ".join(pvalue.split())  # Normalize whitespace
                params[pname] = pvalue
            else:
                # Positional parameter
                part_clean = " ".join(part.split())
                if part_clean and not part_clean.startswith("{"):
                    # Use index as key for positional params
                    params[f"_pos_{len([k for k in params if k.startswith('_pos_')])}"] = part_clean

        return params

    def _extract_actions(self, call_text: str) -> list[dict]:
        """Extract event handler actions from a composable call."""
        actions = []

        # Find onClick, onValueChange, onChange, etc.
        event_patterns = [
            (r'onClick\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onClick"),
            (r'onLongClick\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onLongClick"),
            (r'onChange\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onChange"),
            (r'onValueChange\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onValueChange"),
            (r'onConfirm\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onConfirm"),
            (r'onDismissRequest\s*=\s*\{([^}]*(?:\{[^}]*\}[^}]*)*)\}', "onDismissRequest"),
        ]

        for pattern, event_name in event_patterns:
            for m in re.finditer(pattern, call_text, re.DOTALL):
                handler_body = m.group(1).strip()
                action = {"event": event_name}

                # Check for navigation
                nav_match = re.search(r'navController\.\w*[Nn]avigate\s*\(\s*(?:ScreenRoute\.(\w+)|"([^"]*)")', handler_body)
                if nav_match:
                    target = nav_match.group(1) or nav_match.group(2)
                    action["navigates_to"] = target

                # Check for startActivity
                activity_match = re.search(r'startActivity\s*\(.*?(\w+)::class', handler_body)
                if activity_match:
                    action["starts_activity"] = activity_match.group(1)

                # Store raw handler body (truncated)
                if len(handler_body) < 200:
                    action["handler_body"] = handler_body

                actions.append(action)

        return actions

    def _extract_trailing_lambda(self, call_text: str) -> str | None:
        """Extract the trailing lambda content from a composable call."""
        # Find the last { ... } block that's at the top level
        # First, skip past the parameter parens
        paren_depth = 0
        brace_depth = 0
        last_brace_start = -1
        in_parens = False

        for i, ch in enumerate(call_text):
            if ch == "(":
                paren_depth += 1
                in_parens = True
            elif ch == ")":
                paren_depth -= 1
            elif ch == "{" and paren_depth == 0:
                brace_depth += 1
                if brace_depth == 1:
                    last_brace_start = i
            elif ch == "}" and paren_depth == 0:
                brace_depth -= 1

        if last_brace_start == -1:
            return None

        # Extract content between the last top-level { and }
        depth = 0
        for i in range(last_brace_start, len(call_text)):
            if call_text[i] == "{":
                depth += 1
            elif call_text[i] == "}":
                depth -= 1
                if depth == 0:
                    content = call_text[last_brace_start + 1:i]
                    return content.strip() if content.strip() else None

        return None

    def _split_at_depth_zero(self, text: str, delimiter: str) -> list[str]:
        """Split text by delimiter only at depth 0 (outside nested parens/braces/angles)."""
        parts = []
        depth = 0
        current = []
        in_string = False
        escape_next = False

        for ch in text:
            if escape_next:
                current.append(ch)
                escape_next = False
                continue
            if ch == "\\":
                escape_next = True
                current.append(ch)
                continue
            if ch == '"' and not in_string:
                in_string = True
                current.append(ch)
                continue
            if ch == '"' and in_string:
                in_string = False
                current.append(ch)
                continue
            if in_string:
                current.append(ch)
                continue
            if ch in "({<[":
                depth += 1
            elif ch in ")}>]":
                depth -= 1
            if ch == delimiter[0] and depth == 0 and len(delimiter) == 1:
                parts.append("".join(current))
                current = []
                continue
            current.append(ch)

        if current:
            parts.append("".join(current))
        return parts

    def _parse_when_branches(self, when_block: str, _depth: int = 0) -> list[dict]:
        """Parse when block branches."""
        branches = []
        lines = when_block.split("\n")
        i = 0

        while i < len(lines):
            line = lines[i].strip()
            if not line or line.startswith("//"):
                i += 1
                continue

            # Match: condition -> { or condition -> expression
            branch_match = re.match(r'(.+?)\s*->\s*\{?\s*$', line)
            if branch_match:
                condition = branch_match.group(1).strip()
                if "{" in line:
                    block, end_i = self._extract_brace_block(lines, i)
                    i = end_i + 1
                    children = self._parse_block(block, 0, _depth + 1)
                else:
                    children = []
                    i += 1

                branches.append({
                    "condition": condition,
                    "children": children,
                })
            else:
                i += 1

        return branches


def process_screen(project_root: str, screen_name: str, source_file: str, output_dir: str,
                   resolver: StringResourceResolver | None = None,
                   expander: CustomComposableExpander | None = None) -> dict:
    """Process a single screen composable into a tree JSON."""
    full_path = os.path.join(project_root, source_file)
    if not os.path.exists(full_path):
        return {"error": f"Source file not found: {source_file}"}

    with open(full_path, "r", encoding="utf-8", errors="ignore") as f:
        source = f.read()

    parser = ComposeTreeParser(source, screen_name)

    # Extract function parameters
    params = parser.extract_function_params()

    # Parse the tree
    tree = parser.parse_tree()

    # Resolve string resources
    if resolver and tree:
        resolver.resolve_tree(tree)

    # Expand custom composables one level
    if expander and tree:
        expander.expand_custom_nodes(tree, context_source=source)

    result = {
        "composable_name": screen_name,
        "source_file": source_file,
        "function_parameters": params,
        "compose_tree": tree if tree else [],
    }

    return result


def main():
    if len(sys.argv) < 4:
        print("Usage: python3 parse_compose_tree.py <project_root> <discovery_json> <output_dir>")
        sys.exit(1)

    project_root = os.path.abspath(sys.argv[1])
    discovery_path = os.path.abspath(sys.argv[2])
    output_dir = os.path.abspath(sys.argv[3])

    trees_dir = os.path.join(output_dir, "compose_trees")
    os.makedirs(trees_dir, exist_ok=True)

    with open(discovery_path, "r", encoding="utf-8") as f:
        discovery = json.load(f)

    print(f"[Phase 2] Parsing Compose UI trees")

    # Initialize resource resolver
    print("  Loading string resources...")
    resolver = StringResourceResolver()
    resolver.load(project_root)
    print(f"  Loaded {len(resolver.strings)} string resources")

    # Initialize custom composable expander
    expander = CustomComposableExpander(project_root)

    processed = 0
    errors = 0

    # Process composable registrations (NavHost screens)
    for reg in discovery.get("compose_registrations", []):
        screen_name = reg.get("screen_composable")
        source_file = reg.get("source_file")

        if not screen_name or not source_file:
            print(f"  SKIP: {reg.get('route_constant')} (no screen composable or source)")
            continue

        print(f"  Parsing {screen_name} from {source_file}...")
        result = process_screen(project_root, screen_name, source_file, output_dir, resolver, expander)

        if "error" in result:
            print(f"    ERROR: {result['error']}")
            errors += 1
        else:
            tree_count = len(result.get("compose_tree", []))
            print(f"    OK: {tree_count} top-level nodes")
            processed += 1

        # Add route info
        result["route_constant"] = reg.get("route_constant")
        result["route_pattern"] = reg.get("full_route_pattern")
        result["arguments"] = reg.get("arguments", [])

        out_path = os.path.join(trees_dir, f"{screen_name}.json")
        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(result, f, indent=2, ensure_ascii=False)

    # Process Activity compose entries
    for activity in discovery.get("activities", []):
        if activity.get("ui_type") != "compose":
            continue
        entry_fn = activity.get("compose_entry_function")
        source_file = activity.get("source_file")
        if not entry_fn or not source_file or "(inline)" in (entry_fn or ""):
            # For inline setContent, we still want to analyze the Activity source
            simple_name = activity.get("simple_name", "Unknown")
            if source_file:
                print(f"  Parsing Activity {simple_name} (inline setContent)...")
                with open(os.path.join(project_root, source_file), "r", encoding="utf-8", errors="ignore") as f:
                    source = f.read()

                # Try to find MainActivityUI or similar
                composable_fns = re.findall(r'@Composable\s+(?:private\s+)?fun\s+(\w+)\s*\(', source)
                for fn_name in composable_fns:
                    parser = ComposeTreeParser(source, fn_name)
                    tree = parser.parse_tree()
                    if tree:
                        if resolver:
                            resolver.resolve_tree(tree)
                        if expander:
                            expander.expand_custom_nodes(tree, context_source=source)
                        result = {
                            "composable_name": fn_name,
                            "source_file": source_file,
                            "activity": activity.get("class"),
                            "compose_tree": tree,
                        }
                        out_path = os.path.join(trees_dir, f"Activity_{simple_name}_{fn_name}.json")
                        with open(out_path, "w", encoding="utf-8") as f:
                            json.dump(result, f, indent=2, ensure_ascii=False)
                        processed += 1
            continue

    # Process shell components
    for comp in discovery.get("shell_components", []):
        name = comp.get("name")
        source_file = comp.get("source_file")
        if not name or not source_file:
            continue

        print(f"  Parsing shell component {name}...")
        result = process_screen(project_root, name, source_file, output_dir)
        result["shell_role"] = comp.get("role")

        out_path = os.path.join(trees_dir, f"Shell_{name}.json")
        with open(out_path, "w", encoding="utf-8") as f:
            json.dump(result, f, indent=2, ensure_ascii=False)
        processed += 1

    print(f"\n[Phase 2 Complete]")
    print(f"  Processed: {processed}")
    print(f"  Errors: {errors}")
    print(f"  Output: {trees_dir}/")


if __name__ == "__main__":
    main()