910e62b5创建于 1月15日历史提交
#!/usr/bin/env vpython3
# Copyright 2025 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import argparse
import os
import pprint
import requests
import sys
import json

from enum import Enum


class TestSelectionPhase(Enum):
  # Trigger represents the first phase when ML model is triggered with the
  # required context to generate tests' pass probability.
  TRIGGER = 1
  # Fetch represents the second phase when results generated by ML model are
  # fetched and used for test selection, based on a predetermined pass
  # probability.
  FETCH = 2


API_URL = 'https://decisiongraph-pa.googleapis.com/v1/rundecisiongraph'
DECISION_GRAPH_NAME = {
    TestSelectionPhase.TRIGGER: 'sts_chrome_trigger_model_root',
    TestSelectionPhase.FETCH: 'sts_chrome_fetch_results_root'
}
STAGE_ID = {
    TestSelectionPhase.TRIGGER: 'trigger_model_for_%d_%d_%d',
    TestSelectionPhase.FETCH: 'fetch_results_for_%d_%d_%d'
}
STAGE_NAME = {
    TestSelectionPhase.TRIGGER: 'sts_chrome_trigger_model',
    TestSelectionPhase.FETCH: 'sts_chrome_fetch_results'
}
PROJECT = 'chromium/src'
BRANCH = 'main'
# Only the first word of gerrit host, i.e., %s-review.googlesource.com
HOSTNAME = 'chromium'
LOCATION_ENUM = 1
STAGE_SERVICE_GSLB = 'blade:test-relevance-stage-service-prod-luci'
# TODO(crbug.com/405145095): Change this to a lower value after decisiongraph
# has moved to spanner queues.
MAX_DURATION_SECONDS = 900
TIMEOUT_SECONDS = MAX_DURATION_SECONDS + 60
MAX_ATTEMPTS = 3
BLOCKING_ENUM = 2

BATCH_SIZE = 5


def fetch_api_data(url, json_payload=None):
  """
    Sends an HTTP POST request to the url and returns the JSON response.

    Args:
      url: The url to send request to.
      json_payload: The payload to send with request.

    Returns:
      The JSON response as a dictionary, or None if the request fails.
    """
  try:
    response = requests.post(url, json=json_payload, timeout=TIMEOUT_SECONDS)
    print(response.text)
    print(response.status_code)
    response.raise_for_status(
    )  # Raise an HTTPError for bad responses (4xx and 5xx)
    return response.json()
  except requests.exceptions.RequestException as e:
    print(f'An error occurred: {e}')
    return None


def load_config_from_json(file_path):
  """
    Reads the configuration parameters from a JSON file.

    Args:
      file_path: Path to the JSON file.
    Returns:
      A dictionary containing the configuration parameters.
    """
  config_data = {}
  try:
    with open(file_path, 'r', encoding='utf-8') as f:
      config_data = json.load(f)
  except FileNotFoundError:
    print(f'Error: Configuration file not found at {file_path}')
    sys.exit(1)
  except json.JSONDecodeError as e:
    print(f'Error: Could not decode JSON from {file_path}. Details: {e}')
    sys.exit(1)
  except Exception as e:
    print(
        f'Unexpected error occurred while reading the configuration file: {e}')
    sys.exit(1)

  # Validate required arguments needed to make API call.
  required_args = ['build_id', 'change', 'patchset', 'builder', 'api_key']
  missing_args = [arg for arg in required_args if arg not in config_data]

  if missing_args:
    print('Error: Missing required arguments in JSON config file.')
    for arg in missing_args:
      print(f'  - {arg}')
    sys.exit(1)

  # Type checking/conversion for API call parameters.
  try:
    config_data['change'] = int(config_data['change'])
    config_data['patchset'] = int(config_data['patchset'])
    if not isinstance(config_data['build_id'], str):
      print("Error: 'build_id' must be a string in the JSON configuration.")
      sys.exit(1)
    if not isinstance(config_data['builder'], str):
      print("Error: 'builder' must be a string in the JSON configuration.")
      sys.exit(1)
    if not isinstance(config_data['api_key'], str):
      print("Error: 'api_key' must be a string in the JSON configuration.")
      sys.exit(1)

  except ValueError:
    print("Error: Invalid data type for 'change' or 'patchset' in JSON. "
          "Expected integers.")
    sys.exit(1)
  except KeyError as e:
    # This should be caught by missing_args check, but it's included
    # as a safeguard.
    print(f'Error: A required key {e} is missing during type validation.')
    sys.exit(1)

  return config_data


def overwrite_filter_file(filter_file_dir, test_suite, tests_to_skip):
  """
  Overwrites a filter file with tests to skip.

  Args:
    filter_file_dir: The directory containing the filter files.
    test_suite: The name of the test suite (e.g., 'browser_tests').
    tests_to_skip: A list of test names to be written to the file.
  """
  file_name = f'{test_suite}.filter'
  file_path = os.path.join(filter_file_dir, file_name)

  # The calling script is responsible for creating the directory.
  if not os.path.isdir(filter_file_dir):
    print(f'Error: Filter file directory not found at {filter_file_dir}')
    return False

  try:
    # Overwrite the file with the new list of tests to skip.
    with open(file_path, 'w', encoding='utf-8') as f:
      if tests_to_skip:
        f.write(
            '# A list of tests to be skipped, generated by test selection.\n')
        for test_name in sorted(set(tests_to_skip)):
          f.write(f'-{test_name}.*\n')
    print(
        f'Successfully wrote {len(tests_to_skip)} tests to skip to {file_path}')
    return True
  except Exception as e:
    print(f'An error occurred while writing to {file_path}: {e}')
    return False


def main():
  parser = argparse.ArgumentParser(
      description=
      'A script to trigger or fetch results for Smart Test Selection.')
  parser.add_argument('--test-targets',
                      required=True,
                      nargs='+',
                      type=str,
                      help='Name of the test targets e.g., browser_tests.')
  parser.add_argument(
      '--sts-config-file',
      required=True,
      type=str,
      help='Path to the JSON file containing config for smart test selection.')
  parser.add_argument(
      '--test-selection-phase',
      required=True,
      type=str.upper,
      choices=[p.name for p in TestSelectionPhase],
      help='The phase of test selection to run: TRIGGER or FETCH.')
  parser.add_argument(
      '--filter-file-dir',
      type=str,
      help='Directory to write test filter files. Required for FETCH phase.')
  args = parser.parse_args()

  # Set the phase based on the command-line flag.
  phase = TestSelectionPhase[args.test_selection_phase]

  # --filter-file-dir is only needed for the FETCH phase.
  if phase == TestSelectionPhase.FETCH and not args.filter_file_dir:
    parser.error('--filter-file-dir is required when phase is FETCH.')

  config = load_config_from_json(args.sts_config_file)

  build_id = config['build_id']
  change = config['change']
  patchset = config['patchset']
  builder = config['builder']
  api_key = config['api_key']

  # Find the corresponding "main" CQ builder name.
  canonical_builder = builder.removesuffix('-test-selection')

  test_target_batches = [
      args.test_targets[i:i + BATCH_SIZE]
      for i in range(0, len(args.test_targets), BATCH_SIZE)
  ]

  return_status = 0
  for batch_idx, test_target_batch in enumerate(test_target_batches):
    checks = []
    print('batch num = %d' % batch_idx)
    for test_target in test_target_batch:
      check = {
          'identifier': {
              'luci_test': {
                  'project': PROJECT,
                  'branch': BRANCH,
                  'builder': canonical_builder,
                  'test_suite': test_target,
              }
          },
          'run': {
              'luci_test': {
                  'build_id': build_id
              }
          },
      }
      checks.append(check)

    stage_id = STAGE_ID[phase] % (change, patchset, batch_idx)
    payload = {
        'graph': {
            'name':
            DECISION_GRAPH_NAME[phase],
            'stages': [{
                'stage': {
                    'id': stage_id,
                    'name': STAGE_NAME[phase],
                },
                'execution_options': {
                    'location': LOCATION_ENUM,
                    'address': STAGE_SERVICE_GSLB,
                    'prepare': phase == TestSelectionPhase.TRIGGER,
                    'max_duration': {
                        'seconds': MAX_DURATION_SECONDS
                    },
                    'max_attempts': MAX_ATTEMPTS,
                    'blocking': BLOCKING_ENUM,
                },
            }],
        },
        'input': [{
            'stage': {
                'id': stage_id,
                'name': STAGE_NAME[phase],
            },
            'input': [{
                'checks': checks,
            }],
            'changes': {
                'changes': [{
                    'hostname': HOSTNAME,
                    'change_number': change,
                    'patchset': patchset,
                }],
            },
        }]
    }
    print('payload = ')
    pprint.pprint(payload)

    request_with_key = '%s?key=%s' % (API_URL, api_key)
    response_data = fetch_api_data(url=request_with_key, json_payload=payload)

    if response_data:
      if phase == TestSelectionPhase.TRIGGER:
        print('API Response:')
        pprint.pprint(response_data)
      else:  # FETCH phase
        try:
          stage_outputs = response_data['outputs']
          assert len(stage_outputs) == 1
          checks = stage_outputs[0]['checks']
          for count, check in enumerate(checks):
            print(f'processing check {count}')
            if 'children' not in check:
              continue

            children = check['children']
            test_suite = check['identifier']['luciTest']['testSuite']
            tests_to_skip = []
            for child in children:
              tests_to_skip.append(child['identifier']['luciTest']['testId'])

            if not overwrite_filter_file(args.filter_file_dir, test_suite,
                                         tests_to_skip):
              return_status = 1
        except (KeyError, IndexError) as e:
          print('Error when parsing response from decisiongraph api')
          pprint.pprint(response_data)
          raise e
    else:
      print('Failed to fetch data from the API.')
      return_status = 1

  sys.exit(return_status)


if __name__ == '__main__':
  sys.exit(main())