"""
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
from pprint import pprint
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
import pandas as pd
from qlib.contrib.evaluate import backtest_daily
from qlib.contrib.evaluate import risk_analysis
from qlib.contrib.strategy import TopkDropoutStrategy
class OnlineSimulationExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=None,
trainer="TrainerR",
):
"""
Init OnlineManagerExample.
Args:
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
region (str, optional): the stock region. Defaults to "cn".
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
task_db_name (str, optional): database name. Defaults to "rolling_db".
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
rolling_step (int, optional): the step for rolling. Defaults to 80.
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
self.end_time = end_time
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
)
if trainer == "TrainerRM":
self.trainer = TrainerRM(self.exp_name, self.task_pool)
elif trainer == "TrainerR":
self.trainer = TrainerR(self.exp_name)
else:
raise NotImplementedError(f"This type of input is not supported")
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
begin_time=self.start_time,
)
self.tasks = tasks
def reset(self):
if isinstance(self.trainer, TrainerRM):
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
def main(self):
print("========== reset ==========")
self.reset()
print("========== simulate ==========")
self.rolling_online_manager.simulate(end_time=self.end_time)
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
signals = self.rolling_online_manager.get_signals()
print(signals)
CSI300_BENCH = "SH000903"
STRATEGY_CONFIG = {
"topk": 30,
"n_drop": 3,
"signal": signals.to_frame("score"),
}
strategy_obj = TopkDropoutStrategy(**STRATEGY_CONFIG)
report_normal, positions_normal = backtest_daily(
start_time=signals.index.get_level_values("datetime").min(),
end_time=signals.index.get_level_values("datetime").max(),
strategy=strategy_obj,
)
analysis = dict()
analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"])
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"]
)
analysis_df = pd.concat(analysis)
pprint(analysis_df)
def worker(self):
print("========== worker ==========")
if isinstance(self.trainer, TrainerRM):
self.trainer.worker()
else:
print(f"{type(self.trainer)} is not supported for worker.")
if __name__ == "__main__":
fire.Fire(OnlineSimulationExample)