import os
import pytest
import gc
os.environ['AKG_AGENTS_STREAM_OUTPUT'] = 'on'
from akg_agents.core.agent.utils.feature_extractor import FeatureExtractor
from akg_agents.op.config.config_validator import load_config
@pytest.mark.level0
@pytest.mark.use_model
@pytest.mark.asyncio
async def test_feature_extract():
framework = "torch"
dsl = "triton_ascend"
op_name = "relu"
framework_code_path = f"tests/op/resources/{op_name}_op/{op_name}_{framework}.py"
impl_code_path = f"tests/op/resources/{op_name}_op/{op_name}_{dsl}_{framework}.py"
with open(framework_code_path, "r", encoding="utf-8") as f:
framework_code = f.read()
with open(impl_code_path, "r", encoding="utf-8") as f:
impl_code = f.read()
config = load_config(dsl).get("agent_model_config", {})
feature = FeatureExtractor(
model_config=config,
impl_code=impl_code,
framework_code=framework_code,
dsl=dsl
)
try:
feature_res, _, _ = await feature.run()
print(f"模型返回的算子{op_name}的特征文本:{feature_res}\n")
finally:
if hasattr(feature, "close"):
await feature.close()
elif hasattr(feature, "__aexit__"):
await feature.__aexit__(None, None, None)
gc.collect()