Skip to content

API 参考

完整可运行示例: examples/17_python_bindings/python_bindings_demo.py 覆盖全部核心模块的 Python API 演示。

本文档提供 AXON 量化交易框架各顶层模块的速查表与关键 API 代码示例,帮助开发者快速定位所需功能。


1. 顶层模块速查表

模块 Crate 核心功能 主要类型
rl axon-rl Gymnasium 兼容的交易环境、动作/观测空间、奖励函数 TradingEnv, ActionSpace, ObservationSpace, RewardFn
llm axon-llm LLM 后端抽象、ReAct Agent、工具调用 LLMBackend, ReActAgent, ToolDefinition, Message
hpo axon-hpo 超参数优化(Optuna 集成)、Study / Trial 管理 HPOConfig, StudyConfig, TrialResult, SearchSpaceDef
walk_forward axon-walk-forward 滚动/扩展窗口交叉验证、稳定性分析 WalkForwardConfig, FoldResult, AggregatedMetrics
tracker axon-tracker 实验追踪(MLflow / 内存后端)、指标记录 ExperimentTracker, ParamValue, RunStatus
registry axon-registry 模型版本管理、阶段转换、回滚 ModelRegistry, ModelVersion, ModelStage, SemVer
distributed axon-distributed Ray 分布式训练、参数服务器、Checkpoint DistributedConfig, ClusterConfig, AlgorithmConfig
exchange axon-exchange 交易所适配器、WebSocket、限流器 ExchangeAdapter, ExchangeConfig, RateLimitConfig
explain axon-explain 可解释性:SHAP、反事实、报告生成 KernelSHAP, CounterfactualGenerator, ReportGenerator
ensemble axon-ensemble 模型集成:投票、加权、动态权重、堆叠 DynamicWeightedEnsemble, EnsembleManager, StackingEnsemble
inference axon-inference 模型推理引擎、热更新、多后端支持 InferenceEngine, ModelHotReloader, OnnxBackend, CandleBackend
backtest axon-backtest 事件驱动回测引擎、撮合、冲击模型 BacktestEngine, MatchingEngine, RunResult

2. 关键 API 代码示例

2.1 TradingEnv — 交易环境

TradingEnv 是 AXON 的核心 RL 环境,完全兼容 Gymnasium 接口。

from axon_quant import (
    TradingEnv, EnvConfig,
    DefaultObservationSpace, FeatureConfig, FeatureSource, NormalizerType,
    DiscreteActionSpace, TradingDirection,
    PnLReward, SharpeReward, MultiObjectiveReward,
    MarketBar,
)

# 1. 配置环境
config = EnvConfig(
    initial_capital=100_000.0,    # 初始资金 10 万 USDT
    transaction_cost=0.001,       # 交易成本 10 bps
    slippage=0.0005,              # 滑点 5 bps
    max_position_ratio=1.0,       # 最大满仓
    max_steps=1000,               # 每 episode 最大步数
    seed=None,                    # 随机种子
    symbol="BTCUSDT",             # 交易标的
    return_window=252,            # 收益率历史窗口(用于夏普计算)
)

# 2. 定义观测空间(特征工程)
obs_space = DefaultObservationSpace.new(
    window_size=20,               # 保留最近 20 个时间步
    features=[
        FeatureConfig(
            name="close",
            source=FeatureSource.PriceField("close"),
            normalizer=NormalizerType.ZScore,  # Z-Score 归一化
            clip_range=(-5.0, 5.0),            # 截断异常值
        ),
        FeatureConfig(
            name="volume",
            source=FeatureSource.VolumeField("volume"),
            normalizer=NormalizerType.ZScore,
        ),
        FeatureConfig(
            name="rsi",
            source=FeatureSource.RSI(14),      # 内置 RSI 计算
            normalizer=NormalizerType.MinMax,  # 映射到 [0, 1]
        ),
    ],
)

# 3. 定义动作空间(离散)
action_space = DiscreteActionSpace.new(
    n_quantity_bins=5,            # 5 个仓位档位: 20%/40%/60%/80%/100%
    direction=TradingDirection.Both,  # 允许做多和做空
)

# 4. 定义奖励函数(多目标)
reward_fn = MultiObjectiveReward([
    PnLReward(relative=True, scale=1.0),     # 相对收益率
    SharpeReward(risk_free_rate=0.02, window=20),  # 滚动夏普比率
])

# 5. 加载市场数据
market_data = load_bars("BTCUSDT", "1h", start="2024-01-01", end="2024-06-01")

# 6. 创建环境
env = TradingEnv.new(
    config=config,
    action_space=action_space,
    observation_space=obs_space,
    reward_fn=reward_fn,
    market_data=market_data,
)

# 7. 标准 Gymnasium 交互循环
obs = env.reset()
done = False
total_reward = 0.0

while not done:
    # 这里可以接入 RL 模型或规则策略
    action = model.predict(obs) if model else env.action_space.sample()

    obs, reward, done, info = env.step(action)
    total_reward += reward

    print(env.render())  # 输出: step=123/5000 | value=$102340.50 | pos=0.5000

print(f"Episode 总奖励: {total_reward:.2f}")
print(f"最终净值: {env.portfolio().portfolio_value:.2f}")

2.2 LLMBackend — LLM 后端

LLMBackend 是统一的 LLM 接口,支持 OpenAI、DeepSeek、本地推理服务等。

from axon_quant import LLMBackend, Message, ToolDefinition, LLMResponse

# 创建 OpenAI 后端
llm = OpenAIBackend(
    api_key="YOUR_API_KEY",
    model="deepseek-chat",        # 或 "gpt-4", "claude-3-opus"
    base_url="https://api.deepseek.com",
)

# 基础对话
messages = [
    Message(role="system", content="你是一位专业的量化交易分析师。"),
    Message(role="user", content="分析 BTC 当前的技术面。"),
]

response = await llm.complete(messages)
print(response.content)

# Function Calling(工具调用)
tools = [
    ToolDefinition(
        name="get_price",
        description="获取指定交易对的当前价格",
        parameters={
            "type": "object",
            "properties": {
                "symbol": {"type": "string", "description": "交易对,如 BTCUSDT"},
            },
            "required": ["symbol"],
        },
    ),
    ToolDefinition(
        name="get_rsi",
        description="计算指定交易对的 RSI 指标",
        parameters={
            "type": "object",
            "properties": {
                "symbol": {"type": "string"},
                "period": {"type": "integer", "default": 14},
            },
            "required": ["symbol"],
        },
    ),
]

response = await llm.complete_with_tools(messages, tools)

# 解析工具调用
if response.tool_calls:
    for call in response.tool_calls:
        print(f"调用工具: {call.name}, 参数: {call.arguments}")
        # 执行工具并返回结果...

# 错误处理
from axon_quant import LLMError

try:
    response = await llm.complete(messages)
except LLMError.RateLimited as e:
    print(f"被限流,建议等待 {e.retry_after} 秒")
    await asyncio.sleep(e.retry_after or 60)
except LLMError.ContextOverflow as e:
    print(f"上下文超限: {e.needed} > {e.limit}")
    # 截断历史消息或切换长上下文模型

2.3 Tracker — 实验追踪

ExperimentTracker 提供统一的实验记录接口,支持 MLflow 和内存后端。

from axon_quant import ExperimentTracker, MLflowTracker, MemoryTracker, ParamValue, RunStatus

# 创建 MLflow 追踪器(生产环境)
tracker = MLflowTracker(
    tracking_uri="http://localhost:5000",
    experiment_name="ppo_btc_trading",
    run_name="run_2024_06_18_v1",
)

# 或创建内存追踪器(测试/快速迭代)
# tracker = MemoryTracker.new()

# 记录超参数
tracker.log_param("learning_rate", ParamValue.Float(3e-4))
tracker.log_param("batch_size", ParamValue.Int(128))
tracker.log_param("hidden_size", ParamValue.Int(256))
tracker.log_param("env_symbol", ParamValue.String("BTCUSDT"))

# 批量记录参数
tracker.log_params([
    ("gamma", ParamValue.Float(0.99)),
    ("gae_lambda", ParamValue.Float(0.95)),
    ("clip_range", ParamValue.Float(0.2)),
])

# 记录指标(支持按 step 记录)
for step in range(1000):
    loss = train_step()
    tracker.log_metric("loss", loss, step=step)

    if step % 100 == 0:
        sharpe = evaluate_sharpe()
        tracker.log_metric("sharpe_ratio", sharpe, step=step)
        tracker.log_metric("portfolio_value", env.portfolio().portfolio_value, step=step)

# 记录直方图(如权重分布)
tracker.log_histogram("actor_weights", weights_flattened, step=1000)

# 记录图像(如收益曲线)
tracker.log_image("pnl_curve", png_bytes, format=ImageFormat.PNG, step=1000)

# 上传模型产物
tracker.log_artifact("model.onnx", Path("./models/model.onnx"))

# 设置标签
tracker.set_tag("model_type", "PPO")
tracker.set_tag("data_source", "binance_1h")

# 结束运行
tracker.finish(RunStatus.Success)

# 刷新缓冲区(确保数据已写入)
tracker.flush()

2.4 Registry — 模型注册表

ModelRegistry 管理模型的全生命周期:注册、阶段转换、回滚。

from axon_quant import (
    ModelRegistry, LocalStorage,
    ModelMetadata, ModelStage, SemVer, ModelSignature,
    VersionFilter,
)
from pathlib import Path

# 创建注册表(本地文件存储)
storage = LocalStorage.new(base_dir="./model_registry")
registry = ModelRegistry.new(storage)

# 注册新模型版本
metadata = ModelMetadata(
    tags={
        "algorithm": "PPO",
        "env": "BTCUSDT_1h",
        "sharpe": "1.85",
    },
    description="PPO 模型 v3,优化了夏普比率",
)

signature = ModelSignature(
    inputs=["observation: float32[1,64,128]"],
    outputs=["action_probs: float32[1,3]"],
)

model_version = await registry.register(
    name="ppo_btc_trading",
    artifact_path=Path("./models/ppo_v3.onnx"),
    metadata=metadata,
    signature=signature,
)
print(f"注册成功: {model_version.name}@{model_version.version}")
# 输出: ppo_btc_trading@1.0.0

# 获取最新版本
latest = await registry.get("ppo_btc_trading", version=None)
print(f"最新版本: {latest.version}, 阶段: {latest.stage}")

# 获取 Production 版本
prod = await registry.get_production("ppo_btc_trading")

# 阶段转换: Staging -> Production
await registry.transition_stage(
    name="ppo_btc_trading",
    version=SemVer.parse("1.0.0"),
    new_stage=ModelStage.Production,
)
# 注意: 提升到 Production 时,旧 Production 版本自动降级为 Archived

# 查询版本列表
versions = await registry.list_versions(
    name="ppo_btc_trading",
    filter=VersionFilter(
        stage=ModelStage.Production,
        min_version=SemVer.parse("1.0.0"),
        limit=10,
    ),
)

# 回滚到上一个 Production 版本
rolled_back = await registry.rollback("ppo_btc_trading")
print(f"回滚到: {rolled_back.version}")

# 下载模型产物
await registry.download_artifact(
    name="ppo_btc_trading",
    version=SemVer.parse("1.0.0"),
    dest=Path("./downloads/ppo_v1.onnx"),
)

# 列出所有模型
models = registry.list_models()
print(f"已注册模型: {models}")

2.5 InferenceEngine — 推理引擎

InferenceEngine 提供统一的模型推理接口,支持 ONNX、tch、Candle 三后端。

from axon_quant import (
    InferenceEngine, OnnxBackend, CandleBackend, TchBackend,
    ModelConfig, Device, InferenceBackend,
    Observation, Action,
)
from pathlib import Path

# 通用配置
config = ModelConfig(
    path="models/trading_model.onnx",
    backend=InferenceBackend.ONNX,
    device=Device.CUDA(0),        # 使用 GPU 0
    input_shape=[1, 64, 128],     # [batch, seq_len, features]
    output_dim=3,                 # Buy / Sell / Hold
    fp16=True,                    # 启用 FP16
    num_threads=4,                # CPU 线程数
)

# ONNX 后端
engine = OnnxBackend(config)
engine.load(Path(config.path))

# Candle 后端(纯 Rust,无 Python 依赖)
candle_config = ModelConfig(
    path="models/trading_model.safetensors",
    backend=InferenceBackend.CANDLE,
    device=Device.CPU,
    input_shape=[1, 4, 1],        # input_dim = 1*4*1 = 4
    output_dim=3,
    fp16=False,
    num_threads=4,
)
candle_engine = CandleBackend(candle_config)
candle_engine.load(Path(candle_config.path))

# 单条推理
obs = Observation(
    features=[0.5, -0.2, 1.1, 0.0, ...],  # 64*128=8192 维特征
    feature_names=[...],
    timestamp=1234567890,
)
action = engine.infer(obs)
print(f"预测动作: {action}")

# 批量推理(生产环境推荐)
observations = [obs1, obs2, obs3, obs4]
actions = engine.infer_batch(observations)
print(f"批量预测: {len(actions)} 个动作")

# 热更新(原子替换 session)
from axon_quant import ModelHotReloader

reloader = ModelHotReloader(engine, config)
reloader.spawn_watcher()  # 启动文件监控

# 手动触发重载
new_version = await reloader.reload()
print(f"模型已更新到版本 {new_version}")

# 订阅版本变化
version_rx = reloader.subscribe()
await version_rx.changed()
print(f"检测到新版本: {version_rx.borrow()}")

2.6 ExchangeAdapter — 交易所适配器

ExchangeAdapter 提供统一的交易所接口,目前支持 Binance 和 OKX。

from axon_quant import (
    BinanceAdapter, OkxAdapter,
    ExchangeConfig, ExchangeId,
    Symbol, Order, OrderId, OrderType, Side, TimeInForce,
    RateLimitConfig, ReconnectConfig,
    MarginType, PositionMode,
)
from decimal import Decimal

# Binance 配置
config = ExchangeConfig(
    exchange_id=ExchangeId.Binance,
    api_key="YOUR_API_KEY",
    api_secret="YOUR_API_SECRET",
    passphrase=None,
    testnet=True,
    rest_base_url="https://testnet.binance.vision",
    ws_url="wss://testnet.binance.vision/ws",
    rate_limit=RateLimitConfig(
        requests_per_second=10,
        orders_per_minute=60,
        ws_messages_per_second=50,
    ),
    reconnect=ReconnectConfig(
        max_retries=10,
        initial_backoff_ms=500,
        max_backoff_ms=30000,
        backoff_multiplier=2.0,
        circuit_breaker_threshold=5,
        circuit_breaker_reset_sec=60,
    ),
    position_endpoint="/fapi/v2/positionRisk",
    fapi_base_url="https://testnet.binancefuture.com",
)

# 创建并连接
adapter = BinanceAdapter(config)
await adapter.connect()

# 订阅行情
await adapter.subscribe([Symbol("BTCUSDT"), Symbol("ETHUSDT")])

# 获取行情通道
market_rx = adapter.market_data_rx()
while True:
    msg = await market_rx.recv()
    match msg.type:
        case "Ticker":
            print(f"[{msg.data.symbol}] 买 {msg.data.bid} / 卖 {msg.data.ask}")
        case "Trade":
            print(f"成交: {msg.data.price} x {msg.data.quantity}")

# 下单
order = Order(
    client_order_id=OrderId.new(),
    symbol=Symbol("BTCUSDT"),
    side=Side.Buy,
    order_type=OrderType.Market,
    price=None,
    quantity=Decimal("0.001"),
    time_in_force=TimeInForce.Gtc,
    exchange=ExchangeId.Binance,
    meta={"strategy": "momentum_v1"},
)
order_id = await adapter.send_order(order)

# 撤单
await adapter.cancel_order(order_id)

# 合约操作
await adapter.set_leverage("BTCUSDT", leverage=10)
await adapter.set_margin_type("BTCUSDT", MarginType.Isolated)
await adapter.set_position_mode(hedge_mode=True)

# 查询账户
account = await adapter.get_account_info()
print(f"总余额: {account.total_balance}, 可用: {account.available_balance}")

# 查询资金费率
funding = await adapter.get_funding_rate("BTCUSDT")
print(f"资金费率: {funding.rate}, 下次结算: {funding.next_funding_ms}")

3. 配置参数参考表

3.1 EnvConfig(交易环境)

参数 类型 默认值 说明
initial_capital f64 100_000.0 初始资金
transaction_cost f64 0.001 交易成本比例(10 bps)
slippage f64 0.0005 滑点比例(5 bps)
max_position_ratio f64 1.0 最大持仓比例(0.0 ~ 1.0)
max_steps usize 1000 每 episode 最大步数
seed Option<u64> None 随机种子
symbol String "BTCUSDT" 交易标的代码
return_window usize 252 收益率历史窗口大小

3.2 ExchangeConfig(交易所)

参数 类型 默认值 说明
exchange_id ExchangeId - 交易所标识(Binance / OKX)
api_key String - API 密钥
api_secret String - API 密钥
passphrase Option<String> None OKX 专用 passphrase
testnet bool true 是否使用测试网
rest_base_url String - REST API 基础 URL
ws_url String - WebSocket URL
rate_limit RateLimitConfig - 限流配置
reconnect ReconnectConfig - 重连配置
position_endpoint String "/fapi/v2/positionRisk" 持仓查询端点
fapi_base_url Option<String> None 合约 API 基础 URL

3.3 HPOConfig(超参数优化)

参数 类型 默认值 说明
study.study_name String - Study 名称
study.direction StudyDirection Maximize 优化方向
study.sampler SamplerConfig Tpe 采样器类型
study.pruner PrunerConfig MedianPruner 剪枝器类型
study.storage Option<String> None Optuna storage URL
search_space HashMap - 参数搜索空间定义
hpo.n_trials usize 50 总 trial 数
hpo.n_jobs usize 1 并行 trial 数
hpo.timeout_seconds Option<u64> None 总超时
hpo.early_stopping bool false 是否启用早停

3.4 WalkForwardConfig(滚动验证)

参数 类型 默认值 说明
train_size usize - 训练窗口大小
validation_size usize 0 验证窗口大小
test_size usize - 测试窗口大小
step_size usize - 滚动步长
window_type WindowType Expanding 窗口类型(Rolling / Expanding)
purge_gap usize 0 训练-测试间防泄漏间隔
embargo_pct f64 0.01 Embargo 百分比

3.5 DistributedConfig(分布式训练)

参数 类型 默认值 说明
cluster.num_workers usize - Worker 数量
cluster.num_cpus_per_worker usize 1 每 Worker CPU 数
cluster.num_gpus_per_worker f64 0.0 每 Worker GPU 数
cluster.cluster_address Option<String> None Ray 集群地址
algorithm.algorithm String - 算法名(PPO / SAC / DQN / IMPALA / APE_X)
algorithm.framework String "torch" 框架(torch / tensorflow)
resources.num_envs_per_worker usize - 每 Worker 环境数
resources.train_batch_size usize - 训练批大小
resources.sgd_minibatch_size usize - SGD minibatch 大小
fault_tolerance.checkpoint_interval_s u64 - Checkpoint 间隔(秒)
fault_tolerance.checkpoint_dir String - Checkpoint 保存目录

3.6 ModelConfig(推理引擎)

参数 类型 默认值 说明
path String - 模型文件路径
backend InferenceBackend - 后端类型(ONNX / TCH / CANDLE)
device Device - 设备(CPU / CUDA(n))
input_shape [usize; 3] - 输入形状 [batch, seq, features]
output_dim usize - 输出维度
fp16 bool false 是否启用 FP16
num_threads usize 4 CPU 推理线程数

4. 常用枚举速查

4.1 ActionSpace(动作空间)

from axon_quant import ActionSpace, DiscreteActionSpace, ContinuousActionSpace, TradingDirection

# 离散动作空间
discrete = ActionSpace.Discrete(
    DiscreteActionSpace.new(n_quantity_bins=5, direction=TradingDirection.Both)
)
# 动作索引: 0=Hold, 1-5=Buy(20%-100%), 6-10=Sell(20%-100%)

# 连续动作空间
continuous = ActionSpace.Continuous(
    ContinuousActionSpace.new(min=-1.0, max=1.0)
)
# -1.0 = 满仓做空, 0.0 = 空仓, 1.0 = 满仓做多

4.2 NormalizerType(归一化策略)

from axon_quant import NormalizerType

NormalizerType.ZScore    # (x - mean) / std,保留历史统计量
NormalizerType.MinMax    # (x - min) / (max - min) -> [0, 1]
NormalizerType.Robust    # (x - median) / IQR,抗异常值
NormalizerType.None      # 不归一化

4.3 ModelStage(模型阶段)

from axon_quant import ModelStage

ModelStage.Staging      # 新注册,待验证
ModelStage.Production   # 线上运行
ModelStage.Archived     # 旧版本归档
ModelStage.RolledBack   # 已回滚

4.4 OrderType / TimeInForce(订单类型)

from axon_quant import OrderType, TimeInForce

OrderType.Limit         # 限价单
OrderType.Market        # 市价单
OrderType.StopLoss      # 止损单
OrderType.StopLimit     # 限价止损单

TimeInForce.Gtc         # Good Till Cancelled
TimeInForce.Ioc         # Immediate Or Cancel
TimeInForce.Fok         # Fill Or Kill

5. 模块依赖关系

                    ┌─────────────────┐
                    │   Application   │
                    └────────┬────────┘
        ┌────────────────────┼────────────────────┐
        │                    │                    │
        ▼                    ▼                    ▼
┌──────────────┐   ┌──────────────┐   ┌──────────────┐
│   backtest   │   │   exchange   │   │   ensemble   │
└──────────────┘   └──────────────┘   └──────────────┘
        │                    │                    │
        └────────────────────┼────────────────────┘
                    ┌────────┴────────┐
                    │      rl         │
                    │  (TradingEnv)   │
                    └────────┬────────┘
        ┌────────────────────┼────────────────────┐
        │                    │                    │
        ▼                    ▼                    ▼
┌──────────────┐   ┌──────────────┐   ┌──────────────┐
│  inference   │   │     llm      │   │   explain    │
└──────────────┘   └──────────────┘   └──────────────┘
                    ┌────────┴────────┐
                    │  core types     │
                    └─────────────────┘

6. 版本兼容性

AXON 当前版本为 0.2.0,各 crate 版本统一:

Crate 版本 最低 Rust 版本
axon-core 0.2.0 1.96.0
axon-rl 0.2.0 1.96.0
axon-llm 0.2.0 1.96.0
axon-inference 0.2.0 1.96.0
axon-exchange 0.2.0 1.96.0
axon-ensemble 0.2.0 1.96.0
axon-explain 0.2.0 1.96.0
axon-backtest 0.2.0 1.96.0
axon-hpo 0.2.0 1.96.0
axon-walk-forward 0.2.0 1.96.0
axon-tracker 0.2.0 1.96.0
axon-registry 0.2.0 1.96.0
axon-distributed 0.2.0 1.96.0
axon-monitor 0.2.0 1.96.0
axon-risk 0.2.0 1.96.0
axon-compliance 0.2.0 1.96.0