跳转至

7.3 项目三:数据分析 Agent(Text-to-SQL)

实验目标

本节结束后,你能够: - 独立构建一个完整的 NL → SQL → 执行 → 可视化端到端数据分析 Agent,处理真实业务问题 - 理解并实现Schema 感知的核心机制:如何在百表级数据库中动态检索相关表,而不是把所有表结构塞进 Prompt - 掌握 Text-to-SQL 最关键的工程问题:SQL 自我修正循环、只读沙箱防护、慢查询熔断

核心学习点(3 个): 1. 动态 Schema 检索:用 Embedding 向量化表描述,按需注入上下文,解决"100 张表撑爆上下文"的核心难题 2. 自我修正循环:将 SQL 执行错误结构化后反馈给 LLM,让 Agent 自主纠错而非直接崩溃 3. LLM 驱动可视化决策:让模型判断用柱状图还是折线图,摘要解读数据意义,而不只是吐出一张表格


架构总览

graph TD
    A["用户自然语言问题\n'过去三个月各品类销售额趋势'"] --> B[Schema 检索器]

    subgraph "Schema感知层"
        B --> C{向量相似度检索}
        C --> D[(表结构向量索引\nEmbedding 预建)]
        C --> E[Top-K 相关表的\nDDL + 注释 + 样例数据]
    end

    subgraph "SQL生成层"
        E --> F[SQL 生成器\nLLM + Few-shot]
        F --> G[候选 SQL]
    end

    subgraph "安全执行层"
        G --> H{只读沙箱\n+ 慢查询检测}
        H -->|"超时 / 写操作"| I[❌ 拒绝执行]
        H -->|执行失败| J[错误结构化]
        J --> K[自我修正循环\n最多 3 次]
        K --> F
        H -->|执行成功| L[结果 DataFrame]
    end

    subgraph "可视化层"
        L --> M[图表类型推断\nLLM 决策]
        M --> N{图表类型}
        N -->|时序数据| O[折线图 Plotly]
        N -->|分类对比| P[柱状图 Plotly]
        N -->|占比| Q[饼图 Plotly]
        N -->|明细表格| R[表格渲染]
        O & P & Q & R --> S[摘要文字解读\nLLM 生成]
        S --> T["📊 最终输出\nSQL + 图表 + 摘要"]
    end

    style A fill:#e8f4f8,stroke:#2196F3
    style T fill:#e8f5e9,stroke:#4CAF50
    style I fill:#ffebee,stroke:#f44336

环境准备

# 创建虚拟环境(uv)
uv venv --python 3.11 && source .venv/bin/activate

# 安装依赖(锁定版本)
uv pip install \
  openai==1.51.0 \
  anthropic==0.34.0 \
  litellm==1.49.0 \
  pandas==2.2.3 \
  plotly==5.24.0 \
  pydantic==2.9.2 \
  tiktoken==0.7.0 \
  numpy==2.1.1 \
  sqlparse==0.5.1 \
  kaleido==0.2.1 \
  tenacity==9.0.0 \
  python-dotenv==1.0.1

Colab 用户:!pip install openai anthropic litellm pandas plotly pydantic tiktoken sqlparse kaleido tenacity python-dotenv 即可,无需创建虚拟环境

创建 .env 文件:

OPENAI_API_KEY=sk-...
# 或使用 DashScope(Qwen 系列)
DASHSCOPE_API_KEY=sk-...
# 或使用 DeepSeek
DEEPSEEK_API_KEY=sk-...

模型切换:修改 core_config.py 中的 ACTIVE_MODEL_KEY 即可全局切换模型,无需改动业务代码。


Step-by-Step 实现

Step 0:统一模型配置(core_config.py)

目标:所有模块通过 core_config.py 统一管理模型注册、API Key、base_url 等配置,避免硬编码。修改 ACTIVE_MODEL_KEY 即可全局切换模型。

# core_config.py
"""全局配置:模型注册表与项目常量"""
import os
from typing import TypedDict


class ModelConfig(TypedDict):
    litellm_id: str          # 模型标识符(OpenAI SDK 用 model 参数)
    price_in: float          # 每 1K input tokens 价格(美元)
    price_out: float         # 每 1K output tokens 价格(美元)
    max_tokens_limit: int    # 模型支持的最大 max_tokens
    api_key_env: str | None  # API Key 环境变量名
    base_url: str | None     # API 基础 URL(None 表示使用默认)


# 注册表:key 是界面显示名,value 是调用配置
MODEL_REGISTRY: dict[str, ModelConfig] = {
    "GPT-4o-mini": {
        "litellm_id": "gpt-4o-mini",
        "price_in": 0.00015,
        "price_out": 0.0006,
        "max_tokens_limit": 16384,
        "api_key_env": "OPENAI_API_KEY",
        "base_url": None,
    },
    "Qwen-Max": {
        "litellm_id": "qwen-plus",
        "price_in": 0.001,
        "price_out": 0.004,
        "max_tokens_limit": 4096,
        "api_key_env": "DASHSCOPE_API_KEY",
        "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
    },
    "DeepSeek-V3": {
        "litellm_id": "deepseek/deepseek-chat",
        "price_in": 0.00027,
        "price_out": 0.0011,
        "max_tokens_limit": 4096,
        "api_key_env": "DEEPSEEK_API_KEY",
        "base_url": None,
    },
}

# ✅ 当前激活模型 key — 修改此处全局切换
ACTIVE_MODEL_KEY: str = "Qwen-Max"

# 数据库配置
DB_PATH = os.getenv("DB_PATH", "ecommerce.db")
LARGE_SCHEMA_THRESHOLD = 20
EMBED_MODEL = "text-embedding-3-small"
QUERY_TIMEOUT_SECONDS = 10
MAX_RETRIES = 3
CHARTS_DIR = "charts"


def get_active_config() -> ModelConfig:
    """获取当前激活模型的完整配置"""
    return MODEL_REGISTRY[ACTIVE_MODEL_KEY]


def get_litellm_id(model_key: str | None = None) -> str:
    """获取指定模型(默认激活模型)的模型 ID"""
    key = model_key or ACTIVE_MODEL_KEY
    return MODEL_REGISTRY[key]["litellm_id"]


def get_api_key(model_key: str | None = None) -> str | None:
    """从环境变量读取指定模型的 API Key"""
    key = model_key or ACTIVE_MODEL_KEY
    env_var = MODEL_REGISTRY[key]["api_key_env"]
    return os.environ.get(env_var) if env_var else None


def get_base_url(model_key: str | None = None) -> str | None:
    """获取指定模型的 base_url(None 表示使用 SDK 默认值)"""
    key = model_key or ACTIVE_MODEL_KEY
    return MODEL_REGISTRY[key]["base_url"]


def get_model_list() -> list[str]:
    """获取所有已注册模型的显示名列表"""
    return list(MODEL_REGISTRY.keys())


def estimate_cost(model_key: str, input_tokens: int, output_tokens: int) -> float:
    """根据 Token 数估算调用费用(美元)"""
    cfg = MODEL_REGISTRY[model_key]
    return (
        input_tokens / 1000 * cfg["price_in"]
        + output_tokens / 1000 * cfg["price_out"]
    )

关键点: - 所有业务模块(schema_manager.pysql_generator.pysql_executor.pyvisualizer.py)通过 from core_config import get_litellm_id, get_api_key, get_base_url 获取配置 - 新增 sql_generator.get_openai_client() 辅助函数,统一创建 OpenAI 客户端,避免各模块重复初始化 - ACTIVE_MODEL_KEY 默认 "Qwen-Max",切换为 "GPT-4o-mini""DeepSeek-V3" 即可全局生效


Step 1:构建演示数据库

目标:创建一个足够真实的电商数据库,包含多张关联表,用于验证 Agent 各项能力——包括跨表 JOIN、聚合统计、时间序列分析。

# db_setup.py
"""
演示用电商数据库初始化脚本。
生产环境直接连接现有 PostgreSQL/MySQL,跳过此步骤。
"""
import sqlite3
import random
from datetime import datetime, timedelta
from pathlib import Path


def create_demo_database(db_path: str = "ecommerce.db") -> None:
    """创建包含真实业务语义的演示数据库。"""
    conn = sqlite3.connect(db_path)
    cur = conn.cursor()

    # 建表 DDL
    cur.executescript("""
    CREATE TABLE IF NOT EXISTS categories (
        category_id   INTEGER PRIMARY KEY,
        name          TEXT NOT NULL,
        parent_id     INTEGER,
        created_at    TEXT DEFAULT (datetime('now'))
    );

    CREATE TABLE IF NOT EXISTS products (
        product_id    INTEGER PRIMARY KEY,
        name          TEXT NOT NULL,
        category_id   INTEGER REFERENCES categories(category_id),
        price         REAL NOT NULL,
        cost          REAL NOT NULL,
        stock         INTEGER DEFAULT 0,
        created_at    TEXT DEFAULT (datetime('now'))
    );

    CREATE TABLE IF NOT EXISTS users (
        user_id       INTEGER PRIMARY KEY,
        username      TEXT NOT NULL UNIQUE,
        email         TEXT NOT NULL UNIQUE,
        city          TEXT,
        register_date TEXT DEFAULT (datetime('now')),
        vip_level     INTEGER DEFAULT 0
    );

    CREATE TABLE IF NOT EXISTS orders (
        order_id      INTEGER PRIMARY KEY,
        user_id       INTEGER REFERENCES users(user_id),
        status        TEXT NOT NULL,
        total_amount  REAL NOT NULL,
        created_at    TEXT NOT NULL,
        paid_at       TEXT
    );

    CREATE TABLE IF NOT EXISTS order_items (
        item_id       INTEGER PRIMARY KEY,
        order_id      INTEGER REFERENCES orders(order_id),
        product_id    INTEGER REFERENCES products(product_id),
        quantity      INTEGER NOT NULL,
        unit_price    REAL NOT NULL,
        discount      REAL DEFAULT 0
    );

    CREATE TABLE IF NOT EXISTS user_events (
        event_id      INTEGER PRIMARY KEY,
        user_id       INTEGER REFERENCES users(user_id),
        product_id    INTEGER REFERENCES products(product_id),
        event_type    TEXT NOT NULL,
        event_time    TEXT NOT NULL
    );
    """)

    # 插入种子数据
    categories_data = [
        (1, "电子产品", None), (2, "服装", None), (3, "食品", None),
        (4, "手机", 1), (5, "笔记本", 1), (6, "耳机", 1),
        (7, "男装", 2), (8, "女装", 2),
    ]
    cur.executemany(
        "INSERT OR IGNORE INTO categories VALUES (?,?,?,datetime('now'))",
        categories_data,
    )

    random.seed(42)
    products_data = []
    for i in range(1, 51):
        cat_id = random.choice([4, 5, 6, 7, 8])
        price = round(random.uniform(50, 8000), 2)
        products_data.append((
            i, f"商品_{i:03d}", cat_id,
            price, round(price * 0.6, 2),
            random.randint(0, 500),
        ))
    cur.executemany(
        "INSERT OR IGNORE INTO products(product_id,name,category_id,price,cost,stock) VALUES (?,?,?,?,?,?)",
        products_data,
    )

    cities = ["北京", "上海", "广州", "深圳", "杭州", "成都", "武汉"]
    users_data = [(i, f"user_{i}", f"user_{i}@example.com",
                   random.choice(cities), random.randint(0, 3))
                  for i in range(1, 201)]
    cur.executemany(
        "INSERT OR IGNORE INTO users(user_id,username,email,city,vip_level) VALUES (?,?,?,?,?)",
        users_data,
    )

    base_date = datetime(2024, 1, 1)
    order_id = 1
    item_id = 1
    statuses = ["completed", "completed", "completed", "cancelled", "shipped"]
    for _ in range(800):
        user_id = random.randint(1, 200)
        created_at = base_date + timedelta(days=random.randint(0, 364),
                                           hours=random.randint(0, 23))
        status = random.choice(statuses)
        items_count = random.randint(1, 4)
        total = 0.0

        cur.execute(
            "INSERT OR IGNORE INTO orders VALUES (?,?,?,?,?,?)",
            (order_id, user_id, status, 0,
             created_at.isoformat(), created_at.isoformat() if status != "pending" else None),
        )

        for _ in range(items_count):
            prod = random.choice(products_data)
            qty = random.randint(1, 3)
            price = prod[3]
            disc = round(price * random.uniform(0, 0.15), 2)
            total += (price - disc) * qty
            cur.execute(
                "INSERT OR IGNORE INTO order_items VALUES (?,?,?,?,?,?)",
                (item_id, order_id, prod[0], qty, price, disc),
            )
            item_id += 1

        cur.execute("UPDATE orders SET total_amount=? WHERE order_id=?",
                    (round(total, 2), order_id))
        order_id += 1

    conn.commit()
    conn.close()
    print(f"✅ 演示数据库已创建:{db_path}{order_id-1} 笔订单)")


if __name__ == "__main__":
    create_demo_database()

运行:

python db_setup.py
# ✅ 演示数据库已创建:ecommerce.db(800 笔订单)


Step 2:Schema 感知——元数据提取与向量索引

目标:把数据库的"知识"结构化地提取出来,并用 Embedding 向量化,为大规模 Schema 的动态检索打基础。这是 Text-to-SQL 工程落地的最核心难题——100 张表的 DDL 直接塞进 Prompt 会耗费大量 Token 且降低注意力质量。

# schema_manager.py
"""
Schema 感知模块:
1. 从数据库动态提取 DDL + 注释 + 样例数据
2. 用 Embedding 向量化表描述,支持按问题语义动态检索相关表
"""
import sqlite3
import json
import hashlib
import re
import os
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import tiktoken
from openai import OpenAI
from dotenv import load_dotenv

from core_config import (
    get_api_key, get_base_url, get_litellm_id,
    EMBED_MODEL, LARGE_SCHEMA_THRESHOLD,
)

load_dotenv()


def _get_embedding_client() -> OpenAI:
    """创建用于 Embedding 的 OpenAI 客户端"""
    kwargs = {}
    key = get_api_key()
    url = get_base_url()
    if key:
        kwargs["api_key"] = key
    if url:
        kwargs["base_url"] = url
    return OpenAI(**kwargs)


@dataclass
class TableSchema:
    """单张表的完整 Schema 描述。"""
    table_name: str
    ddl: str                    # CREATE TABLE 语句
    columns: list[dict]         # [{name, type, nullable, comment}, ...]
    sample_rows: list[dict]     # 3-5 行样例数据
    row_count: int              # 表行数(量级参考)
    description: str = ""       # 自然语言描述(用于向量化)
    embedding: list[float] = field(default_factory=list)  # 向量(延迟计算)


class SchemaManager:
    """
    数据库 Schema 管理器。

    设计原则:
    - 小规模(<20 张表):直接把全量 Schema 注入 Prompt
    - 大规模(>=20 张表):向量检索 Top-K 相关表,动态注入
    """

    def __init__(self, db_path: str, embed_model: str = EMBED_MODEL):
        self.db_path = db_path
        self.embed_model = embed_model
        self.tables: dict[str, TableSchema] = {}
        self._load_schemas()

    def _get_connection(self) -> sqlite3.Connection:
        """获取只读数据库连接。"""
        conn = sqlite3.connect(f"file:{self.db_path}?mode=ro", uri=True)
        conn.row_factory = sqlite3.Row
        return conn

    def _load_schemas(self) -> None:
        """从数据库提取所有表的 Schema 信息。"""
        conn = self._get_connection()
        cur = conn.cursor()

        # 获取所有表名(排除 SQLite 内部表)
        tables = cur.execute(
            "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
        ).fetchall()

        for (table_name,) in tables:
            # 提取 DDL
            ddl_row = cur.execute(
                "SELECT sql FROM sqlite_master WHERE type='table' AND name=?",
                (table_name,)
            ).fetchone()
            ddl = ddl_row[0] if ddl_row else ""

            # 提取列信息(PRAGMA table_info 返回 cid/name/type/notnull/dflt_value/pk)
            col_rows = cur.execute(f"PRAGMA table_info({table_name})").fetchall()
            columns = [
                {
                    "name": row["name"],
                    "type": row["type"],
                    "nullable": not row["notnull"],
                    "is_pk": bool(row["pk"]),
                }
                for row in col_rows
            ]

            # 提取行数
            row_count = cur.execute(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]

            # 提取样例数据(最多 3 行)
            sample = cur.execute(f"SELECT * FROM {table_name} LIMIT 3").fetchall()
            col_names = [desc[0] for desc in cur.description]
            sample_rows = [dict(zip(col_names, row)) for row in sample]

            # 构建自然语言描述(用于向量化)
            col_descriptions = []
            for col in columns:
                # 从 DDL 匹配该列的注释
                pattern = rf"{col['name']}\s+\S+[^,\n]*--\s*(.+)"
                match = re.search(pattern, ddl, re.IGNORECASE)
                comment = match.group(1).strip() if match else ""
                col_descriptions.append(
                    f"{col['name']}{col['type']}{':' + comment if comment else ''}"
                )

            description = (
                f"表名:{table_name},"
                f"行数约 {row_count}。"
                f"字段:{', '.join(col_descriptions)}"
            )

            self.tables[table_name] = TableSchema(
                table_name=table_name,
                ddl=ddl,
                columns=columns,
                sample_rows=sample_rows,
                row_count=row_count,
                description=description,
            )

        conn.close()
        print(f"✅ 加载 {len(self.tables)} 张表的 Schema")

    def build_embeddings(self) -> None:
        """
        批量计算所有表描述的 Embedding。
        ⚠️  生产注意:建议缓存到本地文件(pickle/numpy),避免每次启动都调用 API。
        """
        cache_path = f"{self.db_path}.schema_embeddings.json"
        schema_hash = hashlib.md5(
            "".join(t.description for t in self.tables.values()).encode()
        ).hexdigest()

        if os.path.exists(cache_path):
            with open(cache_path) as f:
                cache = json.load(f)
            if cache.get("hash") == schema_hash:
                for table_name, emb in cache["embeddings"].items():
                    if table_name in self.tables:
                        self.tables[table_name].embedding = emb
                print(f"✅ 从缓存加载 Embedding({len(self.tables)} 张表)")
                return

        # 批量请求 Embedding
        descriptions = [t.description for t in self.tables.values()]
        table_names = list(self.tables.keys())

        client = _get_embedding_client()
        response = client.embeddings.create(
            model=self.embed_model,
            input=descriptions,
        )
        for i, item in enumerate(response.data):
            self.tables[table_names[i]].embedding = item.embedding

        # 写入缓存
        with open(cache_path, "w") as f:
            json.dump({
                "hash": schema_hash,
                "embeddings": {n: self.tables[n].embedding for n in table_names}
            }, f)
        print(f"✅ 构建并缓存 {len(self.tables)} 张表的 Embedding")

    def retrieve_relevant_tables(
        self, query: str, top_k: int = 5
    ) -> list[TableSchema]:
        """
        按查询语义检索最相关的 top_k 张表。

        小规模数据库(<阈值)直接返回全部表,跳过向量检索。
        """
        if len(self.tables) < LARGE_SCHEMA_THRESHOLD:
            return list(self.tables.values())

        # 确保 Embedding 已构建
        if not next(iter(self.tables.values())).embedding:
            self.build_embeddings()

        # 计算 query 的 Embedding
        client = _get_embedding_client()
        query_emb = np.array(
            client.embeddings.create(model=self.embed_model, input=[query])
            .data[0].embedding
        )

        # 余弦相似度排序
        scores = []
        for table in self.tables.values():
            table_emb = np.array(table.embedding)
            cos_sim = np.dot(query_emb, table_emb) / (
                np.linalg.norm(query_emb) * np.linalg.norm(table_emb) + 1e-9
            )
            scores.append((cos_sim, table))

        scores.sort(key=lambda x: x[0], reverse=True)
        selected = [t for _, t in scores[:top_k]]
        print(f"📋 检索到相关表:{[t.table_name for t in selected]}")
        return selected

    def format_schema_prompt(
        self, tables: list[TableSchema], include_samples: bool = True
    ) -> str:
        """
        将 Schema 信息格式化为 LLM 友好的 Prompt 片段。

        Token 预算控制:
        - DDL + 注释:每表约 200-400 Token
        - 样例数据(3行):每表约 100-200 Token
        - 5 张表合计约 1500-3000 Token,远低于 128K 上下文限制
        """
        enc = tiktoken.get_encoding("cl100k_base")
        parts = []

        for table in tables:
            part = f"### 表:{table.table_name}(约 {table.row_count} 行)\n"
            part += f"```sql\n{table.ddl}\n```\n"

            if include_samples and table.sample_rows:
                part += "\n**样例数据(前3行):**\n"
                for row in table.sample_rows[:3]:
                    row_str = ", ".join(
                        f"{k}={repr(v)}" for k, v in list(row.items())[:6]
                    )
                    part += f"  {row_str}\n"

            parts.append(part)

        full_prompt = "\n".join(parts)
        token_count = len(enc.encode(full_prompt))
        print(f"📊 Schema Prompt:{len(tables)} 张表,约 {token_count} Token")
        return full_prompt

关键点: - Schema Embedding 结果做本地文件缓存(按内容 Hash 失效),避免每次重建应用都白花钱调 API - format_schema_prompt 对样例数据只取前 3 行、前 6 个字段,防止宽表爆炸 Token 消耗 - 所有配置(LARGE_SCHEMA_THRESHOLDEMBED_MODEL)统一从 core_config 读取 - ⚠️ 生产环境中 PostgreSQL 可用 information_schema.columns 替代 PRAGMA table_info,字段注释从 pg_description 提取


Step 3:SQL 生成器——NL → SQL

目标:设计一个鲁棒的 SQL 生成器,通过 System Prompt 注入 Schema,通过 Few-shot 示例引导模型生成语法正确、语义准确的 SQL。同时用 Pydantic 强制结构化输出,避免模型把 SQL 夹杂在大段解释文字里返回。

# sql_generator.py
"""
NL → SQL 生成器。
使用 Structured Output(JSON Mode)确保输出可解析。
"""
import re
import json
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
from openai import OpenAI
from dotenv import load_dotenv

from core_config import get_litellm_id, get_api_key, get_base_url

load_dotenv()


def get_openai_client(
    api_key: Optional[str] = None,
    base_url: Optional[str] = None,
) -> OpenAI:
    """创建 OpenAI 客户端,使用 core_config 中的配置"""
    key = api_key or get_api_key()
    url = base_url or get_base_url()
    kwargs = {}
    if key:
        kwargs["api_key"] = key
    if url:
        kwargs["base_url"] = url
    return OpenAI(**kwargs)


class Dialect(str, Enum):
    SQLITE = "sqlite"
    POSTGRESQL = "postgresql"
    MYSQL = "mysql"


class SQLGenerationResult(BaseModel):
    """SQL 生成结果的结构化输出 Schema。"""
    sql: str = Field(description="生成的 SQL 查询语句,不含 markdown 代码块标记")
    explanation: str = Field(description="用一句话解释这条 SQL 的查询逻辑")
    confidence: float = Field(ge=0.0, le=1.0, description="生成置信度,0~1")
    ambiguities: list[str] = Field(
        default_factory=list,
        description="问题中存在的模糊点,如'最近'未指定具体时间范围"
    )


def _parse_json_response(text: str) -> dict:
    """从 LLM 返回的文本中提取 JSON,处理可能的格式问题"""
    # 尝试直接解析
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        pass
    # 尝试提取 JSON 块
    match = re.search(r'\{.*\}', text, re.DOTALL)
    if match:
        try:
            return json.loads(match.group())
        except json.JSONDecodeError:
            pass
    # 如果解析失败,返回默认值
    return {"sql": "", "explanation": "解析失败", "confidence": 0.0, "ambiguities": []}


SYSTEM_PROMPT_TEMPLATE = """你是一个专业的数据分析 SQL 专家。你的任务是将用户的自然语言问题转换为正确的 {dialect} SQL 查询。

## 数据库 Schema
{schema}

## 生成规则
1. 只生成 SELECT 查询,绝对不生成 INSERT/UPDATE/DELETE/DROP/ALTER 等写操作
2. 使用 {dialect} 方言语法(如 SQLite 日期函数用 strftime,PostgreSQL 用 DATE_TRUNC)
3. 金额字段保留 2 位小数:ROUND(amount, 2)
4. 时间范围若问题未指定,默认取最近 90 天
5. 结果集行数超过 100 行时,自动加 LIMIT 100
6. 对涉及多表查询,优先使用显式 JOIN 而非子查询(可读性更好)
7. 字段别名使用中文(方便最终用户读图)

## Few-shot 示例
用户:各城市的订单总金额是多少?
SQL:
SELECT u.city AS 城市, ROUND(SUM(o.total_amount), 2) AS 订单总金额
FROM orders o
JOIN users u ON o.user_id = u.user_id
WHERE o.status = 'completed'
GROUP BY u.city
ORDER BY 订单总金额 DESC

用户:上个月每天的新增用户数
SQL:
SELECT DATE(register_date) AS 日期, COUNT(*) AS 新增用户数
FROM users
WHERE register_date >= DATE('now', '-1 month', 'start of month')
  AND register_date < DATE('now', 'start of month')
GROUP BY DATE(register_date)
ORDER BY 日期
"""


class SQLGenerator:
    """自然语言到 SQL 的生成器。"""

    def __init__(
        self,
        model: str | None = None,
        dialect: Dialect = Dialect.SQLITE,
        api_key: Optional[str] = None,
        base_url: Optional[str] = None,
    ):
        self.model = model or get_litellm_id()
        self.dialect = dialect
        self.client = get_openai_client(api_key=api_key, base_url=base_url)

    def generate(
        self,
        question: str,
        schema_prompt: str,
        conversation_history: list[dict] | None = None,
    ) -> SQLGenerationResult:
        """
        生成 SQL 查询。

        Args:
            question: 用户自然语言问题
            schema_prompt: 由 SchemaManager 生成的 Schema 描述
            conversation_history: 多轮对话历史(可选),支持追问场景
        """
        system_prompt = SYSTEM_PROMPT_TEMPLATE.format(
            dialect=self.dialect.value,
            schema=schema_prompt,
        )

        messages = [{"role": "system", "content": system_prompt}]

        if conversation_history:
            messages.extend(conversation_history[-6:])

        messages.append({"role": "user", "content": question})

        # 优先尝试使用结构化输出
        try:
            response = self.client.beta.chat.completions.parse(
                model=self.model,
                messages=messages,
                response_format=SQLGenerationResult,
                temperature=0.1,
            )
            result = response.choices[0].message.parsed
            result.sql = self._clean_sql(result.sql)
            return result
        except Exception:
            # 结构化输出失败时,回退到普通 completion + JSON 手动解析
            print("⚠️  结构化输出解析失败,回退到普通模式...")
            response = self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=0.1,
            )
            content = response.choices[0].message.content or ""
            data = _parse_json_response(content)
            result = SQLGenerationResult(**data)
            result.sql = self._clean_sql(result.sql)
            return result

    @staticmethod
    def _clean_sql(sql: str) -> str:
        """去除 markdown 代码块包裹,提取纯 SQL。"""
        sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE)
        sql = sql.strip().rstrip(";")
        return sql

关键点: - temperature=0.1 是经验值:SQL 生成是确定性任务,高 temperature 会引入随机语法错误 - 新增回退机制:优先使用 client.beta.chat.completions.parse(Structured Output),失败后自动回退到普通 chat.completions.create + 手动 JSON 解析(_parse_json_response)。这提高了对不支持 Structured Output 的模型(如 Qwen、DeepSeek)的兼容性 - ⚠️ Few-shot 示例选择策略:选与目标数据库业务最相关的示例,而不是通用的教科书例子,可提升 10-20% 准确率


Step 4:安全执行层——只读沙箱 + 慢查询检测 + 自我修正循环

目标:这是 Text-to-SQL 生产化的核心安全关卡。LLM 偶尔会生成 DROP TABLE、UPDATE 等写操作(幻觉),也会生成没有 WHERE 条件的全表扫描(慢查询)。必须在执行层拦截,而不是依赖 Prompt 约束。

# sql_executor.py
"""
安全 SQL 执行器:
- 只读沙箱:静态分析 + 连接级 read_only 保护双重防护
- 慢查询检测:超时熔断(默认 10s)
- 自我修正:执行失败后将错误反馈给 LLM,最多重试 3 次
"""
import sqlite3
import time
import signal
import sqlparse
import pandas as pd
from dataclasses import dataclass
from typing import Optional

from core_config import QUERY_TIMEOUT_SECONDS, MAX_RETRIES


# 禁止执行的关键字(大写统一比较)
FORBIDDEN_KEYWORDS = frozenset({
    "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER",
    "TRUNCATE", "REPLACE", "MERGE", "GRANT", "REVOKE",
    "ATTACH", "DETACH",  # SQLite 特有危险操作
})


@dataclass
class ExecutionResult:
    """SQL 执行结果。"""
    success: bool
    data: Optional[pd.DataFrame]   # 成功时的结果集
    error: Optional[str]           # 失败时的错误信息
    execution_time_ms: float       # 执行耗时(毫秒)
    row_count: int = 0


class SQLSafetyError(Exception):
    """SQL 安全检查失败异常。"""


class QueryTimeoutError(Exception):
    """查询超时异常。"""


def _check_sql_safety(sql: str) -> None:
    """
    静态安全分析:检查 SQL 是否包含禁止操作。

    使用 sqlparse 做词法分析,比正则更可靠(能处理注释、大小写、嵌套等)。
    """
    parsed = sqlparse.parse(sql)
    for statement in parsed:
        for token in statement.flatten():
            if token.ttype in (sqlparse.tokens.Keyword, sqlparse.tokens.Keyword.DDL,
                               sqlparse.tokens.Keyword.DML):
                if token.normalized.upper() in FORBIDDEN_KEYWORDS:
                    raise SQLSafetyError(
                        f"拒绝执行:检测到禁止操作 '{token.normalized.upper()}'。"
                        f"本系统仅允许 SELECT 查询。"
                    )

        # 额外检查 statement 类型(双重保险)
        stmt_type = statement.get_type()
        if stmt_type and stmt_type.upper() not in ("SELECT", "UNKNOWN", None):
            raise SQLSafetyError(
                f"拒绝执行:语句类型为 '{stmt_type}',仅允许 SELECT。"
            )


class SQLExecutor:
    """安全的 SQL 执行器,内置只读沙箱和超时熔断。"""

    def __init__(self, db_path: str, timeout_seconds: int = QUERY_TIMEOUT_SECONDS):
        self.db_path = db_path
        self.timeout = timeout_seconds

    def execute(self, sql: str) -> ExecutionResult:
        """
        执行 SQL 并返回结构化结果。

        安全层次:
        1. 静态词法分析(sqlparse)
        2. SQLite URI 只读模式连接
        3. 超时信号熔断(Unix 系统)
        """
        # 层次1:静态安全检查
        try:
            _check_sql_safety(sql)
        except SQLSafetyError as e:
            return ExecutionResult(
                success=False, data=None,
                error=str(e), execution_time_ms=0
            )

        start = time.monotonic()

        # 层次2:只读连接(SQLite URI 模式)
        conn = sqlite3.connect(f"file:{self.db_path}?mode=ro", uri=True)
        conn.row_factory = sqlite3.Row
        conn.execute(f"PRAGMA busy_timeout = {self.timeout * 1000}")

        try:
            # 层次3:信号超时熔断(仅 Unix/Mac,Windows 不支持 SIGALRM)
            def _timeout_handler(signum, frame):
                raise QueryTimeoutError(f"查询超时(>{self.timeout}s),疑似全表扫描或缺少索引。")

            try:
                signal.signal(signal.SIGALRM, _timeout_handler)
                signal.alarm(self.timeout)
            except (AttributeError, OSError):
                pass

            df = pd.read_sql_query(sql + ";", conn)

            try:
                signal.alarm(0)
            except (AttributeError, OSError):
                pass

            elapsed_ms = (time.monotonic() - start) * 1000

            if elapsed_ms > 3000:
                print(f"⚠️  慢查询警告:执行耗时 {elapsed_ms:.0f}ms,考虑添加索引")

            return ExecutionResult(
                success=True,
                data=df,
                error=None,
                execution_time_ms=elapsed_ms,
                row_count=len(df),
            )

        except QueryTimeoutError as e:
            return ExecutionResult(success=False, data=None,
                                   error=str(e), execution_time_ms=self.timeout * 1000)
        except Exception as e:
            return ExecutionResult(
                success=False, data=None,
                error=f"{type(e).__name__}: {e}",
                execution_time_ms=(time.monotonic() - start) * 1000,
            )
        finally:
            conn.close()


class SelfCorrectingExecutor:
    """
    自我修正执行器:将执行错误反馈给 LLM,驱动 SQL 重写。

    修正循环上限:3 次(防止无限循环消耗 Token 和时间)
    """

    def __init__(
        self,
        executor: SQLExecutor,
        sql_generator: "SQLGenerator",
        schema_prompt: str,
    ):
        self.executor = executor
        self.generator = sql_generator
        self.schema_prompt = schema_prompt

    def execute_with_correction(
        self, original_question: str, initial_sql: str
    ) -> tuple[ExecutionResult, str, int]:
        """
        执行 SQL,失败时自动修正。

        Returns:
            (ExecutionResult, 最终使用的SQL, 修正次数)
        """
        current_sql = initial_sql
        correction_history: list[dict] = []

        for attempt in range(MAX_RETRIES + 1):
            result = self.executor.execute(current_sql)

            if result.success:
                if attempt > 0:
                    print(f"✅ 第 {attempt} 次修正后成功")
                return result, current_sql, attempt

            if attempt == MAX_RETRIES:
                print(f"❌ 达到最大修正次数({MAX_RETRIES}),放弃")
                return result, current_sql, attempt

            print(f"🔄 第 {attempt + 1} 次修正,错误:{result.error[:100]}...")

            # 构造修正请求
            correction_prompt = f"""上一条 SQL 执行失败,请修正。

**原始问题:** {original_question}

**失败的 SQL:**
```sql
{current_sql}

错误信息: {result.error}

修正要求: - 分析错误原因后重新生成正确的 SQL - 确保语法符合 SQLite 方言 - 不要解释,直接返回修正后的 SQL """ correction_history.extend([ {"role": "user", "content": f"修正 SQL:{correction_prompt}"}, ])

        new_result = self.generator.generate(
            question=correction_prompt,
            schema_prompt=self.schema_prompt,
            conversation_history=correction_history,
        )
        current_sql = new_result.sql
        correction_history.append({
            "role": "assistant",
            "content": f"修正后的 SQL:{current_sql}"
        })

    return result, current_sql, MAX_RETRIES

`` **关键点**: -file:{path}?mode=ro是 SQLite 的**连接级只读保护**,不依赖 Prompt 约束,即使 LLM 生成了DROP TABLE也会在系统层面被拒绝 - 自我修正循环上限设为 3 次(通过core_config.MAX_RETRIES配置):实测 90% 的错误在 1-2 次修正内解决,超过 3 次通常是问题本身模糊,继续重试无意义 - ⚠️ Windows 不支持SIGALRM,需改用threading.Timer+conn.interrupt()` 实现超时


Step 5:结果可视化——LLM 推断图表类型 + Plotly 渲染

目标:让 LLM 分析查询结果的数据特征(是否有时间列?是分类对比还是占比?),自动选择最合适的图表类型,而不是硬编码规则。再用 Plotly 渲染并生成摘要解读。 ```python

visualizer.py

""" 数据可视化模块: - LLM 分析数据结构,推断最佳图表类型 - Plotly 渲染图表(支持交互式 HTML) - LLM 生成中文数据摘要 """ import json from typing import Optional import pandas as pd import plotly.express as px import plotly.graph_objects as go from enum import Enum from pydantic import BaseModel, Field from openai import OpenAI from pathlib import Path from dotenv import load_dotenv from core_config import get_litellm_id, get_api_key, get_base_url from sql_generator import get_openai_client load_dotenv() class ChartType(str, Enum): BAR = "bar" # 分类对比:各城市销售额 LINE = "line" # 时序趋势:每日订单量 PIE = "pie" # 占比:各品类销售额占比 SCATTER = "scatter" # 相关性:用户消费金额 vs 购买频次 TABLE = "table" # 纯明细:无法图形化的数据 class ChartDecision(BaseModel): """LLM 的图表类型决策。""" chart_type: ChartType x_column: str = Field(description="X 轴或标签列的列名") y_column: str = Field(description="Y 轴或数值列的列名") color_column: str | None = Field(None, description="分组/颜色列") reasoning: str = Field(description="选择此图表类型的理由") title: str = Field(description="图表标题") def infer_chart_type( question: str, df: pd.DataFrame, client: Optional[OpenAI] = None, model: Optional[str] = None, ) -> ChartDecision: """ 让 LLM 分析 DataFrame 结构,推断最佳图表类型。 """ col_info = { col: { "dtype": str(df[col].dtype), "sample": df[col].head(3).tolist(), "unique_count": int(df[col].nunique()), } for col in df.columns } prompt = f"""你是数据可视化专家,分析以下数据并选择最合适的图表类型。 用户原始问题: {question} 数据结构(共 {len(df)} 行,{len(df.columns)} 列): {json.dumps(col_info, ensure_ascii=False, indent=2)} 图表选择规则: - 如果有时间/日期列(dtype 为 object 且值格式如 2024-01、周一) → line - 如果是分类列(unique_count <= 20)对应数值列的对比 → bar - 如果有"占比"、"比例"、"份额"等语义,且分类 <= 10 → pie - 如果两个数值列之间有相关性分析意图 → scatter - 其他情况 → table 重要:x_column 和 y_column 必须是数据中实际存在的列名,区分大小写。 """ llm = client or get_openai_client() model_name = model or get_litellm_id() response = llm.beta.chat.completions.parse( model=model_name, messages=[{"role": "user", "content": prompt}], response_format=ChartDecision, temperature=0, ) return response.choices[0].message.parsed def _generate_summary( question: str, sql: str, df: pd.DataFrame, chart_type: ChartType, client: Optional[OpenAI] = None, model: Optional[str] = None, ) -> str: """让 LLM 用 2-3 句话解读数据核心洞察。""" summary_stats = df.describe(include="all").to_string() if len(df) <= 20 else ( f"共 {len(df)} 行数据。" f"主要数值列统计:{df.select_dtypes(include='number').describe().to_string()}" ) llm = client or get_openai_client() model_name = model or get_litellm_id() response = llm.chat.completions.create( model=model_name, messages=[{ "role": "user", "content": ( f"用户问题:{question}\n\n" f"查询 SQL:{sql}\n\n" f"数据统计:\n{summary_stats}\n\n" "请用 2-3 句话提炼核心业务洞察,语言面向业务人员," "直接说明最重要的发现,不要重复问题,不要提 SQL。" ) }], temperature=0.3, max_tokens=200, ) return response.choices[0].message.content.strip() class DataVisualizer: """数据可视化器:自动选型 + Plotly 渲染 + 摘要生成。""" def init( self, output_dir: str = "charts", client: Optional[OpenAI] = None, model: Optional[str] = None, ): self.output_dir = Path(output_dir) self.output_dir.mkdir(exist_ok=True) self.client = client or get_openai_client() self.model = model or get_litellm_id() def visualize( self, question: str, sql: str, df: pd.DataFrame, save_html: bool = True ) -> dict: """ 完整可视化流程:推断类型 → 渲染图表 → 生成摘要。 Returns: { "chart_type": str, "title": str, "html_path": str | None, "figure": go.Figure | None, "summary": str, } """ decision = _infer_chart_type(question, df, client=self.client, model=self.model) print(f"📊 图表决策:{decision.chart_type.value} — {decision.reasoning}") fig = self._render_chart(df, decision) html_path = None if save_html and fig is not None: safe_title = "".join( c for c in decision.title if c.isalnum() or c in "- " )[:50] html_path = str(self.output_dir / f"{safe_title}.html") fig.write_html(html_path) print(f"💾 图表已保存:{html_path}") summary = _generate_summary( question, sql, df, decision.chart_type, client=self.client, model=self.model, ) print(f"\n📝 数据摘要:\n{summary}") return { "chart_type": decision.chart_type.value, "title": decision.title, "html_path": html_path, "figure": fig, "summary": summary, "decision": decision, } def _render_chart(self, df: pd.DataFrame, decision: ChartDecision) -> go.Figure | None: """根据决策渲染 Plotly 图表。""" required_cols = [decision.x_column, decision.y_column] for col in required_cols: if col not in df.columns: matched = [c for c in df.columns if c.lower() == col.lower()] if matched: if col == decision.x_column: decision.x_column = matched[0] else: decision.y_column = matched[0] else: print(f"⚠️ 列名 '{col}' 不存在") decision.chart_type = ChartType.TABLE common_kwargs = dict( title=decision.title, template="plotly_white", ) if decision.chart_type == ChartType.BAR: fig = px.bar( df, x=decision.x_column, y=decision.y_column, color=decision.color_column, text_auto=".2s", common_kwargs, ) fig.update_layout(xaxis_tickangle=-30) elif decision.chart_type == ChartType.LINE: fig = px.line( df, x=decision.x_column, y=decision.y_column, color=decision.color_column, markers=True, common_kwargs, ) elif decision.chart_type == ChartType.PIE: fig = px.pie( df, names=decision.x_column, values=decision.y_column, hole=0.35, common_kwargs, ) elif decision.chart_type == ChartType.SCATTER: fig = px.scatter( df, x=decision.x_column, y=decision.y_column, color=decision.color_column, trendline="ols", common_kwargs, ) elif decision.chart_type == ChartType.TABLE: fig = go.Figure(data=[go.Table( header=dict(values=list(df.columns), fill_color='paleturquoise'), cells=dict(values=[df[col] for col in df.columns], fill_color='lavender') )]) fig.update_layout(title=decision.title) else: fig = None return fig ```

关键点: - 列名验证必不可少:LLM 有时会幻觉出不存在的列名,加模糊匹配兜底 - _infer_chart_type_generate_summary 均支持 client/model 参数注入,方便测试时传入 mock 客户端 - Table 图表使用简化的 go.Table 渲染(paleturquoise/lavender 配色),不再使用复杂的交替行底色 - ⚠️ Plotly write_html 生成的文件含完整 JS 包(约 3MB);生产环境可改用 include_plotlyjs='cdn' 减小文件体积


Step 6:组装运行——main.py

目标:将所有模块串联,通过 main.py 统一入口运行。脚本依次执行多个预设问题,展示完整的 NL → SQL → 执行 → 可视化流程。

# main.py
"""数据分析 Agent 主入口。"""
import os
from dotenv import load_dotenv

from core_config import get_litellm_id, get_api_key, get_base_url, DB_PATH
from schema_manager import SchemaManager
from sql_generator import SQLGenerator, Dialect, get_openai_client
from sql_executor import SQLExecutor, SelfCorrectingExecutor
from visualizer import DataVisualizer

load_dotenv()


def main():
    print("=== 数据分析 Agent(Text-to-SQL)===")

    db_path = os.getenv("DB_PATH", DB_PATH)

    schema_manager = SchemaManager(db_path)
    sql_generator = SQLGenerator(
        model=get_litellm_id(), dialect=Dialect.SQLITE,
        api_key=get_api_key(), base_url=get_base_url(),
    )
    executor = SQLExecutor(db_path)
    visualizer = DataVisualizer()

    questions = [
        "各城市的订单总金额是多少?",
        "过去三个月各品类销售额趋势",
        "VIP 用户的平均订单金额是多少?",
    ]

    for question in questions:
        print(f"\n---\n问题:{question}")

        tables = schema_manager.retrieve_relevant_tables(question)
        schema_prompt = schema_manager.format_schema_prompt(tables)

        print("🔧 生成 SQL...")
        sql_result = sql_generator.generate(question, schema_prompt)
        print(f"SQL:{sql_result.sql}")

        print("🔍 执行查询...")
        correcting_executor = SelfCorrectingExecutor(
            executor, sql_generator, schema_prompt
        )
        exec_result, final_sql, retries = correcting_executor.execute_with_correction(
            question, sql_result.sql
        )

        if exec_result.success:
            print(f"✅ 执行成功(重试 {retries} 次),返回 {exec_result.row_count} 行")
            print("📊 数据:")
            print(exec_result.data)

            print("\n📈 可视化...")
            viz_result = visualizer.visualize(question, final_sql, exec_result.data)
            print(f"摘要:{viz_result['summary']}")
        else:
            print(f"❌ 执行失败:{exec_result.error}")


if __name__ == "__main__":
    main()

运行:

python main.py


冒烟测试验证

# smoke_test.py
"""数据分析 Agent - 冒烟测试。使用 Mock LLM 避免真实 API 调用。"""
from dotenv import load_dotenv
import os
from unittest.mock import patch, MagicMock

from core_config import get_litellm_id, get_api_key, get_base_url
from db_setup import create_demo_database
from schema_manager import SchemaManager
from sql_generator import SQLGenerator, Dialect, SQLGenerationResult
from sql_executor import SQLExecutor, SelfCorrectingExecutor
from visualizer import DataVisualizer
from visualizer import ChartDecision, ChartType

load_dotenv()


def _make_mock_client():
    """创建 mock OpenAI client,返回结构化输出的 fake response"""
    mock = MagicMock()
    mock.beta.chat.completions.parse.return_value = MagicMock(
        choices=[MagicMock(
            message=MagicMock(
                parsed=SQLGenerationResult(
                    sql="SELECT u.city AS 城市, ROUND(SUM(o.total_amount), 2) AS 订单总金额 "
                        "FROM orders o JOIN users u ON o.user_id = u.user_id "
                        "WHERE o.status = 'completed' GROUP BY u.city ORDER BY 订单总金额 DESC",
                    explanation="按城市分组汇总已完成订单的总金额",
                    confidence=0.95,
                    ambiguities=[],
                )
            )
        )]
    )
    return mock


def run_smoke_test():
    print("=== 数据分析 Agent - 冒烟测试 ===")

    print("\n1. 创建演示数据库...")
    create_demo_database("test_ecommerce.db")

    print("\n2. Schema 管理器测试...")
    schema_manager = SchemaManager("test_ecommerce.db")
    print(f"   ✓ 加载 {len(schema_manager.tables)} 张表")

    print("\n3. SQL 生成测试(Mock LLM)...")
    mock_client = _make_mock_client()

    with patch("sql_generator.get_openai_client", return_value=mock_client):
        sql_generator = SQLGenerator(
            model=get_litellm_id(), dialect=Dialect.SQLITE,
            api_key=get_api_key(), base_url=get_base_url(),
        )
        tables = schema_manager.retrieve_relevant_tables("各城市订单金额")
        schema_prompt = schema_manager.format_schema_prompt(tables)
        result = sql_generator.generate("各城市的订单总金额是多少?", schema_prompt)
        print(f"   ✓ SQL 生成成功:{result.sql[:60]}...")

        print("\n4. SQL 执行测试...")
        executor = SQLExecutor("test_ecommerce.db")
        correcting_executor = SelfCorrectingExecutor(executor, sql_generator, schema_prompt)
        exec_result, final_sql, retries = correcting_executor.execute_with_correction(
            "各城市订单金额", result.sql
        )
        print(f"   ✓ 执行成功(重试 {retries} 次)")
        print(f"   返回 {exec_result.row_count} 行数据")

        print("\n5. 可视化测试(Mock LLM)...")
        mock_viz_client = MagicMock()
        mock_viz_client.beta.chat.completions.parse.return_value = MagicMock(
            choices=[MagicMock(
                message=MagicMock(
                    parsed=ChartDecision(
                        chart_type=ChartType.BAR,
                        x_column="城市",
                        y_column="订单总金额",
                        color_column=None,
                        reasoning="分类对比数据,适合柱状图",
                        title="各城市订单总金额对比",
                    )
                )
            )]
        )
        mock_viz_client.chat.completions.create.return_value = MagicMock(
            choices=[MagicMock(message=MagicMock(content="各城市中,北京订单总金额最高。"))]
        )

        visualizer = DataVisualizer(
            output_dir="test_charts",
            client=mock_viz_client,
            model="gpt-4o-mini",
        )
        viz_result = visualizer.visualize("各城市订单金额", final_sql, exec_result.data)
        print(f"   ✓ 图表类型:{viz_result['chart_type']}")
        print(f"   ✓ 摘要:{viz_result['summary']}")

    print("\n=== 测试完成 ✓ ===")

    os.remove("test_ecommerce.db")


if __name__ == "__main__":
    run_smoke_test()

运行:

python smoke_test.py
# === 数据分析 Agent - 冒烟测试 ===
# 1. 创建演示数据库...
# ✅ 演示数据库已创建:test_ecommerce.db(800 笔订单)
# ...
# === 测试完成 ✓ ===


项目文件结构

7.3 项目三:数据分析 Agent(Text-to-SQL)/
├── core_config.py          # 全局配置:模型注册表 + 项目常量
├── db_setup.py             # 演示电商数据库初始化
├── schema_manager.py       # Schema 提取 + Embedding 向量索引 + 动态检索
├── sql_generator.py        # NL → SQL 生成器(Structured Output + 回退机制)
├── sql_executor.py         # 安全执行器(只读沙箱 + 超时熔断 + 自我修正)
├── visualizer.py           # 自动图表选型 + Plotly 渲染 + 摘要生成
├── main.py                 # 主入口:串联所有模块
├── smoke_test.py           # 冒烟测试(Mock LLM)
├── requirements.txt        # 依赖清单
├── tests/
│   ├── __init__.py
│   └── test_main.py        # pytest 单元测试
└── _backup/                # 整理前的原始文件备份

常见报错与解决方案

报错信息 原因 解决方案
sqlite3.OperationalError: unable to open database file 数据库路径不存在或权限不足 先运行 python db_setup.py 创建数据库;检查路径是否为绝对路径
openai.BadRequestError: Invalid schema for response_format Pydantic 模型中包含不受支持的字段类型 检查 list[str] \| None 是否改为 Optional[list[str]];OpenAI Structured Output 对 Union 类型支持有限
sqlparse 解析后 stmt_typeNone 复杂嵌套 SQL(如 WITH CTE)被 sqlparse 识别为 UNKNOWN CTE 开头是 WITH 不是 SELECT,需在安全检查白名单中加 WITH 语句,并检查内层是否有写操作
signal.SIGALRMAttributeError Windows 系统不支持 SIGALRM 改用 concurrent.futures.ThreadPoolExecutor + future.result(timeout=N) 实现跨平台超时
图表列名报 KeyError LLM 在 ChartDecision 中幻觉了不存在的列名 已在 _render_chart 中加模糊匹配兜底;若仍失败检查 Pydantic 字段描述是否强调"必须使用原始列名"
ModuleNotFoundError: kaleido Plotly 导出静态图片需要 kaleido uv pip install kaleido==0.2.1;Colab 另需 !apt-get install -y chromium
自我修正 3 次后仍失败 问题本身超出数据库能力范围(如要求计算未存储的指标) 检查 result.ambiguities 字段,将模糊点反馈给用户,请求澄清

扩展练习(可选)

  1. 🟡 中等:接入 PostgreSQL 真实数据库db_path 替换为 PostgreSQL 连接串(postgresql://user:pass@host/db),用 sqlalchemy + pandas.read_sql 替代 sqlite3。注意:PostgreSQL 的 Schema 元数据从 information_schema.columnspg_description 提取,字段注释更丰富,能显著提升 SQL 生成准确率。参考问题:如何实现"只读用户"的数据库层隔离(而不是应用层 if 判断)?

  2. 🔴 困难:实现大规模 Schema 的分层检索 当数据库表达 200+ 张时,单层 Embedding 检索精度会下降(相关表被埋没)。尝试实现两级检索:先用表的业务领域(如"财务"、"供应链"、"用户")做粗检索,再在领域内按字段语义做精检索。提示:可以让 LLM 自动为每张表生成业务领域标签,存入 SQLite 元数据表,实现无需人工标注的自动分层。

  3. 🔴 困难:构建评估体系 Text-to-SQL 的核心挑战是如何衡量"SQL 是否正确"。尝试构建一个半自动评估流水线:①手工标注 50 个问题 + 对应正确 SQL;②执行 Agent 生成的 SQL 和标准 SQL,比较结果集(Execution Accuracy);③对于结果不同的案例,让 LLM 判断是"SQL 语义等价但写法不同"还是"真实错误"。目标:在演示数据库上达到 80%+ 的 Execution Accuracy。