Python异步任务队列神器:arq从入门到实战完全指南

一、arq库基础认知

1.1 库核心用途

arq是一款基于Python asyncio与Redis开发的异步任务队列库,专注处理异步、延迟、后台耗时任务,完美适配异步Web框架,是异步架构中任务调度、异步解耦的核心工具。

1.2 工作原理

arq以Redis为中间件存储任务与状态,生产者通过异步接口投递任务,Worker异步监听Redis队列,自动拉取并执行任务,全程基于asyncio实现非阻塞运行,任务执行、重试、结果存储都依托Redis高效完成。

1.3 优缺点

优点:纯异步非阻塞、轻量无冗余依赖、API简洁易上手、兼容FastAPI/Starlette等异步框架、支持任务重试与延迟执行、性能优异。
缺点:仅支持Redis作为后端、功能比Celery精简、不支持多消息队列、复杂任务调度能力较弱。

1.4 License类型

MIT License,开源免费可商用。

二、arq环境安装与基础配置

2.1 安装arq与依赖

arq核心依赖Redis,安装命令:

pip install arq redis

安装完成后可通过命令验证版本:

arq --version

确保本地或远程Redis服务正常运行,arq默认连接本地127.0.0.1:6379无密码Redis实例。

2.2 基础连接配置

arq支持自定义Redis连接参数,包括主机、端口、密码、数据库索引,创建基础配置文件:

# redis_config.py
import asyncio
from arq import create_pool
from arq.connections import RedisSettings

# 自定义Redis连接配置
redis_settings = RedisSettings(
    host="127.0.0.1",
    port=6379,
    password="",  # 有密码则填写
    database=0,
    timeout=5
)

# 测试Redis连接
async def test_redis_connection():
    redis = await create_pool(redis_settings)
    print("Redis连接成功")
    await redis.close()

if __name__ == "__main__":
    asyncio.run(test_redis_connection())

代码说明:通过RedisSettings定义连接参数,create_pool创建异步连接池,实现arq与Redis的基础通信。

三、arq核心使用方式与代码实例

3.1 定义基础异步任务

arq任务必须是异步函数,这是其核心特性,创建任务文件:

# tasks.py
import asyncio
from arq import ArqRedis

# 基础异步任务
async def simple_task(ctx: ArqRedis, content: str) -> str:
    """
    简单异步任务
    :param ctx: 任务上下文,包含Redis连接等信息
    :param content: 任务传入参数
    :return: 任务执行结果
    """
    print(f"开始执行简单任务:{content}")
    # 模拟异步耗时操作
    await asyncio.sleep(2)
    result = f"任务执行完成,内容:{content}"
    print(result)
    return result

# 任务注册:Worker启动时加载的任务列表
async def startup(ctx):
    """Worker启动时执行的钩子函数"""
    print("arq Worker启动成功")

async def shutdown(ctx):
    """Worker关闭时执行的钩子函数"""
    print("arq Worker关闭成功")

# Worker配置类
class WorkerSettings:
    # 注册可执行的任务函数
    functions = [simple_task]
    # 启动与关闭钩子
    on_startup = startup
    on_shutdown = shutdown
    # 绑定Redis配置
    redis_settings = RedisSettings(host="127.0.0.1", port=6379)

代码说明:定义异步任务函数,通过WorkerSettings注册任务,配置启动关闭钩子,Worker会自动加载注册的任务。

3.2 启动arq Worker

Worker是任务消费者,负责监听队列并执行任务,命令行启动:

arq tasks.WorkerSettings

启动成功后会输出:arq Worker启动成功,持续监听Redis任务队列。

3.3 投递任务到arq队列

创建生产者脚本,异步投递任务到arq:

# producer.py
import asyncio
from arq import create_pool
from tasks import redis_settings

# 投递任务
async def enqueue_task():
    # 创建Redis连接池
    redis = await create_pool(redis_settings)
    # 投递任务:任务函数名 + 参数
    job = await redis.enqueue_job("simple_task", "Hello arq异步任务队列")
    # 获取任务ID
    job_id = job.job_id
    print(f"任务投递成功,任务ID:{job_id}")

    # 等待任务执行完成并获取结果
    job_result = await job.result(timeout=10)
    print(f"任务执行结果:{job_result}")

    await redis.close()

if __name__ == "__main__":
    asyncio.run(enqueue_task())

代码说明:通过enqueue_job投递任务,传入任务函数名与参数,可通过job.result()同步等待任务结果,适用于需要获取返回值的场景。

运行producer.py,Worker端会打印任务执行信息,生产者输出:

任务投递成功,任务ID:xxxxxxxx
任务执行结果:任务执行完成,内容:Hello arq异步任务队列

3.4 延迟任务与定时任务

arq支持延迟执行任务,指定延迟秒数:

# 延迟任务:5秒后执行
async def delay_task(ctx: ArqRedis, msg: str):
    await asyncio.sleep(1)
    print(f"延迟任务执行:{msg}")
    return f"延迟任务结果:{msg}"

# 在tasks.py的WorkerSettings中添加任务
functions = [simple_task, delay_task]

投递延迟任务:

# producer.py中添加
async def enqueue_delay_task():
    redis = await create_pool(redis_settings)
    # 延迟5秒执行
    job = await redis.enqueue_job("delay_task", "5秒后执行的延迟任务", _defer_by=5)
    print(f"延迟任务投递成功,任务ID:{job.job_id}")
    await redis.close()

代码说明:_defer_by参数指定延迟秒数,arq会自动计算执行时间,到期后执行任务。

3.5 任务重试机制

arq内置任务重试功能,应对任务执行失败场景,配置重试次数与延迟:

# 可重试任务
async def retry_task(ctx: ArqRedis):
    print("执行可重试任务")
    # 模拟任务执行失败
    raise Exception("任务执行异常,触发重试")

# 自定义任务配置
retry_task.arq_kwargs = {
    "max_tries": 3,        # 最大重试次数
    "retry_delay": 2       # 重试间隔2秒
}

WorkerSettings中注册retry_task,投递任务后,Worker会自动重试3次,适合接口调用、数据同步等不稳定任务。

3.6 获取任务状态与结果

arq支持查询任务状态、结果、异常信息:

async def get_task_status(job_id: str):
    redis = await create_pool(redis_settings)
    # 获取任务对象
    job = await redis.job(job_id)
    if not job:
        print("任务不存在")
        return

    # 获取任务状态
    status = await job.status()
    print(f"任务状态:{status}")

    # 获取任务结果(未完成会抛出异常)
    try:
        result = await job.result()
        print(f"任务结果:{result}")
    except Exception as e:
        print(f"任务未完成或异常:{e}")

    await redis.close()

任务状态包括:pending(等待中)、running(执行中)、complete(完成)、failed(失败)。

四、arq与FastAPI集成实战案例

4.1 集成背景

FastAPI是主流异步Web框架,与arq天然兼容,可实现Web接口异步处理耗时任务,如发送邮件、生成报表、数据爬取等。

4.2 集成代码实现

# fastapi_arq.py
from fastapi import FastAPI
from arq import create_pool
from arq.connections import RedisSettings
import asyncio

app = FastAPI(title="arq+FastAPI异步任务实战")

# Redis配置
redis_settings = RedisSettings(host="127.0.0.1", port=6379)

# 定义异步任务
async def send_email_task(ctx, email: str, content: str):
    """模拟异步发送邮件任务"""
    await asyncio.sleep(3)
    print(f"向{email}发送邮件:{content}")
    return f"邮件发送成功:{email}"

# Worker配置
class WorkerSettings:
    functions = [send_email_task]
    redis_settings = redis_settings

# 应用启动时创建arq连接池
@app.on_event("startup")
async def startup_event():
    app.state.arq_redis = await create_pool(redis_settings)

# 应用关闭时关闭连接
@app.on_event("shutdown")
async def shutdown_event():
    await app.state.arq_redis.close()

# 接口:投递发送邮件任务
@app.post("/send-email")
async def send_email(email: str, content: str):
    job = await app.state.arq_redis.enqueue_job("send_email_task", email, content)
    return {
        "code": 200,
        "msg": "邮件任务投递成功",
        "job_id": job.job_id
    }

# 接口:查询任务状态
@app.get("/task-status/{job_id}")
async def get_task_status(job_id: str):
    job = await app.state.arq_redis.job(job_id)
    if not job:
        return {"code": 404, "msg": "任务不存在"}
    status = await job.status()
    result = None
    if status == "complete":
        result = await job.result()
    return {
        "code": 200,
        "job_id": job_id,
        "status": status,
        "result": result
    }

4.3 启动与访问

  1. 启动arq Worker:
arq fastapi_arq.WorkerSettings
  1. 启动FastAPI服务:
uvicorn fastapi_arq:app --reload
  1. 访问接口:
  • 投递任务:POST http://127.0.0.1:8000/[email protected]&content=测试邮件
  • 查询任务:GET http://127.0.0.1:8000/task-status/任务ID

该案例实现了Web接口与异步任务解耦,用户请求无需等待耗时任务完成,提升接口响应速度与系统并发能力。

五、arq高级特性与实际场景优化

5.1 任务结果过期设置

arq默认永久存储任务结果,可配置过期时间释放Redis空间:

# 任务结果1小时后过期
job = await redis.enqueue_job("simple_task", "测试过期", _expires=3600)

5.2 多任务队列隔离

arq支持自定义队列名称,实现不同业务任务隔离:

# 投递到订单队列
job = await redis.enqueue_job("order_task", "订单数据", _queue="order_queue")

# Worker监听指定队列
class WorkerSettings:
    queue_name = "order_queue"
    functions = [order_task]

5.3 任务并发控制

调整Worker并发数,适配不同服务器性能:

class WorkerSettings:
    functions = [simple_task, delay_task]
    redis_settings = redis_settings
    max_jobs = 10  # 最大并发执行10个任务

相关资源

  • Pypi地址:https://pypi.org/project/arq
  • Github地址:https://github.com/samuelcolvin/arq
  • 官方文档地址:https://arq-docs.helpmanual.io/

关注我,每天分享一个实用的Python自动化工具。

Python数据质量监控神器:whylogs从入门到实战全教程

一、whylogs库基础认知

1.1 库核心用途

whylogs是WhyLabs团队开源的轻量级数据日志与数据质量监控Python库,核心用于自动化生成数据剖面日志、持续追踪数据特征、实时检测数据漂移与异常值,广泛适配数据分析、机器学习流水线、数据工程等场景,无需复杂配置即可嵌入现有Python项目完成数据质量管控。

1.2 工作原理

whylogs通过轻量级流式计算方式,对输入数据(DataFrame、字典、数据流等)进行统计聚合,生成轻量化、可序列化的数据剖面(Profile),不存储原始数据,仅保留分布、计数、缺失值、类型等统计信息,支持离线存储、云端同步与多版本对比,实现无侵入式数据监控。

1.3 优缺点

优点:轻量化无性能损耗、支持流式与批量数据、兼容主流数据框架、隐私安全不存原始数据、可对接可视化平台、开箱即用。
缺点:复杂自定义规则需二次开发、极端小众数据类型支持有限、纯离线模式缺少自动告警扩展。

1.4 License类型

whylogs采用Apache-2.0 License,属于宽松开源协议,允许商业使用、修改、分发与二次发布。

二、whylogs安装与环境准备

2.1 基础安装命令

whylogs支持Python 3.7及以上版本,使用pip即可快速安装,执行以下命令:

pip install whylogs

安装过程会自动依赖numpy、pandas、pyarrow等基础数据处理库,无需额外手动配置。

2.2 验证安装成功

安装完成后,可通过简单导入语句验证是否正常:

# 验证whylogs安装
import whylogs as why

# 无报错则说明安装成功
print("whylogs 安装成功,版本:", why.__version__)

运行代码后输出对应版本号,即代表环境配置完成。

2.3 扩展依赖安装

若需对接云端WhyLabs平台或增强可视化能力,可安装扩展包:

pip install whylogs[viz] whylogs[whylabs]

whylogs[viz]提供本地剖面可视化能力,whylogs[whylabs]支持数据剖面云端上传与集中管理。

三、whylogs核心功能与基础使用

3.1 快速生成数据剖面

数据剖面是whylogs的核心产物,包含数据完整统计信息,支持pandas DataFrame、列表、字典等多种数据格式。

import pandas as pd
import whylogs as why

# 构造示例数据集
data = {
    "用户ID": [1001, 1002, 1003, 1004, None, 1006],
    "消费金额": [99.5, 199.0, 59.8, None, 299.0, 88.0],
    "会员等级": ["普通", "高级", "普通", "高级", "普通", "普通"],
    "购买次数": [3, 5, 2, 7, 1, 4]
}
df = pd.DataFrame(data)

# 使用whylogs生成数据剖面
results = why.log(df)

# 获取剖面对象
profile = results.profile()

# 查看剖面数据
print("数据剖面生成完成!")
profile.view()

代码说明:通过why.log()方法传入DataFrame,自动完成数据统计分析,profile.view()可在控制台输出数据基础指标,包括缺失值数量、数据类型、唯一值计数等。

3.2 查看数据详细统计指标

生成剖面后,可提取单列或全量详细统计信息,包括缺失率、最大值、最小值、均值、分位数等。

# 获取数据剖面视图
profile_view = profile.view()

# 查看全量数据列的统计指标
full_stats = profile_view.to_pandas()
print("全量数据统计指标:")
print(full_stats)

# 单独提取指定列指标
amount_stats = profile_view.get_column("消费金额")
print("\n消费金额字段详细统计:")
print("缺失值数量:", amount_stats.missing.value)
print("最大值:", amount_stats.max.value)
print("最小值:", amount_stats.min.value)
print("均值:", amount_stats.mean.value)

代码说明:profile_view.to_pandas()将统计结果转为DataFrame,方便二次处理;get_column()可精准定位目标字段,获取针对性质量指标。

3.3 流式数据实时监控

whylogs支持流式数据处理,适用于实时数据 pipelines、日志流、接口数据等场景。

# 初始化流式记录器
writer = why.logger(mode="streaming", name="stream_demo", interval=5, when="count")

# 接入第一批次数据
batch1 = pd.DataFrame({
    "访问IP": ["192.168.1.1", "192.168.1.2", "192.168.1.3"],
    "响应时间": [120, 200, 150]
})
writer.log(batch1)

# 接入第二批次数据
batch2 = pd.DataFrame({
    "访问IP": ["192.168.1.4", None, "192.168.1.5"],
    "响应时间": [300, 180, None]
})
writer.log(batch2)

# 关闭记录器并生成最终剖面
stream_profile = writer.close()
print("流式数据剖面生成完成")
stream_profile.view().to_pandas()

代码说明:mode="streaming"开启流式模式,interval=5表示每5条数据自动聚合一次,适合持续产生的实时数据质量监控。

3.4 数据剖面持久化存储

生成的数据剖面可序列化保存为本地文件,方便后续对比、回溯与共享。

# 保存剖面到本地文件
profile.write(path="data_profile.bin")

# 从本地文件加载剖面
loaded_profile = why.read(path="data_profile.bin")

print("加载的历史数据剖面:")
loaded_profile.view().to_pandas()

代码说明:剖面文件体积极小,仅存储统计信息,不占用大量存储空间,适合长期归档。

四、数据漂移检测与多版本对比

4.1 数据漂移检测基础使用

在机器学习场景中,训练数据与在线推理数据的分布差异(数据漂移)会严重影响模型效果,whylogs可快速检测该问题。

# 构造训练数据(基准数据)
train_data = pd.DataFrame({
    "特征A": [10, 12, 11, 13, 12, 10, 11],
    "特征B": [0.5, 0.6, 0.5, 0.7, 0.6, 0.5, 0.6]
})

# 构造推理数据(待检测数据,存在分布偏移)
infer_data = pd.DataFrame({
    "特征A": [18, 19, 20, 17, 18, 19],
    "特征B": [0.1, 0.2, 0.1, 0.3, 0.2, 0.1]
})

# 生成两个数据剖面
train_profile = why.log(train_data).profile()
infer_profile = why.log(infer_data).profile()

# 对比检测数据漂移
from whylogs.core.metrics.metrics import Metric
from whylogs.core.view import DatasetProfileView

train_view = train_profile.view()
infer_view = infer_profile.view()

# 执行漂移检测
drift_report = train_view.compare(infer_view).drift_report()
print("数据漂移检测报告:")
print(drift_report.to_pandas())

代码说明:通过基准剖面与待检测剖面对比,自动计算分布差异,输出漂移评分与漂移等级,帮助快速定位异常特征。

4.2 可视化漂移对比结果

安装扩展依赖后,可直接在Python环境中生成交互式漂移对比图表:

# 生成交互式漂移对比可视化
train_view.compare(infer_view).visualize()

代码说明:运行后会生成包含分布直方图、漂移指数的交互式页面,直观展示数据差异,无需手动绘图。

五、结合机器学习流水线实战案例

5.1 场景说明

本案例模拟完整机器学习流程:数据读取→剖面记录→模型训练→推理数据监控→漂移告警,覆盖实际项目完整链路。

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
import whylogs as why

# 1. 加载并拆分数据集(模拟业务数据)
np.random.seed(42)
raw_data = pd.DataFrame({
    "广告曝光量": np.random.randint(1000, 5000, 100),
    "点击量": np.random.randint(100, 500, 100),
    "转化量": np.random.randint(10, 100, 100)
})
raw_data["转化量"].iloc[80:100] = None  # 人为添加缺失值

# 记录原始数据剖面
raw_profile = why.log(raw_data).profile()
raw_profile.write("raw_data_profile.bin")

# 2. 数据预处理
clean_data = raw_data.dropna()
X = clean_data[["广告曝光量", "点击量"]]
y = clean_data["转化量"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 记录训练集数据剖面
train_profile = why.log(X_train).profile()
train_profile.write("train_data_profile.bin")

# 3. 模型训练
model = LinearRegression()
model.fit(X_train, y_train)

# 4. 模拟在线推理(含异常数据)
online_data = pd.DataFrame({
    "广告曝光量": [9000, 8500, 8800, 1500, 1300],
    "点击量": [200, 180, 190, 800, 750]  # 点击量异常偏高
})

# 记录推理数据剖面
online_profile = why.log(online_data).profile()
online_profile.write("online_data_profile.bin")

# 5. 检测推理数据与训练数据的漂移
drift_result = train_profile.view().compare(online_profile.view()).drift_report()
print("模型推理数据漂移检测:")
print(drift_result.to_pandas())

代码说明:本案例完整复现工业级数据监控流程,从原始数据到模型推理全程记录剖面,及时发现异常数据与分布偏移。

5.2 异常数据自动过滤

基于whylogs检测结果,可实现自动过滤异常推理数据,保障模型稳定性:

# 获取漂移检测结果
drift_df = drift_result.to_pandas()
high_drift_columns = drift_df[drift_df["drift_score"] > 0.6]["column"].tolist()

if high_drift_columns:
    print(f"检测到高漂移字段:{high_drift_columns},自动过滤异常数据")
    # 过滤异常数据
    filtered_online_data = online_data.copy()
    for col in high_drift_columns:
        # 基于训练数据统计值设置阈值
        train_mean = train_profile.view().get_column(col).mean.value
        filtered_online_data = filtered_online_data[filtered_online_data[col] < train_mean * 3]
    print("过滤后数据:")
    print(filtered_online_data)
else:
    print("数据正常,可直接推理")

代码说明:通过漂移分数设置阈值,自动识别高风险字段并过滤异常数据,减少错误输入对模型的影响。

六、集成WhyLabs云端平台

6.1 云端上传数据剖面

whylabs支持将本地剖面上传至云端,实现多项目集中监控、历史回溯、自动告警:

import os
import whylogs as why
from whylogs.api.whylabs.session import WhyLabsSession

# 配置云端密钥(需在WhyLabs官网注册获取)
os.environ["WHYLABS_API_KEY"] = "你的API_KEY"
os.environ["WHYLABS_DEFAULT_DATASET_ID"] = "你的数据集ID"

# 上传剖面到云端
profile.writer("whylabs").write()
print("数据剖面已成功上传至WhyLabs云端")

代码说明:云端平台提供可视化看板、定时监控、团队协作功能,适合企业级数据质量管理。

七、相关资源

  • Pypi地址:https://pypi.org/project/whylogs/
  • Github地址:https://github.com/whylabs/whylogs
  • 官方文档地址:https://docs.whylabs.ai/docs/whylogs/

关注我,每天分享一个实用的Python自动化工具。

Python 机器学习流水线神器:ZenML 从入门到实战全教程

一、ZenML 库概述

ZenML 是一款面向机器学习与 MLOps 领域的开源 Python 库,核心用于构建可复用、可复现、可迁移的端到端 ML 流水线,屏蔽底层环境差异,统一本地、云端、分布式集群的流水线执行逻辑。其基于流水线与步骤抽象设计,将数据读取、预处理、训练、评估、部署拆分为可编排步骤,底层通过配置文件管理运行时环境与组件。优点是轻量化、易上手、跨平台兼容、支持多框架协同,缺点是复杂分布式调度能力弱于 Kubeflow。采用 Apache License 2.0 开源许可,商用友好。

二、ZenML 安装与初始化环境

2.1 基础安装

在使用 ZenML 前,我们需要通过 pip 完成安装,打开命令行执行以下指令:

pip install zenml

该命令会安装 ZenML 核心库以及基础依赖,适合本地快速体验与开发。如果需要对接云服务、数据库、分布式训练等扩展功能,还可以安装对应扩展包。

2.2 初始化 ZenML 环境

安装完成后,必须先初始化 ZenML 工作环境,这一步会创建本地配置文件、数据库、存储目录等核心结构,是后续所有操作的前提。

zenml init

执行成功后,会在当前目录生成 .zenml 隐藏文件夹,用于存储流水线配置、运行记录、元数据等信息。

2.3 安装常用扩展组件

机器学习流水线通常需要对接数据、模型、可视化工具,因此我们安装常用扩展组件:

# 安装可视化、数据处理、模型训练相关扩展
pip install "zenml[server,data,model,tensorflow,sklearn]"

安装完成后,可以启动 ZenML 本地服务,用于查看流水线运行状态、元数据、实验记录等:

zenml up

启动成功后,默认访问地址为 http://127.0.0.1:8237,打开浏览器即可进入 ZenML 可视化控制台。

三、ZenML 核心概念与基础使用

3.1 核心概念解析

  1. 步骤(Step):流水线中最小执行单元,例如数据加载、数据清洗、模型训练、模型评估,每个步骤都是独立函数,通过装饰器标记。
  2. 流水线(Pipeline):由多个步骤按逻辑顺序组合而成,定义完整机器学习工作流,一次定义可多次运行。
  3. 工件(Artifact):步骤之间传递的数据或模型,ZenML 自动管理工件的存储、读取、版本管理,无需手动处理文件读写。
  4. 栈(Stack):定义流水线运行环境,包括编排引擎、元数据存储、工件存储、部署引擎等,本地默认使用本地栈,可无缝切换云端栈。
  5. 运行(Run):流水线的一次执行过程,所有步骤日志、结果、指标都会被记录,支持回溯查看。

3.2 第一个 ZenML 流水线

我们从最简单的示例开始,创建两个步骤并组合成流水线,理解 ZenML 的基础用法。

3.2.1 代码实现

# 导入核心装饰器
from zenml import step, pipeline

# 定义第一个步骤:生成数据
@step
def generate_data() -> int:
    """生成一个整数数据"""
    data = 100
    print(f"生成数据:{data}")
    return data

# 定义第二个步骤:处理数据
@step
def process_data(input_data: int) -> int:
    """对输入数据进行处理,乘以2"""
    result = input_data * 2
    print(f"处理后数据:{result}")
    return result

# 定义流水线:组合步骤
@pipeline
def simple_ml_pipeline():
    """最简单的 ZenML 流水线"""
    data = generate_data()
    output = process_data(data)

# 运行流水线
if __name__ == "__main__":
    simple_ml_pipeline()

3.2.2 代码说明

  1. 使用 @step 装饰器将普通函数标记为 ZenML 步骤,函数的输入输出会自动被 ZenML 管理为工件。
  2. 使用 @pipeline 装饰器将步骤组合为流水线,内部按顺序调用步骤,自动处理数据传递。
  3. 运行脚本后,ZenML 会自动记录运行日志、步骤执行顺序、数据传递结果,可在控制台查看。

执行代码后,命令行会输出执行过程,浏览器控制台会新增一条流水线运行记录,展示每个步骤的执行状态、耗时、输出结果。

四、基于 Sklearn 的机器学习实战流水线

4.1 实战场景说明

本案例使用经典鸢尾花数据集,构建完整机器学习流水线,包含:数据加载、数据划分、模型训练、模型评估四个核心步骤,使用 Sklearn 实现算法,ZenML 完成流水线编排与管理。

4.2 完整代码实现

from zenml import step, pipeline
from zenml.client import Client
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# 步骤1:加载数据集
@step
def load_dataset() -> pd.DataFrame:
    """加载鸢尾花数据集并转换为DataFrame"""
    iris = load_iris()
    data = pd.DataFrame(iris.data, columns=iris.feature_names)
    data["target"] = iris.target
    print("数据集加载完成,形状:", data.shape)
    return data

# 步骤2:划分训练集和测试集
@step
def split_dataset(
    data: pd.DataFrame, test_size: float = 0.2
) -> tuple[pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    """划分训练集和测试集"""
    X = data.drop("target", axis=1)
    y = data["target"]
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42
    )
    print(f"训练集样本数:{len(X_train)},测试集样本数:{len(X_test)}")
    return X_train, X_test, y_train, y_test

# 步骤3:训练随机森林模型
@step
def train_model(
    X_train: pd.DataFrame, y_train: pd.Series
) -> RandomForestClassifier:
    """使用随机森林算法训练模型"""
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    print("模型训练完成")
    return model

# 步骤4:评估模型
@step
def evaluate_model(
    model: RandomForestClassifier, X_test: pd.DataFrame, y_test: pd.Series
) -> float:
    """评估模型并输出准确率"""
    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    print(f"模型测试集准确率:{acc:.4f}")
    return acc

# 定义完整机器学习流水线
@pipeline
def iris_classification_pipeline(test_size: float = 0.2):
    """鸢尾花分类完整流水线"""
    data = load_dataset()
    X_train, X_test, y_train, y_test = split_dataset(data, test_size=test_size)
    model = train_model(X_train, y_train)
    accuracy = evaluate_model(model, X_test, y_test)

# 运行流水线
if __name__ == "__main__":
    # 执行流水线
    run = iris_classification_pipeline(test_size=0.3)

    # 查看运行结果
    client = Client()
    latest_run = client.get_pipeline_run("iris_classification_pipeline")
    print(f"最新运行ID:{latest_run.id}")
    print(f"最终准确率:{latest_run.steps['evaluate_model'].output.read()}")

4.3 代码说明

  1. 四个步骤分别承担数据、划分、训练、评估职责,解耦代码结构,便于单独修改、调试、复用。
  2. 步骤之间自动传递 DataFrame、模型、数组等复杂对象,无需手动保存文件、读取文件。
  3. 流水线支持传入参数(如 test_size),可灵活调整配置,多次运行对比结果。
  4. 通过 ZenML Client 可以读取历史运行结果、步骤输出、元数据,便于后续自动化分析。

运行代码后,控制台会输出数据集信息、样本划分结果、模型训练状态与最终准确率,同时所有信息会同步到 ZenML 控制台,可查看流水线 DAG 图、步骤耗时、模型指标、数据版本等。

五、基于 TensorFlow 的深度学习流水线实战

5.1 实战场景说明

使用简单神经网络对鸢尾花数据集进行分类,展示 ZenML 对接深度学习框架的能力,步骤包括:数据加载、数据预处理、模型构建、模型训练、模型评估。

5.2 完整代码实现

from zenml import step, pipeline
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical

# 步骤1:加载数据
@step
def load_iris_data() -> pd.DataFrame:
    iris = load_iris()
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    df["target"] = iris.target
    return df

# 步骤2:数据预处理
@step
def preprocess_data(
    df: pd.DataFrame
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    X = df.drop("target", axis=1).values
    y = df["target"].values

    # 标准化
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    # 独热编码
    y = to_categorical(y, num_classes=3)

    # 划分数据集
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42
    )
    return X_train, X_test, y_train, y_test

# 步骤3:构建神经网络模型
@step
def build_dnn_model(input_shape: int) -> Sequential:
    model = Sequential()
    model.add(Dense(16, activation="relu", input_shape=(input_shape,)))
    model.add(Dense(8, activation="relu"))
    model.add(Dense(3, activation="softmax"))

    model.compile(
        optimizer="adam",
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )
    return model

# 步骤4:训练模型
@step
def train_dnn_model(
    model: Sequential, X_train: np.ndarray, y_train: np.ndarray
) -> Sequential:
    model.fit(
        X_train, y_train,
        epochs=50,
        batch_size=4,
        validation_split=0.1,
        verbose=1
    )
    return model

# 步骤5:评估模型
@step
def test_dnn_model(
    model: Sequential, X_test: np.ndarray, y_test: np.ndarray
) -> dict:
    loss, acc = model.evaluate(X_test, y_test, verbose=0)
    result = {"test_loss": loss, "test_accuracy": acc}
    print(f"测试损失:{loss:.4f},测试准确率:{acc:.4f}")
    return result

# 定义深度学习流水线
@pipeline
def iris_dnn_pipeline():
    df = load_iris_data()
    X_train, X_test, y_train, y_test = preprocess_data(df)
    model = build_dnn_model(input_shape=X_train.shape[1])
    trained_model = train_dnn_model(model, X_train, y_train)
    metrics = test_dnn_model(trained_model, X_test, y_test)

if __name__ == "__main__":
    iris_dnn_pipeline()

5.3 代码说明

  1. ZenML 可以无缝对接 TensorFlow、PyTorch 等深度学习框架,自动序列化、存储、加载模型。
  2. 预处理步骤包含标准化、独热编码、数据集划分,符合深度学习数据处理规范。
  3. 模型训练过程中的日志、指标、结构都会被 ZenML 记录,便于对比不同超参数效果。
  4. 流水线结构清晰,可直接用于生产环境,替换数据集即可快速迁移到其他项目。

六、ZenML 流水线高级用法

6.1 流水线配置化运行

支持通过外部参数控制流水线行为,适配不同环境、不同数据集、不同超参数,实现一次编写、多次灵活运行。

@pipeline
def configurable_pipeline(
    epochs: int = 50,
    test_size: float = 0.2,
    model_type: str = "random_forest"
):
    # 内部根据参数选择不同模型或逻辑
    pass

6.2 查看历史运行记录

from zenml.client import Client

client = Client()

# 获取所有流水线
pipelines = client.list_pipelines()
for p in pipelines:
    print(p.name)

# 获取某条流水线的所有运行记录
runs = client.get_pipeline("iris_classification_pipeline").runs
for r in runs:
    print(f"运行时间:{r.created},状态:{r.status}")

6.3 流水线缓存机制

ZenML 默认开启缓存,未修改的步骤会直接使用上一次运行结果,大幅提升调试速度:

@step(enable_cache=False)  # 关闭当前步骤缓存
def dynamic_step():
    pass

七、实际项目应用价值与代码价值

在实际机器学习项目中,传统开发模式常面临代码混乱、不可复现、环境迁移困难、实验记录丢失等问题。使用 ZenML 可以将整个工作流结构化,每个步骤独立可维护,所有实验自动记录,方便回溯最优模型。

在团队协作中,统一的流水线规范可以降低沟通成本,新成员可快速理解工作流程;在部署阶段,本地编写的流水线无需大量修改即可运行在云端服务器、K8s 集群等环境,实现从开发到生产的平滑迁移。

上述鸢尾花分类、深度学习流水线代码,可直接作为项目模板,替换数据集、调整模型结构、修改评估指标,即可应用于图像分类、表格数据预测、NLP 任务等多种场景,真正实现一套流水线适配多任务、多环境。

八、相关资源

  • Pypi地址:https://pypi.org/project/zenml
  • Github地址:https://github.com/zenml-io/zenml
  • 官方文档地址:https://docs.zenml.io

关注我,每天分享一个实用的Python自动化工具。

Python 实用工具:Activeloop 从入门到实战,轻松管理向量数据与大模型数据集

一、Activeloop 库简介

Activeloop 是一款专注于数据集管理、向量存储与大模型数据流水线的 Python 工具库,核心用于构建、版本控制、流式加载与查询 AI 数据集,尤其适配大语言模型、计算机视觉场景。其通过统一数据格式实现跨设备、跨框架的数据共享,底层依托高效序列化与云端存储能力,支持本地/云端无缝切换。该库采用 MIT License 开源,优点是易用、轻量、兼容主流 AI 框架,缺点是在超大规模离线数据集上性能略逊于专用分布式框架。

二、Activeloop 安装与环境准备

在使用 Activeloop 之前,我们需要完成库的安装与基础环境配置,支持 Python 3.8 及以上版本,可直接通过 pip 完成安装。

2.1 基础安装

打开命令行工具,执行以下安装命令:

pip install deeplake

Activeloop 的核心功能封装在 deeplake 包中,这是官方推荐的安装方式,安装过程会自动依赖 numpy、pandas 等基础数据处理库。

2.2 验证安装

安装完成后,可通过简单代码验证是否安装成功:

import deeplake

# 打印库版本
print("Activeloop (deeplake) 版本:", deeplake.__version__)

运行代码后,若正常输出版本号,说明环境配置完成,可进入后续功能使用。

三、Activeloop 核心功能与基础使用

Activeloop 核心围绕数据集创建、数据写入、数据读取、数据查询、向量存储五大功能展开,面向 AI 开发者屏蔽底层存储细节,专注数据本身。

3.1 创建本地数据集

数据集是 Activeloop 的核心载体,支持文本、图片、向量、标签等多种数据类型,创建方式简洁直观。

import deeplake
import numpy as np

# 创建本地数据集,路径为当前目录下的 my_first_dataset
ds = deeplake.dataset("./my_first_dataset")

# 定义数据集结构:文本数据、向量数据、标签
ds.create_tensor("text", htype="text")
ds.create_tensor("embedding", htype="embedding")
ds.create_tensor("label", htype="class_label")

print("数据集创建完成,结构如下:")
print(ds.tensors)

代码说明:

  • deeplake.dataset() 用于创建/加载数据集,传入本地路径则生成本地数据集;
  • create_tensor() 定义数据列,htype 指定数据类型,适配 AI 场景常用格式;
  • 执行后会在指定路径生成数据集文件夹,包含数据与元信息。

3.2 向数据集写入数据

创建完数据集结构后,可批量写入文本、向量、标签等数据,模拟大模型训练或检索场景的原始数据。

# 构造模拟数据
text_list = [
    "Python 是一门优雅易用的编程语言",
    "Activeloop 适合管理大模型数据集",
    "向量数据库是 RAG 系统的核心组件",
    "深度学习需要高质量标注数据"
]

# 生成模拟 768 维嵌入向量(适配大模型通用维度)
embedding_list = [np.random.rand(768) for _ in range(4)]
label_list = [0, 1, 1, 0]

# 批量写入数据集
with ds:
    ds.text.extend(text_list)
    ds.embedding.extend(embedding_list)
    ds.label.extend(label_list)

print("数据写入成功,数据集样本数:", len(ds))

代码说明:

  • 使用 with ds: 上下文管理器保证数据写入原子性,避免中途出错导致数据损坏;
  • extend() 用于批量添加数据,适配列表、numpy 数组等格式;
  • 写入后可通过 len(ds) 查看数据集总样本数。

3.3 读取数据集数据

Activeloop 支持索引读取、遍历读取、条件读取,操作方式与列表、DataFrame 高度相似,降低学习成本。

# 按索引读取单条数据
sample = ds[0]
print("第一条数据文本:", sample.text.data())
print("第一条数据向量形状:", sample.embedding.shape)
print("第一条数据标签:", sample.label.data())

# 遍历所有数据
print("\n===== 遍历全部数据 =====")
for i, sample in enumerate(ds):
    print(f"样本 {i}:")
    print(f"文本:{sample.text.data()}")
    print(f"标签:{sample.label.data()}\n")

代码说明:

  • 直接通过索引 ds[i] 获取第 i 条样本;
  • data() 方法提取原始数据,避免返回封装对象;
  • 遍历方式与 Python 列表一致,无需复杂语法。

3.4 条件查询数据

Activeloop 内置轻量查询引擎,支持按标签、数值等条件筛选数据,满足 AI 数据预处理需求。

# 查询标签为 1 的所有样本
filtered_ds = ds[ds.label == 1]

print("标签为 1 的样本数量:", len(filtered_ds))
for sample in filtered_ds:
    print("筛选文本:", sample.text.data())

代码说明:

  • 支持 ==!=>< 等常规比较运算符;
  • 筛选后返回新的数据集视图,不占用额外内存;
  • 适合大模型训练前的数据过滤与采样。

四、Activeloop 向量存储与 RAG 场景实战

向量存储是 Activeloop 的核心亮点,可直接作为轻量级向量数据库使用,适配 RAG(检索增强生成)场景,无需额外部署复杂数据库。

4.1 向量数据写入与检索

# 重新创建向量专用数据集
rag_ds = deeplake.dataset("./rag_embedding_dataset")
rag_ds.create_tensor("content", htype="text")
rag_ds.create_tensor("vector", htype="embedding")

# 写入文档与对应向量
documents = [
    "豆包是字节跳动自研的人工智能助手",
    "Activeloop 可用于构建 RAG 系统的向量库",
    "Python 广泛应用于 AI 与数据科学领域",
    "RAG 系统通过检索提升大模型回答准确性"
]

# 生成模拟向量
vectors = [np.random.rand(128) for _ in range(4)]

with rag_ds:
    rag_ds.content.extend(documents)
    rag_ds.vector.extend(vectors)

# 模拟查询向量并计算相似度
query_vector = np.random.rand(128)
scores = []

for sample in rag_ds:
    vec = sample.vector.data()
    # 余弦相似度简化计算
    similarity = np.dot(query_vector, vec) / (np.linalg.norm(query_vector) * np.linalg.norm(vec))
    scores.append(similarity)

# 获取最相似的文档
best_idx = np.argmax(scores)
print("\n最匹配的文档:", rag_ds[best_idx].content.data())
print("匹配相似度:", round(scores[best_idx], 4))

代码说明:

  • 该示例完整模拟 RAG 系统中文档入库→向量存储→相似度检索流程;
  • 无需依赖 Pinecone、Chroma 等外部向量库,单机即可运行;
  • 适合个人开发者、小型项目快速搭建检索系统。

4.2 与大模型嵌入接口结合

Activeloop 可无缝对接 OpenAI、文心一言、豆包等大模型的嵌入接口,实现真实向量生成与存储。

# 模拟调用大模型生成嵌入向量(可替换为真实 API)
def mock_get_embedding(text: str) -> np.ndarray:
    return np.random.rand(128)

# 构建真实场景数据集
qa_ds = deeplake.dataset("./qa_dataset")
qa_ds.create_tensor("question", htype="text")
qa_ds.create_tensor("answer", htype="text")
qa_ds.create_tensor("q_embedding", htype="embedding")

qa_pairs = [
    {"q": "Activeloop 是什么", "a": "Activeloop 是 Python 数据集与向量管理库"},
    {"q": "如何安装 deeplake", "a": "使用 pip install deeplake 安装"},
    {"q": "RAG 全称是什么", "a": "RAG 全称是 Retrieval Augmented Generation"}
]

with qa_ds:
    for pair in qa_pairs:
        q_emb = mock_get_embedding(pair["q"])
        qa_ds.question.append(pair["q"])
        qa_ds.answer.append(pair["a"])
        qa_ds.q_embedding.append(q_emb)

print("问答数据集构建完成,样本数:", len(qa_ds))

代码说明:

  • mock_get_embedding 替换为真实模型接口,即可生成工业级向量库;
  • 数据集同时存储问题、答案、向量,形成完整 RAG 数据链路。

五、Activeloop 云端数据集使用

Activeloop 支持云端存储,实现多设备、多开发者共享数据集,无需手动传输文件。

5.1 登录与云端数据集创建

首先在命令行登录 Activeloop 账号:

deeplake login

按照提示输入用户名与密码,登录成功后即可创建云端数据集。

# 创建云端数据集(需登录)
# cloud_ds = deeplake.dataset("hub://用户名/my_cloud_dataset")

# 后续读写操作与本地数据集完全一致
# with cloud_ds:
#     cloud_ds.text.extend(["云端数据测试"])

代码说明:

  • 路径以 hub:// 开头表示云端数据集;
  • 读写 API 与本地完全一致,实现本地/云端无缝切换;
  • 适合团队协作、跨设备开发。

六、与 PyTorch/TensorFlow 框架对接

Activeloop 原生支持深度学习框架,可直接转换为框架可读取的数据集,简化训练数据加载流程。

# 转换为 PyTorch DataLoader
from torch.utils.data import DataLoader

# 构建适合训练的数据集
train_ds = deeplake.dataset("./train_dataset")
train_ds.create_tensor("image", htype="image")
train_ds.create_tensor("target", htype="class_label")

# 模拟图像数据
for i in range(10):
    train_ds.image.append(np.random.rand(28, 28, 3))
    train_ds.target.append(np.random.randint(0, 2))

# 转换为 PyTorch 数据集
pytorch_ds = train_ds.pytorch(batch_size=2, shuffle=True)
dataloader = DataLoader(pytorch_ds, batch_size=None)

# 读取训练批次
for batch in dataloader:
    print("图像批次形状:", batch["image"].shape)
    print("标签批次:", batch["target"])
    break

代码说明:

  • 通过 pytorch() 方法直接生成适配框架的数据格式;
  • 支持 shuffle、batch_size、num_workers 等训练参数;
  • 大幅减少数据预处理与格式转换代码量。

七、实际应用案例:简易智能问答系统

结合前面所有知识点,构建一个可直接运行的轻量级智能问答系统,完整体现 Activeloop 的实用价值。

import deeplake
import numpy as np

# 1. 构建知识库数据集
kb_ds = deeplake.dataset("./knowledge_base")
kb_ds.create_tensor("question", htype="text")
kb_ds.create_tensor("answer", htype="text")
kb_ds.create_tensor("embed", htype="embedding")

# 知识库内容
knowledge = [
    {"q": "Python 有哪些常用数据结构", "a": "列表、字典、元组、集合、堆、队列等"},
    {"q": "Activeloop 能做什么", "a": "管理数据集、存储向量、构建 RAG、对接深度学习框架"},
    {"q": "如何读取 deeplake 数据", "a": "通过索引 ds[i] 或遍历读取,使用 data() 获取原始数据"},
    {"q": "deeplake 支持云端吗", "a": "支持,使用 hub:// 路径创建云端数据集"}
]

# 写入数据
with kb_ds:
    for item in knowledge:
        embed = np.random.rand(128)
        kb_ds.question.append(item["q"])
        kb_ds.answer.append(item["a"])
        kb_ds.embed.append(embed)

# 2. 定义检索函数
def search_answer(user_query: str) -> str:
    query_embed = np.random.rand(128)
    max_sim = -1
    best_ans = ""

    for sample in kb_ds:
        sim = np.dot(query_embed, sample.embed.data()) / (
            np.linalg.norm(query_embed) * np.linalg.norm(sample.embed.data())
        )
        if sim > max_sim:
            max_sim = sim
            best_ans = sample.answer.data()
    return best_ans

# 3. 模拟用户提问
if __name__ == "__main__":
    user_query = "如何读取 deeplake 里面的数据"
    answer = search_answer(user_query)

    print("用户问题:", user_query)
    print("系统回答:", answer)

该案例可直接部署在本地,作为小型客服机器人、文档问答助手使用,代码简洁、依赖少、启动快,充分体现 Activeloop 在 AI 小项目中的高效性。

相关资源

  • Pypi地址:https://pypi.org/project/deeplake/
  • Github地址:https://github.com/activeloopai/deeplake
  • 官方文档地址:https://docs.activeloop.ai/

关注我,每天分享一个实用的Python自动化工具。

Python实用数据处理库petl:轻量级表格数据操作完全指南

一、petl库基础认知

petl全称为Python ETL,是一款专注于轻量级数据提取、转换、加载的开源Python库,核心面向表格型数据处理场景,无需依赖 heavy 框架即可完成数据清洗、筛选、合并、格式转换等操作。其工作原理是通过惰性加载处理数据,仅在需要输出时才执行计算,大幅降低内存占用。该库采用MIT开源许可,优点是轻量简洁、上手门槛低、适合中小数据集处理,缺点是不适合超大规模分布式数据运算,性能弱于Pandas等专业数据分析库。

二、petl库安装方法

petl的安装流程十分简便,支持pip快速安装,无需配置复杂环境,适合Python初学者直接使用。
打开命令行工具,执行以下安装命令:

pip install petl

若需要加速安装,可使用国内镜像源:

pip install petl -i https://pypi.tuna.tsinghua.edu.cn/simple

安装完成后,在Python脚本中直接导入即可验证是否安装成功:

import petl
print(petl.__version__)

执行后输出版本号,即代表安装完成,可以正常使用。

三、petl核心数据结构与基础操作

petl核心围绕表格数据展开,其数据结构可以是列表、元组、CSV文件、Excel文件、数据库查询结果等,所有操作均以行和列为基本单位,语法贴近自然语言,极易理解。

3.1 创建基础表格数据

petl可以直接从Python原生数据结构创建数据表,这是入门的第一步。

import petl as etl

# 手动创建表格数据,第一行为表头,后续为数据行
table = [
    ['name', 'age', 'city'],
    ['张三', 22, '北京'],
    ['李四', 25, '上海'],
    ['王五', 28, '深圳'],
    ['赵六', 24, '广州']
]

# 查看数据前3行
print(etl.head(table, 3))

代码说明:

  1. 导入petl库并简写为etl,方便后续调用;
  2. 使用列表嵌套结构定义表格,第一行是字段名,后续每行是一条数据;
  3. head()方法用于查看指定行数的数据,默认查看前5行,此处指定查看前3行,方便快速预览数据结构。

3.2 数据筛选与条件查询

数据筛选是数据处理中最常用的功能,petl提供select、selecteq、selectne等方法实现条件过滤。

import petl as etl

table = [
    ['name', 'age', 'city'],
    ['张三', 22, '北京'],
    ['李四', 25, '上海'],
    ['王五', 28, '深圳'],
    ['赵六', 24, '广州']
]

# 筛选年龄大于24岁的数据
age_filter = etl.select(table, lambda rec: rec.age > 24)
print('年龄大于24岁的数据:')
print(age_filter)

# 筛选城市为上海的数据
city_filter = etl.selecteq(table, 'city', '上海')
print('\n城市为上海的数据:')
print(city_filter)

代码说明:

  1. select()方法支持自定义lambda函数,可实现复杂条件筛选;
  2. selecteq()是等值筛选方法,直接指定字段和对应值即可,语法更简洁;
  3. 所有操作返回新的表格对象,不会修改原始数据,保证数据安全。

3.3 字段新增与修改

处理数据时经常需要新增计算字段或修改现有字段,petl的addfield()和convert()方法可快速实现。

import petl as etl

table = [
    ['name', 'age', 'city'],
    ['张三', 22, '北京'],
    ['李四', 25, '上海'],
    ['王五', 28, '深圳'],
    ['赵六', 24, '广州']
]

# 新增年龄分组字段
table_add = etl.addfield(table, 'age_group', lambda rec: '青年' if rec.age < 25 else '壮年')
print('新增年龄分组字段后:')
print(table_add)

# 修改年龄字段,所有年龄+1
table_convert = etl.convert(table_add, 'age', lambda v: v + 1)
print('\n修改年龄字段后:')
print(table_convert)

代码说明:

  1. addfield()用于新增字段,可通过lambda函数根据现有数据计算新字段值;
  2. convert()用于修改指定字段的值,支持批量转换、格式调整等操作;
  3. 操作支持链式调用,可连续对数据进行处理。

3.4 数据排序与去重

数据排序和去重是数据清洗的必备环节,petl提供sort()、distinct()方法快速处理。

import petl as etl

table = [
    ['name', 'age', 'city'],
    ['张三', 22, '北京'],
    ['李四', 25, '上海'],
    ['张三', 22, '北京'],
    ['王五', 28, '深圳'],
    ['赵六', 24, '广州']
]

# 数据去重
table_distinct = etl.distinct(table)
print('去重后的数据:')
print(table_distinct)

# 按年龄降序排序
table_sort = etl.sort(table_distinct, key='age', reverse=True)
print('\n按年龄降序排序后:')
print(table_sort)

代码说明:

  1. distinct()方法会自动去除完全重复的数据行;
  2. sort()方法通过key指定排序字段,reverse=True表示降序,False为升序;
  3. 处理后的数据结构保持不变,可直接用于后续操作。

四、petl文件读写操作

petl最大的优势之一是支持多种文件格式的读写,包括CSV、Excel、JSON、HTML等,无需依赖其他库即可完成文件数据处理。

4.1 CSV文件读写

CSV是最常用的表格数据格式,petl对CSV的支持极为完善,读写速度快且稳定。

import petl as etl

# 定义测试数据
table = [
    ['name', 'age', 'city'],
    ['张三', 22, '北京'],
    ['李四', 25, '上海'],
    ['王五', 28, '深圳'],
    ['赵六', 24, '广州']
]

# 将数据写入CSV文件
etl.tocsv(table, 'user_data.csv', encoding='utf-8')
print('数据已写入CSV文件')

# 从CSV文件读取数据
read_csv = etl.fromcsv('user_data.csv', encoding='utf-8')
print('\n读取CSV文件数据:')
print(etl.head(read_csv))

代码说明:

  1. tocsv()方法将petl表格数据写入CSV文件,需指定编码格式避免中文乱码;
  2. fromcsv()方法读取CSV文件并转换为petl可处理的表格对象;
  3. 支持指定分隔符、是否包含表头等参数,适配不同格式的CSV文件。

4.2 Excel文件读写

petl读写Excel文件需要依赖openpyxl或xlrd库,需提前安装依赖:

pip install openpyxl xlrd

安装完成后执行读写代码:

import petl as etl

table = [
    ['name', 'age', 'city'],
    ['张三', 22, '北京'],
    ['李四', 25, '上海'],
    ['王五', 28, '深圳'],
    ['赵六', 24, '广州']
]

# 写入Excel文件
etl.toxlsx(table, 'user_data.xlsx', sheet='用户信息')
print('数据已写入Excel文件')

# 读取Excel文件
read_excel = etl.fromxlsx('user_data.xlsx', sheet='用户信息')
print('\n读取Excel文件数据:')
print(etl.head(read_excel))

代码说明:

  1. toxlsx()和fromxlsx()分别实现Excel文件的写入和读取;
  2. 支持指定sheet工作表名称,适合处理多工作表的Excel文件;
  3. 读取后的数据格式与手动创建的表格数据一致,可直接使用所有petl操作方法。

4.3 JSON与HTML格式数据处理

除了常规表格文件,petl还支持JSON和HTML格式的数据转换,适合数据展示和接口对接。

import petl as etl

table = [
    ['name', 'age', 'city'],
    ['张三', 22, '北京'],
    ['李四', 25, '上海'],
    ['王五', 28, '深圳']
]

# 转换为JSON格式
json_data = etl.tojson(table)
print('JSON格式数据:')
print(json_data)

# 转换为HTML表格
html_table = etl.tohtml(table)
print('\nHTML表格代码:')
print(html_table)

代码说明:

  1. tojson()将表格数据转换为JSON格式,适合接口数据返回;
  2. tohtml()将数据转换为HTML表格代码,可直接嵌入网页展示;
  3. 转换过程自动保留字段名和数据对应关系,无需手动格式化。

五、petl数据合并与关联操作

在实际数据处理中,经常需要将多个数据源合并、关联,petl提供cat、join、lookup等方法实现多表操作,功能媲美数据库关联查询。

5.1 多表数据合并

cat()方法用于将结构相同的多个表格合并为一个表格,适合数据拼接场景。

import petl as etl

# 定义两个结构相同的表格
table1 = [
    ['name', 'age', 'city'],
    ['张三', 22, '北京'],
    ['李四', 25, '上海']
]

table2 = [
    ['name', 'age', 'city'],
    ['王五', 28, '深圳'],
    ['赵六', 24, '广州']
]

# 合并两个表格
merge_table = etl.cat(table1, table2)
print('合并后的数据:')
print(merge_table)

代码说明:

  1. cat()方法可接收多个表格参数,一次性合并多个数据源;
  2. 要求表格的字段结构完全一致,否则会出现数据错位;
  3. 合并后保留所有数据行,不进行去重操作,如需去重可配合distinct()使用。

5.2 表关联查询(join)

join()方法实现类似数据库的左连接、内连接,适合多表关联分析。

import petl as etl

# 用户基础信息表
user_table = [
    ['user_id', 'name'],
    [1, '张三'],
    [2, '李四'],
    [3, '王五']
]

# 用户订单表
order_table = [
    ['user_id', 'order_no', 'amount'],
    [1, 'ORDER001', 99],
    [2, 'ORDER002', 199],
    [1, 'ORDER003', 299]
]

# 内连接:只保留匹配的数据
inner_join = etl.join(user_table, order_table, key='user_id')
print('内连接结果:')
print(inner_join)

# 左连接:保留左表所有数据
left_join = etl.leftjoin(user_table, order_table, key='user_id')
print('\n左连接结果:')
print(left_join)

代码说明:

  1. key参数指定关联字段,通常为ID类唯一标识;
  2. 内连接只保留两张表都存在的匹配数据;
  3. 左连接保留左表所有数据,右表无匹配数据时填充为None;
  4. 无需编写复杂SQL,纯Python语法即可实现数据库级别的关联查询。

六、petl与数据库交互

petl支持直接连接MySQL、SQLite、PostgreSQL等数据库,实现数据的提取与写入,是ETL流程的核心功能。

6.1 SQLite数据库操作

SQLite为轻量级文件数据库,无需单独安装服务,适合演示。

import petl as etl
import sqlite3

# 连接SQLite数据库
conn = sqlite3.connect('test.db')

# 创建测试表
conn.execute('''CREATE TABLE IF NOT EXISTS user 
             (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)''')

# 定义数据并写入数据库
table = [
    ['id', 'name', 'age'],
    [1, '张三', 22],
    [2, '李四', 25],
    [3, '王五', 28]
]
etl.todb(table, conn, 'user')
print('数据已写入SQLite数据库')

# 从数据库读取数据
db_table = etl.fromdb(conn, 'SELECT * FROM user')
print('\n从数据库读取的数据:')
print(db_table)

conn.close()

代码说明:

  1. fromdb()从数据库执行SQL查询并返回petl表格数据;
  2. todb()将petl数据写入数据库表;
  3. 支持事务操作、批量写入,效率高于原生Python数据库操作。

6.2 MySQL数据库操作

连接MySQL需要安装pymysql库:

pip install pymysql

连接代码:

import petl as etl
import pymysql

# 连接MySQL数据库
conn = pymysql.connect(
    host='localhost',
    user='root',
    password='123456',
    database='test_db',
    charset='utf8'
)

# 从MySQL查询数据
mysql_table = etl.fromdb(conn, 'SELECT name, age FROM user LIMIT 5')
print('MySQL数据:')
print(mysql_table)

conn.close()

代码说明:

  1. 连接参数需根据实际MySQL环境修改;
  2. 支持复杂SQL查询,结果直接转换为petl可处理的表格;
  3. 可实现数据从数据库提取→清洗→转换→写回数据库的完整ETL流程。

七、petl实际业务案例

用户数据清洗与分析为例,模拟真实业务场景,整合petl所有核心操作。

7.1 业务需求

  1. 读取CSV格式的原始用户数据;
  2. 去除重复数据、空值数据;
  3. 筛选年龄在20-30岁之间的用户;
  4. 新增年龄段分组字段;
  5. 按城市统计用户数量;
  6. 将处理结果保存为Excel文件。

7.2 完整实现代码

import petl as etl

# 1. 读取原始CSV数据
raw_data = etl.fromcsv('raw_user.csv', encoding='utf-8')

# 2. 去除重复数据和空值数据
data_no_dup = etl.distinct(raw_data)
data_no_null = etl.rejectmissing(data_no_dup)

# 3. 筛选年龄20-30岁的用户
data_filter = etl.select(data_no_null, lambda rec: 20 <= int(rec.age) <= 30)

# 4. 新增年龄段分组字段
data_add_field = etl.addfield(data_filter, 'age_group', 
                             lambda rec: '20-25岁' if 20 <= int(rec.age) <=25 else '26-30岁')

# 5. 按城市统计用户数量
city_stats = etl.aggregate(data_add_field, 'city', count=('name', len))

# 6. 保存处理后的数据和统计结果
etl.toxlsx(data_add_field, 'clean_user_data.xlsx', sheet='清洗后数据')
etl.toxlsx(city_stats, 'city_user_stats.xlsx', sheet='城市统计')

print('数据处理完成,结果已保存为Excel文件')

代码说明:

  1. rejectmissing()方法自动剔除包含空值的数据行,保证数据质量;
  2. aggregate()实现分组统计功能,类似SQL的GROUP BY,可统计数量、求和、平均值等;
  3. 整个流程采用链式处理逻辑,代码简洁易读,无需创建中间变量;
  4. 从文件读取到最终输出,全程使用petl完成,适合自动化数据处理脚本。

相关资源

  • Pypi地址:https://pypi.org/project/petl
  • Github地址:https://github.com/petl-developers/petl
  • 官方文档地址:https://petl.readthedocs.io/

关注我,每天分享一个实用的Python自动化工具。

Python 数据建模神器:dbt 从入门到实战,轻松搞定数据仓库开发

一、dbt 核心介绍

dbt(data build tool)是专注于数据仓库建模的 Python 工具,核心作用是让数据工程师、分析师用 SQL 完成数据转换、测试、文档化,无需编写复杂调度脚本。其原理是基于 SQL 编译生成可执行模型,依赖数据仓库引擎执行计算,采用 Apache 2.0 开源协议。优点是上手快、协作友好、自带测试与文档,缺点是依赖数据仓库、不负责数据抽取与加载。

二、dbt 安装与环境初始化

2.1 安装 dbt

dbt 可通过 pip 直接安装,不同数据仓库对应不同适配器,主流适配 BigQuery、Snowflake、Redshift、Databricks、PostgreSQL 等,这里以通用安装和最常用的 PostgreSQL 适配器为例。

打开命令行执行安装命令:

# 安装 dbt 核心及 PostgreSQL 适配器
pip install dbt-core dbt-postgres

安装完成后验证版本:

dbt --version

出现版本信息即代表安装成功,包含 core 版本和对应适配器版本。

2.2 初始化 dbt 项目

安装完成后,创建专属 dbt 项目,命令会自动生成标准目录结构,这是 dbt 规范开发的基础。

# 创建名为 dbt_demo 的项目
dbt init dbt_demo

执行后进入项目目录:

cd dbt_demo

2.3 配置数据仓库连接

dbt 核心配置文件为 profiles.yml,默认在用户目录 .dbt 文件夹下,用于配置数据库连接信息。

以 PostgreSQL 为例,配置内容如下:

dbt_demo:
  target: dev
  outputs:
    dev:
      type: postgres
      host: localhost
      user: postgres
      password: 你的密码
      port: 5432
      dbname: postgres
      schema: dbt_demo
      threads: 4

配置完成后,测试连接是否正常:

dbt debug

显示 All checks passed! 代表连接成功。

三、dbt 标准目录结构

dbt 有固定的目录规范,便于团队协作和模型管理,初始化后的目录结构如下:

dbt_demo/
├── analyses/          # 存放分析查询 SQL
├── dbt_project.yml    # 项目核心配置文件
├── macros/            # 自定义 Jinja2 宏
├── models/            # 核心数据模型目录
├── seeds/             # 存放静态 CSV 数据
├── snapshots/         # 快照数据,记录历史状态
├── tests/             # 自定义数据测试
└── README.md

其中 models 是核心目录,所有业务数据模型都在此编写,dbt_project.yml 用于配置模型、权限、变量等。

四、dbt 基础使用:从简单模型开始

4.1 编写基础模型

模型是 dbt 的核心,本质就是带逻辑的 SQL 文件,存放在 models 目录下。

models 目录下创建 user_orders.sql,编写简单模型,关联用户表和订单表:

-- models/user_orders.sql
{{ config(materialized='table') }}

SELECT
    u.user_id,
    u.username,
    u.create_time AS user_create_time,
    o.order_id,
    o.order_amount,
    o.order_time
FROM
    public.users u
LEFT JOIN
    public.orders o
ON
    u.user_id = o.user_id

代码说明:

  • {{ config(materialized='table') }} 表示模型构建为表,还支持 view、incremental、ephemeral 等类型
  • 模型本质是标准 SQL,dbt 会自动编译并在数据库中生成对应表或视图

4.2 运行 dbt 模型

编写完成后,执行命令运行模型:

dbt run

执行成功后,会在配置的 schema 下生成 user_orders 表,数据自动关联计算完成。

4.3 为模型添加文档描述

dbt 支持直接在 SQL 中添加注释,自动生成文档,无需手动维护。

优化后的模型:

-- models/user_orders.sql
{{ config(materialized='table', tags=['user', 'order']) }}

/*
 * 模型名称: user_orders
 * 功能: 用户与订单关联宽表,用于用户消费分析
 */

SELECT
    u.user_id AS user_id,        -- 用户ID,主键
    u.username AS username,      -- 用户名
    u.create_time AS user_create_time,  -- 用户注册时间
    o.order_id AS order_id,      -- 订单ID
    o.order_amount AS order_amount,  -- 订单金额
    o.order_time AS order_time   -- 下单时间
FROM
    public.users u
LEFT JOIN
    public.orders o
ON
    u.user_id = o.user_id

4.4 生成并查看文档

执行命令生成文档:

dbt docs generate
dbt docs serve

执行后会启动本地服务,浏览器访问可查看完整的模型血缘关系、字段说明、模型依赖,非常适合团队协作。

五、dbt 进阶使用:数据测试与增量模型

5.1 内置数据测试

数据质量是数仓核心,dbt 内置丰富测试,在 models 下创建 schema.yml 配置测试规则:

version: 2

models:
  - name: user_orders
    columns:
      - name: user_id
        tests:
          - not_null
          - unique
      - name: order_amount
        tests:
          - not_null

配置后执行测试命令:

dbt test

dbt 会自动校验字段是否为空、是否唯一,快速发现数据问题。

5.2 自定义测试

除内置测试外,还可编写自定义 SQL 测试,在 tests 目录下创建 test_order_amount_positive.sql

-- 测试订单金额必须大于0
SELECT
    *
FROM
    {{ ref('user_orders') }}
WHERE
    order_amount <= 0

执行 dbt test 时会自动运行该测试,若有负金额则报错。

5.3 增量模型(核心进阶功能)

全量构建在大数据量下效率极低,dbt 提供增量模型,只新增或更新数据。

创建增量订单模型 incremental_orders.sql

-- models/incremental_orders.sql
{{ config(
    materialized='incremental',
    unique_key='order_id',
    incremental_strategy='merge'
) }}

SELECT
    order_id,
    user_id,
    order_amount,
    order_time
FROM
    public.orders

{% if is_incremental() %}
    -- 增量逻辑:只查询比当前模型最大时间大的数据
    WHERE order_time > (SELECT MAX(order_time) FROM {{ this }})
{% endif %}

代码说明:

  • incremental 代表增量模型
  • unique_key 是增量合并的唯一键
  • is_incremental() 宏用于区分首次运行与增量运行
    首次执行全量构建,后续执行只增量同步新数据,大幅提升运行效率。

六、dbt 宏与变量:提升代码复用性

6.1 自定义宏

宏类似 Python 函数,可复用 SQL 逻辑,在 macros 目录下创建 get_date_macro.sql

{% macro get_today() %}
    CURRENT_DATE
{% endmacro %}

在模型中直接调用:

SELECT
    order_id,
    order_time,
    {{ get_today() }} AS stat_date
FROM
    {{ ref('incremental_orders') }}

6.2 使用项目变量

dbt_project.yml 中定义变量:

vars:
  start_date: '2025-01-01'

模型中使用变量:

SELECT
    *
FROM
    {{ ref('user_orders') }}
WHERE
    order_time >= '{{ var("start_date") }}'

七、dbt 种子数据:静态数据导入

dbt 支持将本地 CSV 文件导入数据库,称为种子数据,适合字典表、静态映射表。

seeds 目录下创建 user_level.csv

user_id,level
1,VIP
2,普通用户
3,VIP

执行导入命令:

dbt seed

导入后可在模型中直接引用:

SELECT
    u.*,
    l.level
FROM
    {{ ref('user_orders') }} u
LEFT JOIN
    {{ ref('user_level') }} l
ON
    u.user_id = l.user_id

八、完整实战案例:用户消费统计宽表

结合前面所有知识点,构建一个完整的数仓模型,用于业务分析、报表展示。

8.1 需求说明

构建用户日消费统计表,包含用户ID、用户名、消费日期、消费次数、消费总金额、会员等级,支持增量更新、数据测试、自动文档。

8.2 模型代码

-- models/user_daily_stat.sql
{{ config(
    materialized='incremental',
    unique_key='concat(user_id, stat_date)',
    tags=['report', 'user', 'stat'],
    incremental_strategy='merge'
) }}

WITH order_daily AS (
    SELECT
        user_id,
        DATE(order_time) AS stat_date,
        COUNT(order_id) AS order_count,
        SUM(order_amount) AS total_amount
    FROM
        {{ ref('incremental_orders') }}
    GROUP BY
        user_id, DATE(order_time)
),

user_info AS (
    SELECT
        user_id,
        username
    FROM
        public.users
),

user_level AS (
    SELECT
        user_id,
        level
    FROM
        {{ ref('user_level') }}
)

SELECT
    u.user_id,
    u.username,
    l.level,
    o.stat_date,
    o.order_count,
    o.total_amount
FROM
    user_info u
JOIN
    order_daily o ON u.user_id = o.user_id
LEFT JOIN
    user_level l ON u.user_id = l.user_id

{% if is_incremental() %}
    WHERE o.stat_date > (SELECT MAX(stat_date) FROM {{ this }})
{% endif %}

8.3 测试配置

version: 2

models:
  - name: user_daily_stat
    columns:
      - name: user_id
        tests:
          - not_null
      - name: stat_date
        tests:
          - not_null
      - name: total_amount
        tests:
          - not_null

8.4 完整执行流程

# 运行所有模型
dbt run

# 执行所有测试
dbt test

# 生成文档
dbt docs generate
dbt docs serve

执行完成后,即可得到稳定、可复用、可追溯、带质量保障的用户消费日统计报表。

九、dbt 常用命令总结

  • dbt init:初始化项目
  • dbt debug:检查配置与连接
  • dbt run:运行所有模型
  • dbt run --select 模型名:运行单个模型
  • dbt test:执行所有测试
  • dbt seed:导入种子数据
  • dbt docs generate:生成文档
  • dbt docs serve:启动文档服务

相关资源

  • Pypi地址:https://pypi.org/project/dbt-core/
  • Github地址:https://github.com/dbt-labs/dbt-core
  • 官方文档地址:https://docs.getdbt.com/

关注我,每天分享一个实用的Python自动化工具。

Python 数据管道神器:Kedro 从入门到实战,轻松构建可复用数据工程

一、Kedro 库简介

Kedro 是面向生产级数据工程与数据科学的 Python 框架,专注标准化、可复用、可维护的数据管道构建,基于模块化与配置驱动思想,将数据处理流程拆分为节点与管道,支持版本管理、测试、文档自动化。优点是工程化规范、协作友好、适合复杂项目;缺点是轻量场景略显繁琐。采用 Apache 2.0 开源许可。

二、Kedro 安装与环境准备

2.1 环境要求

Kedro 支持 Python 3.8 及以上版本,兼容 Windows、macOS、Linux 系统,可与 pandas、numpy、scikit-learn、PySpark 无缝集成,适合单机与分布式数据工程。

2.2 安装命令

打开终端或命令提示符,执行以下命令完成安装:

pip install kedro

如需使用可视化工具,可安装扩展:

pip install kedro-viz

安装完成后,验证版本:

kedro --version

出现版本号即表示安装成功。

三、Kedro 核心概念与工作流程

3.1 核心概念

  • 项目(Project):Kedro 工程的根目录,统一管理代码、数据、配置、文档。
  • 节点(Node):最小执行单元,对应一个 Python 函数,负责单一数据处理逻辑。
  • 管道(Pipeline):多个节点按依赖关系组合而成的执行流程,自动解析执行顺序。
  • 目录(Catalog):数据入口配置文件,统一管理数据读取与写入,支持多种格式。
  • 参数(Parameters):集中管理配置参数,便于修改与环境切换。
  • 运行(Run):执行整个或部分数据管道,自动处理依赖与数据流转。

3.2 工作原理

Kedro 通过声明式编程定义数据处理逻辑,不直接硬编码读写路径与执行顺序,而是通过 YAML 配置文件声明数据与参数,通过函数定义节点逻辑,自动构建依赖图并按拓扑顺序执行,保证流程可复现、可测试、可扩展。

四、Kedro 项目创建与目录结构

4.1 创建新项目

在终端进入工作目录,执行命令创建 Kedro 项目:

kedro new

按提示输入项目名称,例如 kedro_demo,等待项目生成。

4.2 标准目录结构

kedro_demo/
├── conf/                # 配置文件(catalog、parameters)
│   ├── base/
│   └── local/
├── data/                # 数据目录(原始、中间、模型、输出)
│   ├── 01_raw/
│   ├── 02_intermediate/
│   ├── 03_primary/
│   ├── 04_feature/
│   ├── 05_model_input/
│   ├── 06_models/
│   ├── 07_model_output/
│   └── 08_reporting/
├── docs/                # 项目文档
├── kedro_demo/          # 主源码目录
│   ├── __init__.py
│   ├── __main__.py
│   ├── pipeline_registry.py  # 管道注册
│   └── pipelines/        # 管道实现
├── logs/                # 运行日志
├── tests/               # 单元测试
├── .gitignore
├── pyproject.toml
└── README.md

该结构遵循数据工程最佳实践,从原始数据到报告输出分层管理,避免混乱。

五、基础使用:从零构建第一个数据管道

5.1 编写数据处理函数

进入 kedro_demo/pipelines,创建 demo_pipeline 文件夹,新增 nodes.py

# -*- coding: utf-8 -*-
import pandas as pd

def load_raw_data(file_path: str) -> pd.DataFrame:
    """
    读取原始CSV数据
    """
    df = pd.read_csv(file_path)
    return df

def clean_data(df: pd.DataFrame) -> pd.DataFrame:
    """
    数据清洗:去重、缺失值填充
    """
    df = df.drop_duplicates()
    df = df.fillna(0)
    return df

def calculate_stats(df: pd.DataFrame) -> pd.DataFrame:
    """
    简单统计计算:新增均值列
    """
    df['mean_value'] = df.select_dtypes(include='number').mean(axis=1)
    return df

5.2 构建管道

同目录创建 pipeline.py

# -*- coding: utf-8 -*-
from kedro.pipeline import Pipeline, node
from .nodes import load_raw_data, clean_data, calculate_stats

def create_pipeline(**kwargs) -> Pipeline:
    return Pipeline(
        [
            node(
                func=load_raw_data,
                inputs="params:raw_data_path",
                outputs="raw_data",
                name="load_raw_data_node",
            ),
            node(
                func=clean_data,
                inputs="raw_data",
                outputs="cleaned_data",
                name="clean_data_node",
            ),
            node(
                func=calculate_stats,
                inputs="cleaned_data",
                outputs="stats_data",
                name="calculate_stats_node",
            ),
        ]
    )

Kedro 会根据 inputsoutputs 自动确定执行顺序。

5.3 配置数据目录

conf/base/catalog.yml 中添加:

raw_data:
  type: pandas.CSVDataSet
  filepath: data/01_raw/input.csv

cleaned_data:
  type: pandas.CSVDataSet
  filepath: data/02_intermediate/cleaned.csv

stats_data:
  type: pandas.CSVDataSet
  filepath: data/03_primary/stats.csv

5.4 配置参数

conf/base/parameters.yml 中添加:

raw_data_path: "data/01_raw/input.csv"

5.5 注册管道

打开 pipeline_registry.py,注册管道:

from kedro.framework.pipeline import Pipeline
from kedro_demo.pipelines.demo_pipeline import create_pipeline as demo_pipeline

def register_pipelines() -> dict[str, Pipeline]:
    pipelines = {
        "__default__": demo_pipeline(),
        "demo": demo_pipeline(),
    }
    return pipelines

5.6 准备测试数据

data/01_raw 下创建 input.csv

id,value1,value2
1,10,20
2,,30
3,40,
1,10,20

5.7 运行管道

在项目根目录执行:

kedro run

运行成功后,可在对应目录看到输出文件。

六、进阶使用:参数化、可视化与多环境

6.1 使用动态参数

修改 parameters.yml

raw_data_path: "data/01_raw/input.csv"
fill_value: 0
drop_duplicates: True

更新 nodes.py

def clean_data(df: pd.DataFrame, fill_value: int, drop_duplicates: bool) -> pd.DataFrame:
    if drop_duplicates:
        df = df.drop_duplicates()
    df = df.fillna(fill_value)
    return df

修改管道节点:

node(
    func=clean_data,
    inputs=dict(df="raw_data", fill_value="params:fill_value", drop_duplicates="params:drop_duplicates"),
    outputs="cleaned_data",
    name="clean_data_node",
),

再次运行即可使用新参数。

6.2 管道可视化

执行命令启动可视化服务:

kedro viz run

浏览器自动打开界面,可查看节点依赖、运行状态、数据流向,支持交互式查看。

6.3 多环境切换

Kedro 支持 baselocalprod 等多环境配置,只需在 conf/ 下新建环境文件夹,覆盖对应配置即可,运行时指定环境:

kedro run --env=prod

七、机器学习实战:基于 Kedro 的分类模型 pipeline

7.1 编写机器学习节点

新建 ml_pipeline 文件夹,nodes.py

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

def split_data(df: pd.DataFrame, test_size: float, random_state: int):
    X = df.drop('target', axis=1)
    y = df['target']
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state
    )
    return X_train, X_test, y_train, y_test

def train_model(X_train: pd.DataFrame, y_train: pd.Series):
    model = LogisticRegression(max_iter=1000)
    model.fit(X_train, y_train)
    return model

def evaluate_model(model, X_test: pd.DataFrame, y_test: pd.Series):
    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    return {"accuracy": acc}

7.2 构建机器学习管道

pipeline.py

from kedro.pipeline import Pipeline, node
from .nodes import split_data, train_model, evaluate_model

def create_ml_pipeline(**kwargs):
    return Pipeline(
        [
            node(
                split_data,
                inputs=["stats_data", "params:test_size", "params:random_state"],
                outputs=["X_train", "X_test", "y_train", "y_test"],
                name="split_data",
            ),
            node(
                train_model,
                inputs=["X_train", "y_train"],
                outputs="model",
                name="train_model",
            ),
            node(
                evaluate_model,
                inputs=["model", "X_test", "y_test"],
                outputs="metrics",
                name="evaluate_model",
            ),
        ]
    )

7.3 配置与运行

catalog.yml 配置模型与指标输出:

model:
  type: pickle.PickleDataSet
  filepath: data/06_models/model.pkl

metrics:
  type: json.JSONDataSet
  filepath: data/07_model_output/metrics.json

在参数中添加:

test_size: 0.2
random_state: 42

注册管道后运行:

kedro run

可得到训练好的模型与评估结果。

八、Kedro 优势与适用场景

Kedro 解决了数据科学项目代码混乱、难以复现、协作成本高的问题,通过强制工程化规范,让脚本式代码升级为可维护、可测试、可部署的生产级项目。适合团队协作、长期维护、需要上线部署的数据管道与机器学习项目,尤其在数据清洗、特征工程、模型训练、批量预测场景优势明显。

它将数据科学家从路径管理、依赖混乱、环境不一致中解放出来,专注算法与逻辑本身,同时让数据工程与生产环境对接更平滑,支持直接对接 Airflow、Prefect、Kubeflow 等调度工具。

相关资源

  • Pypi地址:https://pypi.org/project/kedro/
  • Github地址:https://github.com/kedro-org/kedro
  • 官方文档地址:https://docs.kedro.org/

关注我,每天分享一个实用的Python自动化工具。

Python 任务编排神器:Luigi 库从入门到实战教程

一、Luigi 库概述

Luigi 是 Spotify 开源的 Python 任务编排与工作流管理库,专注于解决复杂批量任务依赖、执行与监控问题,核心原理是通过定义任务依赖关系自动构建执行拓扑图,按依赖顺序调度任务,支持任务失败重试、断点续跑。优点是依赖管理清晰、适配大数据与批处理场景、代码侵入低、易集成;缺点是无原生分布式调度、WebUI 功能简洁。采用 Apache License 2.0 开源协议,可商用与修改。

二、Luigi 库安装与基础环境配置

2.1 安装 Luigi

Luigi 支持 Python 3.6 及以上版本,使用 pip 即可快速安装,打开命令行执行以下命令:

pip install luigi

安装完成后可通过查看版本验证是否安装成功:

luigi --version

若输出对应版本号,说明安装正常。

2.2 核心基础概念

在使用 Luigi 前,需要先掌握几个核心概念:

  1. Task(任务):所有工作流的最小单元,继承 luigi.Task 类,需重写 requires()run()output() 三个核心方法。
  2. requires():定义当前任务依赖的前置任务,无依赖则可不写或返回空。
  3. run():任务的核心逻辑,编写具体执行代码。
  4. output():定义任务执行完成后的输出目标,通常为文件、数据库标识等,用于判断任务是否已完成。
  5. Target(目标):任务输出的抽象载体,常用 LocalTarget(本地文件)、HiveTargetPostgresTarget 等。
  6. 工作流:多个 Task 通过依赖关系串联形成的完整执行流程。

三、Luigi 基础使用与代码示例

3.1 最简单的单机任务

先从无依赖的基础任务入手,创建一个生成文本文件的任务,直观感受 Luigi 的执行逻辑。

创建文件 luigi_demo_01.py,代码如下:

import luigi

# 定义基础任务,继承 luigi.Task
class CreateFile(luigi.Task):
    # 定义任务输出文件
    def output(self):
        # LocalTarget 表示本地文件目标
        return luigi.LocalTarget('hello_luigi.txt')

    # 任务执行逻辑
    def run(self):
        # self.output().open() 获取输出文件句柄
        with self.output().open('w') as f:
            f.write('Hello Luigi! 这是第一个 Luigi 任务\n')
            f.write('任务执行成功!')

if __name__ == '__main__':
    # 命令行方式启动任务
    luigi.run()

代码说明

  1. 自定义 CreateFile 任务继承 luigi.Task,是 Luigi 任务的标准写法。
  2. output() 方法指定输出为本地文件 hello_luigi.txt,Luigi 会通过该文件判断任务是否完成。
  3. run() 方法内编写业务逻辑,向目标文件写入文本内容。
  4. luigi.run() 允许以命令行参数方式启动任务。

执行命令

python luigi_demo_01.py CreateFile --local-scheduler
  • CreateFile:指定要执行的任务类名。
  • --local-scheduler:使用本地调度器,适合单机测试。

执行成功后,目录下会生成 hello_luigi.txt 文件,且再次执行相同命令时,Luigi 会检测到文件已存在,直接判定任务完成,不再重复执行,这就是 Luigi 的幂等性核心特性。

3.2 带依赖的多任务工作流

实际场景中任务通常存在依赖关系,例如先创建数据文件,再读取文件处理数据。下面实现两个任务的依赖串联。

创建文件 luigi_demo_02.py

import luigi

# 任务1:生成原始数据文件
class GenerateData(luigi.Task):
    def output(self):
        return luigi.LocalTarget('data.txt')

    def run(self):
        with self.output().open('w') as f:
            # 写入 1-10 的数字
            for i in range(1, 11):
                f.write(f'{i}\n')

# 任务2:依赖 GenerateData,计算数字总和
class CalculateSum(luigi.Task):
    # 定义依赖任务
    def requires(self):
        return GenerateData()

    def output(self):
        return luigi.LocalTarget('sum_result.txt')

    def run(self):
        # 读取依赖任务的输出文件
        with self.input().open('r') as f:
            lines = f.readlines()
            # 转换为整数并求和
            total = sum(int(line.strip()) for line in lines if line.strip())

        # 写入计算结果
        with self.output().open('w') as f:
            f.write(f'1到10的总和为:{total}')

if __name__ == '__main__':
    luigi.run()

代码说明

  1. GenerateData 任务生成包含 1-10 数字的 data.txt
  2. CalculateSum 任务通过 requires() 依赖 GenerateData,执行前会先自动运行前置任务。
  3. self.input() 可直接获取依赖任务的输出 Target,无需手动指定路径,解耦且安全。
  4. 任务执行完成后生成 sum_result.txt,存储计算结果。

执行命令

python luigi_demo_02.py CalculateSum --local-scheduler

执行流程:先运行 GenerateData 生成数据文件,再运行 CalculateSum 计算总和,若 data.txt 已存在,则跳过前置任务,直接执行计算任务。

3.3 带参数的动态任务

固定任务无法满足多变需求,Luigi 支持通过 luigi.Parameter() 定义参数,实现任务动态化。

创建文件 luigi_demo_03.py

import luigi

# 带参数的生成数据任务
class GenerateNumData(luigi.Task):
    # 定义参数:数字上限
    max_num = luigi.IntParameter(default=10)
    # 定义参数:输出文件名
    filename = luigi.Parameter(default='num_data.txt')

    def output(self):
        return luigi.LocalTarget(self.filename)

    def run(self):
        with self.output().open('w') as f:
            for i in range(1, self.max_num + 1):
                f.write(f'{i}\n')

# 带参数的求和任务
class CalculateDynamicSum(luigi.Task):
    max_num = luigi.IntParameter(default=10)
    filename = luigi.Parameter(default='num_data.txt')

    def requires(self):
        # 向依赖任务传递参数
        return GenerateNumData(max_num=self.max_num, filename=self.filename)

    def output(self):
        return luigi.LocalTarget('dynamic_sum_result.txt')

    def run(self):
        with self.input().open('r') as f:
            lines = f.readlines()
            total = sum(int(line.strip()) for line in lines if line.strip())

        with self.output().open('w') as f:
            f.write(f'1到{self.max_num}的总和为:{total}')

if __name__ == '__main__':
    luigi.run()

代码说明

  1. 使用 luigi.IntParameter 定义整数参数,luigi.Parameter 定义字符串参数,支持默认值。
  2. 子任务可通过 requires() 向父任务传递参数,保持参数一致性。
  3. 输出路径、数据范围均可通过参数动态调整,提升任务复用性。

执行命令(指定参数)

python luigi_demo_03.py CalculateDynamicSum --max-num 20 --filename my_data.txt --local-scheduler

命令中 --max-num 20 对应任务中的 max_num 参数,会生成 1-20 的数据并计算总和。

3.4 任务失败重试与断点续跑

Luigi 内置任务失败重试机制,无需手动编写异常处理,只需在任务中配置重试次数。

示例代码(添加重试配置):

import luigi
import random

class UnstableTask(luigi.Task):
    # 配置重试次数
    retry_count = luigi.IntParameter(default=3)

    def output(self):
        return luigi.LocalTarget('retry_test.txt')

    def run(self):
        # 模拟随机失败
        if random.random() < 0.7:
            raise Exception('任务随机失败,触发重试')
        with self.output().open('w') as f:
            f.write('任务重试成功!')

if __name__ == '__main__':
    luigi.run()

执行命令

python luigi_demo_04.py UnstableTask --local-scheduler

Luigi 会自动捕获任务异常,按 retry_count 配置重试,直到执行成功或耗尽重试次数,适合网络请求、数据库操作等不稳定场景。

四、Luigi 进阶使用:Web 监控与多任务调度

4.1 启动 Luigi 中央调度器与 WebUI

单机调度器仅适合测试,生产环境推荐使用 Luigi 中央调度器,自带 WebUI 可实时查看任务状态、依赖图、执行日志。

1. 启动中央调度器

luigid --port 8082

默认端口 8082,启动后访问:http://localhost:8082 即可打开 Web 监控界面。

2. 提交任务到中央调度器
执行任务时去掉 --local-scheduler,自动连接中央调度器:

python luigi_demo_02.py CalculateSum

在 WebUI 中可查看任务执行进度、失败任务、依赖关系,支持手动终止任务。

4.2 多任务并行执行

Luigi 支持多进程并行执行无依赖的任务,提升执行效率,通过命令行参数指定并行进程数:

python luigi_multi_task.py MainTask --workers 4

--workers 4 表示使用 4 个进程并行执行,无依赖的任务会同时运行,有依赖的任务按顺序执行。

4.3 封装为可复用任务模块

实际项目中会将任务按功能拆分到不同模块,标准目录结构如下:

luigi_project/
├── tasks/
│   ├── __init__.py
│   ├── data_task.py    # 数据生成、清洗任务
│   ├── compute_task.py # 计算、分析任务
│   └── output_task.py  # 结果输出任务
├── config/
│   └── luigi.cfg       # Luigi 配置文件
└── main.py             # 任务入口

luigi.cfg 可配置默认调度器、重试次数、日志路径等,简化命令行参数:

[core]
default-scheduler-host = localhost
default-scheduler-port = 8082
retry-attempts = 3

五、真实场景实战案例:数据清洗与统计分析工作流

5.1 案例需求

模拟企业日常数据处理流程,完成以下任务:

  1. 生成原始 CSV 数据(包含姓名、年龄、城市、销售额)。
  2. 清洗数据:去除空值、过滤异常年龄、标准化城市名称。
  3. 统计分析:按城市分组计算总销售额、平均年龄。
  4. 输出统计结果到文本文件。

5.2 完整代码实现

创建文件 data_workflow.py

import luigi
import csv
import os

# 任务1:生成原始 CSV 数据
class GenerateRawData(luigi.Task):
    def output(self):
        return luigi.LocalTarget('raw_data.csv')

    def run(self):
        # 模拟业务数据
        data = [
            ['姓名', '年龄', '城市', '销售额'],
            ['张三', 25, '北京', 5000],
            ['李四', 32, '上海', 8000],
            ['王五', '', '广州', 6000],
            ['赵六', 40, '深圳', 12000],
            ['钱七', 150, '北京', 3000],
            ['孙八', 28, '上海', 9000],
            ['周九', None, '广州', 4500]
        ]
        with self.output().open('w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerows(data)

# 任务2:清洗数据
class CleanData(luigi.Task):
    def requires(self):
        return GenerateRawData()

    def output(self):
        return luigi.LocalTarget('cleaned_data.csv')

    def run(self):
        with self.input().open('r', encoding='utf-8') as f:
            reader = csv.reader(f)
            header = next(reader)
            cleaned_data = [header]

            for row in reader:
                # 跳过空值行
                if not all(row):
                    continue
                name, age, city, sales = row
                # 过滤异常年龄
                try:
                    age = int(age)
                    sales = int(sales)
                except:
                    continue
                if 18 <= age <= 60:
                    cleaned_data.append([name, age, city, sales])

        with self.output().open('w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerows(cleaned_data)

# 任务3:按城市统计销售额
class CitySalesStat(luigi.Task):
    def requires(self):
        return CleanData()

    def output(self):
        return luigi.LocalTarget('city_sales_report.txt')

    def run(self):
        city_stat = {}
        with self.input().open('r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                city = row['城市']
                sales = int(row['销售额'])
                age = int(row['年龄'])

                if city not in city_stat:
                    city_stat[city] = {'total_sales': 0, 'total_age': 0, 'count': 0}
                city_stat[city]['total_sales'] += sales
                city_stat[city]['total_age'] += age
                city_stat[city]['count'] += 1

        # 生成报告
        with self.output().open('w', encoding='utf-8') as f:
            f.write('城市销售统计报告\n')
            f.write('='*30 + '\n')
            for city, data in city_stat.items():
                avg_age = data['total_age'] / data['count']
                f.write(f'城市:{city}\n')
                f.write(f'总销售额:{data["total_sales"]}\n')
                f.write(f'平均年龄:{avg_age:.2f}\n')
                f.write('-'*30 + '\n')

if __name__ == '__main__':
    luigi.run()

5.3 案例执行与效果

执行命令

python data_workflow.py CitySalesStat --local-scheduler

执行流程:

  1. 生成 raw_data.csv 原始数据。
  2. 清洗空值、异常年龄,生成 cleaned_data.csv
  3. 按城市统计数据,生成 city_sales_report.txt 报告。

该案例完整还原了企业数据处理流程,体现了 Luigi 依赖管理、断点续跑、任务复用的核心价值,可直接扩展对接数据库、Hive、Spark 等大数据组件。

相关资源

  • Pypi地址:https://pypi.org/project/luigi/
  • Github地址:https://github.com/spotify/luigi
  • 官方文档地址:https://luigi.readthedocs.io/

关注我,每天分享一个实用的Python自动化工具。

Python 任务队列利器:rq 从入门到实战完全指南

一、rq 库概述

rq 全称 Redis Queue,是一款基于 Redis 开发的轻量级 Python 任务队列库,专注于处理异步任务与后台任务,核心原理是将任务函数与参数存入 Redis,由独立工作进程异步拉取执行,规避同步任务阻塞主程序的问题。该库采用 MIT 开源许可,优点是轻量简洁、易上手、无额外依赖、适配小型到中型项目;缺点是不支持复杂任务调度、集群能力较弱,更适合轻量异步场景。

二、rq 环境安装与基础配置

2.1 环境依赖准备

rq 强依赖 Redis 数据库,使用前需先完成 Redis 安装与启动,Windows、macOS、Linux 均有对应安装方式,安装后通过 redis-server 启动服务,默认端口 6379,无密码时本地可直接连接。

2.2 rq 库安装

rq 仅需通过 pip 即可安装,命令简洁无复杂配置,适合新手快速部署:

pip install rq

安装完成后,可通过导入 rq 验证是否成功,无报错则说明安装正常:

import rq
print(rq.__version__)

2.3 Redis 连接配置

rq 默认连接本地 Redis,如需自定义主机、端口、密码、数据库,可创建 Redis 连接对象,适配不同部署环境:

from redis import Redis
from rq import Queue

# 默认本地连接
redis_conn = Redis(host='localhost', port=6379, db=0)

# 带密码的远程连接
# redis_conn = Redis(host='xxx.xxx.xxx.xxx', port=6379, password='your_password', db=0)

# 初始化任务队列
task_queue = Queue(connection=redis_conn)

这段代码的作用是建立 Python 程序与 Redis 的通信通道,所有任务都会通过该连接存入 Redis 队列,是 rq 运行的基础配置。

三、rq 核心使用方式与基础代码示例

3.1 定义可被执行的任务函数

rq 的任务本质是普通 Python 函数,需满足可被序列化无复杂闭包的条件,先定义简单任务用于测试:

# tasks.py 任务文件,单独存放便于管理
import time

def simple_task(name):
    """简单异步任务:模拟耗时操作"""
    time.sleep(2)
    return f"Hello {name}, 异步任务执行完成!"

def calculate_sum(a, b):
    """计算两数之和的任务"""
    time.sleep(1)
    return f"{a} + {b} = {a + b}"

将任务单独放在 tasks.py 文件,是因为 rq 工作进程需要通过模块路径找到函数,分散存放会导致任务无法执行。

3.2 向队列添加异步任务

创建主程序文件,将任务函数加入队列,实现异步提交,不阻塞主程序运行:

# main.py 主程序文件
from redis import Redis
from rq import Queue
from tasks import simple_task, calculate_sum

# 连接 Redis
redis_conn = Redis(host='localhost', port=6379, db=0)
queue = Queue(connection=redis_conn)

# 提交任务到队列,非阻塞执行
job1 = queue.enqueue(simple_task, "Python开发者")
job2 = queue.enqueue(calculate_sum, 10, 20)

# 输出任务ID,用于后续查询
print(f"任务1 ID: {job1.id}")
print(f"任务2 ID: {job2.id}")
print("主程序继续执行,无需等待任务完成")

代码说明:queue.enqueue() 是核心提交方法,第一个参数为任务函数,后续为函数参数,调用后立即返回任务对象,主程序不会等待任务执行,实现异步解耦。

3.3 启动 rq 工作进程执行任务

任务提交到 Redis 后,需要启动工作进程消费队列任务,打开新的命令行窗口,进入项目目录,执行:

rq worker

执行后工作进程会持续监听 Redis 队列,一旦有新任务就立即执行,输出任务执行日志,执行完成后返回结果。

3.4 查看任务执行状态与结果

rq 提供丰富的任务状态查询方法,可在主程序中获取任务是否完成、结果、失败原因:

# result_check.py 任务结果查询
from redis import Redis
from rq import Queue
from rq.job import Job

redis_conn = Redis(host='localhost', port=6379, db=0)
queue = Queue(connection=redis_conn)

# 通过任务ID获取任务
job = Job.fetch('你的任务ID', connection=redis_conn)

# 查询任务状态
print(f"任务是否执行完成: {job.is_finished}")
print(f"任务是否执行失败: {job.is_failed}")
print(f"任务执行结果: {job.result}")
print(f"任务执行状态: {job.get_status()}")

任务状态分为 queued(排队中)、started(执行中)、finished(已完成)、failed(执行失败),可根据状态做后续业务处理。

四、rq 进阶功能使用

4.1 多队列管理

rq 支持创建多个队列,分类处理不同类型任务,避免任务阻塞:

# multi_queue.py
from redis import Redis
from rq import Queue
from tasks import simple_task, calculate_sum

redis_conn = Redis(host='localhost', port=6379, db=0)

# 创建不同优先级/类型的队列
high_queue = Queue('high', connection=redis_conn)
low_queue = Queue('low', connection=redis_conn)

# 向指定队列提交任务
high_queue.enqueue(calculate_sum, 100, 200)
low_queue.enqueue(simple_task, "普通用户")

启动工作进程时可指定监听队列:

rq worker high low

4.2 任务延迟执行

rq 支持设置任务延迟执行时间,单位为秒,满足定时异步任务需求:

# 延迟5秒执行任务
job = queue.enqueue(simple_task, "延迟任务", delay=5)

代码说明:delay 参数指定任务提交后,等待指定秒数再被工作进程执行,适用于延迟通知、延迟审核等场景。

4.3 任务失败重试与异常处理

为任务设置失败重试次数、重试间隔,提升任务执行稳定性:

# retry_task.py 带重试的任务
from rq import Retry

# 最多重试3次,每次间隔2秒
retry_strategy = Retry(max=3, interval=2)
job = queue.enqueue(simple_task, "重试任务", retry=retry_strategy)

同时可在任务函数中捕获异常,记录失败原因:

def error_task():
    try:
        # 可能出错的逻辑
        1 / 0
    except Exception as e:
        print(f"任务执行异常: {str(e)}")
        raise  # 抛出异常让rq标记任务失败

4.4 清空队列与删除任务

运维场景中可清空队列、删除指定任务,避免无效任务堆积:

# 清空当前队列所有任务
queue.empty()

# 删除指定任务
job.delete()

五、rq 实际业务场景案例

5.1 案例一:异步发送邮件

实际项目中,发送邮件是耗时操作,用 rq 异步处理可提升接口响应速度:

# email_task.py
import time
import smtplib
from email.mime.text import MIMEText

def send_async_email(to_email, content):
    """异步发送邮件任务"""
    try:
        # 模拟邮件发送(实际项目替换为真实邮件配置)
        time.sleep(3)
        msg = MIMEText(content, 'plain', 'utf-8')
        msg['From'] = '[email protected]'
        msg['To'] = to_email
        msg['Subject'] = '异步邮件通知'

        # 模拟发送成功
        print(f"邮件已发送至: {to_email}")
        return True
    except Exception as e:
        print(f"邮件发送失败: {str(e)}")
        return False

提交邮件任务:

# 提交异步邮件任务,不阻塞主程序
queue.enqueue(send_async_email, "[email protected]", "您的订单已支付成功")

5.2 案例二:异步生成报表

大数据量报表生成耗时较长,通过 rq 后台执行,生成完成后通知用户:

# report_task.py
import time
import pandas as pd

def generate_excel_report(data_list, save_path):
    """异步生成Excel报表"""
    time.sleep(5)  # 模拟数据处理耗时
    df = pd.DataFrame(data_list)
    df.to_excel(save_path, index=False)
    return f"报表生成完成,保存路径: {save_path}"

提交报表任务:

data = [{'name': '张三', 'score': 90}, {'name': '李四', 'score': 85}]
queue.enqueue(generate_excel_report, data, "./report.xlsx")

5.3 完整项目运行流程

  1. 启动 Redis 服务:redis-server
  2. 定义任务函数到 tasks.py
  3. 主程序提交任务到队列
  4. 启动工作进程:rq worker
  5. 查看任务执行状态与结果
  6. 业务逻辑根据任务结果做后续处理

该流程可直接应用于 Web 项目、自动化脚本、数据分析工具中,解决同步任务阻塞问题。

六、相关资源

  • Pypi地址:https://pypi.org/project/rq/
  • Github地址:https://github.com/rq/rq
  • 官方文档地址:https://python-rq.org/

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:joblib 高效序列化与并行计算详解

一、joblib 库概述

joblib 是 Python 生态中轻量且高效的工具库,核心用于对象序列化、模型持久化、并行计算,尤其适配大数据对象与机器学习场景。其原理是优化 pickle 序列化逻辑,支持大数组分块存储,借助多进程实现并行加速。优点是轻量无依赖、读写速度快、内存占用低、并行接口简洁;缺点是不适合跨语言使用,复杂自定义对象兼容性有限。该库采用 BSD 开源许可证,可自由商用与修改。

二、joblib 安装方法

joblib 安装无需复杂环境配置,支持 pip 与 conda 两种安装方式,适配所有主流 Python 版本与操作系统。

1. pip 安装(推荐)

打开命令行工具,执行以下命令即可完成安装:

pip install joblib

2. conda 安装

若使用 Anaconda 或 Miniconda 环境,可执行:

conda install -c anaconda joblib

安装完成后,在 Python 脚本中直接导入即可使用,无额外配置步骤:

import joblib
print(joblib.__version__)

执行后输出版本号,即代表安装成功。

三、joblib 核心功能与代码实例

joblib 核心功能分为两大模块:对象持久化(dump/load)并行计算(Parallel/delayed),同时提供内存缓存、压缩存储等辅助功能,覆盖日常开发与机器学习全场景。

3.1 基础对象持久化:dump 与 load

序列化与反序列化是 joblib 最基础的功能,替代原生 pickle 模块,针对 numpy 数组、pandas 数据框、机器学习模型做了深度优化,读写速度远超 pickle,且支持大文件分块存储。

3.1.1 基础数据类型存储与读取

演示存储列表、字典、数值等基础数据类型:

import joblib

# 定义测试数据
data_list = [1, 2, 3, 4, 5]
data_dict = {"name": "joblib教程", "version": 1.3, "function": ["序列化", "并行计算"]}
number = 100

# 序列化保存数据
joblib.dump(data_list, "data_list.pkl")
joblib.dump(data_dict, "data_dict.pkl")
joblib.dump(number, "number.pkl")

# 反序列化读取数据
load_list = joblib.load("data_list.pkl")
load_dict = joblib.load("data_dict.pkl")
load_number = joblib.load("number.pkl")

print("读取列表:", load_list)
print("读取字典:", load_dict)
print("读取数值:", load_number)

代码说明joblib.dump(对象, 保存路径) 用于将 Python 对象写入文件,joblib.load(文件路径) 用于读取文件还原对象,操作逻辑与 pickle 一致,但底层优化更适合大数据对象。

3.1.2 压缩存储

joblib 支持直接存储压缩文件,节省磁盘空间,支持 gzip、bz2、xz 三种压缩格式,只需在文件名后缀标注即可:

import joblib
import numpy as np

# 生成大型 numpy 数组
big_array = np.random.rand(10000, 1000)

# 压缩存储为 gzip 格式
joblib.dump(big_array, "big_array.gz", compress=("gzip", 3))
# compress 参数可指定压缩等级,范围 0-9,数值越大压缩率越高,速度越慢

# 读取压缩文件
load_array = joblib.load("big_array.gz")
print("压缩后数组形状:", load_array.shape)
print("数组占用内存:", load_array.nbytes / 1024 / 1024, "MB")

代码说明:大数据对象直接存储会占用大量磁盘空间,使用压缩存储可减少 50%-90% 空间,joblib 读取时会自动解压,无需手动处理。

3.1.3 多对象合并存储

joblib 支持一次性存储多个对象,读取时按顺序还原,适合批量保存相关数据:

import joblib

# 定义多个对象
obj1 = [1, 2, 3]
obj2 = {"a": 1, "b": 2}
obj3 = "joblib多对象存储"

# 合并存储
joblib.dump([obj1, obj2, obj3], "multi_obj.pkl")

# 批量读取
load_obj1, load_obj2, load_obj3 = joblib.load("multi_obj.pkl")

print(load_obj1)
print(load_obj2)
print(load_obj3)

代码说明:将多个对象放入列表中统一存储,读取时按存储顺序解包,简化多文件管理逻辑。

3.2 机器学习模型持久化

joblib 诞生初衷就是为了解决机器学习模型存储问题,是 scikit-learn 官方推荐的模型保存工具,完美适配决策树、随机森林、逻辑回归、SVM 等模型。

import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# 加载数据集并划分训练集测试集
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)

# 训练随机森林模型
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)

# 保存训练好的模型
joblib.dump(model, "iris_rf_model.pkl")

# 加载模型并预测
load_model = joblib.load("iris_rf_model.pkl")
predict = load_model.predict(X_test)

print("预测结果:", predict[:5])
print("模型准确率:", load_model.score(X_test, y_test))

代码说明:机器学习模型训练耗时较长,使用 joblib 保存后,下次直接加载即可预测,无需重新训练,极大提升开发效率。

3.3 并行计算:Parallel 与 delayed

原生 Python 多进程代码繁琐,joblib 封装了 Paralleldelayed 装饰器,一行代码实现多进程并行,大幅提升循环任务执行速度。

3.3.1 基础并行任务

import time
import joblib
from joblib import Parallel, delayed

# 定义单任务函数
def task_func(x):
    time.sleep(1)
    return x * x

# 串行执行
start = time.time()
serial_result = [task_func(i) for i in range(8)]
print("串行耗时:", time.time() - start, "秒")

# 并行执行(4进程)
start = time.time()
parallel_result = Parallel(n_jobs=4)(delayed(task_func)(i) for i in range(8))
print("并行耗时:", time.time() - start, "秒")
print("并行结果:", parallel_result)

代码说明n_jobs 指定进程数,设置为 -1 表示使用 CPU 全部核心,delayed 用于包装需要并行执行的函数,并行执行时间随进程数增加显著缩短。

3.3.2 带参数的并行任务

from joblib import Parallel, delayed

def calc_func(a, b, power):
    return (a + b) ** power

# 多参数并行执行
result = Parallel(n_jobs=2)(delayed(calc_func)(i, i+1, 2) for i in range(5))
print("多参数并行结果:", result)

代码说明delayed 可传递任意数量参数,适配复杂业务函数,无需修改函数本身逻辑。

3.3.3 并行进度显示

处理大量任务时,可通过 verbose 参数显示执行进度:

from joblib import Parallel, delayed
import time

def long_task(x):
    time.sleep(0.5)
    return x

# 显示进度
result = Parallel(n_jobs=3, verbose=10)(delayed(long_task)(i) for i in range(20))

代码说明verbose 数值越大,进度输出越详细,方便监控长时间并行任务的执行状态。

3.4 内存缓存:Memory

joblib 提供内存缓存功能,缓存函数执行结果,重复调用时直接读取缓存,避免重复计算,适合耗时较长的函数。

import joblib
import time

# 创建缓存目录
memory = joblib.Memory(location="cache_dir", verbose=0)

# 装饰器缓存函数
@memory.cache
def slow_calculate(n):
    time.sleep(2)
    return sum(range(n+1))

# 第一次执行:计算并缓存
start = time.time()
print(slow_calculate(10000))
print("第一次执行耗时:", time.time() - start, "秒")

# 第二次执行:直接读取缓存
start = time.time()
print(slow_calculate(10000))
print("第二次执行耗时:", time.time() - start, "秒")

代码说明:函数参数不变时,直接返回缓存结果,参数变化时重新计算,自动管理缓存文件,无需手动清理。

四、实际综合案例:机器学习模型训练与批量预测

结合 joblib 序列化、并行计算功能,实现完整的机器学习模型保存、加载、批量预测流程。

import joblib
import numpy as np
from sklearn.svm import SVC
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from joblib import Parallel, delayed

# 1. 加载数据并训练模型
digits = load_digits()
X, y = digits.data, digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

model = SVC()
model.fit(X_train, y_train)

# 2. 保存模型
joblib.dump(model, "digits_svc_model.pkl")
print("模型保存成功")

# 3. 加载模型
load_model = joblib.load("digits_svc_model.pkl")

# 4. 并行批量预测
def predict_single(idx):
    sample = X_test[idx:idx+1]
    return load_model.predict(sample)[0]

# 并行预测前100个样本
predict_result = Parallel(n_jobs=-1)(delayed(predict_single)(i) for i in range(100))
print("批量预测结果:", predict_result[:10])

# 5. 保存预测结果
joblib.dump(predict_result, "predict_result.pkl")
print("预测结果保存成功")

案例说明:该案例覆盖 joblib 三大核心功能,模型持久化避免重复训练,并行预测提升推理速度,结果序列化方便后续分析,是工业级项目常用开发模式。

相关资源

  • Pypi地址:https://pypi.org/project/joblib/
  • Github地址:https://github.com/joblib/joblib
  • 官方文档地址:https://joblib.readthedocs.io/

关注我,每天分享一个实用的Python自动化工具。