3a348aec创建于 2024年6月20日历史提交
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#  Copyright (c) Microsoft Corporation.\n",
    "#  Licensed under the MIT License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, site\n",
    "from pathlib import Path\n",
    "\n",
    "################################# NOTE #################################\n",
    "#  Please be aware that if colab installs the latest numpy and pyqlib  #\n",
    "#  in this cell, users should RESTART the runtime in order to run the  #\n",
    "#  following cells successfully.                                       #\n",
    "########################################################################\n",
    "\n",
    "try:\n",
    "    import qlib\n",
    "except ImportError:\n",
    "    # install qlib\n",
    "    ! pip install --upgrade numpy\n",
    "    ! pip install pyqlib\n",
    "    if \"google.colab\" in sys.modules:\n",
    "        # The Google colab environment is a little outdated. We have to downgrade the pyyaml to make it compatible with other packages\n",
    "        ! pip install pyyaml==5.4.1\n",
    "    # reload\n",
    "    site.main()\n",
    "\n",
    "scripts_dir = Path.cwd().parent.joinpath(\"scripts\")\n",
    "if not scripts_dir.joinpath(\"get_data.py\").exists():\n",
    "    # download get_data.py script\n",
    "    scripts_dir = Path(\"~/tmp/qlib_code/scripts\").expanduser().resolve()\n",
    "    scripts_dir.mkdir(parents=True, exist_ok=True)\n",
    "    import requests\n",
    "\n",
    "    with requests.get(\"https://raw.githubusercontent.com/microsoft/qlib/main/scripts/get_data.py\", timeout=10) as resp:\n",
    "        with open(scripts_dir.joinpath(\"get_data.py\"), \"wb\") as fp:\n",
    "            fp.write(resp.content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import qlib\n",
    "import pandas as pd\n",
    "from qlib.constant import REG_CN\n",
    "from qlib.utils import exists_qlib_data, init_instance_by_config\n",
    "from qlib.workflow import R\n",
    "from qlib.workflow.record_temp import SignalRecord, PortAnaRecord\n",
    "from qlib.utils import flatten_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# use default data\n",
    "# NOTE: need to download data from remote: python scripts/get_data.py qlib_data_cn --target_dir ~/.qlib/qlib_data/cn_data\n",
    "provider_uri = \"~/.qlib/qlib_data/cn_data\"  # target_dir\n",
    "if not exists_qlib_data(provider_uri):\n",
    "    print(f\"Qlib data is not found in {provider_uri}\")\n",
    "    sys.path.append(str(scripts_dir))\n",
    "    from get_data import GetData\n",
    "\n",
    "    GetData().qlib_data(target_dir=provider_uri, region=REG_CN)\n",
    "qlib.init(provider_uri=provider_uri, region=REG_CN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "market = \"csi300\"\n",
    "benchmark = \"SH000300\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# train model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "###################################\n",
    "# train model\n",
    "###################################\n",
    "data_handler_config = {\n",
    "    \"start_time\": \"2008-01-01\",\n",
    "    \"end_time\": \"2020-08-01\",\n",
    "    \"fit_start_time\": \"2008-01-01\",\n",
    "    \"fit_end_time\": \"2014-12-31\",\n",
    "    \"instruments\": market,\n",
    "}\n",
    "\n",
    "task = {\n",
    "    \"model\": {\n",
    "        \"class\": \"LGBModel\",\n",
    "        \"module_path\": \"qlib.contrib.model.gbdt\",\n",
    "        \"kwargs\": {\n",
    "            \"loss\": \"mse\",\n",
    "            \"colsample_bytree\": 0.8879,\n",
    "            \"learning_rate\": 0.0421,\n",
    "            \"subsample\": 0.8789,\n",
    "            \"lambda_l1\": 205.6999,\n",
    "            \"lambda_l2\": 580.9768,\n",
    "            \"max_depth\": 8,\n",
    "            \"num_leaves\": 210,\n",
    "            \"num_threads\": 20,\n",
    "        },\n",
    "    },\n",
    "    \"dataset\": {\n",
    "        \"class\": \"DatasetH\",\n",
    "        \"module_path\": \"qlib.data.dataset\",\n",
    "        \"kwargs\": {\n",
    "            \"handler\": {\n",
    "                \"class\": \"Alpha158\",\n",
    "                \"module_path\": \"qlib.contrib.data.handler\",\n",
    "                \"kwargs\": data_handler_config,\n",
    "            },\n",
    "            \"segments\": {\n",
    "                \"train\": (\"2008-01-01\", \"2014-12-31\"),\n",
    "                \"valid\": (\"2015-01-01\", \"2016-12-31\"),\n",
    "                \"test\": (\"2017-01-01\", \"2020-08-01\"),\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "}\n",
    "\n",
    "# model initialization\n",
    "model = init_instance_by_config(task[\"model\"])\n",
    "dataset = init_instance_by_config(task[\"dataset\"])\n",
    "\n",
    "# start exp to train model\n",
    "with R.start(experiment_name=\"train_model\"):\n",
    "    R.log_params(**flatten_dict(task))\n",
    "    model.fit(dataset)\n",
    "    R.save_objects(trained_model=model)\n",
    "    rid = R.get_recorder().id"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# prediction, backtest & analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "###################################\n",
    "# prediction, backtest & analysis\n",
    "###################################\n",
    "port_analysis_config = {\n",
    "    \"executor\": {\n",
    "        \"class\": \"SimulatorExecutor\",\n",
    "        \"module_path\": \"qlib.backtest.executor\",\n",
    "        \"kwargs\": {\n",
    "            \"time_per_step\": \"day\",\n",
    "            \"generate_portfolio_metrics\": True,\n",
    "        },\n",
    "    },\n",
    "    \"strategy\": {\n",
    "        \"class\": \"TopkDropoutStrategy\",\n",
    "        \"module_path\": \"qlib.contrib.strategy.signal_strategy\",\n",
    "        \"kwargs\": {\n",
    "            \"model\": model,\n",
    "            \"dataset\": dataset,\n",
    "            \"topk\": 50,\n",
    "            \"n_drop\": 5,\n",
    "        },\n",
    "    },\n",
    "    \"backtest\": {\n",
    "        \"start_time\": \"2017-01-01\",\n",
    "        \"end_time\": \"2020-08-01\",\n",
    "        \"account\": 100000000,\n",
    "        \"benchmark\": benchmark,\n",
    "        \"exchange_kwargs\": {\n",
    "            \"freq\": \"day\",\n",
    "            \"limit_threshold\": 0.095,\n",
    "            \"deal_price\": \"close\",\n",
    "            \"open_cost\": 0.0005,\n",
    "            \"close_cost\": 0.0015,\n",
    "            \"min_cost\": 5,\n",
    "        },\n",
    "    },\n",
    "}\n",
    "\n",
    "# backtest and analysis\n",
    "with R.start(experiment_name=\"backtest_analysis\"):\n",
    "    recorder = R.get_recorder(recorder_id=rid, experiment_name=\"train_model\")\n",
    "    model = recorder.load_object(\"trained_model\")\n",
    "\n",
    "    # prediction\n",
    "    recorder = R.get_recorder()\n",
    "    ba_rid = recorder.id\n",
    "    sr = SignalRecord(model, dataset, recorder)\n",
    "    sr.generate()\n",
    "\n",
    "    # backtest & analysis\n",
    "    par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
    "    par.generate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# analyze graphs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from qlib.contrib.report import analysis_model, analysis_position\n",
    "from qlib.data import D\n",
    "\n",
    "recorder = R.get_recorder(recorder_id=ba_rid, experiment_name=\"backtest_analysis\")\n",
    "print(recorder)\n",
    "pred_df = recorder.load_object(\"pred.pkl\")\n",
    "report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n",
    "positions = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n",
    "analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## analysis position"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "analysis_position.report_graph(report_normal_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### risk analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "analysis_position.risk_analysis_graph(analysis_df, report_normal_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## analysis model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_df = dataset.prepare(\"test\", col_set=\"label\")\n",
    "label_df.columns = [\"label\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### score IC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_label = pd.concat([label_df, pred_df], axis=1, sort=True).reindex(label_df.index)\n",
    "analysis_position.score_ic_graph(pred_label)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### model performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "analysis_model.model_performance_graph(pred_label)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}