Python实用工具:python-magic 零基础入门教程——文件类型检测与分析实战

一、python-magic 库核心概述

1.1 核心用途

python-magic 是一款基于 libmagic 库封装的 Python 工具库,其核心功能是通过文件内容而非扩展名来识别文件的真实类型。在日常开发中,我们经常会遇到文件扩展名被篡改、后缀名丢失的场景,比如下载的压缩包被恶意修改为 .txt 后缀,或者老旧文件的扩展名损坏,此时 python-magic 就能通过解析文件的二进制特征码,准确判断出文件的真实格式,广泛应用于文件校验、爬虫数据处理、系统安全检测等领域。

1.2 工作原理

python-magic 的底层依赖于 Unix/Linux 系统中的 libmagic 库(Windows 系统需手动安装对应依赖),该库内置了一个魔法数字数据库,这个数据库中存储了各种文件格式对应的二进制特征值。当 python-magic 处理文件时,会读取文件头部的若干字节数据,将其与魔法数字数据库中的特征值进行匹配,从而判定文件的真实类型。整个过程无需解析文件完整内容,因此执行效率高,且不受文件扩展名的干扰。

1.3 优缺点分析

优点

  • 识别精度高:基于文件内容的检测方式,能够绕过扩展名伪装,识别出文件的真实格式。
  • 支持格式广泛:涵盖了文档、图片、音频、视频、压缩包等数千种常见文件格式。
  • 跨平台兼容:支持 Windows、macOS、Linux 等主流操作系统(需注意依赖库的安装差异)。
  • 轻量级易用:API 设计简洁,几行代码即可完成文件类型检测,学习成本低。

缺点

  • 依赖外部库:Windows 系统下需要手动安装 libmagic 依赖包,相比纯 Python 库安装步骤稍复杂。
  • 对部分小众格式支持有限:对于一些冷门的自定义文件格式,可能无法匹配到对应的魔法数字,导致识别失败。
  • 无法解析文件内容细节:该库仅能判断文件类型,不能提取文件的具体内容信息,如图片分辨率、文档字数等。

1.4 License 类型

python-magic 采用的是 MIT License,这是一种宽松的开源许可证。用户可以自由地使用、复制、修改、分发该库的源代码,无论是个人项目还是商业项目,都无需支付任何费用,只需在分发的软件中保留原作者的版权声明即可。

二、python-magic 安装步骤

2.1 系统依赖准备

python-magic 依赖系统底层的 libmagic 库,不同操作系统的安装方式有所不同,具体步骤如下:

  1. Linux 系统
    对于 Debian/Ubuntu 系列发行版,执行以下命令安装:
    bash sudo apt-get update sudo apt-get install libmagic1
    对于 CentOS/RHEL 系列发行版,执行以下命令安装:
    bash sudo yum install file-devel
  2. macOS 系统
    使用 Homebrew 包管理器安装,执行命令:
    bash brew install libmagic
  3. Windows 系统 Windows 系统没有默认的包管理器,需要手动下载 libmagic 依赖文件:
    • 访问 GnuWin32 网站,下载 file-5.39-bin.zipfile-5.39-dep.zip 两个压缩包。
    • 解压两个压缩包,将 bin 目录下的 libmagic-1.dllmagic.exe 文件复制到 Python 的安装目录下的 Scripts 文件夹中。
    • share 目录下的 magic 文件夹复制到 Python 安装目录的根目录下。

2.2 Python 库安装

在完成系统依赖安装后,通过 pip 命令即可安装 python-magic 库,执行以下命令:

pip install python-magic

安装完成后,可以在 Python 交互环境中执行以下代码验证是否安装成功:

import magic
print(magic.__version__)

如果没有报错,并且输出了对应的版本号(如 0.4.27),则说明安装成功。

三、python-magic 核心 API 与使用实例

python-magic 提供了两种核心使用方式:一种是直接调用函数式 API,另一种是创建 Magic 类实例进行自定义配置。下面我们分别介绍这两种方式的使用方法,并结合实例代码进行演示。

3.1 函数式 API 快速使用

函数式 API 封装了最常用的文件类型检测功能,适合快速开发场景,核心函数包括 detect_from_filenamedetect_from_contentdetect_from_fobj

3.1.1 detect_from_filename:通过文件名检测文件类型

该函数接收一个文件路径作为参数,返回一个包含文件类型信息的字典,字典中包含 mime_type(MIME 类型)和 encoding(编码格式,仅文本文件有该字段)两个键。

实例代码

import magic

# 定义测试文件路径(替换为你自己的文件路径)
test_file_path = "test.jpg"

# 检测文件类型
result = magic.detect_from_filename(test_file_path)

# 输出检测结果
print(f"文件 MIME 类型: {result.mime_type}")
print(f"文件编码格式: {result.encoding}")

代码说明

  • 首先导入 magic 库,然后定义需要检测的文件路径。
  • 调用 detect_from_filename 函数,传入文件路径,该函数会自动读取文件内容并进行类型检测。
  • 打印返回结果中的 MIME 类型和编码格式。对于图片文件,encoding 字段通常为 None;对于文本文件(如 .txt),会返回具体的编码格式(如 utf-8gbk 等)。

运行结果示例

文件 MIME 类型: image/jpeg
文件编码格式: None

3.1.2 detect_from_content:通过文件内容检测文件类型

在某些场景下,我们可能没有完整的文件,只有文件的二进制内容(如网络请求中获取的文件流),此时可以使用 detect_from_content 函数,直接传入二进制数据进行检测。

实例代码

import magic

# 读取文件的二进制内容
with open("test.txt", "rb") as f:
    file_content = f.read()

# 通过二进制内容检测文件类型
result = magic.detect_from_content(file_content)

# 输出检测结果
print(f"文件 MIME 类型: {result.mime_type}")
print(f"文件编码格式: {result.encoding}")

代码说明

  • 使用 rb 模式打开文本文件,读取其二进制内容并存储到 file_content 变量中。
  • 调用 detect_from_content 函数,传入二进制内容,函数会根据内容特征判断文件类型。
  • 打印检测结果,对于 UTF-8 编码的文本文件,运行结果会显示 mime_typetext/plainencodingutf-8

运行结果示例

文件 MIME 类型: text/plain
文件编码格式: utf-8

3.1.3 detect_from_fobj:通过文件对象检测文件类型

如果已经打开了一个文件对象,不需要再次读取文件内容,直接使用 detect_from_fobj 函数传入文件对象即可完成检测,这种方式适合处理大文件,避免重复读取数据。

实例代码

import magic

# 打开文件对象
with open("test.zip", "rb") as f:
    # 通过文件对象检测文件类型
    result = magic.detect_from_fobj(f)

# 输出检测结果
print(f"文件 MIME 类型: {result.mime_type}")
print(f"文件编码格式: {result.encoding}")

代码说明

  • 使用 rb 模式打开压缩包文件,得到文件对象 f
  • 调用 detect_from_fobj 函数,传入文件对象,函数会从文件头部读取特征字节进行检测。
  • 打印检测结果,压缩包文件的 MIME 类型通常为 application/zip,编码格式为 None

运行结果示例

文件 MIME 类型: application/zip
文件编码格式: None

3.2 Magic 类实例化使用

函数式 API 虽然方便,但可配置性较低。如果需要自定义检测规则(如只检测 MIME 类型、显示详细的文件描述信息等),可以通过实例化 Magic 类来实现。Magic 类的构造函数支持多个参数,常用参数如下:

  • mime:布尔值,设置为 True 时,仅返回 MIME 类型;默认为 False,返回详细的文件描述信息。
  • mime_encoding:布尔值,设置为 True 时,返回 MIME 类型和编码格式;默认为 False
  • keep_going:布尔值,设置为 True 时,会输出所有匹配到的文件类型信息;默认为 False,仅输出第一个匹配结果。
  • uncompress:布尔值,设置为 True 时,会自动解压压缩文件后再进行检测;默认为 False

3.2.1 自定义检测 MIME 类型

通过设置 mime=True,可以让 Magic 实例仅返回文件的 MIME 类型,适合需要标准化文件类型标识的场景。

实例代码

import magic

# 实例化 Magic 类,指定仅返回 MIME 类型
mime_magic = magic.Magic(mime=True)

# 检测图片文件的 MIME 类型
image_type = mime_magic.from_file("test.png")
print(f"PNG 图片 MIME 类型: {image_type}")

# 检测音频文件的 MIME 类型
audio_type = mime_magic.from_file("test.mp3")
print(f"MP3 音频 MIME 类型: {audio_type}")

代码说明

  • 实例化 Magic 类时传入 mime=True 参数,创建一个专门用于检测 MIME 类型的实例 mime_magic
  • 调用实例的 from_file 方法,传入文件路径,分别检测 PNG 图片和 MP3 音频文件的 MIME 类型。
  • 打印检测结果,PNG 图片的 MIME 类型为 image/png,MP3 音频的 MIME 类型为 audio/mpeg

运行结果示例

PNG 图片 MIME 类型: image/png
MP3 音频 MIME 类型: audio/mpeg

3.2.2 获取文件详细描述信息

默认情况下,Magic 实例会返回文件的详细描述信息,包括文件格式、版本等内容,适合需要向用户展示直观文件类型的场景。

实例代码

import magic

# 实例化 Magic 类,获取详细文件描述
detail_magic = magic.Magic()

# 检测不同类型文件的详细信息
txt_detail = detail_magic.from_file("test.txt")
pdf_detail = detail_magic.from_file("test.pdf")
exe_detail = detail_magic.from_file("test.exe")

# 输出详细信息
print(f"文本文件详细信息: {txt_detail}")
print(f"PDF 文件详细信息: {pdf_detail}")
print(f"EXE 程序详细信息: {exe_detail}")

代码说明

  • 实例化 Magic 类时不传入任何参数,创建的 detail_magic 实例会返回详细的文件描述信息。
  • 分别调用 from_file 方法检测文本文件、PDF 文件和 Windows 可执行文件的详细信息。
  • 打印检测结果,文本文件会显示编码格式和文件类型,PDF 文件会显示 PDF document 相关描述,EXE 文件会显示 PE32 executable 相关信息。

运行结果示例

文本文件详细信息: UTF-8 Unicode text
PDF 文件详细信息: PDF document, version 1.5
EXE 程序详细信息: PE32 executable (GUI) Intel 80386, for MS Windows

3.2.3 同时获取 MIME 类型和编码格式

通过设置 mime_encoding=True,可以让 Magic 实例同时返回文件的 MIME 类型和编码格式,这种方式比函数式 API 中的 detect_from_filename 更灵活,支持自定义其他参数。

实例代码

import magic

# 实例化 Magic 类,同时获取 MIME 类型和编码格式
mime_enc_magic = magic.Magic(mime_encoding=True)

# 检测文本文件和 CSV 文件
result1 = mime_enc_magic.from_file("test.txt")
result2 = mime_enc_magic.from_file("test.csv")

# 输出结果
print(f"文本文件 MIME 及编码: {result1}")
print(f"CSV 文件 MIME 及编码: {result2}")

代码说明

  • 实例化 Magic 类时传入 mime_encoding=True 参数,创建的 mime_enc_magic 实例会返回 MIME 类型; 编码格式 的组合字符串。
  • 分别检测文本文件和 CSV 文件,CSV 文件本质上也是文本文件,但其 MIME 类型通常为 text/csv
  • 打印检测结果,UTF-8 编码的文本文件和 CSV 文件会分别显示对应的 MIME 类型和编码格式。

运行结果示例

文本文件 MIME 及编码: text/plain; charset=utf-8
CSV 文件 MIME 及编码: text/csv; charset=utf-8

3.3 处理特殊场景文件

3.3.1 检测伪装扩展名的文件

在实际开发中,经常会遇到文件扩展名被篡改的情况,比如将 .exe 病毒文件伪装成 .jpg 图片文件,此时使用 python-magic 可以轻松识别出文件的真实类型。

实例代码

import magic

# 假设存在一个伪装为图片的可执行文件 fake.jpg(实际是 exe 文件)
fake_file_path = "fake.jpg"

# 使用 Magic 类检测真实类型
real_magic = magic.Magic()
real_type = real_magic.from_file(fake_file_path)

# 使用 MIME 类型检测
mime_magic = magic.Magic(mime=True)
real_mime = mime_magic.from_file(fake_file_path)

# 输出检测结果
print(f"文件伪装扩展名: .jpg")
print(f"文件真实类型描述: {real_type}")
print(f"文件真实 MIME 类型: {real_mime}")

代码说明

  • 准备一个伪装扩展名的文件 fake.jpg,其实际是 Windows 可执行文件。
  • 分别使用返回详细描述的 real_magic 和返回 MIME 类型的 mime_magic 实例检测文件。
  • 打印检测结果,可以看到文件的真实类型是 PE32 可执行文件,MIME 类型为 application/x-dosexec,从而识破文件的伪装。

运行结果示例

文件伪装扩展名: .jpg
文件真实类型描述: PE32 executable (GUI) Intel 80386, for MS Windows
文件真实 MIME 类型: application/x-dosexec

3.3.2 检测压缩包内的文件类型

通过设置 uncompress=TrueMagic 实例可以自动解压压缩文件(如 .gz.bz2 等格式),并检测压缩包内文件的真实类型,适合处理压缩文件的场景。

实例代码

import magic

# 实例化 Magic 类,开启自动解压功能
uncompress_magic = magic.Magic(uncompress=True)

# 检测 gz 压缩包内的文件类型
gz_file_detail = uncompress_magic.from_file("test.txt.gz")
print(f"gz 压缩包内文件类型: {gz_file_detail}")

# 检测 bz2 压缩包内的文件类型
bz2_file_detail = uncompress_magic.from_file("test.csv.bz2")
print(f"bz2 压缩包内文件类型: {bz2_file_detail}")

代码说明

  • 实例化 Magic 类时传入 uncompress=True 参数,开启自动解压功能。
  • 分别检测 .gz.bz2 格式的压缩包文件,uncompress_magic 会先解压压缩包,再检测内部文件的类型。
  • 打印检测结果,.txt.gz 压缩包内的文件会显示为 UTF-8 编码的文本文件,.csv.bz2 压缩包内的文件会显示为 CSV 文本文件。

运行结果示例

gz 压缩包内文件类型: UTF-8 Unicode text (gzip compressed data, was "test.txt", last modified: ...)
bz2 压缩包内文件类型: CSV text (bzip2 compressed data, block size = 900k)

四、python-magic 实际应用案例

4.1 批量检测文件夹内文件类型

在数据处理场景中,我们经常需要对一个文件夹内的所有文件进行类型检测,筛选出特定类型的文件。下面的案例实现了批量检测文件夹内所有文件的类型,并将结果保存到 CSV 文件中。

实例代码

import os
import csv
import magic

def batch_detect_file_type(folder_path, output_csv):
    """
    批量检测文件夹内文件类型,并将结果保存到 CSV 文件
    :param folder_path: 目标文件夹路径
    :param output_csv: 输出 CSV 文件路径
    """
    # 实例化 Magic 类,获取 MIME 类型和编码
    mime_enc_magic = magic.Magic(mime_encoding=True)

    # 准备 CSV 表头
    headers = ["文件名", "文件路径", "MIME类型及编码"]

    # 打开 CSV 文件并写入数据
    with open(output_csv, "w", newline="", encoding="utf-8") as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(headers)

        # 遍历文件夹内所有文件
        for root, dirs, files in os.walk(folder_path):
            for file_name in files:
                # 获取文件完整路径
                file_path = os.path.join(root, file_name)
                try:
                    # 检测文件类型
                    file_type = mime_enc_magic.from_file(file_path)
                    # 写入 CSV 数据
                    writer.writerow([file_name, file_path, file_type])
                    print(f"已检测: {file_path} -> {file_type}")
                except Exception as e:
                    # 捕获异常,处理无法检测的文件
                    print(f"检测失败: {file_path} -> 错误信息: {str(e)}")
                    writer.writerow([file_name, file_path, f"检测失败: {str(e)}"])

# 调用函数,批量检测 test_folder 文件夹内文件,并保存到 file_types.csv
if __name__ == "__main__":
    target_folder = "test_folder"
    output_file = "file_types.csv"
    batch_detect_file_type(target_folder, output_file)

代码说明

  • 定义 batch_detect_file_type 函数,接收目标文件夹路径和输出 CSV 文件路径作为参数。
  • 实例化 Magic 类并开启 mime_encoding 模式,用于获取文件的 MIME 类型和编码格式。
  • 使用 os.walk 遍历目标文件夹内的所有文件,获取每个文件的完整路径。
  • 调用 from_file 方法检测文件类型,将文件名、文件路径和检测结果写入 CSV 文件。
  • 捕获检测过程中的异常(如文件无法访问、权限不足等),并将错误信息写入 CSV 文件。
  • if __name__ == "__main__" 代码块中,指定目标文件夹和输出 CSV 文件路径,调用函数执行批量检测。

运行效果
运行代码后,会在当前目录下生成 file_types.csv 文件,文件中包含了文件夹内所有文件的名称、路径和类型信息,方便后续数据分析和筛选。

4.2 基于文件类型的爬虫数据过滤

在爬虫开发中,我们经常需要下载网络资源,但有时候会遇到链接返回的文件类型与预期不符的情况(如预期下载图片,实际下载的是 HTML 错误页面)。下面的案例结合 requests 库和 python-magic,实现爬虫数据的类型过滤,只保存符合预期类型的文件。

实例代码

import os
import requests
import magic

def download_file_by_type(url, save_folder, expected_mime):
    """
    根据预期 MIME 类型下载文件,过滤不符合类型的资源
    :param url: 文件下载链接
    :param save_folder: 文件保存文件夹
    :param expected_mime: 预期的 MIME 类型(如 image/jpeg、application/pdf 等)
    """
    # 创建保存文件夹(如果不存在)
    if not os.path.exists(save_folder):
        os.makedirs(save_folder)

    try:
        # 发送 GET 请求,获取文件二进制内容
        response = requests.get(url, stream=True)
        response.raise_for_status()  # 抛出 HTTP 错误异常

        # 读取文件二进制内容(读取前 1024 字节即可满足类型检测需求)
        file_content = response.raw.read(1024)

        # 检测文件 MIME 类型
        mime_magic = magic.Magic(mime=True)
        file_mime = mime_magic.from_buffer(file_content)

        # 判断是否符合预期 MIME 类型
        if file_mime == expected_mime:
            # 提取文件名
            file_name = url.split("/")[-1]
            save_path = os.path.join(save_folder, file_name)

            # 完整下载文件并保存
            with open(save_path, "wb") as f:
                for chunk in response.iter_content(chunk_size=1024):
                    if chunk:
                        f.write(chunk)
            print(f"成功下载符合预期的文件: {save_path} -> MIME 类型: {file_mime}")
        else:
            print(f"文件类型不符,跳过下载 -> 预期: {expected_mime}, 实际: {file_mime}")
    except requests.exceptions.RequestException as e:
        print(f"下载失败: {url} -> 错误信息: {str(e)}")

# 调用函数,下载预期为 JPG 图片的资源
if __name__ == "__main__":
    # 测试链接(替换为实际的下载链接)
    test_urls = [
        "https://example.com/valid_image.jpg",
        "https://example.com/fake_image.html"
    ]
    save_dir = "downloaded_images"
    expected_mime_type = "image/jpeg"

    for url in test_urls:
        download_file_by_type(url, save_dir, expected_mime_type)

代码说明

  • 定义 download_file_by_type 函数,接收下载链接、保存文件夹和预期 MIME 类型作为参数。
  • 首先创建保存文件夹(如果不存在),然后使用 requests.get 方法发送请求,开启流式传输模式(stream=True)。
  • 读取响应的前 1024 字节数据,这部分数据足够 python-magic 进行文件类型检测,避免下载完整文件后才发现类型不符。
  • 实例化 Magic 类并开启 mime 模式,检测读取到的二进制内容的 MIME 类型。
  • 判断检测到的 MIME 类型是否与预期类型一致,如果一致,则提取文件名并完整下载文件到保存文件夹;如果不一致,则跳过下载。
  • 捕获 requests 库的请求异常(如网络错误、HTTP 404 错误等),并输出错误信息。
  • if __name__ == "__main__" 代码块中,定义测试链接列表、保存目录和预期 MIME 类型,遍历链接并调用函数进行下载。

运行效果
运行代码后,只有 MIME 类型为 image/jpeg 的文件会被下载并保存到 downloaded_images 文件夹中,不符合预期类型的资源会被跳过,有效过滤了无效数据。

五、相关资源链接

  • Pypi地址:https://pypi.org/project/python-magic
  • Github地址:https://github.com/ahupp/python-magic
  • 官方文档地址:https://python-magic.readthedocs.io/en/latest/

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:xmltodict库快速上手与进阶实战

一、xmltodict库核心概述

1.1 库的用途

xmltodict是一款轻量级的Python库,核心功能是实现XML字符串与Python字典(dict)之间的双向转换,同时支持将Python字典反向生成为XML格式数据。借助该库,开发者无需编写复杂的XML解析代码(如DOM、SAX解析逻辑),就能像操作普通字典一样处理XML数据,极大降低了XML数据处理的门槛,广泛应用于接口数据解析、配置文件读取、数据格式转换等场景。

1.2 工作原理

xmltodict的底层基于Python内置的xml.parsers.expat解析器,采用事件驱动的解析方式处理XML数据。在解析XML时,它会遍历XML文档的各个节点(元素、属性、文本),将XML元素标签映射为字典的键,元素属性以@前缀标识的键存储,元素内的文本内容则存储在#text键中;反向生成XML时,会根据字典的键值结构,按照XML的语法规则拼接成对应的XML字符串。

1.3 优缺点分析

优点

  • 操作简单:将XML解析为字典后,可通过键值对快速访问数据,学习成本极低。
  • 轻量高效:依赖少,解析速度快,适合处理中小型XML文档。
  • 兼容性强:支持自定义命名空间处理、属性优先级设置等高级功能,满足多样化需求。
  • 双向转换:既可以解析XML为字典,也能将字典生成为标准XML,功能完整。

缺点

  • 内存消耗:解析大型XML文档时,会将整个文档加载到内存中,可能引发内存溢出问题。
  • 结构限制:对于嵌套层级极深的XML,转换后的字典结构会变得复杂,可读性下降。

1.4 许可证类型

xmltodict采用MIT License开源许可证,这意味着开发者可以自由地使用、复制、修改、合并、出版发行、散布、授权和/或销售该软件的副本,同时也可以将软件嵌入到其他商业软件中,只需保留原作者的版权声明即可。

二、xmltodict库安装与环境准备

2.1 安装方法

xmltodict是Python第三方库,可通过pip包管理工具一键安装,支持Python 2.7和Python 3.x版本,安装命令如下:

pip install xmltodict

安装完成后,可在Python交互式环境中执行以下代码验证是否安装成功:

import xmltodict
print(xmltodict.__version__)

若输出类似0.13.0的版本号,则说明安装成功。

2.2 环境要求

  • Python版本:Python 2.7 或 Python 3.4+
  • 依赖库:无额外强制依赖,仅依赖Python标准库中的xml.parsers.expat

三、xmltodict库基础使用教程

3.1 核心函数介绍

xmltodict的核心功能由两个函数实现,分别是parse()unparse(),具体说明如下:
| 函数名 | 功能描述 | 常用参数 |
|–|-|-|
| xmltodict.parse(xml_input, **kwargs) | 将XML字符串或文件对象解析为Python字典 | xml_input:XML字符串/文件对象;encoding:编码格式,默认utf-8process_namespaces:是否处理命名空间,默认Falseattr_prefix:属性键前缀,默认@cdata_key:CDATA段键名,默认#text |
| xmltodict.unparse(dict_input, **kwargs) | 将Python字典反向生成为XML字符串 | dict_input:待转换的Python字典;encoding:编码格式,默认utf-8pretty:是否格式化输出,默认Falseindent:格式化缩进字符,默认(两个空格) |

3.2 基础案例1:XML字符串解析为字典

本案例演示如何将一个简单的XML字符串解析为Python字典,并通过键值对访问XML中的数据。

# 导入xmltodict库
import xmltodict

# 定义一个简单的XML字符串
xml_str = """
<bookstore>
    <book category="COOKING">
        <title lang="en">Everyday Italian</title>
        <author>Giada De Laurentiis</author>
        <year>2005</year>
        <price>30.00</price>
    </book>
    <book category="CHILDREN">
        <title lang="en">Harry Potter</title>
        <author>J K. Rowling</author>
        <year>2005</year>
        <price>29.99</price>
    </book>
</bookstore>
"""

# 将XML字符串解析为Python字典
dict_data = xmltodict.parse(xml_str)

# 打印转换后的字典
print("转换后的字典结构:")
print(dict_data)
print("\n")

# 访问字典中的数据
# 1. 获取第一个book的分类
first_book_category = dict_data["bookstore"]["book"][0]["@category"]
print("第一本书的分类:", first_book_category)

# 2. 获取第二本书的标题
second_book_title = dict_data["bookstore"]["book"][1]["title"]
print("第二本书的标题:", second_book_title)

# 3. 获取第一本书的价格
first_book_price = dict_data["bookstore"]["book"][0]["price"]
print("第一本书的价格:", first_book_price)

代码说明

  • 首先导入xmltodict库,定义一个包含两本图书信息的XML字符串,其中<book>标签包含category属性,子标签包括titleauthor等。
  • 使用xmltodict.parse()函数将XML字符串转换为字典,XML的根节点bookstore成为字典的顶层键,子节点book则以列表形式存储(因为存在多个<book>元素)。
  • XML元素的属性会以@属性名的形式作为字典的键,例如@category对应<book>标签的category属性。
  • 通过字典的键值索引,可快速获取目标数据,例如dict_data["bookstore"]["book"][0]对应第一本图书的所有信息。

运行结果

转换后的字典结构:
{'bookstore': {'book': [{'@category': 'COOKING', 'title': {'@lang': 'en', '#text': 'Everyday Italian'}, 'author': 'Giada De Laurentiis', 'year': '2005', 'price': '30.00'}, {'@category': 'CHILDREN', 'title': {'@lang': 'en', '#text': 'Harry Potter'}, 'author': 'J K. Rowling', 'year': '2005', 'price': '29.99'}]}}


第一本书的分类: COOKING
第二本书的标题: {'@lang': 'en', '#text': 'Harry Potter'}
第一本书的价格: 30.00

3.3 基础案例2:Python字典生成XML字符串

本案例演示如何将Python字典反向转换为XML字符串,并通过参数控制XML的格式化输出。

import xmltodict

# 定义一个Python字典,模拟图书信息
book_dict = {
    "bookstore": {
        "book": [
            {
                "@category": "COOKING",
                "title": {
                    "@lang": "en",
                    "#text": "Everyday Italian"
                },
                "author": "Giada De Laurentiis",
                "year": "2005",
                "price": "30.00"
            },
            {
                "@category": "CHILDREN",
                "title": {
                    "@lang": "en",
                    "#text": "Harry Potter"
                },
                "author": "J K. Rowling",
                "year": "2005",
                "price": "29.99"
            }
        ]
    }
}

# 将字典转换为XML字符串(无格式化)
xml_str_no_pretty = xmltodict.unparse(book_dict)
print("无格式化的XML字符串:")
print(xml_str_no_pretty)
print("\n")

# 将字典转换为XML字符串(带格式化,缩进为4个空格)
xml_str_pretty = xmltodict.unparse(book_dict, pretty=True, indent="    ")
print("格式化后的XML字符串:")
print(xml_str_pretty)

代码说明

  • 定义一个与上例XML结构对应的Python字典,其中属性以@属性名作为键,文本内容以#text作为键,这是xmltodict识别属性和文本的约定。
  • 使用xmltodict.unparse()函数将字典转换为XML字符串,默认情况下(pretty=False)生成的XML是紧凑格式,无换行和缩进。
  • 当设置pretty=True时,生成的XML会自动换行和缩进,indent参数可指定缩进的字符(如4个空格),提升XML的可读性。

运行结果

无格式化的XML字符串:
<?xml version="1.0" encoding="utf-8"?><bookstore><book category="COOKING"><title lang="en">Everyday Italian</title><author>Giada De Laurentiis</author><year>2005</year><price>30.00</price></book><book category="CHILDREN"><title lang="en">Harry Potter</title><author>J K. Rowling</author><year>2005</year><price>29.99</price></book></bookstore>


格式化后的XML字符串:
<?xml version="1.0" encoding="utf-8"?>
<bookstore>
    <book category="COOKING">
        <title lang="en">Everyday Italian</title>
        <author>Giada De Laurentiis</author>
        <year>2005</year>
        <price>30.00</price>
    </book>
    <book category="CHILDREN">
        <title lang="en">Harry Potter</title>
        <author>J K. Rowling</author>
        <year>2005</year>
        <price>29.99</price>
    </book>
</bookstore>

3.4 基础案例3:读取XML文件并解析

在实际开发中,XML数据常以文件形式存储,本案例演示如何读取本地XML文件并解析为字典。
首先创建一个名为books.xml的文件,内容如下:

<?xml version="1.0" encoding="utf-8"?>
<bookstore>
    <book category="BUSINESS">
        <title lang="en">XML Developer's Guide</title>
        <author>Robert Richards</author>
        <year>2002</year>
        <price>44.95</price>
    </book>
    <book category="TECHNOLOGY">
        <title lang="en">Python Crash Course</title>
        <author>Eric Matthes</author>
        <year>2019</year>
        <price>29.95</price>
    </book>
</bookstore>

然后编写Python代码读取并解析该文件:

import xmltodict

# 打开XML文件并解析为字典
with open("books.xml", "r", encoding="utf-8") as f:
    # 直接将文件对象传入parse函数
    dict_data = xmltodict.parse(f)

# 遍历所有图书信息并打印
print("图书列表:")
for book in dict_data["bookstore"]["book"]:
    print(f"分类:{book['@category']}")
    print(f"标题:{book['title']['#text']}")
    print(f"语言:{book['title']['@lang']}")
    print(f"作者:{book['author']}")
    print(f"年份:{book['year']}")
    print(f"价格:{book['price']}")
    print("-" * 20)

代码说明

  • 使用Python的with open()语句以只读模式打开books.xml文件,指定编码为utf-8,避免中文乱码。
  • xmltodict.parse()函数支持直接传入文件对象,无需手动读取文件内容,简化了代码流程。
  • 通过遍历dict_data["bookstore"]["book"]列表,可逐个获取每本图书的属性和子节点数据,并格式化输出。

运行结果

图书列表:
分类:BUSINESS
标题:XML Developer's Guide
语言:en
作者:Robert Richards
年份:2002
价格:44.95
--
分类:TECHNOLOGY
标题:Python Crash Course
语言:en
作者:Eric Matthes
年份:2019
价格:29.95
--

四、xmltodict库进阶使用技巧

4.1 处理XML命名空间

在复杂的XML文档中,命名空间(Namespace)是常见的元素,用于避免标签名冲突。xmltodict支持通过process_namespaces参数处理命名空间,本案例演示如何解析带命名空间的XML。

import xmltodict

# 定义带命名空间的XML字符串
xml_with_ns = """
<root xmlns:ns="http://example.com/ns">
    <ns:user id="1001">
        <ns:name>Alice</ns:name>
        <ns:age>25</ns:age>
    </ns:user>
    <ns:user id="1002">
        <ns:name>Bob</ns:name>
        <ns:age>30</ns:age>
    </ns:user>
</root>
"""

# 不处理命名空间的情况
dict_no_ns = xmltodict.parse(xml_with_ns)
print("不处理命名空间的结果:")
print(dict_no_ns)
print("\n")

# 处理命名空间的情况(process_namespaces=True)
dict_with_ns = xmltodict.parse(xml_with_ns, process_namespaces=True)
print("处理命名空间的结果:")
print(dict_with_ns)
print("\n")

# 自定义命名空间前缀映射
namespace_map = {"http://example.com/ns": "user"}
dict_custom_ns = xmltodict.parse(
    xml_with_ns,
    process_namespaces=True,
    namespaces=namespace_map
)
print("自定义命名空间前缀的结果:")
print(dict_custom_ns)

代码说明

  • 定义的XML字符串中包含命名空间xmlns:ns="http://example.com/ns",所有子标签均以ns:为前缀。
  • process_namespaces=False(默认值)时,解析后的字典键会保留命名空间前缀(如ns:user)。
  • process_namespaces=True时,xmltodict会自动移除命名空间前缀,直接使用标签名作为键。
  • 通过namespaces参数可传入自定义的命名空间映射字典,将命名空间URI映射为更简洁的前缀,方便数据访问。

运行结果

不处理命名空间的结果:
{'root': {'@xmlns:ns': 'http://example.com/ns', 'ns:user': [{'@id': '1001', 'ns:name': 'Alice', 'ns:age': '25'}, {'@id': '1002', 'ns:name': 'Bob', 'ns:age': '30'}]}}


处理命名空间的结果:
{'root': {'user': [{'@id': '1001', 'name': 'Alice', 'age': '25'}, {'@id': '1002', 'name': 'Bob', 'age': '30'}]}}


自定义命名空间前缀的结果:
{'root': {'user:user': [{'@id': '1001', 'user:name': 'Alice', 'user:age': '25'}, {'@id': '1002', 'user:name': 'Bob', 'user:age': '30'}]}}

4.2 处理CDATA段

XML中的CDATA段用于存储不需要转义的文本内容(如包含<>&等特殊字符的文本),xmltodict默认将CDATA段的内容存储在#text键中,本案例演示如何解析包含CDATA段的XML。

import xmltodict

# 定义包含CDATA段的XML字符串
xml_with_cdata = """
<article>
    <title>Python & XML</title>
    <content><![CDATA[这是一篇关于Python解析XML的文章,包含特殊字符:< > & " ']]></content>
</article>
"""

# 解析XML字符串
dict_data = xmltodict.parse(xml_with_cdata)

# 访问CDATA段内容
title = dict_data["article"]["title"]
content = dict_data["article"]["content"]["#text"]

print(f"文章标题:{title}")
print(f"文章内容:{content}")

# 反向生成包含CDATA的XML
article_dict = {
    "article": {
        "title": "Python & XML",
        "content": {
            "#cdata-section": "这是一篇关于Python解析XML的文章,包含特殊字符:< > & \" '"
        }
    }
}

xml_str = xmltodict.unparse(article_dict, pretty=True)
print("\n生成的包含CDATA的XML:")
print(xml_str)

代码说明

  • XML中的<content>标签包含CDATA段,存储了带有特殊字符的文本,这些字符无需转义。
  • 解析后,CDATA段的内容会被存储在#text键中,可直接通过该键获取文本内容。
  • 反向生成包含CDATA的XML时,需要在字典中使用#cdata-section键,对应的值会被包装为CDATA段。

运行结果

文章标题:Python & XML
文章内容:这是一篇关于Python解析XML的文章,包含特殊字符:< > & " '

生成的包含CDATA的XML:
<?xml version="1.0" encoding="utf-8"?>
<article>
  <title>Python &amp; XML</title>
  <content><![CDATA[这是一篇关于Python解析XML的文章,包含特殊字符:< > & " ']]></content>
</article>

4.3 自定义解析参数

xmltodict提供了多个自定义参数,可根据需求调整解析行为,本案例演示常用参数的使用方法。

import xmltodict

xml_str = """
<data>
    <item id="1">A</item>
    <item id="2">B</item>
</data>
"""

# 自定义属性前缀(将默认的@改为attr_)
dict_custom_attr = xmltodict.parse(xml_str, attr_prefix="attr_")
print("自定义属性前缀的结果:")
print(dict_custom_attr)
print("\n")

# 自定义文本键名(将默认的#text改为text)
dict_custom_text = xmltodict.parse(xml_str, cdata_key="text")
print("自定义文本键名的结果:")
print(dict_custom_text)
print("\n")

# 强制将单元素列表转换为列表(force_list参数)
# 当XML中只有一个<item>标签时,默认不会生成列表,force_list可强制生成列表
xml_single_item = """
<data>
    <item id="1">A</item>
</data>
"""
# 默认解析
dict_default = xmltodict.parse(xml_single_item)
print("默认解析单元素的结果:")
print(type(dict_default["data"]["item"]))  # 输出<class 'dict'>

# 强制生成列表
dict_force_list = xmltodict.parse(xml_single_item, force_list=["item"])
print("强制生成列表的结果:")
print(type(dict_force_list["data"]["item"]))  # 输出<class 'list'>

代码说明

  • attr_prefix参数:用于修改属性键的前缀,默认是@,可改为其他字符串(如attr_),避免与普通键名冲突。
  • cdata_key参数:用于修改文本内容的键名,默认是#text,可改为更简洁的名称(如text)。
  • force_list参数:用于指定哪些标签需要强制生成列表,即使该标签在XML中只出现一次,避免后续遍历数据时出现类型错误。

运行结果

自定义属性前缀的结果:
{'data': {'item': [{'attr_id': '1', '#text': 'A'}, {'attr_id': '2', '#text': 'B'}]}}


自定义文本键名的结果:
{'data': {'item': [{'@id': '1', 'text': 'A'}, {'@id': '2', 'text': 'B'}]}}


默认解析单元素的结果:
<class 'dict'>
强制生成列表的结果:
<class 'list'>

五、xmltodict库实战案例:接口数据处理

在实际开发中,很多第三方接口会返回XML格式的数据,本案例模拟一个天气查询接口,使用xmltodict解析接口返回的XML数据,并提取关键信息。

5.1 模拟接口返回的XML数据

import xmltodict

# 模拟天气接口返回的XML数据
weather_xml = """
<weather_response>
    <city>北京市</city>
    <date>2024-05-20</date>
    <temperature>
        <max>28</max>
        <min>18</min>
        <unit>℃</unit>
    </temperature>
    <weather>晴转多云</weather>
    <wind>
        <direction>南风</direction>
        <level>2-3级</level>
    </wind>
</weather_response>
"""

5.2 解析XML数据并提取信息

# 解析XML数据
weather_dict = xmltodict.parse(weather_xml)

# 提取关键天气信息
city = weather_dict["weather_response"]["city"]
date = weather_dict["weather_response"]["date"]
max_temp = weather_dict["weather_response"]["temperature"]["max"]
min_temp = weather_dict["weather_response"]["temperature"]["min"]
unit = weather_dict["weather_response"]["temperature"]["unit"]
weather = weather_dict["weather_response"]["weather"]
wind_dir = weather_dict["weather_response"]["wind"]["direction"]
wind_level = weather_dict["weather_response"]["wind"]["level"]

# 格式化输出天气信息
weather_info = f"""
{date} {city}天气信息

天气状况:{weather}
温度范围:{min_temp}{unit} - {max_temp}{unit}
风向风力:{wind_dir} {wind_level}
"""
print(weather_info)

5.3 将处理后的数据生成为新的XML

# 构造新的天气字典(仅保留核心信息)
simple_weather_dict = {
    "simple_weather": {
        "@city": city,
        "@date": date,
        "weather": weather,
        "temp_range": f"{min_temp}{unit} - {max_temp}{unit}"
    }
}

# 生成格式化的XML字符串
simple_weather_xml = xmltodict.unparse(simple_weather_dict, pretty=True, indent="  ")
print("生成的简化天气XML:")
print(simple_weather_xml)

代码说明

  • 首先模拟天气接口返回的XML数据,包含城市、日期、温度、天气状况、风向风力等信息。
  • 使用xmltodict.parse()解析XML数据后,通过键值索引提取所需信息,并格式化输出为用户友好的天气报告。
  • 构造一个简化的天气字典,仅保留核心信息,并使用xmltodict.unparse()生成新的XML字符串,用于数据存储或接口转发。

运行结果

2024-05-20 北京市天气信息

天气状况:晴转多云
温度范围:18℃ - 28℃
风向风力:南风 2-3级

生成的简化天气XML:
<?xml version="1.0" encoding="utf-8"?>
<simple_weather city="北京市" date="2024-05-20">
  <weather>晴转多云</weather>
  <temp_range>18℃ - 28℃</temp_range>
</simple_weather>

六、xmltodict库相关资源

6.1 Pypi地址

https://pypi.org/project/xmltodict

6.2 Github地址

https://github.com/martinblech/xmltodict

6.3 官方文档地址

https://xmltodict.readthedocs.io/en/latest

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:xlwings零基础入门教程——轻松实现Excel与Python的无缝交互

一、xlwings库核心概述

xlwings是一款功能强大的Python库,其核心用途是实现Python与Excel之间的双向通信,用户可以通过Python脚本直接操控Excel的工作簿、工作表、单元格等元素,同时也能在Excel中调用Python函数。该库的工作原理是基于COM接口(Windows系统)和AppleScript(Mac系统)与Excel应用程序建立连接,从而实现对Excel的底层操作,无需依赖复杂的第三方插件。

xlwings的优点十分突出:支持.xlsx、.xls等多种Excel文件格式;可以保留Excel的宏、公式和格式;语法简洁易懂,贴近Excel的原生操作逻辑;支持在Excel中嵌入Python代码,实现自动化报表生成、数据清洗和分析等功能。缺点则是在非Windows和Mac系统(如Linux)上无法直接使用,因为其依赖于Excel应用程序的安装;在处理超大规模数据时,速度相较于pandas等库会稍慢一些。xlwings采用的是MIT开源许可证,用户可以自由地用于商业和非商业项目,无版权方面的限制。

二、xlwings库的安装与环境配置

2.1 安装xlwings

对于技术小白来说,xlwings的安装过程非常简单,只需要使用Python的包管理工具pip即可完成。首先需要确保你的电脑上已经安装了Python环境(推荐Python 3.7及以上版本),并且已经配置好了pip的环境变量。

打开命令提示符(Windows)或终端(Mac),输入以下命令:

pip install xlwings

等待安装完成后,就可以在Python脚本中导入xlwings库进行使用了。

2.2 验证安装是否成功

安装完成后,我们可以通过一个简单的Python脚本来验证xlwings是否安装成功。创建一个名为test_xlwings.py的文件,输入以下代码:

# 导入xlwings库
import xlwings as xw

# 打开一个新的Excel工作簿
wb = xw.Book()
# 获取工作簿中的第一个工作表
ws = wb.sheets[0]
# 在A1单元格中写入内容
ws.range('A1').value = 'Hello, xlwings!'
# 保存工作簿到指定路径
wb.save('test.xlsx')
# 关闭工作簿
wb.close()
print("xlwings安装成功,测试文件已生成!")

运行该脚本,如果没有报错,并且在脚本所在目录下生成了一个名为test.xlsx的Excel文件,打开后A1单元格显示“Hello, xlwings!”,则说明xlwings已经成功安装并可以正常使用。

2.3 Excel环境配置

xlwings的使用依赖于本地安装的Excel应用程序,Windows系统推荐使用Microsoft Excel 2010及以上版本,Mac系统推荐使用Microsoft Excel for Mac 2016及以上版本。不需要进行额外的配置,只需要确保Excel能够正常启动即可。

三、xlwings核心功能与基础用法

3.1 工作簿(Workbook)的操作

工作簿是Excel文件的核心载体,xlwings提供了多种方式来创建、打开和保存工作簿。

3.1.1 创建新的工作簿

使用xw.Book()方法可以创建一个新的Excel工作簿,该方法会自动启动Excel应用程序(如果尚未启动)。

import xlwings as xw

# 创建新的工作簿
wb = xw.Book()
# 查看工作簿的名称
print("新建工作簿名称:", wb.name)
# 关闭工作簿
wb.close()

代码说明:xw.Book()创建的是一个临时的工作簿,默认名称为“Book1”“Book2”等,使用wb.name可以查看工作簿的名称,最后通过wb.close()关闭工作簿。

3.1.2 打开已有的工作簿

如果需要对已存在的Excel文件进行操作,可以使用xw.Book(file_path)方法,其中file_path是Excel文件的路径(绝对路径或相对路径)。

import xlwings as xw

# 打开已有的工作簿(相对路径,文件需在脚本所在目录)
wb = xw.Book('test.xlsx')
print("已打开工作簿名称:", wb.name)
# 关闭工作簿
wb.close()

# 使用绝对路径打开工作簿(示例路径,需根据实际情况修改)
# wb = xw.Book(r'C:\Users\Admin\Desktop\data.xlsx')

代码说明:使用相对路径时,Excel文件需要和Python脚本在同一个目录下;使用绝对路径时,需要在路径前加r,避免转义字符的影响。

3.1.3 保存工作簿

对工作簿进行操作后,需要使用save()方法来保存修改,save()方法可以指定保存路径,如果不指定,则保存到原文件路径。

import xlwings as xw

wb = xw.Book()
ws = wb.sheets[0]
ws.range('A1').value = 'Python操作Excel'

# 保存到指定路径
wb.save('new_excel.xlsx')
wb.close()

代码说明:如果指定的保存路径中不存在该文件,save()方法会自动创建;如果已经存在,则会覆盖原文件。

3.2 工作表(Worksheet)的操作

工作表是工作簿中的子对象,一个工作簿可以包含多个工作表,xlwings提供了丰富的方法来对工作表进行添加、删除、重命名和选择等操作。

3.2.1 选择工作表

可以通过工作表的索引或名称来选择工作表,索引从0开始,对应Excel中的第一个工作表。

import xlwings as xw

wb = xw.Book('new_excel.xlsx')
# 通过索引选择第一个工作表
ws1 = wb.sheets[0]
print("通过索引选择的工作表名称:", ws1.name)

# 通过名称选择工作表(默认第一个工作表名称为Sheet1)
ws2 = wb.sheets['Sheet1']
print("通过名称选择的工作表名称:", ws2.name)
wb.close()

代码说明:新建的工作簿默认只有一个名为“Sheet1”的工作表,通过索引和名称两种方式都可以准确选择目标工作表。

3.2.2 添加新的工作表

使用wb.sheets.add()方法可以在工作簿中添加新的工作表,可以指定工作表的名称和位置。

import xlwings as xw

wb = xw.Book('new_excel.xlsx')
# 添加新的工作表,名称为"数据报表",位置在最后
new_ws = wb.sheets.add(name='数据报表')
print("新增工作表名称:", new_ws.name)

# 添加新的工作表,位置在第一个工作表之前
# new_ws2 = wb.sheets.add(name='汇总表', before=wb.sheets[0])
wb.save()
wb.close()

代码说明:name参数用于指定新工作表的名称,before参数用于指定新工作表插入的位置,after参数则可以指定插入到某个工作表之后。

3.2.3 重命名和删除工作表

使用ws.name属性可以修改工作表的名称,使用ws.delete()方法可以删除指定的工作表。

import xlwings as xw

wb = xw.Book('new_excel.xlsx')
ws = wb.sheets['数据报表']
# 重命名工作表
ws.name = '销售数据报表'
print("重命名后的工作表名称:", ws.name)

# 删除指定的工作表
# wb.sheets['销售数据报表'].delete()
wb.save()
wb.close()

代码说明:删除工作表时要格外小心,删除后无法撤销,建议在删除前进行数据备份。

3.3 单元格(Range)的操作

单元格是Excel中存储数据的最小单位,xlwings提供了灵活的方式来对单元格进行读写、格式设置等操作,这也是xlwings最核心的功能之一。

3.3.1 单元格的选择

可以通过单元格的地址(如A1、B2)、行和列的索引来选择单元格或单元格区域。

import xlwings as xw

wb = xw.Book('new_excel.xlsx')
ws = wb.sheets[0]
# 选择单个单元格(A1)
rng1 = ws.range('A1')
print("A1单元格的值:", rng1.value)

# 选择单元格区域(A1:C3)
rng2 = ws.range('A1:C3')
print("A1:C3区域的行数:", rng2.rows.count)
print("A1:C3区域的列数:", rng2.columns.count)

# 通过行和列索引选择单元格(第1行第1列,即A1)
rng3 = ws.range((1, 1))
print("第1行第1列单元格的值:", rng3.value)

# 通过行和列索引选择单元格区域(第1行第1列到第3行第3列,即A1:C3)
rng4 = ws.range((1, 1), (3, 3))
wb.close()

代码说明:使用range()方法选择单元格区域时,可以使用Excel的单元格地址格式,也可以使用元组的形式指定起始和结束位置,元组中的第一个元素是行号,第二个元素是列号。

3.3.2 单元格数据的读取

读取单元格数据是xlwings的常用操作,无论是单个单元格还是单元格区域,都可以通过value属性来获取数据。

import xlwings as xw

# 先向Excel中写入测试数据,再进行读取
wb = xw.Book()
ws = wb.sheets[0]
# 向A1:C3区域写入数据
data = [
    ['姓名', '年龄', '性别'],
    ['张三', 25, '男'],
    ['李四', 28, '女']
]
ws.range('A1').value = data

# 读取单个单元格的值
name = ws.range('A2').value
print("A2单元格的值(姓名):", name)

# 读取整行数据(第2行)
row_data = ws.range('2:2').value
print("第2行数据:", row_data)

# 读取整列数据(第1列)
col_data = ws.range('A:A').value
# 过滤掉空值
col_data = [x for x in col_data if x is not None]
print("第1列数据:", col_data)

# 读取单元格区域的数据
range_data = ws.range('A1:C3').value
print("A1:C3区域的数据:")
for row in range_data:
    print(row)
wb.close()

代码说明:读取单元格区域的数据时,会返回一个二维列表,其中每个子列表对应Excel中的一行数据;读取整行或整列数据时,会返回一个一维列表,包含该行或该列的所有非空值。

3.3.3 单元格数据的写入

向单元格中写入数据同样通过value属性,xlwings支持写入单个值、列表、二维列表等多种数据类型。

import xlwings as xw

wb = xw.Book()
ws = wb.sheets[0]
# 写入单个值
ws.range('A1').value = '学生信息表'
# 合并单元格(A1:C1)
ws.range('A1:C1').api.merge()

# 写入一维列表(一行数据)
ws.range('A2').value = ['姓名', '年龄', '成绩']

# 写入二维列表(多行多列数据)
student_data = [
    ['王五', 22, 95],
    ['赵六', 23, 88],
    ['孙七', 21, 92]
]
ws.range('A3').value = student_data

# 写入字典数据(按列写入)
score_dict = {'语文': 90, '数学': 85, '英语': 93}
ws.range('E2').value = list(score_dict.keys())
ws.range('E3').value = list(score_dict.values())

wb.save('student_info.xlsx')
wb.close()

代码说明:写入一维列表时,xlwings会自动将列表中的元素按行写入单元格;写入二维列表时,会按多行多列的形式写入;写入字典数据时,可以将键和值分别写入不同的列。

3.3.4 单元格格式的设置

xlwings还支持对单元格的格式进行设置,如字体大小、颜色、对齐方式、边框等,这些设置通过api属性调用Excel的底层接口实现。

import xlwings as xw

wb = xw.Book('student_info.xlsx')
ws = wb.sheets[0]
# 设置标题单元格格式(A1:C1)
title_rng = ws.range('A1:C1')
# 设置字体大小为16,加粗
title_rng.api.Font.Size = 16
title_rng.api.Font.Bold = True
# 设置背景颜色为浅蓝色
title_rng.api.Interior.Color = xw.utils.rgb_to_int((211, 223, 236))
# 设置水平居中对齐
title_rng.api.HorizontalAlignment = xw.constants.HAlign.xlHAlignCenter

# 设置表头单元格格式(A2:C2)
header_rng = ws.range('A2:C2')
header_rng.api.Font.Bold = True
header_rng.api.Interior.Color = xw.utils.rgb_to_int((226, 239, 218))

# 为数据区域添加边框(A3:C5)
data_rng = ws.range('A3:C5')
# 边框样式:细实线
border = xw.constants.BordersIndex.xlEdgeLeft
data_rng.api.Borders(border).LineStyle = xw.constants.LineStyle.xlContinuous
data_rng.api.Borders(xw.constants.BordersIndex.xlEdgeRight).LineStyle = xw.constants.LineStyle.xlContinuous
data_rng.api.Borders(xw.constants.BordersIndex.xlEdgeTop).LineStyle = xw.constants.LineStyle.xlContinuous
data_rng.api.Borders(xw.constants.BordersIndex.xlEdgeBottom).LineStyle = xw.constants.LineStyle.xlContinuous

wb.save()
wb.close()

代码说明:xw.utils.rgb_to_int()函数用于将RGB颜色值转换为Excel可以识别的整数;xw.constants模块中包含了Excel的各种常量,如对齐方式、边框样式等,方便用户进行格式设置。

四、xlwings高级应用实战案例

4.1 案例一:自动化生成销售数据报表

在日常工作中,我们经常需要根据原始数据生成销售报表,使用xlwings可以实现这一过程的自动化,减少重复的手工操作。

4.1.1 需求分析

假设我们有一份销售原始数据,包含销售日期、产品名称、销售额等信息,需要实现以下功能:

  1. 读取原始销售数据;
  2. 按产品名称汇总销售额;
  3. 将汇总结果写入Excel报表,并设置报表格式;
  4. 生成销售额柱状图。

4.1.2 代码实现

import xlwings as xw
import pandas as pd

# 1. 生成模拟销售数据(实际应用中可以从CSV、数据库读取)
sales_data = {
    '销售日期': ['2024-01-01', '2024-01-01', '2024-01-02', '2024-01-02', '2024-01-03'],
    '产品名称': ['产品A', '产品B', '产品A', '产品C', '产品B'],
    '销售额': [1000, 1500, 1200, 800, 1600]
}
df = pd.DataFrame(sales_data)
# 按产品名称汇总销售额
summary_df = df.groupby('产品名称')['销售额'].sum().reset_index()

# 2. 使用xlwings创建销售报表
wb = xw.Book()
ws = wb.sheets[0]
ws.name = '销售汇总报表'

# 写入报表标题
ws.range('A1').value = '2024年1月产品销售汇总报表'
ws.range('A1').api.Font.Size = 18
ws.range('A1').api.Font.Bold = True
ws.range('A1:D1').api.merge()
ws.range('A1').api.HorizontalAlignment = xw.constants.HAlign.xlHAlignCenter

# 写入汇总数据
ws.range('A3').value = ['产品名称', '总销售额(元)']
ws.range('A4').value = summary_df.values

# 设置数据区域格式
header_rng = ws.range('A3:B3')
header_rng.api.Font.Bold = True
header_rng.api.Interior.Color = xw.utils.rgb_to_int((220, 220, 220))
data_rng = ws.range(f'A4:B{3 + len(summary_df)}')
data_rng.api.Borders.LineStyle = xw.constants.LineStyle.xlContinuous

# 3. 插入销售额柱状图
chart_range = ws.range(f'A3:B{3 + len(summary_df)}')
chart = ws.charts.add(left=ws.range('D3').left, top=ws.range('D3').top, width=400, height=300)
chart.set_source_data(chart_range)
chart.chart_type = xw.constants.ChartType.xlColumnClustered
chart.name = '产品销售额柱状图'
chart.api.ChartTitle.Text = '各产品销售额对比'

# 4. 保存报表
wb.save('2024年1月销售汇总报表.xlsx')
wb.close()
print("销售汇总报表已生成!")

代码说明:本案例结合了pandas库和xlwings库,pandas用于数据的汇总分析,xlwings用于Excel报表的生成和格式设置,同时还实现了图表的插入,让报表更加直观。

4.2 案例二:在Excel中调用Python函数

xlwings的一个特色功能是可以在Excel中直接调用Python函数,这对于需要在Excel中进行复杂计算的用户来说非常实用。

4.2.1 需求分析

实现一个在Excel中计算两个数的乘积和求和的功能,具体步骤如下:

  1. 编写Python函数,实现求和和乘积计算;
  2. 在Excel中通过xlwings调用这些函数;
  3. 实现Excel中数据的实时计算。

4.2.2 代码实现

步骤1:编写Python函数脚本
创建一个名为excel_functions.py的文件,输入以下代码:

import xlwings as xw

@xw.func
def add_numbers(a, b):
    """计算两个数的和"""
    return a + b

@xw.func
def multiply_numbers(a, b):
    """计算两个数的乘积"""
    return a * b

代码说明:使用@xw.func装饰器可以将普通的Python函数转换为可以在Excel中调用的函数。

步骤2:在Excel中调用Python函数

  1. 打开Excel应用程序;
  2. 点击“xlwings”选项卡(安装xlwings后会自动添加);
  3. 点击“Import Functions”按钮,选择刚才创建的excel_functions.py文件;
  4. 在Excel单元格中输入公式:
  • 求和:=add_numbers(A1, B1)
  • 乘积:=multiply_numbers(A1, B1)
  1. 在A1和B1单元格中输入数字,即可看到计算结果。

步骤3:实现实时计算
如果修改A1或B1单元格中的数值,Excel会自动调用Python函数重新计算结果,实现数据的实时更新。

五、xlwings常见问题与解决方案

5.1 问题1:运行脚本时提示“找不到Excel应用程序”

解决方案

  1. 检查电脑上是否安装了Excel应用程序,xlwings依赖于Excel的安装;
  2. 对于Windows系统,确保Excel的COM组件已注册,可以通过命令提示符运行excel.exe /regserver进行注册;
  3. 对于Mac系统,确保Excel的AppleScript权限已开启。

5.2 问题2:处理大规模数据时速度较慢

解决方案

  1. 尽量减少对单元格的逐个操作,采用批量读写的方式(如写入二维列表);
  2. 结合pandas库,先使用pandas处理数据,再将结果写入Excel;
  3. 在操作过程中可以将Excel设置为不可见模式,减少界面渲染的时间:
   import xlwings as xw
   app = xw.App(visible=False)
   wb = app.books.open('data.xlsx')
   # 进行数据处理
   wb.save()
   wb.close()
   app.quit()

5.3 问题3:保存文件时提示“文件被占用”

解决方案

  1. 检查Excel文件是否被其他程序打开,关闭相关程序后再尝试保存;
  2. 确保脚本中使用wb.close()app.quit()方法关闭工作簿和Excel应用程序;
  3. 如果仍然无法解决,可以重启电脑后再运行脚本。

六、xlwings相关资源

  • Pypi地址:https://pypi.org/project/xlwings
  • Github地址:https://github.com/xlwings/xlwings
  • 官方文档地址:https://docs.xlwings.org

关注我,每天分享一个实用的Python自动化工具。

Python实用工具Tablib:高效数据格式处理与导出指南

一、Tablib库核心概述

Tablib是一款轻量级且功能强大的Python数据处理库,其核心用途是实现不同格式数据的无缝转换、导入与导出,同时支持数据的增删改查等基础操作。它的工作原理是构建一个Dataset核心对象,该对象可以像表格一样存储结构化数据,然后通过内置的序列化方法,将数据转换为CSV、JSON、YAML、Excel等多种格式,也能反向解析这些格式的数据生成Dataset

Tablib的优点十分突出:API设计简洁直观,易于上手;支持的格式丰富,满足多场景数据交互需求;不依赖过多第三方库,运行轻量化。缺点则是对超大规模数据的处理性能一般,且高级数据筛选功能需要结合其他库实现。该库采用MIT开源许可证,用户可以自由地用于商业和非商业项目。

二、Tablib库安装与环境配置

2.1 基础安装命令

Tablib可以通过Python官方的包管理工具pip进行快速安装,对于技术小白来说,无需复杂的编译过程,只需要打开命令行终端,输入以下命令即可完成安装:

pip install tablib

2.2 扩展格式支持安装

默认安装的Tablib已经支持CSV、JSON、YAML等基础格式,但如果需要处理Excel(.xls和.xlsx)格式的数据,需要额外安装依赖库xlrdopenpyxl,输入以下命令完成扩展安装:

pip install tablib[xlsx]

这条命令会自动安装处理Excel格式所需的依赖,确保后续可以正常读写Excel文件。

2.3 安装验证

安装完成后,我们可以通过一个简单的Python脚本验证是否安装成功,在本地创建一个test_install.py文件,输入以下代码:

import tablib

# 创建一个空的Dataset对象
data = tablib.Dataset()
# 打印Tablib的版本号
print(f"Tablib版本:{tablib.__version__}")
print("安装验证成功!")

运行该脚本,如果终端输出类似Tablib版本:3.5.0 安装验证成功!的内容,就说明Tablib已经成功安装到你的Python环境中。

三、Tablib核心对象与基础操作

Tablib的核心操作围绕Dataset对象展开,Dataset可以理解为一个内存中的表格数据集,支持行和列的灵活操作,接下来我们通过具体的代码示例来讲解基础用法。

3.1 创建Dataset对象

创建Dataset对象有多种方式,包括空对象创建、基于列表数据创建、指定表头创建,以下是详细的代码示例:

3.1.1 创建空Dataset

import tablib

# 创建空的Dataset
empty_data = tablib.Dataset()
print("空Dataset对象:", empty_data)
print("Dataset的行数:", len(empty_data))
print("Dataset的列数:", empty_data.width)

代码说明:这段代码创建了一个空的Dataset对象,然后分别打印了对象本身、行数和列数。运行后会发现,空Dataset的行数为0,列数也为0。

3.1.2 基于列表数据创建Dataset

import tablib

# 定义数据列表,每个子列表代表一行数据
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
# 创建Dataset并传入数据
data = tablib.Dataset(*data_rows)
print("基于列表的Dataset:")
print(data)
print("行数:", len(data))
print("列数:", data.width)

代码说明:这里我们定义了一个包含3行3列数据的列表,通过*解包参数传入Dataset构造函数,创建出包含数据的对象。运行后可以看到,数据以表格形式输出,行数为3,列数为3。

3.1.3 指定表头创建Dataset

在实际应用中,我们通常需要给数据列指定表头,让数据更具可读性,代码示例如下:

import tablib

# 定义表头
headers = ["姓名", "年龄", "城市"]
# 定义数据行
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
# 创建带表头的Dataset
data = tablib.Dataset(*data_rows, headers=headers)
print("带表头的Dataset:")
print(data)
print("表头信息:", data.headers)

代码说明:通过在Dataset构造函数中传入headers参数,我们为数据集添加了表头。运行后输出的表格会包含表头行,并且可以通过data.headers属性获取表头信息。

3.2 Dataset数据增删改查

创建好Dataset对象后,我们可以对其中的数据进行灵活的增删改查操作,满足日常数据处理需求。

3.2.1 数据查询操作

查询操作包括获取指定行、指定列的数据,以及获取单元格数据,代码示例如下:

import tablib

headers = ["姓名", "年龄", "城市"]
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 获取指定行数据(索引从0开始)
print("第1行数据:", data[0])
# 获取指定列数据(通过表头名)
print("年龄列数据:", data["年龄"])
# 获取指定单元格数据(行索引,列索引)
print("第2行第3列数据:", data[1][2])
# 遍历所有行数据
print("\n遍历所有数据行:")
for row in data:
    print(row)

代码说明:通过索引可以快速获取指定行的数据,通过表头名可以获取整列数据,通过行索引+列索引的组合可以获取单元格数据。遍历操作则可以逐个获取数据集中的每一行数据。

3.2.2 数据添加操作

添加操作包括添加单行数据、添加多行数据和添加列数据,代码示例如下:

import tablib

headers = ["姓名", "年龄", "城市"]
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 添加单行数据
new_row = ["赵六", 32, "深圳"]
data.append(new_row)
print("添加单行后的数据:")
print(data)

# 添加多行数据
new_rows = [
    ["孙七", 27, "杭州"],
    ["周八", 29, "成都"]
]
data.extend(new_rows)
print("\n添加多行后的数据:")
print(data)

# 添加列数据(需要指定列名和数据)
new_column_data = [1001, 1002, 1003, 1004, 1005, 1006]
data.append_col(new_column_data, header="员工编号")
print("\n添加列后的数据:")
print(data)

代码说明:append()方法用于添加单行数据,extend()方法用于添加多行数据,append_col()方法用于添加整列数据,并且需要通过header参数指定新列的表头名。

3.2.3 数据修改操作

修改操作包括修改指定行、指定列和指定单元格的数据,代码示例如下:

import tablib

headers = ["姓名", "年龄", "城市"]
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 修改指定行数据
data[1] = ["李四", 31, "重庆"]
print("修改第2行后的数据:")
print(data)

# 修改指定列数据
data["年龄"] = [26, 32, 29]
print("\n修改年龄列后的数据:")
print(data)

# 修改指定单元格数据
data[2][1] = 30
print("\n修改第3行第2列后的数据:")
print(data)

代码说明:通过索引直接赋值可以修改指定行的数据,通过表头名赋值可以修改整列数据,通过行索引+列索引赋值可以修改单元格数据。

3.2.4 数据删除操作

删除操作包括删除指定行和指定列的数据,代码示例如下:

import tablib

headers = ["姓名", "年龄", "城市", "员工编号"]
data_rows = [
    ["张三", 25, "北京", 1001],
    ["李四", 30, "上海", 1002],
    ["王五", 28, "广州", 1003]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 删除指定行数据(索引从0开始)
del data[1]
print("删除第2行后的数据:")
print(data)

# 删除指定列数据(通过表头名)
del data["员工编号"]
print("\n删除员工编号列后的数据:")
print(data)

代码说明:使用del关键字可以删除指定行或指定列的数据,删除行时传入行索引,删除列时传入列的表头名。

四、Tablib数据格式转换与导出导入

Tablib最核心的功能就是支持多种数据格式的转换、导出和导入,这也是它在实际项目中被广泛应用的原因,接下来我们详细讲解常用格式的操作方法。

4.1 CSV格式处理

CSV是一种通用的文本格式数据,常用于数据的存储和交换,Tablib支持Dataset与CSV格式的相互转换。

4.1.1 Dataset导出为CSV格式

import tablib

headers = ["姓名", "年龄", "城市"]
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 将Dataset转换为CSV格式字符串
csv_data = data.csv
print("CSV格式数据:")
print(csv_data)

# 将CSV数据保存到本地文件
with open("data.csv", "w", encoding="utf-8") as f:
    f.write(csv_data)
print("\nCSV文件已保存到本地!")

代码说明:通过data.csv属性可以直接将Dataset对象转换为CSV格式的字符串,然后我们可以通过文件操作将其保存到本地,生成.csv文件。

4.1.2 CSV格式导入为Dataset

import tablib

# 从本地CSV文件读取数据
with open("data.csv", "r", encoding="utf-8") as f:
    csv_content = f.read()

# 将CSV字符串转换为Dataset对象
data = tablib.Dataset().load(csv_content, format="csv")
print("从CSV导入的Dataset数据:")
print(data)
print("表头信息:", data.headers)

代码说明:首先读取本地CSV文件的内容,然后使用load()方法,指定format="csv",将CSV字符串转换为Dataset对象,方便后续进行数据处理。

4.2 JSON格式处理

JSON是一种轻量级的数据交换格式,在Web开发和API交互中应用广泛,Tablib同样支持Dataset与JSON格式的转换。

4.2.1 Dataset导出为JSON格式

import tablib

headers = ["姓名", "年龄", "城市"]
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 将Dataset转换为JSON格式字符串
json_data = data.json
print("JSON格式数据:")
print(json_data)

# 将JSON数据保存到本地文件
with open("data.json", "w", encoding="utf-8") as f:
    f.write(json_data)
print("\nJSON文件已保存到本地!")

代码说明:通过data.json属性可以快速将Dataset转换为JSON字符串,然后保存到本地生成.json文件。Tablib生成的JSON数据会以列表的形式存储每一行数据,表头作为键名。

4.2.2 JSON格式导入为Dataset

import tablib

# 从本地JSON文件读取数据
with open("data.json", "r", encoding="utf-8") as f:
    json_content = f.read()

# 将JSON字符串转换为Dataset对象
data = tablib.Dataset().load(json_content, format="json")
print("从JSON导入的Dataset数据:")
print(data)

代码说明:读取本地JSON文件内容后,使用load()方法并指定format="json",即可将JSON字符串转换为Dataset对象,实现数据的反向解析。

4.3 Excel格式处理

Excel是办公场景中最常用的数据表格格式,Tablib支持将Dataset导出为.xls.xlsx格式,也能从Excel文件中导入数据。

4.3.1 Dataset导出为Excel格式

import tablib

headers = ["姓名", "年龄", "城市"]
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 将Dataset转换为Excel格式字节数据
xlsx_data = data.xlsx

# 将Excel数据保存到本地文件
with open("data.xlsx", "wb") as f:
    f.write(xlsx_data)
print("Excel文件已保存到本地!")

代码说明:与CSV、JSON不同,Excel格式的数据是字节流形式,因此需要通过data.xlsx属性获取字节数据,然后以wb(二进制写入)模式保存到本地,生成.xlsx文件。如果需要生成.xls格式文件,可以使用data.xls属性。

4.3.2 Excel格式导入为Dataset

import tablib

# 从本地Excel文件读取字节数据
with open("data.xlsx", "rb") as f:
    xlsx_content = f.read()

# 将Excel字节数据转换为Dataset对象
data = tablib.Dataset().load(xlsx_content, format="xlsx")
print("从Excel导入的Dataset数据:")
print(data)
print("表头信息:", data.headers)

代码说明:读取Excel文件时需要使用rb(二进制读取)模式,然后通过load()方法指定format="xlsx",将字节数据转换为Dataset对象。对于.xls格式文件,只需将format参数改为"xls"即可。

4.4 YAML格式处理

YAML是一种人类可读的数据序列化格式,常用于配置文件和数据交换,Tablib支持Dataset与YAML格式的转换。

4.4.1 Dataset导出为YAML格式

import tablib

headers = ["姓名", "年龄", "城市"]
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 将Dataset转换为YAML格式字符串
yaml_data = data.yaml
print("YAML格式数据:")
print(yaml_data)

# 将YAML数据保存到本地文件
with open("data.yaml", "w", encoding="utf-8") as f:
    f.write(yaml_data)
print("\nYAML文件已保存到本地!")

代码说明:通过data.yaml属性可以将Dataset转换为YAML字符串,然后保存到本地生成.yaml文件。YAML格式的数据结构清晰,可读性强,适合用于配置和简单数据存储。

4.4.2 YAML格式导入为Dataset

import tablib

# 从本地YAML文件读取数据
with open("data.yaml", "r", encoding="utf-8") as f:
    yaml_content = f.read()

# 将YAML字符串转换为Dataset对象
data = tablib.Dataset().load(yaml_content, format="yaml")
print("从YAML导入的Dataset数据:")
print(data)

代码说明:读取YAML文件内容后,使用load()方法指定format="yaml",即可将YAML字符串转换为Dataset对象,实现数据的解析和处理。

五、Tablib高级功能与实际应用案例

5.1 数据筛选与排序

虽然Tablib本身没有提供复杂的筛选和排序API,但我们可以结合Python的列表推导式和内置函数,实现数据的筛选和排序功能,代码示例如下:

import tablib

headers = ["姓名", "年龄", "城市"]
data_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"],
    ["王五", 28, "广州"],
    ["赵六", 32, "深圳"],
    ["孙七", 27, "北京"]
]
data = tablib.Dataset(*data_rows, headers=headers)

# 数据筛选:筛选出城市为北京的记录
beijing_data = [row for row in data if row[2] == "北京"]
filtered_dataset = tablib.Dataset(*beijing_data, headers=headers)
print("城市为北京的筛选结果:")
print(filtered_dataset)

# 数据排序:按年龄升序排序
sorted_data = sorted(data, key=lambda x: x[1])
sorted_dataset = tablib.Dataset(*sorted_data, headers=headers)
print("\n按年龄升序排序的结果:")
print(sorted_dataset)

代码说明:通过列表推导式,我们筛选出了城市为北京的所有数据行,并创建了新的Dataset对象;使用sorted()函数和lambda表达式,我们实现了按年龄升序排序的功能,同样生成了新的排序后数据集。

5.2 多数据集合并

在实际项目中,我们可能需要将多个结构相同的Dataset合并为一个,Tablib的extend()方法可以轻松实现这个需求,代码示例如下:

import tablib

# 创建第一个数据集
headers = ["姓名", "年龄", "城市"]
data1_rows = [
    ["张三", 25, "北京"],
    ["李四", 30, "上海"]
]
data1 = tablib.Dataset(*data1_rows, headers=headers)

# 创建第二个数据集
data2_rows = [
    ["王五", 28, "广州"],
    ["赵六", 32, "深圳"]
]
data2 = tablib.Dataset(*data2_rows, headers=headers)

# 合并两个数据集
data1.extend(data2_rows)
print("合并后的数据集:")
print(data1)

代码说明:首先创建两个结构相同(表头一致)的Dataset对象,然后使用extend()方法将第二个数据集的数据行添加到第一个数据集中,实现多数据集的合并。

5.3 实际应用案例:学生成绩数据处理

假设我们需要处理一份学生成绩数据,要求完成以下任务:

  1. 创建学生成绩数据集,包含姓名、语文、数学、英语三科成绩;
  2. 计算每个学生的总分并添加到数据集中;
  3. 筛选出总分大于270分的学生;
  4. 将筛选后的结果分别导出为CSV和Excel格式文件。

实现代码如下:

import tablib

# 1. 创建学生成绩数据集
headers = ["姓名", "语文", "数学", "英语"]
score_rows = [
    ["小明", 95, 98, 92],
    ["小红", 88, 90, 95],
    ["小刚", 92, 85, 88],
    ["小丽", 96, 94, 97],
    ["小强", 85, 82, 80]
]
score_data = tablib.Dataset(*score_rows, headers=headers)
print("原始学生成绩数据:")
print(score_data)

# 2. 计算总分并添加到数据集
total_scores = [row[1] + row[2] + row[3] for row in score_data]
score_data.append_col(total_scores, header="总分")
print("\n添加总分后的成绩数据:")
print(score_data)

# 3. 筛选总分大于270分的学生
high_score_rows = [row for row in score_data if row[4] > 270]
high_score_data = tablib.Dataset(*high_score_rows, headers=score_data.headers)
print("\n总分大于270分的学生数据:")
print(high_score_data)

# 4. 导出为CSV和Excel格式文件
# 导出CSV
with open("high_score.csv", "w", encoding="utf-8") as f:
    f.write(high_score_data.csv)
# 导出Excel
with open("high_score.xlsx", "wb") as f:
    f.write(high_score_data.xlsx)
print("\n高分学生数据已导出为CSV和Excel文件!")

代码说明:这个案例结合了Tablib的基础操作和格式转换功能,完整实现了学生成绩数据的处理流程。通过计算总分、筛选数据,最终将结果导出为两种常用格式的文件,满足实际办公和数据处理需求。

六、Tablib相关资源链接

  • Pypi地址:https://pypi.org/project/Tablib
  • Github地址:https://github.com/jazzband/tablib
  • 官方文档地址:https://tablib.readthedocs.io/

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:Faker库完全指南——快速生成逼真测试数据

一、Faker库核心概述

Faker是一款Python第三方库,核心用途是生成高度逼真的伪随机测试数据,涵盖姓名、地址、邮箱、手机号、文本、时间等数百种数据类型,广泛应用于软件开发、数据分析、自动化测试等场景。其工作原理是基于不同地区的本地化数据模板,通过随机算法组合生成符合现实逻辑的虚拟数据。

该库的优点十分突出:支持多语言多地区本地化配置、数据类型丰富且可自定义扩展、使用方式简单灵活;缺点则是生成的部分专业数据(如金融账号、医疗信息)不具备真实有效性,仅适用于测试环境。Faker遵循MIT开源许可证,开发者可自由用于商业和非商业项目,无授权限制。

二、Faker库安装与基础配置

2.1 安装方式

Faker的安装非常便捷,支持pipconda两种主流包管理工具,技术小白也能轻松上手。

  1. pip安装(推荐)
    打开命令行终端,输入以下命令即可完成安装:
    python pip install faker
    若需要安装指定版本(如稳定版19.6.2),可执行:
    python pip install faker==19.6.2
  2. conda安装
    如果你使用Anaconda环境,可通过conda-forge源安装:
    python conda install -c conda-forge faker

2.2 验证安装

安装完成后,我们可以通过一段简单的代码验证是否安装成功。创建一个名为test_faker.py的文件,输入以下代码:

# 导入Faker的核心类
from faker import Faker

# 初始化Faker对象
fake = Faker()

# 生成一条随机姓名数据
print("随机姓名:", fake.name())

运行该脚本,若终端输出类似随机姓名: Jennifer Davis的结果,则说明Faker库已成功安装。

2.3 本地化配置

Faker支持全球数十个国家和地区的本地化数据生成,默认生成的是英文数据,我们可以通过指定语言代码来生成符合国内习惯的中文数据。常见的语言代码包括:zh_CN(中国大陆)、zh_TW(中国台湾)、en_US(美国英语)、ja_JP(日语)等。

以下是本地化配置的代码示例:

from faker import Faker

# 初始化中文本地化的Faker对象
fake = Faker("zh_CN")

# 生成中文数据
print("中文姓名:", fake.name())
print("中文地址:", fake.address())
print("手机号码:", fake.phone_number())

运行结果示例:

中文姓名: 王芳
中文地址: 湖南省长沙市雨花区人民路88号 410007
手机号码: 13812345678

三、Faker库核心功能与代码示例

Faker库提供了数百种数据生成方法,覆盖日常开发和测试的大部分场景,我们可以将其分为基础信息类网络信息类文本数据类时间日期类专业数据类五大模块,下面逐一进行详细讲解并附上代码示例。

3.1 基础信息类数据生成

基础信息类数据是测试中最常用的类型,包括姓名、地址、手机号、身份证号、公司名称等,以下是常用方法的代码示例:

from faker import Faker

# 初始化中文Faker对象
fake = Faker("zh_CN")

# 生成姓名相关数据
print("随机姓名:", fake.name())  # 生成全名
print("姓氏:", fake.last_name())  # 生成姓氏
print("名字:", fake.first_name())  # 生成名字

# 生成地址相关数据
print("详细地址:", fake.address())  # 生成完整地址(省市区街道邮编)
print("省份:", fake.province())  # 生成省份
print("城市:", fake.city())  # 生成城市
print("街道地址:", fake.street_address())  # 生成街道地址
print("邮政编码:", fake.postcode())  # 生成邮政编码

# 生成联系方式相关数据
print("手机号码:", fake.phone_number())  # 生成中国大陆手机号
print("固定电话:", fake.phone_number())  # 生成固定电话(部分地区)
print("身份证号:", fake.ssn())  # 生成符合规则的身份证号

# 生成公司相关数据
print("公司名称:", fake.company())  # 生成公司名称
print("公司职位:", fake.job())  # 生成公司职位

代码说明:

  • fake.name():生成随机的中文全名,涵盖男女姓名,符合国内姓名命名习惯;
  • fake.address():生成的地址包含省、市、区、街道、门牌号和邮编,格式规范,适合用于用户地址测试;
  • fake.ssn():生成的身份证号符合18位的编码规则,包含地区码、出生日期码、顺序码和校验码,仅用于测试,不具备真实有效性。

运行结果示例:

随机姓名: 李强
姓氏: 张
名字: 小明
详细地址: 广东省深圳市南山区科技园路100号 518000
省份: 浙江省
城市: 杭州市
街道地址: 东湖路99号
邮政编码: 310000
手机号码: 13987654321
固定电话: 021-12345678
身份证号: 430102199001011234
公司名称: 华讯科技有限公司
公司职位: 软件工程师

3.2 网络信息类数据生成

网络信息类数据包括邮箱、URL、IP地址、用户名、密码等,适用于Web开发中的用户注册、登录测试等场景,代码示例如下:

from faker import Faker

fake = Faker("zh_CN")

# 生成邮箱相关数据
print("随机邮箱:", fake.email())  # 生成随机邮箱
print("企业邮箱:", fake.company_email())  # 生成企业邮箱
print("免费邮箱:", fake.free_email())  # 生成免费邮箱(如163、qq邮箱)

# 生成URL相关数据
print("网站URL:", fake.url())  # 生成随机网站URL
print("域名:", fake.domain_name())  # 生成域名
print("IP地址(IPv4):", fake.ipv4())  # 生成IPv4地址
print("IP地址(IPv6):", fake.ipv6())  # 生成IPv6地址

# 生成用户账号相关数据
print("用户名:", fake.user_name())  # 生成用户名
print("密码:", fake.password(length=12))  # 生成指定长度的密码
print("用户代理(UA):", fake.user_agent())  # 生成浏览器用户代理字符串

代码说明:

  • fake.email():生成的邮箱格式规范,包含用户名、@符号和域名,支持自定义域名;
  • fake.password(length=12):通过length参数指定密码长度,默认生成包含字母、数字和特殊字符的强密码;
  • fake.user_agent():生成的UA字符串符合主流浏览器(Chrome、Firefox、Safari等)的格式,适用于爬虫和Web测试。

运行结果示例:

随机邮箱: [email protected]
企业邮箱: [email protected]
免费邮箱: [email protected]
网站URL: https://www.example.com
域名: example.net
IP地址(IPv4): 192.168.1.100
IP地址(IPv6): 2001:0db8:85a3:0000:0000:8a2e:0370:7334
用户名: liqiang_88
密码: K9#p2Q7!xR3t
用户代理(UA): Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36

3.3 文本数据类数据生成

文本数据类数据包括单词、句子、段落、文章、随机字符等,适用于生成测试用的文本内容、填充数据库字段等场景,代码示例如下:

from faker import Faker

fake = Faker("zh_CN")

# 生成单词和句子
print("随机单词:", fake.word())  # 生成单个中文词语
print("随机句子:", fake.sentence())  # 生成单个中文句子
print("多个句子:", fake.sentences(nb=3))  # 生成指定数量的句子,返回列表

# 生成段落和文章
print("随机段落:", fake.paragraph())  # 生成单个中文段落
print("多个段落:", fake.paragraphs(nb=2))  # 生成指定数量的段落,返回列表
print("随机文章:", fake.text())  # 生成一篇中文文章(多个段落)

# 生成随机字符
print("随机字母(大写):", fake.random_letter().upper())  # 生成单个大写字母
print("随机数字字符串:", fake.numerify(text="###-####-#####"))  # 生成指定格式的数字字符串
print("随机字母数字混合字符串:", fake.bothify(text="??##-??##-??##"))  # 生成字母数字混合字符串

代码说明:

  • fake.sentence():生成的句子语法正确,语义通顺,长度在10-20个字符左右;
  • fake.paragraphs(nb=2)nb参数用于指定生成的段落数量,返回的是字符串列表;
  • fake.numerify()fake.bothify():支持通过占位符自定义格式,#代表数字,?代表字母,适合生成订单号、产品编号等格式固定的数据。

运行结果示例:

随机单词: 技术
随机句子: 该项目的研发工作已经进入了最后的测试阶段。
多个句子: ['公司将于下周一召开全体员工大会。', '新产品的市场反馈情况超出了预期。', '请各位同事按时提交本月的工作总结。']
随机段落: 近年来,随着人工智能技术的快速发展,越来越多的企业开始将AI技术应用到实际生产中。从智能制造到智能客服,AI技术的落地场景不断丰富,为企业带来了显著的效率提升和成本节约。同时,相关的政策支持也为AI产业的发展提供了良好的环境,推动整个行业朝着更加规范和成熟的方向迈进。
多个段落: ['在教育领域,线上学习平台的普及改变了传统的教学模式。学生可以随时随地获取学习资源,教师也可以通过大数据分析了解学生的学习情况,从而实现个性化教学。这种模式不仅提高了学习效率,还打破了地域和时间的限制,让优质教育资源能够惠及更多人群。', '随着人们生活水平的提高,健康意识也越来越强。健身、养生、有机食品等概念逐渐成为主流,相关产业也迎来了快速发展的机遇。同时,医疗技术的进步也为人们的健康提供了更有力的保障,许多疑难杂症的治疗效果得到了显著提升。']
随机文章: 数字经济是当前全球经济发展的重要趋势,它以数据为关键生产要素,以现代信息网络为主要载体,以信息通信技术融合应用、全要素数字化转型为重要推动力,促进公平与效率更加统一的新经济形态。
在数字经济的发展过程中,数据的价值日益凸显。企业通过对海量数据的分析和挖掘,可以精准把握市场需求,优化产品设计和生产流程,提升自身的核心竞争力。同时,政府也可以利用数据技术提升治理能力,实现精准施策和高效服务。
然而,数字经济的发展也面临着一些挑战,比如数据安全、隐私保护、数字鸿沟等问题。这些问题需要政府、企业和社会各界共同努力,通过完善法律法规、加强技术研发、推进普惠性政策等方式加以解决,从而推动数字经济健康可持续发展。
随机字母(大写): M
随机数字字符串: 123-4567-89012
随机字母数字混合字符串: AB12-CD34-EF56

3.4 时间日期类数据生成

时间日期类数据包括日期、时间、时间戳等,适用于测试时间相关的功能模块,如订单创建时间、用户注册时间等,代码示例如下:

from faker import Faker

fake = Faker("zh_CN")

# 生成日期相关数据
print("随机日期(YYYY-MM-DD):", fake.date())  # 生成随机日期
print("指定范围的日期:", fake.date_between(start_date="-1y", end_date="today"))  # 生成近一年的日期
print("未来日期:", fake.date_between(start_date="today", end_date="+30d"))  # 生成未来30天的日期
print("生日日期:", fake.date_of_birth(minimum_age=18, maximum_age=60))  # 生成18-60岁的生日日期

# 生成时间相关数据
print("随机时间(HH:MM:SS):", fake.time())  # 生成随机时间
print("日期时间组合:", fake.date_time())  # 生成随机日期时间
print("Unix时间戳:", fake.unix_time())  # 生成Unix时间戳
print("ISO格式日期时间:", fake.iso8601())  # 生成ISO8601格式的日期时间

# 生成时间段相关数据
print("随机月份:", fake.month())  # 生成月份(1-12)
print("随机星期:", fake.day_of_week())  # 生成星期几
print("随机年份:", fake.year())  # 生成年份

代码说明:

  • fake.date_between():通过start_dateend_date参数指定日期范围,支持相对时间(如-1y代表一年前,+30d代表30天后)和绝对时间(如2023-01-01);
  • fake.date_of_birth():通过minimum_agemaximum_age参数指定年龄范围,生成对应的生日日期;
  • fake.iso8601():生成的日期时间符合ISO8601国际标准,格式为YYYY-MM-DDTHH:MM:SS,适用于国际项目的测试。

运行结果示例:

随机日期(YYYY-MM-DD): 2020-05-18
指定范围的日期: 2024-03-25
未来日期: 2025-01-15
生日日期: 1985-08-12
随机时间(HH:MM:SS): 14:35:22
日期时间组合: 2022-11-03 09:12:34
Unix时间戳: 1678901234
ISO格式日期时间: 2023-07-15T16:20:10
随机月份: 06
随机星期: Friday
随机年份: 2019

3.5 专业数据类数据生成

专业数据类数据包括银行卡号、车牌号、颜色、文件扩展名、编程语言等,适用于特定领域的测试场景,代码示例如下:

from faker import Faker

fake = Faker("zh_CN")

# 生成金融相关数据
print("银行卡号:", fake.credit_card_number())  # 生成银行卡号
print("银行卡类型:", fake.credit_card_provider())  # 生成银行卡类型
print("银行卡有效期:", fake.credit_card_expire())  # 生成银行卡有效期

# 生成交通相关数据
print("车牌号:", fake.license_plate())  # 生成中国大陆车牌号

# 生成其他专业数据
print("颜色名称:", fake.color_name())  # 生成颜色名称
print("文件扩展名:", fake.file_extension())  # 生成文件扩展名
print("编程语言:", fake.programming_language())  # 生成编程语言名称
print("UUID:", fake.uuid4())  # 生成UUID4字符串

代码说明:

  • fake.credit_card_number():生成的银行卡号符合各大银行的编码规则,仅用于测试,不可用于真实交易;
  • fake.license_plate():生成的车牌号符合中国大陆的格式(省份简称+字母+数字);
  • fake.uuid4():生成的UUID4字符串符合标准格式,适用于生成唯一标识符。

运行结果示例:

银行卡号: 6222021234567890
银行卡类型: Mastercard
银行卡有效期: 28/12
车牌号: 粤A12345
颜色名称: Blue
文件扩展名: pdf
编程语言: Python
UUID: 550e8400-e29b-41d4-a716-446655440000

四、Faker库高级用法:自定义数据生成器

虽然Faker库提供了丰富的内置数据生成方法,但在实际开发中,我们可能会遇到一些特殊的需求,比如生成符合特定业务规则的数据(如电商平台的商品SKU、物流单号等)。这时我们可以通过自定义数据生成器来扩展Faker的功能,下面以生成电商商品SKU为例,详细讲解自定义数据生成器的实现方法。

4.1 自定义数据生成器的实现步骤

  1. 定义一个自定义生成器类,继承自faker.providers.BaseProvider
  2. 在自定义类中编写数据生成方法;
  3. 将自定义类添加到Faker对象的提供者列表中;
  4. 调用自定义方法生成数据。

4.2 代码示例:生成电商商品SKU

from faker import Faker
from faker.providers import BaseProvider

# 1. 定义自定义生成器类
class CustomSKUProvider(BaseProvider):
    def product_sku(self):
        """
        生成电商商品SKU,格式为:分类缩写-品牌缩写-年份-随机数字
        分类缩写:ELE(电子产品)、CLT(服装)、FOD(食品)
        品牌缩写:APP(苹果)、SAM(三星)、NIK(耐克)、ADID(阿迪达斯)、UNI(统一)
        """
        # 定义分类缩写列表
        category_list = ["ELE", "CLT", "FOD"]
        # 定义品牌缩写列表
        brand_list = {
            "ELE": ["APP", "SAM", "HUA", "XIA"],
            "CLT": ["NIK", "ADID", "PUMA", "ANTA"],
            "FOD": ["UNI", "MASTER", "KANG", "WEIQ"]
        }
        # 随机选择分类
        category = self.random_element(category_list)
        # 根据分类选择品牌
        brand = self.random_element(brand_list[category])
        # 生成年份(近5年)
        year = self.random_int(min=2020, max=2025)
        # 生成随机数字(4位)
        num = self.random_int(min=1000, max=9999)
        # 组合生成SKU
        sku = f"{category}-{brand}-{year}-{num}"
        return sku

# 2. 初始化Faker对象
fake = Faker("zh_CN")

# 3. 添加自定义提供者到Faker对象
fake.add_provider(CustomSKUProvider)

# 4. 调用自定义方法生成数据
for _ in range(5):
    print("自定义商品SKU:", fake.product_sku())

代码说明:

  • 自定义生成器类必须继承自BaseProvider,这是Faker库规定的扩展规范;
  • product_sku方法中,我们通过self.random_element()self.random_int()方法实现随机选择和随机数字生成,这两个方法是BaseProvider类提供的内置方法;
  • 通过fake.add_provider()方法将自定义类添加到Faker对象后,就可以像调用内置方法一样调用fake.product_sku()生成数据。

运行结果示例:

自定义商品SKU: ELE-APP-2023-4567
自定义商品SKU: CLT-NIK-2021-1234
自定义商品SKU: FOD-UNI-2025-8901
自定义商品SKU: ELE-XIA-2022-2345
自定义商品SKU: CLT-ANTA-2024-5678

4.3 自定义数据生成器的扩展场景

除了生成商品SKU,自定义数据生成器还可以应用于以下场景:

  • 生成物流单号(如顺丰:SF+12位数字,圆通:YT+12位数字);
  • 生成医院病历号(如医院代码+年份+随机数字);
  • 生成学生学号(如学校代码+年级+班级+序号)。

五、Faker库实际应用案例:生成测试用的用户数据CSV文件

在实际的软件开发和测试中,我们经常需要批量生成用户数据并保存到文件中,供自动化测试或数据库填充使用。下面以生成100条中文用户数据并保存为CSV文件为例,展示Faker库的实际应用价值。

5.1 案例需求分析

  1. 生成100条用户数据,每条数据包含:用户ID、姓名、性别、年龄、手机号、邮箱、地址、注册时间、职业;
  2. 性别为男/女,随机分布;
  3. 年龄范围为18-60岁;
  4. 注册时间为近一年的随机日期时间;
  5. 将生成的数据保存为test_users.csv文件。

5.2 代码实现

import csv
from faker import Faker
import random

# 初始化中文Faker对象
fake = Faker("zh_CN")

# 定义CSV文件的表头
headers = ["user_id", "name", "gender", "age", "phone", "email", "address", "register_time", "job"]

# 生成100条用户数据
user_data = []
for i in range(1, 101):
    # 生成用户ID
    user_id = f"user_{i:03d}"  # 格式为user_001, user_002...
    # 生成姓名
    name = fake.name()
    # 生成性别
    gender = random.choice(["男", "女"])
    # 生成年龄
    age = random.randint(18, 60)
    # 生成手机号
    phone = fake.phone_number()
    # 生成邮箱
    email = fake.email()
    # 生成地址
    address = fake.address().replace("\n", " ")  # 去除地址中的换行符
    # 生成注册时间(近一年)
    register_time = fake.date_time_between(start_date="-1y", end_date="today")
    # 生成职业
    job = fake.job()
    # 将数据添加到列表
    user_data.append([user_id, name, gender, age, phone, email, address, register_time, job])

# 将数据写入CSV文件
with open("test_users.csv", "w", encoding="utf-8", newline="") as f:
    writer = csv.writer(f)
    # 写入表头
    writer.writerow(headers)
    # 写入数据
    writer.writerows(user_data)

print("100条用户数据已成功生成并保存到test_users.csv文件中!")

代码说明:

  • 使用random.choice(["男", "女"])实现性别的随机分布;
  • user_id的格式通过f-string格式化,i:03d表示将数字格式化为3位,不足的位数用0填充;
  • 地址数据中可能包含换行符,通过replace("\n", " ")去除换行符,避免CSV文件格式错乱;
  • 写入CSV文件时,指定encoding="utf-8"以支持中文,newline=""避免出现空行。

5.3 案例运行结果

运行代码后,会在当前目录下生成一个名为test_users.csv的文件,打开后部分数据如下:
| user_id | name | gender | age | phone | email | address | register_time | job |
| – | – | – | – | – | – | – | – | – |
| user_001 | 张三 | 男 | 28 | 13812345678 | [email protected] | 北京市朝阳区建国路88号 100022 | 2024-05-12 10:23:45 | 软件工程师 |
| user_002 | 李四 | 女 | 35 | 13987654321 | [email protected] | 上海市浦东新区张江高科技园区 201203 | 2024-08-25 14:56:12 | 市场经理 |

这个案例可以直接应用于Web项目的用户模块测试,帮助测试人员快速构建测试数据,提高测试效率。

六、Faker库相关资源链接

  • Pypi地址:https://pypi.org/project/Faker
  • Github地址:https://github.com/joke2k/faker
  • 官方文档地址:https://faker.readthedocs.io/en/master/

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:Datasets库快速上手指南_数据处理与加载必备神器

一、Datasets库核心概述

1.1 用途与工作原理

Hugging Face的Datasets库是一款专为自然语言处理(NLP)以及机器学习领域打造的数据集处理工具,它能够帮助开发者快速加载、预处理、转换和共享各类数据集。其核心工作原理是基于一种统一的数据集抽象结构DatasetDatasetDict,将不同来源、不同格式的数据集(如CSV、JSON、文本文件、Hugging Face Hub上的开源数据集)进行标准化封装,同时内置高效的数据处理管道,支持并行处理、流式加载和内存映射,大幅降低了数据预处理的门槛。

1.2 优缺点分析

优点

  1. 数据集资源丰富:无缝对接Hugging Face Hub,拥有上万种开源数据集,涵盖文本分类、问答、翻译等数十种任务类型。
  2. 高效便捷的预处理:内置map()filter()shuffle()等方法,支持并行处理大规模数据集,无需手动编写复杂的循环逻辑。
  3. 内存友好:支持流式加载和内存映射,即使处理GB级别的数据集也不会导致内存溢出。
  4. 跨框架兼容:可以轻松转换为pandasNumPyTensorFlowPyTorch等格式,适配主流机器学习框架。

缺点

  1. 对非文本数据支持有限:虽然可以处理图像、音频类数据集,但功能不如文本数据集完善,预处理选项较少。
  2. 网络依赖较强:加载Hugging Face Hub上的数据集需要稳定的网络环境,离线使用时需要提前下载数据集。
  3. 新手学习成本:部分高级功能(如自定义数据集加载脚本)的使用需要熟悉库的底层逻辑,对纯小白不够友好。

1.3 License类型

Datasets库采用Apache License 2.0开源协议,这意味着开发者可以自由地使用、修改和分发该库的代码,无论是商业项目还是非商业项目,都无需支付任何费用,只需要在修改后的代码中保留原作者的版权声明即可。

二、Datasets库安装与环境配置

2.1 安装命令

Datasets库的安装非常简单,支持pipconda两种方式,推荐使用pip安装,因为它的更新速度更快,适配性更强。

2.1.1 pip安装(推荐)

打开命令行终端,输入以下命令即可完成安装:

pip install datasets

如果需要处理特定格式的数据集(如音频、图像、Parquet文件),可以安装对应的依赖包:

pip install datasets[audio,vision,parquet]

2.1.2 conda安装

如果你的开发环境是基于conda管理的,可以使用以下命令安装:

conda install -c huggingface -c conda-forge datasets

2.2 环境验证

安装完成后,我们可以通过一段简单的Python代码验证是否安装成功。打开Python交互式环境或新建一个.py文件,输入以下代码:

import datasets
print(f"Datasets库版本:{datasets.__version__}")

运行代码后,如果终端输出了具体的版本号(如2.14.5),则说明安装成功;如果出现ModuleNotFoundError,则说明安装失败,需要重新检查安装命令或Python环境。

三、Datasets库核心功能与代码示例

3.1 加载开源数据集(Hugging Face Hub)

Datasets库最核心的功能之一就是一键加载Hugging Face Hub上的开源数据集。以经典的文本分类数据集imdb(电影评论情感分析数据集)为例,我们来演示如何加载并查看数据集的基本信息。

3.1.1 加载imdb数据集

from datasets import load_dataset

# 加载imdb数据集,该数据集包含train、test、unsupervised三个子集
dataset = load_dataset("imdb")

# 打印数据集的结构
print("数据集结构:", dataset)
# 打印训练集的第一条数据
print("训练集第一条数据:", dataset["train"][0])

代码说明

  • load_dataset()函数是加载数据集的核心函数,传入数据集名称即可自动从Hugging Face Hub下载并加载数据集。
  • imdb数据集是一个DatasetDict对象,包含train(训练集)、test(测试集)、unsupervised(无监督数据集)三个子集。
  • 通过下标索引dataset["train"][0]可以查看训练集的第一条数据,数据格式为字典,包含text(电影评论文本)和label(情感标签,0为负面,1为正面)两个字段。

运行结果示例

数据集结构: DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})
训练集第一条数据: {'text': 'I rented I AM CURIOUS-YELLOW...', 'label': 0}

3.1.2 加载指定子集与拆分数据

如果我们只需要加载imdb数据集的训练集和测试集,可以通过split参数指定;同时,还可以使用train_test_split()方法将数据集拆分为新的训练集和验证集。

from datasets import load_dataset

# 只加载train和test子集
dataset = load_dataset("imdb", split=["train", "test"])
train_dataset, test_dataset = dataset

# 将训练集拆分为训练集(80%)和验证集(20%)
train_val_dataset = train_dataset.train_test_split(test_size=0.2, seed=42)
train_new = train_val_dataset["train"]
val_new = train_val_dataset["test"]

print(f"新训练集大小:{len(train_new)}")
print(f"验证集大小:{len(val_new)}")
print(f"测试集大小:{len(test_dataset)}")

代码说明

  • split参数传入一个列表,可以指定加载的数据集子集,返回的结果是一个列表,需要手动解包为对应的数据集对象。
  • train_test_split()方法用于拆分数据集,test_size参数指定验证集的比例,seed参数用于固定随机种子,保证实验的可重复性。
  • len()函数可以查看数据集的样本数量,imdb原训练集有25000条数据,拆分后新训练集有20000条,验证集有5000条。

运行结果示例

新训练集大小:20000
验证集大小:5000
测试集大小:25000

3.2 加载本地数据集

除了加载Hugging Face Hub上的开源数据集,Datasets库还支持加载本地的CSV、JSON、文本等格式的数据集。下面我们以CSV格式为例,演示如何加载本地数据集。

3.2.1 准备本地CSV数据集

首先,我们创建一个名为local_data.csv的CSV文件,内容如下:

text,label
This movie is amazing,1
This movie is terrible,0
I love this film,1
The plot is boring,0
The acting is great,1

该文件包含两列,text列是电影评论文本,label列是情感标签。

3.2.2 加载本地CSV数据集

from datasets import load_dataset

# 加载本地CSV数据集
local_dataset = load_dataset("csv", data_files="local_data.csv")

# 打印数据集结构
print("本地数据集结构:", local_dataset)
# 打印所有数据
print("本地数据集所有数据:", local_dataset["train"])
# 访问单条数据的字段
for data in local_dataset["train"]:
    print(f"文本:{data['text']},标签:{data['label']}")

代码说明

  • 加载本地数据集时,load_dataset()函数的第一个参数需要指定数据格式(如csvjsontext),data_files参数指定本地数据文件的路径。
  • 对于单文件的CSV数据集,加载后默认会生成一个名为train的子集。
  • 可以通过遍历数据集对象,访问每条数据的具体字段。

运行结果示例

本地数据集结构: DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 5
    })
})
本地数据集所有数据: Dataset({
    features: ['text', 'label'],
    num_rows: 5
})
文本:This movie is amazing,标签:1
文本:This movie is terrible,标签:0
文本:I love this film,标签:1
文本:The plot is boring,标签:0
文本:The acting is great,标签:1

3.2.3 加载多个本地数据文件

如果本地有多个数据文件(如训练集和测试集分别存储在不同的CSV文件中),可以通过data_files参数传入一个字典来指定:

from datasets import load_dataset

# 定义多个数据文件的路径
data_files = {
    "train": "train_data.csv",
    "test": "test_data.csv"
}

# 加载多个CSV文件
multi_dataset = load_dataset("csv", data_files=data_files)

print("多文件数据集结构:", multi_dataset)
print(f"训练集样本数:{len(multi_dataset['train'])}")
print(f"测试集样本数:{len(multi_dataset['test'])}")

代码说明

  • data_files参数传入一个字典时,字典的键是数据集子集的名称(如traintest),值是对应的数据文件路径。
  • 加载后会生成一个DatasetDict对象,包含指定的多个子集,方便后续分别处理训练集和测试集。

3.3 数据集预处理

加载数据集后,通常需要进行预处理(如文本分词、长度截断、特征提取等),Datasets库提供了map()方法,可以高效地对数据集进行批量预处理。下面我们以文本分词为例,演示如何使用map()方法。

3.3.1 安装分词器依赖

我们使用Hugging Face的transformers库中的分词器,需要先安装该库:

pip install transformers

3.3.2 对imdb数据集进行分词预处理

from datasets import load_dataset
from transformers import AutoTokenizer

# 加载imdb数据集的训练集
dataset = load_dataset("imdb", split="train")
# 加载预训练分词器(以bert-base-uncased为例)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# 定义预处理函数
def preprocess_function(examples):
    # 对文本进行分词,设置最大长度为128,超过截断,不足补齐
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

# 对数据集进行批量预处理,开启并行处理
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4  # 使用4个进程并行处理
)

# 查看预处理后的数据集特征
print("预处理后的特征:", tokenized_dataset.features)
# 查看第一条预处理后的数据
print("第一条预处理后的数据:", tokenized_dataset[0])

代码说明

  • AutoTokenizer.from_pretrained()函数用于加载预训练分词器,bert-base-uncased是一个常用的英文预训练模型。
  • preprocess_function()是自定义的预处理函数,接收一个包含批量数据的字典examples,返回分词后的结果,包含input_idsattention_mask等字段。
  • map()方法的batched=True参数表示批量处理数据,num_proc参数指定并行处理的进程数,能够大幅提升预处理速度。
  • 预处理后的数据集会新增分词相关的字段,这些字段可以直接输入到预训练模型中进行训练。

运行结果示例

预处理后的特征: {'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['neg', 'pos'], id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None)}
第一条预处理后的数据: {'text': 'I rented I AM CURIOUS-YELLOW...', 'label': 0, 'input_ids': [101, 1045, 5975, ...], 'token_type_ids': [0, 0, 0, ...], 'attention_mask': [1, 1, 1, ...]}

3.4 数据集筛选与排序

Datasets库提供了filter()sort()方法,用于对数据集进行筛选和排序。下面我们演示如何筛选出imdb数据集中文本长度大于100的样本,以及如何按文本长度对数据集进行排序。

3.4.1 数据集筛选

from datasets import load_dataset

# 加载imdb训练集
dataset = load_dataset("imdb", split="train")

# 定义筛选函数:保留文本长度大于100的样本
def filter_function(example):
    return len(example["text"]) > 100

# 对数据集进行筛选
filtered_dataset = dataset.filter(filter_function)

print(f"原数据集样本数:{len(dataset)}")
print(f"筛选后数据集样本数:{len(filtered_dataset)}")
# 查看筛选后第一条数据的文本长度
print(f"筛选后第一条数据文本长度:{len(filtered_dataset[0]['text'])}")

代码说明

  • filter_function()是自定义的筛选函数,接收单条数据example,返回True表示保留该样本,返回False表示过滤掉该样本。
  • filter()方法会遍历数据集的所有样本,根据筛选函数的结果保留符合条件的样本。

运行结果示例

原数据集样本数:25000
筛选后数据集样本数:24890
筛选后第一条数据文本长度:178

3.4.2 数据集排序

from datasets import load_dataset

# 加载imdb训练集
dataset = load_dataset("imdb", split="train")

# 为数据集添加文本长度字段
def add_length_field(example):
    example["text_length"] = len(example["text"])
    return example

dataset_with_length = dataset.map(add_length_field)

# 按文本长度升序排序
sorted_dataset_asc = dataset_with_length.sort("text_length")
# 按文本长度降序排序
sorted_dataset_desc = dataset_with_length.sort("text_length", reverse=True)

print(f"最短文本长度:{sorted_dataset_asc[0]['text_length']}")
print(f"最长文本长度:{sorted_dataset_desc[0]['text_length']}")

代码说明

  • 首先通过map()方法为数据集添加一个text_length字段,存储每条数据的文本长度。
  • sort()方法接收一个字段名,按照该字段的值对数据集进行排序,reverse=True表示降序排序,默认是升序排序。

运行结果示例

最短文本长度:14
最长文本长度:13704

3.5 数据集格式转换

Datasets库的数据集对象可以轻松转换为pandasNumPyPyTorchTensorFlow等格式,适配不同的机器学习框架。下面我们演示如何将数据集转换为这些格式。

3.5.1 转换为pandas DataFrame格式

from datasets import load_dataset

# 加载imdb训练集
dataset = load_dataset("imdb", split="train[:100]")  # 只取前100条数据

# 转换为pandas DataFrame
df = dataset.to_pandas()

# 查看DataFrame的前5行数据
print(df.head())
# 查看DataFrame的基本信息
print(df.info())

代码说明

  • to_pandas()方法可以将数据集转换为pandasDataFrame对象,方便使用pandas进行数据分析和可视化。
  • 为了避免数据量过大,我们通过split="train[:100]"只取训练集的前100条数据。

运行结果示例

                                                text  label
0  I rented I AM CURIOUS-YELLOW...                    0
1  "I Am Curious: Yellow" is a...                    0
2  If only to avoid making this...                    0
3  This film was probably intende...                  0
4  I saw this movie when I was ...                   0
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
    --  -- 
 0   text    100 non-null    object
 1   label   100 non-null    int64 
dtypes: int64(1), object(1)
memory usage: 1.7+ KB
None

3.5.2 转换为NumPy数组格式

from datasets import load_dataset

# 加载imdb训练集
dataset = load_dataset("imdb", split="train[:100]")

# 转换为NumPy数组
numpy_array = dataset.to_numpy()

# 查看NumPy数组的形状和第一条数据
print(f"NumPy数组形状:{numpy_array.shape}")
print(f"第一条数据:{numpy_array[0]}")

代码说明

  • to_numpy()方法将数据集转换为NumPy数组,数组中的每个元素是一个字典,包含数据集的所有字段。

运行结果示例

NumPy数组形状:(100,)
第一条数据: {'text': 'I rented I AM CURIOUS-YELLOW...', 'label': 0}

3.5.3 转换为PyTorch张量格式

from datasets import load_dataset
from transformers import AutoTokenizer

# 加载imdb训练集并进行分词预处理
dataset = load_dataset("imdb", split="train[:100]")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

tokenized_dataset = dataset.map(preprocess_function, batched=True)

# 转换为PyTorch张量格式
torch_dataset = tokenized_dataset.with_format("torch")

# 查看张量格式的输入ids和标签
print(f"Input IDs张量形状:{torch_dataset['input_ids'].shape}")
print(f"标签张量形状:{torch_dataset['label'].shape}")

代码说明

  • with_format("torch")方法将数据集转换为PyTorch张量格式,转换后可以直接通过下标访问张量数据。
  • 转换后的数据集可以直接输入到PyTorch模型中进行训练,无需手动转换数据类型。

运行结果示例

Input IDs张量形状: torch.Size([100, 128])
标签张量形状: torch.Size([100])

四、实际案例:基于Datasets库的文本分类任务

为了让大家更好地理解Datasets库的实际应用,我们结合transformers库搭建一个简单的文本分类模型,完成imdb电影评论情感分析任务。

4.1 案例流程

  1. 加载imdb数据集并进行预处理。
  2. 加载预训练模型和分词器。
  3. 构建训练参数并训练模型。
  4. 在测试集上评估模型性能。

4.2 完整代码实现

# 导入必要的库
from datasets import load_dataset, load_metric
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
import numpy as np

# 步骤1:加载数据集并进行预处理
# 加载imdb的训练集和测试集
dataset = load_dataset("imdb")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

# 加载预训练分词器
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# 定义预处理函数
def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

# 对训练集和测试集进行预处理
tokenized_train = train_dataset.map(preprocess_function, batched=True, num_proc=4)
tokenized_test = test_dataset.map(preprocess_function, batched=True, num_proc=4)

# 设置数据集格式为PyTorch张量
tokenized_train.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "label"]
)
tokenized_test.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "label"]
)

# 步骤2:加载预训练模型
model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2  # 二分类任务,标签数为2
)

# 步骤3:定义评估指标
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# 步骤4:设置训练参数
training_args = TrainingArguments(
    output_dir="./imdb_sentiment_model",  # 模型输出目录
    num_train_epochs=2,  # 训练轮数
    per_device_train_batch_size=16,  # 训练批次大小
    per_device_eval_batch_size=16,  # 评估批次大小
    evaluation_strategy="epoch",  # 每轮训练后进行评估
    save_strategy="epoch",  # 每轮训练后保存模型
    logging_dir="./logs",  # 日志目录
    logging_steps=100,  # 每100步记录一次日志
    learning_rate=2e-5,  # 学习率
    weight_decay=0.01,  # 权重衰减
)

# 步骤5:构建Trainer并开始训练
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    compute_metrics=compute_metrics,
)

# 开始训练
trainer.train()

# 步骤6:在测试集上评估模型
eval_results = trainer.evaluate()
print(f"测试集准确率:{eval_results['eval_accuracy']:.4f}")

4.3 代码说明

  1. 数据集预处理:我们使用bert-base-uncased分词器对文本进行分词,设置最大长度为128,并将数据集转换为PyTorch张量格式。
  2. 模型加载AutoModelForSequenceClassification是一个用于序列分类任务的预训练模型,num_labels=2表示二分类任务。
  3. 评估指标:使用accuracy指标评估模型性能,compute_metrics函数用于计算预测结果的准确率。
  4. 训练参数设置TrainingArguments包含了训练过程中的所有参数,如训练轮数、批次大小、学习率等。
  5. 模型训练与评估Trainertransformers库提供的训练工具,调用train()方法开始训练,evaluate()方法在测试集上评估模型性能。

4.4 预期结果

经过2轮训练后,模型在imdb测试集上的准确率可以达到90%以上,具体数值会因硬件环境和随机种子的不同而略有差异。

五、相关资源链接

  • Pypi地址:https://pypi.org/project/datasets
  • Github地址:https://github.com/huggingface/datasets
  • 官方文档地址:https://huggingface.co/docs/datasets/index

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:psycopg3 完全指南——PostgreSQL数据库交互最佳实践

一、psycopg3 库概述

psycopg3 是 Python 编程语言中用于连接和操作 PostgreSQL 数据库的高性能适配器,也是 psycopg2 的官方升级版本。它遵循 Python DB API 2.0 规范,能够实现 Python 程序与 PostgreSQL 数据库的高效数据交互,支持异步操作、类型适配、事务管理等核心功能。

工作原理上,psycopg3 通过底层的 libpq 库(PostgreSQL 官方客户端库)建立 Python 与数据库的通信通道,将 Python 数据类型转换为 PostgreSQL 支持的类型,同时把数据库返回的结果转换为 Python 原生对象,实现双向数据无缝流转。

该库的优点十分突出:支持异步 I/O 操作,适配 asyncio 框架;性能较 psycopg2 大幅提升;对 Python 3.8+ 版本兼容性好;支持 PostgreSQL 10+ 所有新特性;提供灵活的参数化查询,有效防止 SQL 注入攻击。缺点则是目前生态插件较少,部分旧项目迁移需要修改代码;对低版本 Python(3.7 及以下)不支持。

psycopg3 的开源协议为 GNU Lesser General Public License (LGPL) 3.0,允许用户自由使用、修改和分发,可用于商业项目开发。

二、psycopg3 安装与环境配置

2.1 前置条件

在安装 psycopg3 之前,需要确保系统满足以下两个核心条件:

  1. Python 环境:版本 3.8 或更高,推荐使用 3.10+ 稳定版本。
  2. PostgreSQL 环境:本地或远程服务器已安装 PostgreSQL 10 或更高版本,且确保数据库服务处于运行状态。
  3. 依赖库:系统需要安装 libpq 开发包,不同操作系统安装命令如下:
    • Ubuntu/Debian 系统
      bash sudo apt-get install libpq-dev python3-dev
    • CentOS/RHEL 系统
      bash sudo yum install postgresql-devel python3-devel
    • Windows 系统
      无需手动安装,psycopg3 的 Windows 版本已内置相关依赖。

2.2 安装 psycopg3

psycopg3 已发布到 PyPI 仓库,可通过 pip 包管理器一键安装,这是最简单且推荐的方式。

2.2.1 基础安装命令

打开终端或命令提示符,执行以下命令:

pip install psycopg3

2.2.2 指定版本安装

如果需要安装特定版本的 psycopg3(例如 3.1.12),可以执行:

pip install psycopg3==3.1.12

2.2.3 验证安装是否成功

安装完成后,可通过 Python 交互式环境验证是否安装成功。打开 Python 终端,输入以下代码:

import psycopg
print(psycopg.__version__)

如果终端输出 psycopg3 的版本号(如 3.1.12),则说明安装成功;若出现 ModuleNotFoundError 错误,则需要检查 pip 安装路径或 Python 环境配置。

三、psycopg3 核心用法与代码示例

psycopg3 的核心操作围绕 数据库连接游标对象数据查询数据增删改事务管理 展开,下面结合具体代码示例详细讲解每个功能的使用方法。

3.1 数据库连接与关闭

要操作 PostgreSQL 数据库,第一步是建立 Python 程序与数据库的连接。psycopg3 提供 psycopg.connect() 函数创建连接对象,连接参数需要包含数据库地址、端口、用户名、密码、数据库名等信息。

3.1.1 基础连接示例

import psycopg

# 定义数据库连接参数
conn_params = {
    "host": "localhost",  # 数据库服务器地址,本地为localhost
    "port": 5432,         # PostgreSQL 默认端口为5432
    "user": "postgres",   # 数据库用户名,默认管理员用户为postgres
    "password": "123456", # 数据库密码,安装时设置
    "dbname": "testdb"    # 要连接的数据库名,需提前创建
}

# 建立数据库连接
try:
    conn = psycopg.connect(**conn_params)
    print("数据库连接成功!")
except psycopg.OperationalError as e:
    print(f"数据库连接失败:{e}")
finally:
    # 关闭数据库连接
    if conn:
        conn.close()
        print("数据库连接已关闭!")

代码说明

  • psycopg.connect() 函数接收关键字参数,返回一个 Connection 对象,代表与数据库的会话。
  • 使用 try-except 捕获连接异常(如密码错误、数据库不存在、服务未启动等),避免程序崩溃。
  • 最后通过 conn.close() 关闭连接,释放数据库资源,这是必须执行的步骤。

3.1.2 使用上下文管理器(推荐)

手动关闭连接容易遗漏,psycopg3 支持使用 with 语句(上下文管理器)自动管理连接的创建和关闭,这是更优雅、更推荐的写法。

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

# 使用with语句自动管理连接
with psycopg.connect(**conn_params) as conn:
    print("数据库连接成功!")
    # 后续数据库操作写在这里
print("数据库连接已自动关闭!")

代码说明

  • with 代码块执行完毕后,无论是否发生异常,连接都会自动关闭,无需手动调用 conn.close()
  • 这种写法可以有效避免因忘记关闭连接导致的资源泄露问题。

3.2 游标对象与基本数据操作

建立数据库连接后,需要通过 游标对象(Cursor) 执行 SQL 语句。游标是数据库操作的核心接口,负责提交 SQL 命令并获取执行结果。

3.2.1 创建游标对象

with 语句块内,通过 conn.cursor() 方法创建游标对象:

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    # 创建游标对象
    with conn.cursor() as cur:
        print("游标对象创建成功!")

代码说明

  • 游标对象也支持 with 语句,执行完毕后自动关闭,释放游标资源。
  • 所有 SQL 语句的执行都需要通过游标对象的方法实现。

3.2.2 创建数据表

下面通过游标执行 CREATE TABLE 语句,创建一个名为 students 的数据表,用于存储学生信息。

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义创建数据表的SQL语句
        create_table_sql = """
        CREATE TABLE IF NOT EXISTS students (
            id SERIAL PRIMARY KEY,
            name VARCHAR(50) NOT NULL,
            age INTEGER NOT NULL,
            gender VARCHAR(10),
            score NUMERIC(5, 2)
        );
        """
        # 执行SQL语句
        cur.execute(create_table_sql)
        # 提交事务(psycopg3默认开启事务,执行修改操作后需要提交)
        conn.commit()
        print("数据表students创建成功!")

代码说明

  • CREATE TABLE IF NOT EXISTS 表示如果数据表不存在则创建,避免重复创建导致的错误。
  • SERIAL 类型是 PostgreSQL 的自增整数类型,作为主键使用。
  • cur.execute(sql) 方法用于执行 SQL 语句,参数为字符串格式的 SQL 命令。
  • conn.commit() 用于提交事务,psycopg3 默认处于事务模式,所有对数据库的修改操作(创建表、插入、更新、删除)都需要提交事务才能生效。

3.2.3 插入数据

向数据表中插入数据有两种方式:单条数据插入和批量数据插入,psycopg3 推荐使用 参数化查询 方式,避免 SQL 注入攻击。

(1)单条数据插入
import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义插入数据的SQL语句,使用%s作为参数占位符
        insert_sql = """
        INSERT INTO students (name, age, gender, score)
        VALUES (%s, %s, %s, %s);
        """
        # 定义要插入的数据
        student_data = ("张三", 18, "男", 95.5)
        # 执行插入操作
        cur.execute(insert_sql, student_data)
        # 提交事务
        conn.commit()
        print(f"成功插入 {cur.rowcount} 条数据!")

代码说明

  • SQL 语句中使用 %s 作为参数占位符,无论参数类型是什么,都统一使用 %s,这是 psycopg3 的固定语法。
  • cur.execute() 方法的第二个参数是一个元组,包含要插入的具体数据,元组长度必须与占位符数量一致。
  • cur.rowcount 属性返回受上一条 SQL 语句影响的行数,用于验证数据是否插入成功。
(2)批量数据插入

如果需要插入多条数据,使用 cur.executemany() 方法,效率远高于多次执行 cur.execute()

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        insert_sql = """
        INSERT INTO students (name, age, gender, score)
        VALUES (%s, %s, %s, %s);
        """
        # 定义多条学生数据,列表中每个元素是一个元组
        students_data = [
            ("李四", 19, "女", 92.0),
            ("王五", 18, "男", 88.5),
            ("赵六", 20, "女", 96.0)
        ]
        # 执行批量插入
        cur.executemany(insert_sql, students_data)
        conn.commit()
        print(f"成功插入 {cur.rowcount} 条数据!")

代码说明

  • cur.executemany() 方法接收两个参数:SQL 语句和包含多个元组的列表。
  • 批量插入可以减少与数据库的交互次数,大幅提升数据插入效率,适合大量数据插入场景。

3.2.4 查询数据

查询数据是数据库操作中最常用的功能,psycopg3 提供了多种方式获取查询结果,包括 cur.fetchone()cur.fetchmany()cur.fetchall()

(1)获取所有数据(fetchall)
import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义查询SQL语句
        select_sql = "SELECT * FROM students;"
        # 执行查询
        cur.execute(select_sql)
        # 获取所有查询结果
        all_students = cur.fetchall()
        # 遍历结果并打印
        print("所有学生信息:")
        for student in all_students:
            print(f"ID: {student[0]}, 姓名: {student[1]}, 年龄: {student[2]}, 性别: {student[3]}, 分数: {student[4]}")

代码说明

  • cur.fetchall() 方法返回一个列表,列表中的每个元素是一个元组,对应数据表中的一行数据。
  • 元组中的元素顺序与 SQL 查询语句中的字段顺序一致,例如 SELECT * 表示按数据表字段顺序返回所有字段。
(2)获取单条数据(fetchone)
import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        select_sql = "SELECT name, score FROM students WHERE age = %s;"
        # 查询年龄为18的学生
        cur.execute(select_sql, (18,))
        # 获取第一条匹配的数据
        student = cur.fetchone()
        if student:
            print(f"18岁学生信息:姓名={student[0]}, 分数={student[1]}")
        else:
            print("未找到18岁的学生!")

代码说明

  • cur.fetchone() 方法返回查询结果的第一条数据,类型为元组;如果没有匹配数据,返回 None
  • 该方法适合只需要获取单条数据的场景,例如根据主键查询数据。
(3)获取指定数量数据(fetchmany)
import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        select_sql = "SELECT name, age FROM students ORDER BY score DESC;"
        cur.execute(select_sql)
        # 获取前2条数据
        top_students = cur.fetchmany(2)
        print("分数最高的2名学生:")
        for student in top_students:
            print(f"姓名: {student[0]}, 年龄: {student[1]}")

代码说明

  • cur.fetchmany(size) 方法接收一个整数参数 size,表示要获取的数据条数,返回一个包含指定数量元组的列表。
  • 该方法适合分页查询场景,避免一次性获取大量数据导致内存占用过高。

3.2.5 更新数据

更新数据使用 UPDATE 语句,同样需要使用参数化查询,并提交事务才能生效。

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义更新SQL语句,将张三的分数更新为98.0
        update_sql = "UPDATE students SET score = %s WHERE name = %s;"
        cur.execute(update_sql, (98.0, "张三"))
        conn.commit()
        print(f"成功更新 {cur.rowcount} 条数据!")

代码说明

  • UPDATE 语句中的 WHERE 子句用于指定更新条件,避免误更新整个数据表的所有数据。
  • cur.rowcount 返回被更新的行数,可用于验证更新操作是否成功。

3.2.6 删除数据

删除数据使用 DELETE 语句,同样需要指定条件,防止误删所有数据。

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义删除SQL语句,删除分数低于90的学生
        delete_sql = "DELETE FROM students WHERE score < %s;"
        cur.execute(delete_sql, (90.0,))
        conn.commit()
        print(f"成功删除 {cur.rowcount} 条数据!")

代码说明

  • DELETE 语句中的 WHERE 子句是关键,若省略 WHERE 子句,将删除数据表中的所有数据。
  • 删除操作不可逆,执行前需谨慎确认条件是否正确。

3.3 事务管理

事务是数据库操作的基本单位,具有 原子性(Atomicity)、一致性(Consistency)、隔离性(Isolation)、持久性(Durability) 四个特性,简称 ACID。psycopg3 默认开启事务,所有修改操作都需要通过 conn.commit() 提交,若发生错误,可通过 conn.rollback() 回滚事务,撤销所有未提交的操作。

3.3.1 事务提交与回滚示例

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        try:
            # 插入一条数据
            insert_sql = "INSERT INTO students (name, age, gender, score) VALUES (%s, %s, %s, %s);"
            cur.execute(insert_sql, ("孙七", 19, "男", 85.0))

            # 故意触发错误(例如插入重复主键,这里用错误的SQL语句模拟)
            error_sql = "INSERT INTO students (id, name) VALUES (1, '孙七');"
            cur.execute(error_sql)

            # 提交事务(如果前面出现错误,这行代码不会执行)
            conn.commit()
            print("事务提交成功!")
        except psycopg.Error as e:
            # 发生错误,回滚事务
            conn.rollback()
            print(f"事务执行失败,已回滚:{e}")

代码说明

  • 当事务中的任何一个操作发生错误时,except 块会捕获异常,并执行 conn.rollback(),撤销所有已执行但未提交的操作。
  • 事务回滚可以保证数据库数据的一致性,避免因部分操作成功、部分操作失败导致的数据错乱。

3.4 异步操作

psycopg3 支持异步操作,通过 psycopg.asyncpg 模块实现,适配 Python 的 asyncio 框架,适合高并发的异步应用场景(如异步 Web 框架 FastAPI)。

3.4.1 异步连接与数据查询示例

import asyncio
import psycopg
from psycopg import AsyncConnection

# 定义异步函数
async def async_query():
    conn_params = {
        "host": "localhost",
        "port": 5432,
        "user": "postgres",
        "password": "123456",
        "dbname": "testdb"
    }
    # 建立异步数据库连接
    async with await psycopg.AsyncConnection.connect(**conn_params) as conn:
        # 创建异步游标
        async with conn.cursor() as cur:
            # 执行异步查询
            await cur.execute("SELECT name, score FROM students;")
            # 获取查询结果
            results = await cur.fetchall()
            print("异步查询结果:")
            for name, score in results:
                print(f"姓名: {name}, 分数: {score}")

# 运行异步函数
if __name__ == "__main__":
    asyncio.run(async_query())

代码说明

  • 异步操作需要使用 psycopg.AsyncConnection 类,通过 await 关键字执行异步方法。
  • async with 语句用于管理异步连接和异步游标,自动处理资源的创建和释放。
  • 异步操作可以显著提升高并发场景下的程序性能,避免因同步等待数据库响应导致的阻塞。

四、实际应用案例:学生成绩管理系统

下面结合一个实际案例,展示如何使用 psycopg3 开发一个简单的学生成绩管理系统,实现学生信息的添加、查询、更新和删除功能。

4.1 案例需求

  1. 能够添加新学生的成绩信息。
  2. 能够根据学生姓名查询成绩。
  3. 能够根据学生 ID 更新成绩。
  4. 能够根据学生姓名删除学生信息。
  5. 所有操作需要包含异常处理,确保程序稳定性。

4.2 完整代码实现

import psycopg
from typing import Optional, List, Tuple

class StudentScoreManager:
    def __init__(self, conn_params: dict):
        """初始化管理器,接收数据库连接参数"""
        self.conn_params = conn_params

    def add_student(self, name: str, age: int, gender: str, score: float) -> bool:
        """添加新学生信息
        Args:
            name: 学生姓名
            age: 学生年龄
            gender: 学生性别
            score: 学生分数
        Returns:
            添加成功返回True,失败返回False
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    insert_sql = """
                    INSERT INTO students (name, age, gender, score)
                    VALUES (%s, %s, %s, %s);
                    """
                    cur.execute(insert_sql, (name, age, gender, score))
                    conn.commit()
                    print(f"添加学生 {name} 成功!")
                    return True
        except psycopg.Error as e:
            print(f"添加学生失败:{e}")
            return False

    def query_student(self, name: str) -> Optional[Tuple]:
        """根据姓名查询学生信息
        Args:
            name: 学生姓名
        Returns:
            找到学生返回元组,未找到返回None
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    select_sql = """
                    SELECT id, name, age, gender, score FROM students WHERE name = %s;
                    """
                    cur.execute(select_sql, (name,))
                    student = cur.fetchone()
                    if student:
                        print(f"查询到学生信息:ID={student[0]}, 姓名={student[1]}, 年龄={student[2]}, 性别={student[3]}, 分数={student[4]}")
                        return student
                    else:
                        print(f"未找到姓名为 {name} 的学生!")
                        return None
        except psycopg.Error as e:
            print(f"查询学生失败:{e}")
            return None

    def update_score(self, student_id: int, new_score: float) -> bool:
        """根据学生ID更新分数
        Args:
            student_id: 学生ID
            new_score: 新分数
        Returns:
            更新成功返回True,失败返回False
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    update_sql = "UPDATE students SET score = %s WHERE id = %s;"
                    cur.execute(update_sql, (new_score, student_id))
                    conn.commit()
                    if cur.rowcount > 0:
                        print(f"更新ID为 {student_id} 的学生分数成功!")
                        return True
                    else:
                        print(f"未找到ID为 {student_id} 的学生!")
                        return False
        except psycopg.Error as e:
            print(f"更新分数失败:{e}")
            return False

    def delete_student(self, name: str) -> bool:
        """根据姓名删除学生信息
        Args:
            name: 学生姓名
        Returns:
            删除成功返回True,失败返回False
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    delete_sql = "DELETE FROM students WHERE name = %s;"
                    cur.execute(delete_sql, (name,))
                    conn.commit()
                    if cur.rowcount > 0:
                        print(f"删除学生 {name} 成功!")
                        return True
                    else:
                        print(f"未找到姓名为 {name} 的学生!")
                        return False
        except psycopg.Error as e:
            print(f"删除学生失败:{e}")
            return False

    def list_all_students(self) -> List[Tuple]:
        """查询所有学生信息
        Returns:
            包含所有学生信息的列表
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    cur.execute("SELECT * FROM students;")
                    all_students = cur.fetchall()
                    print(f"共查询到 {len(all_students)} 名学生信息:")
                    for student in all_students:
                        print(f"ID: {student[0]}, 姓名: {student[1]}, 年龄: {student[2]}, 性别: {student[3]}, 分数: {student[4]}")
                    return all_students
        except psycopg.Error as e:
            print(f"查询所有学生失败:{e}")
            return []

# 主程序入口
if __name__ == "__main__":
    # 配置数据库连接参数
    conn_params = {
        "host": "localhost",
        "port": 5432,
        "user": "postgres",
        "password": "123456",
        "dbname": "testdb"
    }

    # 创建学生成绩管理器实例
    manager = StudentScoreManager(conn_params)

    # 执行各项操作
    manager.add_student("周八", 19, "男", 91.5)
    manager.query_student("周八")
    manager.update_score(5, 94.0)  # 假设周八的ID为5
    manager.list_all_students()
    manager.delete_student("周八")
    manager.list_all_students()

代码说明

  • 该案例通过面向对象的方式封装了学生成绩管理的核心功能,每个方法对应一个具体的数据库操作。
  • 每个方法都包含了异常处理,确保程序在遇到数据库错误时不会崩溃,同时输出错误信息便于调试。
  • 主程序入口创建了管理器实例,并依次执行添加、查询、更新、删除和列表查询操作,展示了完整的业务流程。

五、psycopg3 高级特性

5.1 类型适配

psycopg3 能够自动将 Python 数据类型转换为 PostgreSQL 支持的类型,同时也能将 PostgreSQL 类型转换为 Python 类型。例如:

  • Python 的 int → PostgreSQL 的 INTEGER
  • Python 的 float → PostgreSQL 的 NUMERIC
  • Python 的 str → PostgreSQL 的 VARCHAR
  • Python 的 datetime.date → PostgreSQL 的 DATE

如果需要自定义类型适配,可以使用 psycopg.types 模块进行扩展。

5.2 连接池

在高并发应用中,频繁创建和关闭数据库连接会消耗大量资源,psycopg3 推荐使用连接池技术管理数据库连接。可以使用第三方库 psycopg_pool(psycopg 官方提供的连接池库)实现连接池功能。

5.2.1 安装 psycopg_pool

pip install psycopg_pool

5.2.2 连接池使用示例

import psycopg
from psycopg_pool import ConnectionPool

# 创建连接池
conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

pool = ConnectionPool(conninfo=conn_params, min_size=2, max_size=10)

# 从连接池获取连接
with pool.connection() as conn:
    with conn.cursor() as cur:
        cur.execute("SELECT name FROM students LIMIT 2;")
        print(cur.fetchall())

# 关闭连接池
pool.close()

代码说明

  • ConnectionPool 类的 min_size 参数表示连接池的最小连接数,max_size 表示最大连接数。
  • 使用 pool.connection() 从连接池获取连接,使用完毕后自动归还到连接池,无需手动关闭。

六、相关资源链接

  • Pypi地址:https://pypi.org/project/psycopg3
  • Github地址:https://github.com/psycopg/psycopg
  • 官方文档地址:https://www.psycopg.org/psycopg3/docs/

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:aioprometheus异步监控指标采集与上报教程

一、aioprometheus库核心概述

1.1 用途与工作原理

aioprometheus是一款基于Python asyncio框架开发的异步指标采集与上报库,专门用于为异步Python应用程序提供Prometheus监控指标的定义、收集和暴露功能。其工作原理遵循Prometheus的监控规范,通过定义Counter(计数器)、Gauge(仪表盘)、Summary(摘要)、Histogram(直方图)四种核心指标类型,在异步代码中埋点采集数据,再通过HTTP服务暴露指标接口,供Prometheus服务器定期拉取数据,最终实现对异步应用的性能监控与状态分析。

1.2 优缺点分析

优点

  1. 完全异步化设计,与aiohttp等异步框架完美兼容,不会阻塞事件循环,适合高并发异步应用场景;
  2. 原生支持Prometheus的四种核心指标类型,满足绝大多数监控需求;
  3. 提供灵活的指标标签(label)机制,支持多维度指标分析;
  4. 轻量级架构,安装简单,运行时资源消耗低。

缺点

  1. 文档相对精简,对于新手而言部分高级用法需要阅读源码理解;
  2. 生态相较于同步的prometheus_client库稍窄,第三方集成插件较少;
  3. 仅支持异步Python环境,无法直接用于同步应用程序。

1.3 License类型

aioprometheus采用MIT License开源协议,这意味着开发者可以自由地将其用于个人、商业项目,允许修改、分发源码,只需保留原始版权声明即可。

二、aioprometheus安装与环境准备

2.1 安装命令

aioprometheus库已发布至PyPI,可通过pip包管理工具直接安装,建议使用Python 3.7及以上版本(兼容asyncio的全部特性),安装命令如下:

pip install aioprometheus
# 如需使用aiohttp集成功能,可安装扩展依赖
pip install aioprometheus[aiohttp]

2.2 环境验证

安装完成后,可通过以下Python代码验证环境是否配置成功,该代码会创建一个简单的Counter指标并打印,确认库能够正常导入和使用:

import asyncio
from aioprometheus import Counter, Registry

async def verify_environment():
    # 创建指标注册表
    registry = Registry()
    # 定义Counter指标,用于统计请求次数
    http_requests_total = Counter(
        "http_requests_total",
        "Total number of HTTP requests",
        {"method": "GET", "endpoint": "/api"}
    )
    # 将指标注册到注册表
    registry.register(http_requests_total)
    # 增加指标计数
    http_requests_total.inc()
    # 打印指标信息
    print("指标信息:", http_requests_total.samples)

if __name__ == "__main__":
    asyncio.run(verify_environment())

运行上述代码,若控制台输出类似指标信息: [Sample(name='http_requests_total', labels={'method': 'GET', 'endpoint': '/api'}, value=1.0)]的内容,则说明aioprometheus已成功安装并可正常使用。

三、aioprometheus核心指标类型与使用方法

Prometheus的监控体系基于四种核心指标类型,aioprometheus对这四种类型进行了完整的异步封装,下面我们逐一讲解每种指标的定义、使用场景和代码示例。

3.1 Counter(计数器)

适用场景:用于统计只增不减的数值,例如请求次数、错误发生次数、任务完成数量等。Counter的核心操作是inc(),用于将指标值增加指定数值(默认增加1)。

代码示例:统计异步Web服务的GET请求次数

import asyncio
from aioprometheus import Counter, Registry, render
from aiohttp import web

# 1. 创建全局注册表
registry = Registry()

# 2. 定义Counter指标
# 参数说明:指标名、指标帮助信息、默认标签
http_requests_total = Counter(
    "http_requests_total",
    "Total count of HTTP requests by method and endpoint",
    labelnames=["method", "endpoint"]
)

# 3. 将指标注册到注册表
registry.register(http_requests_total)

# 4. 定义异步请求处理函数
async def handle_api_request(request):
    # 获取请求方法和路径
    method = request.method
    endpoint = request.path
    # 增加Counter计数,传入标签值
    http_requests_total.inc({"method": method, "endpoint": endpoint})
    return web.json_response({"status": "success", "data": "Hello, aioprometheus!"})

# 5. 定义指标暴露接口,供Prometheus拉取
async def metrics_handler(request):
    # 生成Prometheus格式的指标数据
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 6. 创建aiohttp应用并配置路由
async def create_app():
    app = web.Application()
    # 业务接口
    app.add_routes([web.get("/api/hello", handle_api_request)])
    # 指标暴露接口
    app.add_routes([web.get("/metrics", metrics_handler)])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8080)

代码说明

  • 首先创建Registry对象,用于管理所有指标,这是aioprometheus的核心管理组件;
  • 定义Counter指标时,通过labelnames参数指定标签维度,支持后续按请求方法和接口路径统计;
  • 在请求处理函数handle_api_request中,调用inc()方法增加计数,并传入当前请求的标签值;
  • 配置/metrics接口,通过render函数将注册表中的指标数据转换为Prometheus可识别的格式;
  • 运行程序后,访问http://localhost:8080/api/hello触发请求计数,再访问http://localhost:8080/metrics即可查看指标数据。

3.2 Gauge(仪表盘)

适用场景:用于统计可增可减的数值,例如当前内存使用量、活跃连接数、队列长度等。Gauge支持inc()(增加)、dec()(减少)、set()(直接设置值)、set_to_current_time()(设置为当前时间戳)等操作。

代码示例:监控异步任务队列的长度

import asyncio
import random
from aioprometheus import Gauge, Registry, render
from aiohttp import web

# 创建注册表和Gauge指标
registry = Registry()
task_queue_length = Gauge(
    "task_queue_length",
    "Current number of tasks in the async queue",
    labelnames=["queue_name"]
)
registry.register(task_queue_length)

# 模拟异步任务队列
task_queue = asyncio.Queue()

# 生产任务:向队列中添加随机数量的任务
async def task_producer():
    while True:
        # 随机生成1-5个任务
        task_count = random.randint(1, 5)
        for _ in range(task_count):
            await task_queue.put(f"task_{asyncio.get_event_loop().time()}")
        # 更新Gauge指标:设置为当前队列长度
        queue_len = task_queue.qsize()
        task_queue_length.set({"queue_name": "user_task_queue"}, queue_len)
        print(f"Added {task_count} tasks, current queue length: {queue_len}")
        await asyncio.sleep(3)

# 消费任务:从队列中取出任务并处理
async def task_consumer():
    while True:
        if not task_queue.empty():
            task = await task_queue.get()
            # 模拟任务处理耗时
            await asyncio.sleep(1)
            print(f"Processed task: {task}")
            task_queue.task_done()
            # 更新Gauge指标:处理完任务后更新队列长度
            queue_len = task_queue.qsize()
            task_queue_length.set({"queue_name": "user_task_queue"}, queue_len)
        else:
            await asyncio.sleep(1)

# 指标暴露接口
async def metrics_handler(request):
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 创建aiohttp应用
async def create_app():
    app = web.Application()
    app.add_routes([web.get("/metrics", metrics_handler)])
    # 启动生产者和消费者协程
    asyncio.create_task(task_producer())
    asyncio.create_task(task_consumer())
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8081)

代码说明

  • Gauge指标task_queue_length用于监控任务队列的实时长度,通过labelnames区分不同队列;
  • task_producer协程负责生产任务,每次添加任务后调用set()方法更新指标值为当前队列长度;
  • task_consumer协程负责消费任务,处理完任务后同样更新指标值;
  • 运行程序后,访问http://localhost:8081/metrics可查看队列长度的实时变化,该指标会随着任务的生产和消费动态增减。

3.3 Summary(摘要)

适用场景:用于统计数值的分布情况,例如请求响应时间的平均值、中位数、95分位数等。Summary通过observe()方法记录数值,自动计算并存储指定分位数的统计结果。

代码示例:统计异步接口的响应时间

import asyncio
import time
from aioprometheus import Summary, Registry, render
from aiohttp import web

# 创建注册表和Summary指标
registry = Registry()
request_duration_seconds = Summary(
    "request_duration_seconds",
    "Summary of request processing duration in seconds",
    labelnames=["method", "endpoint"],
    # 指定需要统计的分位数和误差范围
    quantiles={0.5: 0.05, 0.95: 0.01, 0.99: 0.001}
)
registry.register(request_duration_seconds)

# 定义响应时间统计中间件
async def timing_middleware(app, handler):
    async def middleware_handler(request):
        # 记录请求开始时间
        start_time = time.perf_counter()
        try:
            # 执行原始请求处理函数
            response = await handler(request)
            return response
        finally:
            # 计算请求耗时
            duration = time.perf_counter() - start_time
            # 记录耗时到Summary指标
            request_duration_seconds.observe(
                {"method": request.method, "endpoint": request.path},
                duration
            )
    return middleware_handler

# 模拟耗时的业务接口
async def slow_api_handler(request):
    # 模拟接口处理耗时:0.1-0.5秒
    delay = asyncio.sleep(random.uniform(0.1, 0.5))
    await delay
    return web.json_response({"status": "success", "message": "Slow API response"})

# 普通业务接口
async def fast_api_handler(request):
    return web.json_response({"status": "success", "message": "Fast API response"})

# 指标暴露接口
async def metrics_handler(request):
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 创建aiohttp应用
async def create_app():
    app = web.Application(middlewares=[timing_middleware])
    app.add_routes([
        web.get("/api/slow", slow_api_handler),
        web.get("/api/fast", fast_api_handler),
        web.get("/metrics", metrics_handler)
    ])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8082)

代码说明

  • 定义Summary指标时,通过quantiles参数指定需要统计的分位数,例如0.5代表中位数,0.95代表95分位数,后面的数值为允许的误差范围;
  • 实现timing_middleware中间件,在请求处理前后记录时间,计算耗时并通过observe()方法传入Summary指标;
  • 分别定义慢接口和快接口,模拟不同的响应耗时;
  • 运行程序后,多次访问/api/slow/api/fast接口,再访问/metrics即可查看响应时间的统计数据,包括平均值(_sum_count计算得出)、各分位数的数值。

3.4 Histogram(直方图)

适用场景:与Summary类似,用于统计数值的分布情况,但Histogram会将数值划分到不同的区间(bucket)中,统计每个区间的数值数量,适合用于绘制分布直方图。Histogram通过observe()方法记录数值,自动统计各区间的计数。

代码示例:统计异步任务的执行耗时分布

import asyncio
import random
import time
from aioprometheus import Histogram, Registry, render
from aiohttp import web

# 创建注册表和Histogram指标
registry = Registry()
task_execution_duration_seconds = Histogram(
    "task_execution_duration_seconds",
    "Histogram of task execution duration in seconds",
    labelnames=["task_type"],
    # 定义bucket区间:0.1, 0.2, 0.5, 1.0, +inf
    buckets=[0.1, 0.2, 0.5, 1.0]
)
registry.register(task_execution_duration_seconds)

# 模拟不同类型的异步任务
async def execute_task(task_type):
    # 根据任务类型设置不同的耗时范围
    if task_type == "light":
        duration = random.uniform(0.05, 0.15)
    elif task_type == "medium":
        duration = random.uniform(0.15, 0.3)
    else: # heavy
        duration = random.uniform(0.3, 0.8)
    # 模拟任务执行
    await asyncio.sleep(duration)
    # 记录任务耗时到Histogram指标
    task_execution_duration_seconds.observe({"task_type": task_type}, duration)
    return f"{task_type}_task_completed"

# 任务调度接口:接收任务类型参数并执行任务
async def task_scheduler_handler(request):
    task_type = request.query.get("task_type", "light")
    if task_type not in ["light", "medium", "heavy"]:
        return web.json_response({"error": "Invalid task type"}, status=400)
    result = await execute_task(task_type)
    return web.json_response({"status": "success", "result": result})

# 指标暴露接口
async def metrics_handler(request):
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 创建aiohttp应用
async def create_app():
    app = web.Application()
    app.add_routes([
        web.get("/api/task", task_scheduler_handler),
        web.get("/metrics", metrics_handler)
    ])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8083)

代码说明

  • 定义Histogram指标时,通过buckets参数指定区间边界,指标会自动统计落入每个区间的数值数量,最后一个区间默认为+inf
  • execute_task函数模拟不同类型任务的执行耗时,根据任务类型设置不同的耗时范围,并通过observe()方法记录耗时;
  • 访问/api/task?task_type=light等接口触发任务执行,多次调用后访问/metrics,可查看各任务类型的耗时分布情况,包括每个bucket的计数、总和(_sum)和总次数(_count)。

四、aioprometheus高级用法:自定义指标标签与多注册表管理

4.1 动态标签与标签值替换

在实际应用中,指标标签的值往往需要动态获取,例如根据用户ID、请求IP等信息区分指标维度。aioprometheus支持在运行时动态传入标签值,下面以用户登录次数统计为例展示动态标签的用法:

import asyncio
from aioprometheus import Counter, Registry, render
from aiohttp import web

registry = Registry()
user_login_total = Counter(
    "user_login_total",
    "Total number of user logins by user type and platform",
    labelnames=["user_type", "platform"]
)
registry.register(user_login_total)

# 模拟用户登录接口,接收用户类型和平台参数
async def login_handler(request):
    data = await request.json()
    user_type = data.get("user_type", "guest") # 可选值:guest, member, admin
    platform = data.get("platform", "web") # 可选值:web, ios, android
    # 动态传入标签值
    user_login_total.inc({"user_type": user_type, "platform": platform})
    return web.json_response({"status": "success", "message": "Login successful"})

async def metrics_handler(request):
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

async def create_app():
    app = web.Application()
    app.add_routes([
        web.post("/api/login", login_handler),
        web.get("/metrics", metrics_handler)
    ])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8084)

代码说明

  • 指标user_login_total定义了user_typeplatform两个标签,用于区分不同类型用户和不同登录平台的登录次数;
  • login_handler接口中,从请求参数中动态获取user_typeplatform的值,并传入inc()方法;
  • 通过Postman等工具发送POST请求到/api/login,请求体携带{"user_type": "member", "platform": "ios"}等参数,即可按标签维度统计登录次数。

4.2 多注册表管理

在大型应用中,不同模块可能需要独立的指标管理,此时可以使用多注册表机制,将不同模块的指标注册到不同的Registry对象中,再统一暴露或分别暴露指标接口。

import asyncio
from aioprometheus import Counter, Registry, render
from aiohttp import web

# 为用户模块创建注册表
user_registry = Registry()
user_register_total = Counter(
    "user_register_total",
    "Total number of user registrations",
    labelnames=["channel"]
)
user_registry.register(user_register_total)

# 为订单模块创建注册表
order_registry = Registry()
order_create_total = Counter(
    "order_create_total",
    "Total number of order creations",
    labelnames=["order_type"]
)
order_registry.register(order_create_total)

# 用户注册接口
async def user_register_handler(request):
    data = await request.json()
    channel = data.get("channel", "direct")
    user_register_total.inc({"channel": channel})
    return web.json_response({"status": "success", "message": "User registered"})

# 订单创建接口
async def order_create_handler(request):
    data = await request.json()
    order_type = data.get("order_type", "normal")
    order_create_total.inc({"order_type": order_type})
    return web.json_response({"status": "success", "message": "Order created"})

# 暴露用户模块指标
async def user_metrics_handler(request):
    content, http_headers = render(user_registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 暴露订单模块指标
async def order_metrics_handler(request):
    content, http_headers = render(order_registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 暴露所有模块指标(合并注册表)
async def all_metrics_handler(request):
    # 合并多个注册表的指标
    combined_registry = Registry()
    for metric in user_registry.collect():
        combined_registry.register(metric)
    for metric in order_registry.collect():
        combined_registry.register(metric)
    content, http_headers = render(combined_registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

async def create_app():
    app = web.Application()
    app.add_routes([
        web.post("/api/user/register", user_register_handler),
        web.post("/api/order/create", order_create_handler),
        web.get("/metrics/user", user_metrics_handler),
        web.get("/metrics/order", order_metrics_handler),
        web.get("/metrics/all", all_metrics_handler)
    ])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8085)

代码说明

  • 分别为用户模块和订单模块创建独立的Registry对象,注册各自的指标;
  • 提供/metrics/user/metrics/order接口,分别暴露对应模块的指标;
  • 实现/metrics/all接口,通过遍历两个注册表的指标并注册到新的Registry对象中,实现指标合并暴露,满足Prometheus一次性拉取所有指标的需求。

五、aioprometheus与Prometheus服务器集成实战

5.1 Prometheus服务器配置

要实现对aioprometheus暴露指标的监控,需要配置Prometheus服务器定期拉取指标数据。首先下载并安装Prometheus(下载地址:https://prometheus.io/download/),然后修改prometheus.yml配置文件:

global:
  scrape_interval: 15s # 每15秒拉取一次指标

scrape_configs:
  - job_name: 'aioprometheus_demo'
    static_configs:
      - targets: ['localhost:8080', 'localhost:8081', 'localhost:8082', 'localhost:8083', 'localhost:8084', 'localhost:8085']
    metrics_path: '/metrics'

配置说明

  • scrape_interval设置为15秒,表示Prometheus每15秒拉取一次指标;
  • job_name为任务名称,可自定义;
  • targets指定需要拉取指标的服务地址列表,即我们之前运行的各个aioprometheus示例服务;
  • metrics_path指定指标接口路径,默认为/metrics

5.2 启动Prometheus并查看指标

启动Prometheus服务器:

./prometheus --config.file=prometheus.yml

访问http://localhost:9090进入Prometheus Web界面,在查询框中输入指标名(如http_requests_totaltask_queue_length等),即可查看指标的实时数据和变化趋势。

5.3 可视化监控面板(Grafana集成)

为了更直观地展示指标数据,可将Prometheus作为数据源集成到Grafana中:

  1. 安装并启动Grafana(下载地址:https://grafana.com/grafana/download);
  2. 访问http://localhost:3000,使用默认账号密码(admin/admin)登录;
  3. 进入Configuration > Data Sources,添加Prometheus数据源,设置URL为http://localhost:9090
  4. 进入Dashboards > Import,导入官方或自定义的Dashboard模板,即可实现指标的可视化监控。

六、相关资源链接

  • Pypi地址:https://pypi.org/project/aioprometheus
  • Github地址:https://github.com/claws/aioprometheus
  • 官方文档地址:https://aioprometheus.readthedocs.io/en/latest/

关注我,每天分享一个实用的Python自动化工具。

Python文件管理神器:filedepot库从入门到精通

一、filedepot库核心概述

filedepot是一款轻量级的Python文件存储抽象层库,核心用途是为开发者提供统一的文件操作接口,支持本地文件系统、AWS S3、Google Cloud Storage等多种存储后端,无需修改业务代码即可切换存储方式。其工作原理是通过定义FileStorage抽象基类,不同存储后端实现该类的方法,实现接口与实现的解耦。

该库的优点是轻量无冗余依赖、接口简洁易用、扩展性强;缺点是高级功能(如断点续传)需自行实现,部分云存储后端的支持需额外安装依赖。License类型为MIT License,允许自由使用、修改和分发,适合商业和开源项目。

二、filedepot库安装与环境配置

2.1 基础安装命令

filedepot库已发布至PyPI,可直接使用pip包管理器安装,命令如下:

pip install filedepot

该命令会自动安装filedepot的核心依赖,满足本地文件存储的基本使用需求。

2.2 云存储后端依赖安装

如果需要使用AWS S3、Google Cloud Storage等云存储后端,需要安装对应的依赖包,具体命令如下:

  • AWS S3后端依赖安装
pip install filedepot[s3]
  • Google Cloud Storage后端依赖安装
pip install filedepot[gcs]

安装完成后,可通过pip show filedepot命令验证安装是否成功,查看库的版本和安装路径。

三、filedepot核心API与基础使用

3.1 核心类与接口介绍

filedepot的核心是FileStorage抽象基类,该类定义了文件存储的通用操作接口,所有存储后端都需要实现以下核心方法:
| 方法名 | 功能描述 |
|–|-|
| save(file_obj, filename=None, **kwargs) | 保存文件对象到存储后端,返回文件唯一标识 |
| open(file_id, mode='rb') | 根据文件ID打开文件,返回文件对象 |
| delete(file_id) | 根据文件ID删除文件 |
| exists(file_id) | 判断指定ID的文件是否存在 |

在实际使用中,我们不需要直接实例化FileStorage类,而是使用其具体的实现类,例如本地存储的LocalFileStorage

3.2 本地文件存储基础操作

3.2.1 初始化本地存储

首先导入LocalFileStorage类,指定存储目录并初始化存储对象。代码示例如下:

from depot.io.local import LocalFileStorage

# 初始化本地文件存储,指定存储目录为./my_files
storage = LocalFileStorage('./my_files')

执行上述代码后,filedepot会自动创建./my_files目录(如果不存在),后续所有文件都会保存在该目录下。

3.2.2 保存文件到本地存储

保存文件的方式有两种:一种是保存已有的文件对象,另一种是保存字节数据。下面分别展示两种方式的代码示例。

方式一:保存文件对象

# 打开本地的test.txt文件,以二进制模式读取
with open('test.txt', 'rb') as f:
    # 保存文件到存储后端,filename指定保存的文件名
    file_id = storage.save(f, filename='test_save.txt')

# 打印文件ID,该ID是文件的唯一标识
print(f"文件保存成功,文件ID: {file_id}")

代码说明:storage.save()方法接收文件对象作为参数,filename参数可选,如果不指定,filedepot会自动生成一个随机的文件名。执行后会在./my_files目录下生成test_save.txt文件,并返回一个唯一的文件ID。

方式二:保存字节数据

from io import BytesIO

# 创建字节流对象,写入测试数据
data = b"Hello, filedepot! This is a test data."
file_obj = BytesIO(data)

# 保存字节流数据到存储后端
file_id = storage.save(file_obj, filename='byte_data.txt')
print(f"字节数据保存成功,文件ID: {file_id}")

代码说明:通过BytesIO将字节数据转换为文件对象,再传递给save()方法,实现无需创建本地临时文件即可保存数据的需求。

3.2.3 读取存储的文件

通过文件ID可以读取已保存的文件内容,代码示例如下:

# 根据文件ID打开文件
with storage.open(file_id) as f:
    content = f.read()
    print(f"文件内容: {content.decode('utf-8')}")

代码说明:storage.open()方法返回一个文件对象,使用read()方法可以读取文件内容,由于读取的是二进制数据,需要使用decode('utf-8')转换为字符串。

3.2.4 判断文件是否存在

使用exists()方法可以判断指定ID的文件是否存在,代码示例如下:

# 检查文件是否存在
if storage.exists(file_id):
    print(f"文件ID {file_id} 对应的文件存在")
else:
    print(f"文件ID {file_id} 对应的文件不存在")

3.2.5 删除存储的文件

使用delete()方法可以删除指定ID的文件,代码示例如下:

# 删除文件
storage.delete(file_id)
print(f"文件ID {file_id} 对应的文件已删除")

# 再次检查文件是否存在
if not storage.exists(file_id):
    print(f"文件删除成功")

3.3 AWS S3云存储操作

3.3.1 初始化S3存储

要使用AWS S3后端,需要先配置AWS的访问密钥和存储桶信息。代码示例如下:

from depot.io.awss3 import S3Storage

# 初始化S3存储
s3_storage = S3Storage(
    access_key_id='YOUR_AWS_ACCESS_KEY',
    secret_access_key='YOUR_AWS_SECRET_KEY',
    bucket='your-bucket-name',
    region_name='us-east-1'
)

代码说明:access_key_idsecret_access_key是AWS账号的访问密钥,bucket是S3存储桶名称,region_name是存储桶所在的区域。需要注意的是,必须确保AWS账号拥有该存储桶的读写权限。

3.3.2 S3存储的文件操作

S3存储的文件操作接口与本地存储完全一致,无需修改业务逻辑即可实现存储后端的切换。代码示例如下:

# 保存文件到S3
with open('test_s3.txt', 'rb') as f:
    s3_file_id = s3_storage.save(f, filename='s3_test_file.txt')

print(f"S3文件保存成功,文件ID: {s3_file_id}")

# 读取S3中的文件
with s3_storage.open(s3_file_id) as f:
    s3_content = f.read()
    print(f"S3文件内容: {s3_content.decode('utf-8')}")

# 删除S3中的文件
s3_storage.delete(s3_file_id)
print(f"S3文件ID {s3_file_id} 对应的文件已删除")

代码说明:无论是本地存储还是S3存储,都使用相同的save()open()delete()方法,体现了filedepot接口统一的优势。

四、filedepot高级功能与实战案例

4.1 文件元数据管理

filedepot在保存文件时,会自动记录文件的元数据信息,例如文件名、文件大小、上传时间等。可以通过get_file_metadata()方法获取元数据,代码示例如下:

# 保存文件并获取元数据
with open('test_metadata.txt', 'rb') as f:
    meta_file_id = storage.save(f, filename='test_metadata.txt')

# 获取文件元数据
metadata = storage.get_file_metadata(meta_file_id)
print("文件元数据信息:")
for key, value in metadata.items():
    print(f"{key}: {value}")

执行上述代码后,会输出类似以下的元数据信息:

文件元数据信息:
filename: test_metadata.txt
content_type: text/plain
content_length: 24
last_modified: 2026-01-08T12:00:00Z

代码说明:元数据中包含了文件的名称、类型、大小和最后修改时间等信息,方便开发者进行文件管理和统计。

4.2 文件存储异常处理

在实际开发中,文件操作可能会出现各种异常,例如权限不足、存储目录不存在、网络故障等。filedepot提供了统一的异常类,方便开发者进行异常捕获和处理。代码示例如下:

from depot.exceptions import FileStorageError, FileNotFoundError

try:
    # 尝试打开不存在的文件ID
    with storage.open('invalid_file_id') as f:
        content = f.read()
except FileNotFoundError as e:
    print(f"错误: 指定的文件不存在 - {e}")
except FileStorageError as e:
    print(f"文件存储错误: {e}")
except Exception as e:
    print(f"其他错误: {e}")

代码说明:FileNotFoundError用于捕获文件不存在的异常,FileStorageError是filedepot的基础异常类,用于捕获所有存储相关的异常。合理的异常处理可以提高程序的健壮性。

4.3 实战案例:多后端文件存储切换工具

本案例将实现一个简单的文件存储工具,支持在本地存储和S3存储之间无缝切换,无需修改核心业务代码。

4.3.1 工具类实现

from depot.io.local import LocalFileStorage
from depot.io.awss3 import S3Storage
from depot.exceptions import FileStorageError

class MultiBackendFileManager:
    def __init__(self, backend_type='local', **kwargs):
        """
        初始化多后端文件管理器
        :param backend_type: 存储后端类型,可选 'local' 或 's3'
        :param kwargs: 存储后端的配置参数
        """
        self.backend_type = backend_type
        self.storage = self._init_storage(** kwargs)

    def _init_storage(self, **kwargs):
        """初始化存储后端"""
        if self.backend_type == 'local':
            # 本地存储需要传入存储目录参数
            return LocalFileStorage(kwargs.get('storage_dir', './default_files'))
        elif self.backend_type == 's3':
            # S3存储需要传入AWS配置参数
            return S3Storage(
                access_key_id=kwargs.get('access_key_id'),
                secret_access_key=kwargs.get('secret_access_key'),
                bucket=kwargs.get('bucket'),
                region_name=kwargs.get('region_name', 'us-east-1')
            )
        else:
            raise ValueError(f"不支持的存储后端类型: {self.backend_type}")

    def save_file(self, file_obj, filename=None):
        """保存文件"""
        try:
            file_id = self.storage.save(file_obj, filename=filename)
            return file_id
        except FileStorageError as e:
            print(f"保存文件失败: {e}")
            return None

    def read_file(self, file_id):
        """读取文件"""
        try:
            with self.storage.open(file_id) as f:
                return f.read()
        except FileNotFoundError:
            print(f"文件ID {file_id} 不存在")
            return None
        except FileStorageError as e:
            print(f"读取文件失败: {e}")
            return None

    def delete_file(self, file_id):
        """删除文件"""
        try:
            self.storage.delete(file_id)
            return True
        except FileNotFoundError:
            print(f"文件ID {file_id} 不存在")
            return False
        except FileStorageError as e:
            print(f"删除文件失败: {e}")
            return False

4.3.2 工具类使用示例

示例1:使用本地存储后端

# 初始化本地存储管理器
local_manager = MultiBackendFileManager(
    backend_type='local',
    storage_dir='./my_local_files'
)

# 保存文件
with open('example.txt', 'w', encoding='utf-8') as f:
    f.write("这是一个多后端文件存储工具的测试文件")

with open('example.txt', 'rb') as f:
    file_id = local_manager.save_file(f, filename='local_example.txt')

if file_id:
    print(f"本地存储文件ID: {file_id}")
    # 读取文件
    content = local_manager.read_file(file_id)
    if content:
        print(f"本地文件内容: {content.decode('utf-8')}")
    # 删除文件
    if local_manager.delete_file(file_id):
        print("本地文件删除成功")

示例2:使用S3存储后端

# 初始化S3存储管理器
s3_manager = MultiBackendFileManager(
    backend_type='s3',
    access_key_id='YOUR_AWS_ACCESS_KEY',
    secret_access_key='YOUR_AWS_SECRET_KEY',
    bucket='your-bucket-name',
    region_name='us-east-1'
)

# 保存字节数据到S3
from io import BytesIO
data = b"Hello, S3 Storage! From MultiBackendFileManager."
file_obj = BytesIO(data)
s3_file_id = s3_manager.save_file(file_obj, filename='s3_example.txt')

if s3_file_id:
    print(f"S3存储文件ID: {s3_file_id}")
    # 读取文件
    s3_content = s3_manager.read_file(s3_file_id)
    if s3_content:
        print(f"S3文件内容: {s3_content.decode('utf-8')}")
    # 删除文件
    if s3_manager.delete_file(s3_file_id):
        print("S3文件删除成功")

代码说明:该工具类通过封装filedepot的不同存储后端,实现了统一的文件操作接口。开发者只需修改backend_type和对应的配置参数,即可在本地存储和S3存储之间切换,无需修改保存、读取、删除等核心业务逻辑。

4.4 自定义存储后端扩展

filedepot的扩展性很强,如果需要支持其他存储后端(如FTP、阿里云OSS等),可以通过继承FileStorage抽象基类并实现其方法来完成。下面以自定义一个简单的内存存储后端为例,展示扩展方法。

4.4.1 自定义内存存储后端实现

from depot.io.interfaces import FileStorage, FileStorageError
from depot.io.utils import FileIntent
from io import BytesIO
import uuid

class MemoryFileStorage(FileStorage):
    def __init__(self):
        # 使用字典存储文件,key为文件ID,value为文件内容和元数据
        self._files = {}

    def save(self, file_obj, filename=None, content_type=None, **kwargs):
        # 生成唯一的文件ID
        file_id = str(uuid.uuid4())
        # 读取文件内容
        file_content = file_obj.read()
        # 处理文件名
        if filename is None:
            filename = f"memory_file_{file_id}"
        # 处理内容类型
        if content_type is None:
            content_type = "application/octet-stream"
        # 存储文件内容和元数据
        self._files[file_id] = {
            'content': file_content,
            'filename': filename,
            'content_type': content_type,
            'content_length': len(file_content)
        }
        return file_id

    def open(self, file_id, mode='rb'):
        if mode != 'rb':
            raise FileStorageError("内存存储仅支持二进制读取模式")
        if file_id not in self._files:
            raise FileNotFoundError(f"文件ID {file_id} 不存在")
        # 返回字节流对象
        return BytesIO(self._files[file_id]['content'])

    def delete(self, file_id):
        if file_id not in self._files:
            raise FileNotFoundError(f"文件ID {file_id} 不存在")
        del self._files[file_id]

    def exists(self, file_id):
        return file_id in self._files

    def get_file_metadata(self, file_id):
        if file_id not in self._files:
            raise FileNotFoundError(f"文件ID {file_id} 不存在")
        return self._files[file_id].copy()

4.4.2 自定义内存存储后端使用示例

# 初始化内存存储
memory_storage = MemoryFileStorage()

# 保存文件
data = b"这是自定义内存存储的测试数据"
file_obj = BytesIO(data)
file_id = memory_storage.save(file_obj, filename='memory_test.txt')

print(f"内存存储文件ID: {file_id}")

# 读取文件
with memory_storage.open(file_id) as f:
    content = f.read()
    print(f"内存文件内容: {content.decode('utf-8')}")

# 获取元数据
metadata = memory_storage.get_file_metadata(file_id)
print(f"文件元数据: {metadata}")

# 删除文件
memory_storage.delete(file_id)
if not memory_storage.exists(file_id):
    print("内存文件删除成功")

代码说明:通过继承FileStorage抽象基类,实现save()open()delete()exists()等核心方法,即可自定义存储后端。该内存存储后端将文件内容保存在内存字典中,适用于临时文件存储的场景。

五、filedepot库常见问题与解决方案

5.1 问题1:保存大文件时内存占用过高

问题描述:当使用save()方法保存大文件时,会一次性读取文件内容到内存,导致内存占用过高。
解决方案:使用文件流分块读取的方式,避免一次性加载整个文件到内存。代码示例如下:

def save_large_file(storage, file_path, chunk_size=1024*1024):
    """
    分块保存大文件
    :param storage: 文件存储对象
    :param file_path: 大文件路径
    :param chunk_size: 分块大小,默认1MB
    :return: 文件ID
    """
    with open(file_path, 'rb') as f:
        # 创建临时文件对象
        temp_file = BytesIO()
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            temp_file.write(chunk)
        # 将文件指针移到开头
        temp_file.seek(0)
        return storage.save(temp_file, filename=file_path.split('/')[-1])

# 使用示例
large_file_id = save_large_file(storage, 'large_file.zip')
print(f"大文件保存成功,文件ID: {large_file_id}")

5.2 问题2:云存储后端连接超时

问题描述:使用S3等云存储后端时,经常出现连接超时的情况。
解决方案:增加连接超时参数,设置重试机制。以S3存储为例,代码示例如下:

from botocore.config import Config

# 配置S3连接参数,设置超时和重试次数
config = Config(
    connect_timeout=30,
    read_timeout=30,
    retries={'max_attempts': 3}
)

# 初始化S3存储,传入配置参数
s3_storage = S3Storage(
    access_key_id='YOUR_AWS_ACCESS_KEY',
    secret_access_key='YOUR_AWS_SECRET_KEY',
    bucket='your-bucket-name',
    region_name='us-east-1',
    config=config
)

代码说明:通过botocore.config.Config设置连接超时、读取超时和重试次数,提高云存储连接的稳定性。

5.3 问题3:文件ID管理困难

问题描述:filedepot自动生成的文件ID是随机字符串,不利于业务系统中的文件管理。
解决方案:自定义文件ID生成规则,例如使用业务标识+时间戳的方式。代码示例如下:

import time

def custom_file_id_generator(prefix='file_'):
    """自定义文件ID生成器"""
    timestamp = int(time.time() * 1000)
    return f"{prefix}{timestamp}"

# 使用自定义文件ID保存文件
with open('test.txt', 'rb') as f:
    # 生成自定义文件ID
    custom_file_id = custom_file_id_generator(prefix='myapp_')
    # 使用_file_id参数指定自定义文件ID
    storage.save(f, filename='test.txt', _file_id=custom_file_id)

print(f"自定义文件ID: {custom_file_id}")

代码说明:通过save()方法的_file_id参数,可以指定自定义的文件ID,方便业务系统根据文件ID关联业务数据。

六、filedepot库相关资源

  • PyPI地址:https://pypi.org/project/filedepot
  • Github地址:https://github.com/xxxxx/xxxxxx
  • 官方文档地址:https://www.xxxxx.com/xxxxxx

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:ODMantic入门到精通——异步MongoDB数据建模与操作指南

一、ODMantic库核心概述

ODMantic是一款专为Python异步生态设计的对象文档映射器(ODM),主要用于简化MongoDB数据库的操作流程。其核心工作原理是将Python类与MongoDB集合进行映射,类的实例对应集合中的文档,开发者无需编写原生MongoDB查询语句,通过操作Python对象即可完成数据的增删改查。

该库的优点十分突出:完全支持异步操作,可无缝对接asyncioFastAPI等异步框架;API设计简洁直观,与Python开发者熟悉的SQLAlchemy等ORM工具风格相近,学习成本低;内置数据校验功能,基于pydantic实现字段类型和约束的校验,保障数据一致性。缺点则是生态相较于老牌ODM工具mongoengine更小众,部分高级查询功能的支持度有待提升;仅适用于MongoDB,通用性较弱。

ODMantic的开源协议为MIT License,这意味着开发者可以自由地用于商业和非商业项目,修改和分发源码也不受过多限制。

二、ODMantic环境安装与配置

2.1 安装前提条件

在安装ODMantic之前,需要确保本地环境满足以下要求:

  1. Python版本≥3.7(推荐3.9及以上版本,兼容性更好);
  2. 已安装并运行MongoDB数据库(本地或远程实例均可,推荐版本≥4.0);
  3. 网络环境正常,能够通过pip下载相关依赖包。

2.2 安装命令

ODMantic的安装非常简单,直接使用pip工具在命令行中执行以下命令即可:

pip install odmantic

该命令会自动安装ODMantic及其核心依赖,包括pydantic(数据校验)、motor(异步MongoDB驱动)等,无需手动单独安装。

2.3 基础配置:连接MongoDB

使用ODMantic的第一步是建立与MongoDB数据库的异步连接。我们需要借助odmantic提供的AIOEngine类,该类是与数据库交互的核心入口,负责管理连接和执行操作。

以下是基础的连接代码示例:

from odmantic import AIOEngine
import asyncio
from motor.motor_asyncio import AsyncIOMotorClient

# 定义异步连接函数
async def connect_to_mongodb():
    # 1. 创建MongoDB异步客户端
    # 本地数据库地址:mongodb://localhost:27017/
    # 若为远程数据库,替换为对应的连接字符串,例如包含用户名和密码的地址:
    # mongodb://username:password@remote_host:27017/
    client = AsyncIOMotorClient("mongodb://localhost:27017/")
    # 2. 初始化AIOEngine,指定要操作的数据库名称
    # 这里指定数据库名为"odmantic_demo",若不存在会自动创建
    engine = AIOEngine(client=client, database="odmantic_demo")
    print("成功连接到MongoDB数据库!")
    return engine

# 运行异步函数
if __name__ == "__main__":
    engine = asyncio.run(connect_to_mongodb())

代码说明

  • AsyncIOMotorClientmotor库提供的异步MongoDB客户端,负责与数据库建立TCP连接;
  • AIOEngine是ODMantic的核心引擎,接收客户端实例和数据库名称作为参数,后续所有的数据操作都需要通过该引擎对象完成;
  • asyncio.run()用于运行异步函数,在实际的异步项目(如FastAPI)中,可直接通过await调用connect_to_mongodb()

三、ODMantic核心功能与代码示例

3.1 定义数据模型(Model)

ODMantic的数据模型基于Python类实现,继承自odmantic.Model,类中的字段对应MongoDB文档的键。字段类型通过pydantic的类型注解指定,同时支持设置默认值、必填约束等属性。

3.1.1 基础模型定义

我们以一个“用户(User)”模型为例,演示如何定义基础的数据模型:

from odmantic import Model
from typing import Optional
from datetime import datetime

class User(Model):
    # 字段1:用户名,字符串类型,必填
    username: str
    # 字段2:年龄,整数类型,必填
    age: int
    # 字段3:邮箱,可选字符串类型,默认值为None
    email: Optional[str] = None
    # 字段4:注册时间,datetime类型,默认值为当前时间
    register_time: datetime = datetime.now()

    # 可选配置:指定对应的MongoDB集合名称
    # 若不指定,默认集合名为类名的小写复数形式(此处为"users")
    class Config:
        collection = "user_collection"

代码说明

  • 模型类必须继承odmantic.Model,这是ODMantic识别模型的标志;
  • 字段的类型注解支持Python原生类型(strintdatetime等)和typing模块中的类型(Optional表示可选字段);
  • 通过Config类可以自定义模型的配置,例如collection参数指定模型对应的MongoDB集合名称,若不指定,ODMantic会自动将类名转为小写复数形式作为集合名;
  • 字段可以设置默认值,如register_time默认值为当前时间,email默认值为None

3.1.2 模型字段的常用约束

基于pydantic的特性,ODMantic支持为字段添加各种约束条件,例如字符串长度、数值范围等,确保存入数据库的数据符合预期。以下是带约束的模型示例:

from odmantic import Model, Field
from typing import Optional
from datetime import datetime

class UserWithConstraints(Model):
    # 用户名:字符串类型,长度在3-20之间,必填
    username: str = Field(min_length=3, max_length=20)
    # 年龄:整数类型,范围在0-120之间,必填
    age: int = Field(ge=0, le=120)
    # 邮箱:可选字符串类型,必须符合邮箱格式
    email: Optional[str] = Field(None, regex=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$")
    # 注册时间:默认当前时间
    register_time: datetime = datetime.now()

    class Config:
        collection = "constraint_users"

代码说明

  • Field类用于为字段添加额外约束,参数min_lengthmax_length限制字符串长度,ge(大于等于)和le(小于等于)限制数值范围;
  • regex参数用于指定字符串的正则表达式验证规则,这里用于校验邮箱格式的合法性;
  • 当创建模型实例时,如果字段值不符合约束条件,会直接抛出ValidationError异常,避免非法数据存入数据库。

3.2 数据的增删改查(CRUD)操作

CRUD是数据库操作的核心,ODMantic通过简洁的API实现异步的增删改查功能,所有操作都需要通过之前初始化的AIOEngine对象完成。

在进行后续操作前,我们先统一初始化引擎,并定义一个通用的异步运行函数,方便执行异步代码:

from odmantic import AIOEngine, Model, Field
from typing import Optional, List
from datetime import datetime
import asyncio
from motor.motor_asyncio import AsyncIOMotorClient

# 初始化引擎
async def get_engine():
    client = AsyncIOMotorClient("mongodb://localhost:27017/")
    return AIOEngine(client=client, database="odmantic_demo")

# 定义User模型
class User(Model):
    username: str = Field(min_length=3, max_length=20)
    age: int = Field(ge=0, le=120)
    email: Optional[str] = None
    register_time: datetime = datetime.now()

    class Config:
        collection = "users"

3.2.1 新增数据(Create)

新增数据即向MongoDB集合中插入文档,使用engine.save()方法,该方法接收一个模型实例作为参数,异步将数据插入数据库,并返回插入后的实例(包含自动生成的id字段)。

代码示例

# 新增单个用户
async def create_single_user(engine: AIOEngine):
    # 创建User模型实例
    user = User(username="zhangsan", age=25, email="[email protected]")
    # 保存到数据库
    saved_user = await engine.save(user)
    print("新增用户成功:")
    print(f"用户ID: {saved_user.id}")
    print(f"用户名: {saved_user.username}")
    print(f"年龄: {saved_user.age}")
    print(f"邮箱: {saved_user.email}")
    return saved_user

# 新增多个用户
async def create_multiple_users(engine: AIOEngine):
    # 创建多个用户实例
    user1 = User(username="lisi", age=30)
    user2 = User(username="wangwu", age=22, email="[email protected]")
    # 批量保存
    saved_users = await engine.save_all([user1, user2])
    print("\n批量新增用户成功,共新增{}个用户:".format(len(saved_users)))
    for u in saved_users:
        print(f"ID: {u.id}, 用户名: {u.username}")

# 执行新增操作
async def run_create_demo():
    engine = await get_engine()
    await create_single_user(engine)
    await create_multiple_users(engine)

if __name__ == "__main__":
    asyncio.run(run_create_demo())

代码说明

  • engine.save()用于插入单个文档,插入后会自动为实例添加id属性(对应MongoDB文档的_id字段,类型为ObjectId);
  • engine.save_all()用于批量插入多个文档,接收一个模型实例列表作为参数,返回插入后的实例列表;
  • 插入操作是异步的,必须使用await关键字调用。

3.2.2 查询数据(Read)

ODMantic提供了多种查询方式,包括查询单个文档、查询多个文档、条件查询、排序、分页等,满足不同的查询需求。

(1)查询单个文档

使用engine.find_one()方法,可根据条件查询单个文档,若未找到则返回None

代码示例

# 根据用户名查询单个用户
async def find_user_by_username(engine: AIOEngine, username: str):
    # 构造查询条件:User.username == username
    user = await engine.find_one(User, User.username == username)
    if user:
        print(f"\n查询到用户:")
        print(f"ID: {user.id}, 用户名: {user.username}, 年龄: {user.age}, 邮箱: {user.email}")
    else:
        print(f"\n未找到用户名为{username}的用户")
    return user

# 执行查询操作
async def run_find_one_demo():
    engine = await get_engine()
    await find_user_by_username(engine, "zhangsan")
    await find_user_by_username(engine, "zhaoliu")

if __name__ == "__main__":
    asyncio.run(run_find_one_demo())

代码说明

  • engine.find_one()的第一个参数是模型类,第二个参数是查询条件,条件的写法为模型类.字段名 == 目标值
  • 若存在多个符合条件的文档,该方法只会返回第一个匹配的文档。
(2)查询多个文档

使用engine.find()方法,可查询符合条件的所有文档,返回一个异步迭代器,可通过async for遍历,或通过list()转换为列表。

代码示例

# 查询所有用户
async def find_all_users(engine: AIOEngine):
    users = await engine.find(User)
    print("\n所有用户列表:")
    for user in users:
        print(f"ID: {user.id}, 用户名: {user.username}, 年龄: {user.age}")

# 条件查询:年龄大于23的用户
async def find_users_by_age(engine: AIOEngine, min_age: int):
    users = await engine.find(User, User.age > min_age)
    print(f"\n年龄大于{min_age}的用户列表:")
    for user in users:
        print(f"ID: {user.id}, 用户名: {user.username}, 年龄: {user.age}")

# 执行多文档查询
async def run_find_demo():
    engine = await get_engine()
    await find_all_users(engine)
    await find_users_by_age(engine, 23)

if __name__ == "__main__":
    asyncio.run(run_find_demo())

代码说明

  • engine.find()的参数与find_one()一致,第一个参数为模型类,第二个参数为可选的查询条件;
  • 支持的查询运算符包括:==(等于)、!=(不等于)、>(大于)、<(小于)、>=(大于等于)、<=(小于等于)、in_(包含在列表中)等,例如User.username.in_(["zhangsan", "lisi"])
(3)排序与分页

在查询时,可通过sort()方法对结果进行排序,通过skip()limit()方法实现分页功能。

代码示例

# 排序查询:按年龄降序排列
async def find_users_sorted(engine: AIOEngine):
    users = await engine.find(User).sort(User.age, -1)
    print("\n按年龄降序排列的用户列表:")
    for user in users:
        print(f"用户名: {user.username}, 年龄: {user.age}")

# 分页查询:每页2条,查询第2页
async def find_users_paginated(engine: AIOEngine, page: int, page_size: int):
    skip_count = (page - 1) * page_size
    users = await engine.find(User).skip(skip_count).limit(page_size)
    print(f"\n第{page}页用户列表(每页{page_size}条):")
    for user in users:
        print(f"用户名: {user.username}, 年龄: {user.age}")

# 执行排序和分页查询
async def run_sort_paginate_demo():
    engine = await get_engine()
    await find_users_sorted(engine)
    await find_users_paginated(engine, page=2, page_size=2)

if __name__ == "__main__":
    asyncio.run(run_sort_paginate_demo())

代码说明

  • sort()方法的第一个参数是排序字段,第二个参数为1(升序)或-1(降序);
  • skip(n)表示跳过前n条数据,limit(m)表示最多返回m条数据,两者结合即可实现分页。

3.2.3 更新数据(Update)

更新数据有两种方式:一种是先查询出模型实例,修改实例的字段值后调用engine.save()方法;另一种是使用engine.update()方法直接执行更新操作。

(1)基于实例的更新

代码示例

# 更新用户信息:修改邮箱和年龄
async def update_user_by_instance(engine: AIOEngine, username: str):
    # 1. 查询用户
    user = await engine.find_one(User, User.username == username)
    if not user:
        print(f"未找到用户{username},更新失败")
        return
    # 2. 修改字段值
    user.age = 26
    user.email = "[email protected]"
    # 3. 保存更新
    updated_user = await engine.save(user)
    print(f"\n用户{username}更新成功:")
    print(f"年龄: {updated_user.age}, 邮箱: {updated_user.email}")

# 执行更新操作
async def run_update_instance_demo():
    engine = await get_engine()
    await update_user_by_instance(engine, "zhangsan")

if __name__ == "__main__":
    asyncio.run(run_update_instance_demo())

代码说明

  • 基于实例的更新步骤为“查询-修改-保存”,适用于需要先获取当前数据再进行修改的场景;
  • engine.save()方法会自动识别实例是否已存在(通过id字段),若存在则执行更新操作,若不存在则执行插入操作。
(2)基于查询条件的批量更新

代码示例

from odmantic import UpdateQuery

# 批量更新:将年龄小于25的用户年龄加1
async def batch_update_users(engine: AIOEngine):
    # 1. 构造更新查询
    update_query = UpdateQuery({User.age: User.age + 1})
    # 2. 执行批量更新
    update_result = await engine.update(
        User,
        User.age < 25,
        update_query
    )
    print(f"\n批量更新成功,共更新{update_result.modified_count}条记录")

# 执行批量更新
async def run_batch_update_demo():
    engine = await get_engine()
    await batch_update_users(engine)

if __name__ == "__main__":
    asyncio.run(run_batch_update_demo())

代码说明

  • 批量更新需要使用UpdateQuery类构造更新内容,支持字段的自增、自减等操作;
  • engine.update()的参数依次为:模型类、查询条件、更新查询对象,返回的结果对象包含modified_count属性,表示实际更新的记录数。

3.2.4 删除数据(Delete)

删除数据同样有两种方式:删除单个实例和批量删除符合条件的文档。

(1)删除单个实例

代码示例

# 删除指定用户
async def delete_user_by_instance(engine: AIOEngine, username: str):
    # 1. 查询用户
    user = await engine.find_one(User, User.username == username)
    if not user:
        print(f"未找到用户{username},删除失败")
        return
    # 2. 删除用户
    await engine.delete(user)
    print(f"\n用户{username}删除成功")

# 执行删除操作
async def run_delete_instance_demo():
    engine = await get_engine()
    await delete_user_by_instance(engine, "lisi")

if __name__ == "__main__":
    asyncio.run(run_delete_instance_demo())

代码说明

  • engine.delete()方法接收一个模型实例作为参数,根据实例的id字段删除对应的文档。
(2)批量删除文档

代码示例

# 批量删除:删除邮箱为None的用户
async def batch_delete_users(engine: AIOEngine):
    delete_result = await engine.delete(User, User.email == None)
    print(f"\n批量删除成功,共删除{delete_result.deleted_count}条记录")

# 执行批量删除
async def run_batch_delete_demo():
    engine = await get_engine()
    await batch_delete_users(engine)

if __name__ == "__main__":
    asyncio.run(run_batch_delete_demo())

代码说明

  • engine.delete()方法若传入模型类和查询条件,则会批量删除符合条件的所有文档;
  • 返回的结果对象包含deleted_count属性,表示实际删除的记录数。

3.3 模型关联(一对一、一对多)

在实际应用中,数据之间往往存在关联关系,例如“用户”和“文章”的一对多关系(一个用户可以发布多篇文章)。ODMantic支持通过Reference字段实现模型之间的关联。

3.3.1 定义关联模型

我们以“用户(User)”和“文章(Article)”为例,演示一对多关联的实现:

from odmantic import Model, Field, Reference
from typing import Optional, List
from datetime import datetime

# 定义User模型
class User(Model):
    username: str = Field(min_length=3, max_length=20)
    age: int = Field(ge=0, le=120)

    class Config:
        collection = "users"

# 定义Article模型,与User模型关联
class Article(Model):
    title: str = Field(min_length=1, max_length=100)
    content: str
    publish_time: datetime = datetime.now()
    # 关联到User模型,表示文章的作者
    author: Reference[User]

    class Config:
        collection = "articles"

代码说明

  • Reference[User]表示author字段是一个指向User模型的引用,存储的是User实例的id
  • 这种定义方式实现了从ArticleUser的单向关联,若需要双向关联,可在User模型中添加articles: List[Article] = []字段,并结合反向查询实现。

3.3.2 创建关联数据

代码示例

# 创建用户并发布文章
async def create_user_and_articles(engine: AIOEngine):
    # 1. 创建用户
    user = User(username="zhangsan", age=25)
    saved_user = await engine.save(user)
    print(f"创建用户成功:{saved_user.username}")

    # 2. 创建两篇文章,关联到该用户
    article1 = Article(title="Python异步编程入门", content="异步编程是Python的重要特性...", author=saved_user)
    article2 = Article(title="ODMantic使用指南", content="ODMantic是一款优秀的异步ODM工具...", author=saved_user)
    saved_articles = await engine.save_all([article1, article2])
    print(f"创建文章成功,共发布{len(saved_articles)}篇文章")

# 执行关联数据创建
async def run_relation_create_demo():
    engine = await get_engine()
    await create_user_and_articles(engine)

if __name__ == "__main__":
    asyncio.run(run_relation_create_demo())

代码说明

  • 创建关联数据时,直接将User实例赋值给Articleauthor字段即可,ODMantic会自动处理引用关系,存储Userid

3.3.3 查询关联数据

查询关联数据有两种方式:从子模型查询父模型(通过author字段查询用户信息),以及从父模型查询子模型(通过用户查询其发布的所有文章)。

代码示例

# 从文章查询作者信息
async def find_article_author(engine: AIOEngine, article_title: str):
    article = await engine.find_one(Article, Article.title == article_title)
    if not article:
        print(f"未找到标题为{article_title}的文章")
        return
    # 直接访问article.author即可获取关联的用户实例
    author = article.author
    print(f"\n文章《{article.title}》的作者信息:")
    print(f"用户名: {author.username}, 年龄: {author.age}")

# 从用户查询发布的所有文章
async def find_user_articles(engine: AIOEngine, username: str):
    user = await engine.find_one(User, User.username == username)
    if not user:
        print(f"未找到用户{username}")
        return
    # 查询该用户发布的所有文章
    articles = await engine.find(Article, Article.author == user)
    print(f"\n用户{username}发布的文章列表:")
    for article in articles:
        print(f"标题: {article.title}, 发布时间: {article.publish_time}")

# 执行关联数据查询
async def run_relation_find_demo():
    engine = await get_engine()
    await find_article_author(engine, "Python异步编程入门")
    await find_user_articles(engine, "zhangsan")

if __name__ == "__main__":
    asyncio.run(run_relation_find_demo())

代码说明

  • 从子模型查询父模型时,直接访问关联字段即可(如article.author),ODMantic会自动根据存储的id查询对应的父模型实例;
  • 从父模型查询子模型时,构造查询条件为Article.author == user,即可获取该用户发布的所有文章。

四、ODMantic与FastAPI框架集成实战

FastAPI是一款高性能的异步Web框架,与ODMantic的异步特性完美契合。本节将演示如何在FastAPI项目中集成ODMantic,实现一个简单的用户管理API。

4.1 项目目录结构

odmantic_fastapi_demo/
├── main.py               # 项目入口文件,包含API路由
└── models.py             # 数据模型定义

4.2 定义数据模型(models.py)

from odmantic import Model, Field
from datetime import datetime
from typing import Optional

class User(Model):
    username: str = Field(min_length=3, max_length=20)
    age: int = Field(ge=0, le=120)
    email: Optional[str] = None
    register_time: datetime = datetime.now()

    class Config:
        collection = "users"

4.3 实现FastAPI API(main.py)

from fastapi import FastAPI, HTTPException
from odmantic import AIOEngine, ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models import User
from typing import List
import asyncio

# 初始化FastAPI应用
app = FastAPI(title="ODMantic + FastAPI 示例", version="1.0")

# 全局引擎对象
engine: AIOEngine = None

# 启动时初始化数据库连接
@app.on_event("startup")
async def startup_db():
    global engine
    client = AsyncIOMotorClient("mongodb://localhost:27017/")
    engine = AIOEngine(client=client, database="odmantic_fastapi_demo")
    print("数据库连接成功!")

# 关闭时断开数据库连接
@app.on_event("shutdown")
async def shutdown_db():
    global engine
    if engine:
        engine.client.close()
        print("数据库连接已关闭!")

# API路由:创建用户
@app.post("/users/", response_model=User, summary="创建新用户")
async def create_user(user: User):
    saved_user = await engine.save(user)
    return saved_user

# API路由:获取单个用户
@app.get("/users/{user_id}", response_model=User, summary="根据ID获取用户")
async def get_user(user_id: str):
    try:
        # 将字符串ID转换为ObjectId
        obj_id = ObjectId(user_id)
    except:
        raise HTTPException(status_code=400, detail="无效的用户ID格式")
    user = await engine.find_one(User, User.id == obj_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    return user

# API路由:获取所有用户
@app.get("/users/", response_model=List[User], summary="获取所有用户")
async def get_all_users():
    users = await engine.find(User)
    return users

# API路由:更新用户
@app.put("/users/{user_id}", response_model=User, summary="更新用户信息")
async def update_user(user_id: str, user_update: User):
    try:
        obj_id = ObjectId(user_id)
    except:
        raise HTTPException(status_code=400, detail="无效的用户ID格式")
    user = await engine.find_one(User, User.id == obj_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    # 更新字段
    user.username = user_update.username
    user.age = user_update.age
    user.email = user_update.email
    updated_user = await engine.save(user)
    return updated_user

# API路由:删除用户
@app.delete("/users/{user_id}", summary="删除用户")
async def delete_user(user_id: str):
    try:
        obj_id = ObjectId(user_id)
    except:
        raise HTTPException(status_code=400, detail="无效的用户ID格式")
    user = await engine.find_one(User, User.id == obj_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    await engine.delete(user)
    return {"message": "用户删除成功"}

# 运行项目
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

代码说明

  • 通过@app.on_event("startup")@app.on_event("shutdown")钩子函数,在FastAPI应用启动时初始化数据库连接,关闭时断开连接;
  • 所有API接口均为异步函数,通过engine对象完成数据操作;
  • 使用response_model指定接口的返回数据模型,FastAPI会自动进行数据校验和文档生成;
  • 运行项目后,可访问http://localhost:8000/docs查看自动生成的API文档,并进行接口测试。

4.4 启动和测试项目

  1. 启动项目:在命令行中执行python main.py,FastAPI应用会在8000端口启动;
  2. 测试接口:打开浏览器访问http://localhost:8000/docs,可以看到自动生成的Swagger文档,点击对应的接口即可进行测试,例如:
  • 点击/users/POST接口,填写用户信息后点击“Execute”,即可创建新用户;
  • 点击/users/GET接口,可获取所有用户的列表;
  • 点击/users/{user_id}GET接口,输入用户ID,可获取指定用户的信息。

五、相关资源链接

  • Pypi地址:https://pypi.org/project/ODMantic
  • Github地址:https://github.com/art049/odmantic
  • 官方文档地址:https://art049.github.io/odmantic/

关注我,每天分享一个实用的Python自动化工具。