Python实用工具:Datasets库快速上手指南_数据处理与加载必备神器

一、Datasets库核心概述

1.1 用途与工作原理

Hugging Face的Datasets库是一款专为自然语言处理(NLP)以及机器学习领域打造的数据集处理工具,它能够帮助开发者快速加载、预处理、转换和共享各类数据集。其核心工作原理是基于一种统一的数据集抽象结构DatasetDatasetDict,将不同来源、不同格式的数据集(如CSV、JSON、文本文件、Hugging Face Hub上的开源数据集)进行标准化封装,同时内置高效的数据处理管道,支持并行处理、流式加载和内存映射,大幅降低了数据预处理的门槛。

1.2 优缺点分析

优点

  1. 数据集资源丰富:无缝对接Hugging Face Hub,拥有上万种开源数据集,涵盖文本分类、问答、翻译等数十种任务类型。
  2. 高效便捷的预处理:内置map()filter()shuffle()等方法,支持并行处理大规模数据集,无需手动编写复杂的循环逻辑。
  3. 内存友好:支持流式加载和内存映射,即使处理GB级别的数据集也不会导致内存溢出。
  4. 跨框架兼容:可以轻松转换为pandasNumPyTensorFlowPyTorch等格式,适配主流机器学习框架。

缺点

  1. 对非文本数据支持有限:虽然可以处理图像、音频类数据集,但功能不如文本数据集完善,预处理选项较少。
  2. 网络依赖较强:加载Hugging Face Hub上的数据集需要稳定的网络环境,离线使用时需要提前下载数据集。
  3. 新手学习成本:部分高级功能(如自定义数据集加载脚本)的使用需要熟悉库的底层逻辑,对纯小白不够友好。

1.3 License类型

Datasets库采用Apache License 2.0开源协议,这意味着开发者可以自由地使用、修改和分发该库的代码,无论是商业项目还是非商业项目,都无需支付任何费用,只需要在修改后的代码中保留原作者的版权声明即可。

二、Datasets库安装与环境配置

2.1 安装命令

Datasets库的安装非常简单,支持pipconda两种方式,推荐使用pip安装,因为它的更新速度更快,适配性更强。

2.1.1 pip安装(推荐)

打开命令行终端,输入以下命令即可完成安装:

pip install datasets

如果需要处理特定格式的数据集(如音频、图像、Parquet文件),可以安装对应的依赖包:

pip install datasets[audio,vision,parquet]

2.1.2 conda安装

如果你的开发环境是基于conda管理的,可以使用以下命令安装:

conda install -c huggingface -c conda-forge datasets

2.2 环境验证

安装完成后,我们可以通过一段简单的Python代码验证是否安装成功。打开Python交互式环境或新建一个.py文件,输入以下代码:

import datasets
print(f"Datasets库版本:{datasets.__version__}")

运行代码后,如果终端输出了具体的版本号(如2.14.5),则说明安装成功;如果出现ModuleNotFoundError,则说明安装失败,需要重新检查安装命令或Python环境。

三、Datasets库核心功能与代码示例

3.1 加载开源数据集(Hugging Face Hub)

Datasets库最核心的功能之一就是一键加载Hugging Face Hub上的开源数据集。以经典的文本分类数据集imdb(电影评论情感分析数据集)为例,我们来演示如何加载并查看数据集的基本信息。

3.1.1 加载imdb数据集

from datasets import load_dataset

# 加载imdb数据集,该数据集包含train、test、unsupervised三个子集
dataset = load_dataset("imdb")

# 打印数据集的结构
print("数据集结构:", dataset)
# 打印训练集的第一条数据
print("训练集第一条数据:", dataset["train"][0])

代码说明

  • load_dataset()函数是加载数据集的核心函数,传入数据集名称即可自动从Hugging Face Hub下载并加载数据集。
  • imdb数据集是一个DatasetDict对象,包含train(训练集)、test(测试集)、unsupervised(无监督数据集)三个子集。
  • 通过下标索引dataset["train"][0]可以查看训练集的第一条数据,数据格式为字典,包含text(电影评论文本)和label(情感标签,0为负面,1为正面)两个字段。

运行结果示例

数据集结构: DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})
训练集第一条数据: {'text': 'I rented I AM CURIOUS-YELLOW...', 'label': 0}

3.1.2 加载指定子集与拆分数据

如果我们只需要加载imdb数据集的训练集和测试集,可以通过split参数指定;同时,还可以使用train_test_split()方法将数据集拆分为新的训练集和验证集。

from datasets import load_dataset

# 只加载train和test子集
dataset = load_dataset("imdb", split=["train", "test"])
train_dataset, test_dataset = dataset

# 将训练集拆分为训练集(80%)和验证集(20%)
train_val_dataset = train_dataset.train_test_split(test_size=0.2, seed=42)
train_new = train_val_dataset["train"]
val_new = train_val_dataset["test"]

print(f"新训练集大小:{len(train_new)}")
print(f"验证集大小:{len(val_new)}")
print(f"测试集大小:{len(test_dataset)}")

代码说明

  • split参数传入一个列表,可以指定加载的数据集子集,返回的结果是一个列表,需要手动解包为对应的数据集对象。
  • train_test_split()方法用于拆分数据集,test_size参数指定验证集的比例,seed参数用于固定随机种子,保证实验的可重复性。
  • len()函数可以查看数据集的样本数量,imdb原训练集有25000条数据,拆分后新训练集有20000条,验证集有5000条。

运行结果示例

新训练集大小:20000
验证集大小:5000
测试集大小:25000

3.2 加载本地数据集

除了加载Hugging Face Hub上的开源数据集,Datasets库还支持加载本地的CSV、JSON、文本等格式的数据集。下面我们以CSV格式为例,演示如何加载本地数据集。

3.2.1 准备本地CSV数据集

首先,我们创建一个名为local_data.csv的CSV文件,内容如下:

text,label
This movie is amazing,1
This movie is terrible,0
I love this film,1
The plot is boring,0
The acting is great,1

该文件包含两列,text列是电影评论文本,label列是情感标签。

3.2.2 加载本地CSV数据集

from datasets import load_dataset

# 加载本地CSV数据集
local_dataset = load_dataset("csv", data_files="local_data.csv")

# 打印数据集结构
print("本地数据集结构:", local_dataset)
# 打印所有数据
print("本地数据集所有数据:", local_dataset["train"])
# 访问单条数据的字段
for data in local_dataset["train"]:
    print(f"文本:{data['text']},标签:{data['label']}")

代码说明

  • 加载本地数据集时,load_dataset()函数的第一个参数需要指定数据格式(如csvjsontext),data_files参数指定本地数据文件的路径。
  • 对于单文件的CSV数据集,加载后默认会生成一个名为train的子集。
  • 可以通过遍历数据集对象,访问每条数据的具体字段。

运行结果示例

本地数据集结构: DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 5
    })
})
本地数据集所有数据: Dataset({
    features: ['text', 'label'],
    num_rows: 5
})
文本:This movie is amazing,标签:1
文本:This movie is terrible,标签:0
文本:I love this film,标签:1
文本:The plot is boring,标签:0
文本:The acting is great,标签:1

3.2.3 加载多个本地数据文件

如果本地有多个数据文件(如训练集和测试集分别存储在不同的CSV文件中),可以通过data_files参数传入一个字典来指定:

from datasets import load_dataset

# 定义多个数据文件的路径
data_files = {
    "train": "train_data.csv",
    "test": "test_data.csv"
}

# 加载多个CSV文件
multi_dataset = load_dataset("csv", data_files=data_files)

print("多文件数据集结构:", multi_dataset)
print(f"训练集样本数:{len(multi_dataset['train'])}")
print(f"测试集样本数:{len(multi_dataset['test'])}")

代码说明

  • data_files参数传入一个字典时,字典的键是数据集子集的名称(如traintest),值是对应的数据文件路径。
  • 加载后会生成一个DatasetDict对象,包含指定的多个子集,方便后续分别处理训练集和测试集。

3.3 数据集预处理

加载数据集后,通常需要进行预处理(如文本分词、长度截断、特征提取等),Datasets库提供了map()方法,可以高效地对数据集进行批量预处理。下面我们以文本分词为例,演示如何使用map()方法。

3.3.1 安装分词器依赖

我们使用Hugging Face的transformers库中的分词器,需要先安装该库:

pip install transformers

3.3.2 对imdb数据集进行分词预处理

from datasets import load_dataset
from transformers import AutoTokenizer

# 加载imdb数据集的训练集
dataset = load_dataset("imdb", split="train")
# 加载预训练分词器(以bert-base-uncased为例)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# 定义预处理函数
def preprocess_function(examples):
    # 对文本进行分词,设置最大长度为128,超过截断,不足补齐
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

# 对数据集进行批量预处理,开启并行处理
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4  # 使用4个进程并行处理
)

# 查看预处理后的数据集特征
print("预处理后的特征:", tokenized_dataset.features)
# 查看第一条预处理后的数据
print("第一条预处理后的数据:", tokenized_dataset[0])

代码说明

  • AutoTokenizer.from_pretrained()函数用于加载预训练分词器,bert-base-uncased是一个常用的英文预训练模型。
  • preprocess_function()是自定义的预处理函数,接收一个包含批量数据的字典examples,返回分词后的结果,包含input_idsattention_mask等字段。
  • map()方法的batched=True参数表示批量处理数据,num_proc参数指定并行处理的进程数,能够大幅提升预处理速度。
  • 预处理后的数据集会新增分词相关的字段,这些字段可以直接输入到预训练模型中进行训练。

运行结果示例

预处理后的特征: {'text': Value(dtype='string', id=None), 'label': ClassLabel(names=['neg', 'pos'], id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None)}
第一条预处理后的数据: {'text': 'I rented I AM CURIOUS-YELLOW...', 'label': 0, 'input_ids': [101, 1045, 5975, ...], 'token_type_ids': [0, 0, 0, ...], 'attention_mask': [1, 1, 1, ...]}

3.4 数据集筛选与排序

Datasets库提供了filter()sort()方法,用于对数据集进行筛选和排序。下面我们演示如何筛选出imdb数据集中文本长度大于100的样本,以及如何按文本长度对数据集进行排序。

3.4.1 数据集筛选

from datasets import load_dataset

# 加载imdb训练集
dataset = load_dataset("imdb", split="train")

# 定义筛选函数:保留文本长度大于100的样本
def filter_function(example):
    return len(example["text"]) > 100

# 对数据集进行筛选
filtered_dataset = dataset.filter(filter_function)

print(f"原数据集样本数:{len(dataset)}")
print(f"筛选后数据集样本数:{len(filtered_dataset)}")
# 查看筛选后第一条数据的文本长度
print(f"筛选后第一条数据文本长度:{len(filtered_dataset[0]['text'])}")

代码说明

  • filter_function()是自定义的筛选函数,接收单条数据example,返回True表示保留该样本,返回False表示过滤掉该样本。
  • filter()方法会遍历数据集的所有样本,根据筛选函数的结果保留符合条件的样本。

运行结果示例

原数据集样本数:25000
筛选后数据集样本数:24890
筛选后第一条数据文本长度:178

3.4.2 数据集排序

from datasets import load_dataset

# 加载imdb训练集
dataset = load_dataset("imdb", split="train")

# 为数据集添加文本长度字段
def add_length_field(example):
    example["text_length"] = len(example["text"])
    return example

dataset_with_length = dataset.map(add_length_field)

# 按文本长度升序排序
sorted_dataset_asc = dataset_with_length.sort("text_length")
# 按文本长度降序排序
sorted_dataset_desc = dataset_with_length.sort("text_length", reverse=True)

print(f"最短文本长度:{sorted_dataset_asc[0]['text_length']}")
print(f"最长文本长度:{sorted_dataset_desc[0]['text_length']}")

代码说明

  • 首先通过map()方法为数据集添加一个text_length字段,存储每条数据的文本长度。
  • sort()方法接收一个字段名,按照该字段的值对数据集进行排序,reverse=True表示降序排序,默认是升序排序。

运行结果示例

最短文本长度:14
最长文本长度:13704

3.5 数据集格式转换

Datasets库的数据集对象可以轻松转换为pandasNumPyPyTorchTensorFlow等格式,适配不同的机器学习框架。下面我们演示如何将数据集转换为这些格式。

3.5.1 转换为pandas DataFrame格式

from datasets import load_dataset

# 加载imdb训练集
dataset = load_dataset("imdb", split="train[:100]")  # 只取前100条数据

# 转换为pandas DataFrame
df = dataset.to_pandas()

# 查看DataFrame的前5行数据
print(df.head())
# 查看DataFrame的基本信息
print(df.info())

代码说明

  • to_pandas()方法可以将数据集转换为pandasDataFrame对象,方便使用pandas进行数据分析和可视化。
  • 为了避免数据量过大,我们通过split="train[:100]"只取训练集的前100条数据。

运行结果示例

                                                text  label
0  I rented I AM CURIOUS-YELLOW...                    0
1  "I Am Curious: Yellow" is a...                    0
2  If only to avoid making this...                    0
3  This film was probably intende...                  0
4  I saw this movie when I was ...                   0
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 2 columns):
 #   Column  Non-Null Count  Dtype 
    --  -- 
 0   text    100 non-null    object
 1   label   100 non-null    int64 
dtypes: int64(1), object(1)
memory usage: 1.7+ KB
None

3.5.2 转换为NumPy数组格式

from datasets import load_dataset

# 加载imdb训练集
dataset = load_dataset("imdb", split="train[:100]")

# 转换为NumPy数组
numpy_array = dataset.to_numpy()

# 查看NumPy数组的形状和第一条数据
print(f"NumPy数组形状:{numpy_array.shape}")
print(f"第一条数据:{numpy_array[0]}")

代码说明

  • to_numpy()方法将数据集转换为NumPy数组,数组中的每个元素是一个字典,包含数据集的所有字段。

运行结果示例

NumPy数组形状:(100,)
第一条数据: {'text': 'I rented I AM CURIOUS-YELLOW...', 'label': 0}

3.5.3 转换为PyTorch张量格式

from datasets import load_dataset
from transformers import AutoTokenizer

# 加载imdb训练集并进行分词预处理
dataset = load_dataset("imdb", split="train[:100]")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

tokenized_dataset = dataset.map(preprocess_function, batched=True)

# 转换为PyTorch张量格式
torch_dataset = tokenized_dataset.with_format("torch")

# 查看张量格式的输入ids和标签
print(f"Input IDs张量形状:{torch_dataset['input_ids'].shape}")
print(f"标签张量形状:{torch_dataset['label'].shape}")

代码说明

  • with_format("torch")方法将数据集转换为PyTorch张量格式,转换后可以直接通过下标访问张量数据。
  • 转换后的数据集可以直接输入到PyTorch模型中进行训练,无需手动转换数据类型。

运行结果示例

Input IDs张量形状: torch.Size([100, 128])
标签张量形状: torch.Size([100])

四、实际案例:基于Datasets库的文本分类任务

为了让大家更好地理解Datasets库的实际应用,我们结合transformers库搭建一个简单的文本分类模型,完成imdb电影评论情感分析任务。

4.1 案例流程

  1. 加载imdb数据集并进行预处理。
  2. 加载预训练模型和分词器。
  3. 构建训练参数并训练模型。
  4. 在测试集上评估模型性能。

4.2 完整代码实现

# 导入必要的库
from datasets import load_dataset, load_metric
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer
)
import numpy as np

# 步骤1:加载数据集并进行预处理
# 加载imdb的训练集和测试集
dataset = load_dataset("imdb")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

# 加载预训练分词器
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# 定义预处理函数
def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=128
    )

# 对训练集和测试集进行预处理
tokenized_train = train_dataset.map(preprocess_function, batched=True, num_proc=4)
tokenized_test = test_dataset.map(preprocess_function, batched=True, num_proc=4)

# 设置数据集格式为PyTorch张量
tokenized_train.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "label"]
)
tokenized_test.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "label"]
)

# 步骤2:加载预训练模型
model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2  # 二分类任务,标签数为2
)

# 步骤3:定义评估指标
metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# 步骤4:设置训练参数
training_args = TrainingArguments(
    output_dir="./imdb_sentiment_model",  # 模型输出目录
    num_train_epochs=2,  # 训练轮数
    per_device_train_batch_size=16,  # 训练批次大小
    per_device_eval_batch_size=16,  # 评估批次大小
    evaluation_strategy="epoch",  # 每轮训练后进行评估
    save_strategy="epoch",  # 每轮训练后保存模型
    logging_dir="./logs",  # 日志目录
    logging_steps=100,  # 每100步记录一次日志
    learning_rate=2e-5,  # 学习率
    weight_decay=0.01,  # 权重衰减
)

# 步骤5:构建Trainer并开始训练
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    compute_metrics=compute_metrics,
)

# 开始训练
trainer.train()

# 步骤6:在测试集上评估模型
eval_results = trainer.evaluate()
print(f"测试集准确率:{eval_results['eval_accuracy']:.4f}")

4.3 代码说明

  1. 数据集预处理:我们使用bert-base-uncased分词器对文本进行分词,设置最大长度为128,并将数据集转换为PyTorch张量格式。
  2. 模型加载AutoModelForSequenceClassification是一个用于序列分类任务的预训练模型,num_labels=2表示二分类任务。
  3. 评估指标:使用accuracy指标评估模型性能,compute_metrics函数用于计算预测结果的准确率。
  4. 训练参数设置TrainingArguments包含了训练过程中的所有参数,如训练轮数、批次大小、学习率等。
  5. 模型训练与评估Trainertransformers库提供的训练工具,调用train()方法开始训练,evaluate()方法在测试集上评估模型性能。

4.4 预期结果

经过2轮训练后,模型在imdb测试集上的准确率可以达到90%以上,具体数值会因硬件环境和随机种子的不同而略有差异。

五、相关资源链接

  • Pypi地址:https://pypi.org/project/datasets
  • Github地址:https://github.com/huggingface/datasets
  • 官方文档地址:https://huggingface.co/docs/datasets/index

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:psycopg3 完全指南——PostgreSQL数据库交互最佳实践

一、psycopg3 库概述

psycopg3 是 Python 编程语言中用于连接和操作 PostgreSQL 数据库的高性能适配器,也是 psycopg2 的官方升级版本。它遵循 Python DB API 2.0 规范,能够实现 Python 程序与 PostgreSQL 数据库的高效数据交互,支持异步操作、类型适配、事务管理等核心功能。

工作原理上,psycopg3 通过底层的 libpq 库(PostgreSQL 官方客户端库)建立 Python 与数据库的通信通道,将 Python 数据类型转换为 PostgreSQL 支持的类型,同时把数据库返回的结果转换为 Python 原生对象,实现双向数据无缝流转。

该库的优点十分突出:支持异步 I/O 操作,适配 asyncio 框架;性能较 psycopg2 大幅提升;对 Python 3.8+ 版本兼容性好;支持 PostgreSQL 10+ 所有新特性;提供灵活的参数化查询,有效防止 SQL 注入攻击。缺点则是目前生态插件较少,部分旧项目迁移需要修改代码;对低版本 Python(3.7 及以下)不支持。

psycopg3 的开源协议为 GNU Lesser General Public License (LGPL) 3.0,允许用户自由使用、修改和分发,可用于商业项目开发。

二、psycopg3 安装与环境配置

2.1 前置条件

在安装 psycopg3 之前,需要确保系统满足以下两个核心条件:

  1. Python 环境:版本 3.8 或更高,推荐使用 3.10+ 稳定版本。
  2. PostgreSQL 环境:本地或远程服务器已安装 PostgreSQL 10 或更高版本,且确保数据库服务处于运行状态。
  3. 依赖库:系统需要安装 libpq 开发包,不同操作系统安装命令如下:
    • Ubuntu/Debian 系统
      bash sudo apt-get install libpq-dev python3-dev
    • CentOS/RHEL 系统
      bash sudo yum install postgresql-devel python3-devel
    • Windows 系统
      无需手动安装,psycopg3 的 Windows 版本已内置相关依赖。

2.2 安装 psycopg3

psycopg3 已发布到 PyPI 仓库,可通过 pip 包管理器一键安装,这是最简单且推荐的方式。

2.2.1 基础安装命令

打开终端或命令提示符,执行以下命令:

pip install psycopg3

2.2.2 指定版本安装

如果需要安装特定版本的 psycopg3(例如 3.1.12),可以执行:

pip install psycopg3==3.1.12

2.2.3 验证安装是否成功

安装完成后,可通过 Python 交互式环境验证是否安装成功。打开 Python 终端,输入以下代码:

import psycopg
print(psycopg.__version__)

如果终端输出 psycopg3 的版本号(如 3.1.12),则说明安装成功;若出现 ModuleNotFoundError 错误,则需要检查 pip 安装路径或 Python 环境配置。

三、psycopg3 核心用法与代码示例

psycopg3 的核心操作围绕 数据库连接游标对象数据查询数据增删改事务管理 展开,下面结合具体代码示例详细讲解每个功能的使用方法。

3.1 数据库连接与关闭

要操作 PostgreSQL 数据库,第一步是建立 Python 程序与数据库的连接。psycopg3 提供 psycopg.connect() 函数创建连接对象,连接参数需要包含数据库地址、端口、用户名、密码、数据库名等信息。

3.1.1 基础连接示例

import psycopg

# 定义数据库连接参数
conn_params = {
    "host": "localhost",  # 数据库服务器地址,本地为localhost
    "port": 5432,         # PostgreSQL 默认端口为5432
    "user": "postgres",   # 数据库用户名,默认管理员用户为postgres
    "password": "123456", # 数据库密码,安装时设置
    "dbname": "testdb"    # 要连接的数据库名,需提前创建
}

# 建立数据库连接
try:
    conn = psycopg.connect(**conn_params)
    print("数据库连接成功!")
except psycopg.OperationalError as e:
    print(f"数据库连接失败:{e}")
finally:
    # 关闭数据库连接
    if conn:
        conn.close()
        print("数据库连接已关闭!")

代码说明

  • psycopg.connect() 函数接收关键字参数,返回一个 Connection 对象,代表与数据库的会话。
  • 使用 try-except 捕获连接异常(如密码错误、数据库不存在、服务未启动等),避免程序崩溃。
  • 最后通过 conn.close() 关闭连接,释放数据库资源,这是必须执行的步骤。

3.1.2 使用上下文管理器(推荐)

手动关闭连接容易遗漏,psycopg3 支持使用 with 语句(上下文管理器)自动管理连接的创建和关闭,这是更优雅、更推荐的写法。

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

# 使用with语句自动管理连接
with psycopg.connect(**conn_params) as conn:
    print("数据库连接成功!")
    # 后续数据库操作写在这里
print("数据库连接已自动关闭!")

代码说明

  • with 代码块执行完毕后,无论是否发生异常,连接都会自动关闭,无需手动调用 conn.close()
  • 这种写法可以有效避免因忘记关闭连接导致的资源泄露问题。

3.2 游标对象与基本数据操作

建立数据库连接后,需要通过 游标对象(Cursor) 执行 SQL 语句。游标是数据库操作的核心接口,负责提交 SQL 命令并获取执行结果。

3.2.1 创建游标对象

with 语句块内,通过 conn.cursor() 方法创建游标对象:

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    # 创建游标对象
    with conn.cursor() as cur:
        print("游标对象创建成功!")

代码说明

  • 游标对象也支持 with 语句,执行完毕后自动关闭,释放游标资源。
  • 所有 SQL 语句的执行都需要通过游标对象的方法实现。

3.2.2 创建数据表

下面通过游标执行 CREATE TABLE 语句,创建一个名为 students 的数据表,用于存储学生信息。

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义创建数据表的SQL语句
        create_table_sql = """
        CREATE TABLE IF NOT EXISTS students (
            id SERIAL PRIMARY KEY,
            name VARCHAR(50) NOT NULL,
            age INTEGER NOT NULL,
            gender VARCHAR(10),
            score NUMERIC(5, 2)
        );
        """
        # 执行SQL语句
        cur.execute(create_table_sql)
        # 提交事务(psycopg3默认开启事务,执行修改操作后需要提交)
        conn.commit()
        print("数据表students创建成功!")

代码说明

  • CREATE TABLE IF NOT EXISTS 表示如果数据表不存在则创建,避免重复创建导致的错误。
  • SERIAL 类型是 PostgreSQL 的自增整数类型,作为主键使用。
  • cur.execute(sql) 方法用于执行 SQL 语句,参数为字符串格式的 SQL 命令。
  • conn.commit() 用于提交事务,psycopg3 默认处于事务模式,所有对数据库的修改操作(创建表、插入、更新、删除)都需要提交事务才能生效。

3.2.3 插入数据

向数据表中插入数据有两种方式:单条数据插入和批量数据插入,psycopg3 推荐使用 参数化查询 方式,避免 SQL 注入攻击。

(1)单条数据插入
import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义插入数据的SQL语句,使用%s作为参数占位符
        insert_sql = """
        INSERT INTO students (name, age, gender, score)
        VALUES (%s, %s, %s, %s);
        """
        # 定义要插入的数据
        student_data = ("张三", 18, "男", 95.5)
        # 执行插入操作
        cur.execute(insert_sql, student_data)
        # 提交事务
        conn.commit()
        print(f"成功插入 {cur.rowcount} 条数据!")

代码说明

  • SQL 语句中使用 %s 作为参数占位符,无论参数类型是什么,都统一使用 %s,这是 psycopg3 的固定语法。
  • cur.execute() 方法的第二个参数是一个元组,包含要插入的具体数据,元组长度必须与占位符数量一致。
  • cur.rowcount 属性返回受上一条 SQL 语句影响的行数,用于验证数据是否插入成功。
(2)批量数据插入

如果需要插入多条数据,使用 cur.executemany() 方法,效率远高于多次执行 cur.execute()

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        insert_sql = """
        INSERT INTO students (name, age, gender, score)
        VALUES (%s, %s, %s, %s);
        """
        # 定义多条学生数据,列表中每个元素是一个元组
        students_data = [
            ("李四", 19, "女", 92.0),
            ("王五", 18, "男", 88.5),
            ("赵六", 20, "女", 96.0)
        ]
        # 执行批量插入
        cur.executemany(insert_sql, students_data)
        conn.commit()
        print(f"成功插入 {cur.rowcount} 条数据!")

代码说明

  • cur.executemany() 方法接收两个参数:SQL 语句和包含多个元组的列表。
  • 批量插入可以减少与数据库的交互次数,大幅提升数据插入效率,适合大量数据插入场景。

3.2.4 查询数据

查询数据是数据库操作中最常用的功能,psycopg3 提供了多种方式获取查询结果,包括 cur.fetchone()cur.fetchmany()cur.fetchall()

(1)获取所有数据(fetchall)
import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义查询SQL语句
        select_sql = "SELECT * FROM students;"
        # 执行查询
        cur.execute(select_sql)
        # 获取所有查询结果
        all_students = cur.fetchall()
        # 遍历结果并打印
        print("所有学生信息:")
        for student in all_students:
            print(f"ID: {student[0]}, 姓名: {student[1]}, 年龄: {student[2]}, 性别: {student[3]}, 分数: {student[4]}")

代码说明

  • cur.fetchall() 方法返回一个列表,列表中的每个元素是一个元组,对应数据表中的一行数据。
  • 元组中的元素顺序与 SQL 查询语句中的字段顺序一致,例如 SELECT * 表示按数据表字段顺序返回所有字段。
(2)获取单条数据(fetchone)
import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        select_sql = "SELECT name, score FROM students WHERE age = %s;"
        # 查询年龄为18的学生
        cur.execute(select_sql, (18,))
        # 获取第一条匹配的数据
        student = cur.fetchone()
        if student:
            print(f"18岁学生信息:姓名={student[0]}, 分数={student[1]}")
        else:
            print("未找到18岁的学生!")

代码说明

  • cur.fetchone() 方法返回查询结果的第一条数据,类型为元组;如果没有匹配数据,返回 None
  • 该方法适合只需要获取单条数据的场景,例如根据主键查询数据。
(3)获取指定数量数据(fetchmany)
import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        select_sql = "SELECT name, age FROM students ORDER BY score DESC;"
        cur.execute(select_sql)
        # 获取前2条数据
        top_students = cur.fetchmany(2)
        print("分数最高的2名学生:")
        for student in top_students:
            print(f"姓名: {student[0]}, 年龄: {student[1]}")

代码说明

  • cur.fetchmany(size) 方法接收一个整数参数 size,表示要获取的数据条数,返回一个包含指定数量元组的列表。
  • 该方法适合分页查询场景,避免一次性获取大量数据导致内存占用过高。

3.2.5 更新数据

更新数据使用 UPDATE 语句,同样需要使用参数化查询,并提交事务才能生效。

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义更新SQL语句,将张三的分数更新为98.0
        update_sql = "UPDATE students SET score = %s WHERE name = %s;"
        cur.execute(update_sql, (98.0, "张三"))
        conn.commit()
        print(f"成功更新 {cur.rowcount} 条数据!")

代码说明

  • UPDATE 语句中的 WHERE 子句用于指定更新条件,避免误更新整个数据表的所有数据。
  • cur.rowcount 返回被更新的行数,可用于验证更新操作是否成功。

3.2.6 删除数据

删除数据使用 DELETE 语句,同样需要指定条件,防止误删所有数据。

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        # 定义删除SQL语句,删除分数低于90的学生
        delete_sql = "DELETE FROM students WHERE score < %s;"
        cur.execute(delete_sql, (90.0,))
        conn.commit()
        print(f"成功删除 {cur.rowcount} 条数据!")

代码说明

  • DELETE 语句中的 WHERE 子句是关键,若省略 WHERE 子句,将删除数据表中的所有数据。
  • 删除操作不可逆,执行前需谨慎确认条件是否正确。

3.3 事务管理

事务是数据库操作的基本单位,具有 原子性(Atomicity)、一致性(Consistency)、隔离性(Isolation)、持久性(Durability) 四个特性,简称 ACID。psycopg3 默认开启事务,所有修改操作都需要通过 conn.commit() 提交,若发生错误,可通过 conn.rollback() 回滚事务,撤销所有未提交的操作。

3.3.1 事务提交与回滚示例

import psycopg

conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

with psycopg.connect(**conn_params) as conn:
    with conn.cursor() as cur:
        try:
            # 插入一条数据
            insert_sql = "INSERT INTO students (name, age, gender, score) VALUES (%s, %s, %s, %s);"
            cur.execute(insert_sql, ("孙七", 19, "男", 85.0))

            # 故意触发错误(例如插入重复主键,这里用错误的SQL语句模拟)
            error_sql = "INSERT INTO students (id, name) VALUES (1, '孙七');"
            cur.execute(error_sql)

            # 提交事务(如果前面出现错误,这行代码不会执行)
            conn.commit()
            print("事务提交成功!")
        except psycopg.Error as e:
            # 发生错误,回滚事务
            conn.rollback()
            print(f"事务执行失败,已回滚:{e}")

代码说明

  • 当事务中的任何一个操作发生错误时,except 块会捕获异常,并执行 conn.rollback(),撤销所有已执行但未提交的操作。
  • 事务回滚可以保证数据库数据的一致性,避免因部分操作成功、部分操作失败导致的数据错乱。

3.4 异步操作

psycopg3 支持异步操作,通过 psycopg.asyncpg 模块实现,适配 Python 的 asyncio 框架,适合高并发的异步应用场景(如异步 Web 框架 FastAPI)。

3.4.1 异步连接与数据查询示例

import asyncio
import psycopg
from psycopg import AsyncConnection

# 定义异步函数
async def async_query():
    conn_params = {
        "host": "localhost",
        "port": 5432,
        "user": "postgres",
        "password": "123456",
        "dbname": "testdb"
    }
    # 建立异步数据库连接
    async with await psycopg.AsyncConnection.connect(**conn_params) as conn:
        # 创建异步游标
        async with conn.cursor() as cur:
            # 执行异步查询
            await cur.execute("SELECT name, score FROM students;")
            # 获取查询结果
            results = await cur.fetchall()
            print("异步查询结果:")
            for name, score in results:
                print(f"姓名: {name}, 分数: {score}")

# 运行异步函数
if __name__ == "__main__":
    asyncio.run(async_query())

代码说明

  • 异步操作需要使用 psycopg.AsyncConnection 类,通过 await 关键字执行异步方法。
  • async with 语句用于管理异步连接和异步游标,自动处理资源的创建和释放。
  • 异步操作可以显著提升高并发场景下的程序性能,避免因同步等待数据库响应导致的阻塞。

四、实际应用案例:学生成绩管理系统

下面结合一个实际案例,展示如何使用 psycopg3 开发一个简单的学生成绩管理系统,实现学生信息的添加、查询、更新和删除功能。

4.1 案例需求

  1. 能够添加新学生的成绩信息。
  2. 能够根据学生姓名查询成绩。
  3. 能够根据学生 ID 更新成绩。
  4. 能够根据学生姓名删除学生信息。
  5. 所有操作需要包含异常处理,确保程序稳定性。

4.2 完整代码实现

import psycopg
from typing import Optional, List, Tuple

class StudentScoreManager:
    def __init__(self, conn_params: dict):
        """初始化管理器,接收数据库连接参数"""
        self.conn_params = conn_params

    def add_student(self, name: str, age: int, gender: str, score: float) -> bool:
        """添加新学生信息
        Args:
            name: 学生姓名
            age: 学生年龄
            gender: 学生性别
            score: 学生分数
        Returns:
            添加成功返回True,失败返回False
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    insert_sql = """
                    INSERT INTO students (name, age, gender, score)
                    VALUES (%s, %s, %s, %s);
                    """
                    cur.execute(insert_sql, (name, age, gender, score))
                    conn.commit()
                    print(f"添加学生 {name} 成功!")
                    return True
        except psycopg.Error as e:
            print(f"添加学生失败:{e}")
            return False

    def query_student(self, name: str) -> Optional[Tuple]:
        """根据姓名查询学生信息
        Args:
            name: 学生姓名
        Returns:
            找到学生返回元组,未找到返回None
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    select_sql = """
                    SELECT id, name, age, gender, score FROM students WHERE name = %s;
                    """
                    cur.execute(select_sql, (name,))
                    student = cur.fetchone()
                    if student:
                        print(f"查询到学生信息:ID={student[0]}, 姓名={student[1]}, 年龄={student[2]}, 性别={student[3]}, 分数={student[4]}")
                        return student
                    else:
                        print(f"未找到姓名为 {name} 的学生!")
                        return None
        except psycopg.Error as e:
            print(f"查询学生失败:{e}")
            return None

    def update_score(self, student_id: int, new_score: float) -> bool:
        """根据学生ID更新分数
        Args:
            student_id: 学生ID
            new_score: 新分数
        Returns:
            更新成功返回True,失败返回False
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    update_sql = "UPDATE students SET score = %s WHERE id = %s;"
                    cur.execute(update_sql, (new_score, student_id))
                    conn.commit()
                    if cur.rowcount > 0:
                        print(f"更新ID为 {student_id} 的学生分数成功!")
                        return True
                    else:
                        print(f"未找到ID为 {student_id} 的学生!")
                        return False
        except psycopg.Error as e:
            print(f"更新分数失败:{e}")
            return False

    def delete_student(self, name: str) -> bool:
        """根据姓名删除学生信息
        Args:
            name: 学生姓名
        Returns:
            删除成功返回True,失败返回False
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    delete_sql = "DELETE FROM students WHERE name = %s;"
                    cur.execute(delete_sql, (name,))
                    conn.commit()
                    if cur.rowcount > 0:
                        print(f"删除学生 {name} 成功!")
                        return True
                    else:
                        print(f"未找到姓名为 {name} 的学生!")
                        return False
        except psycopg.Error as e:
            print(f"删除学生失败:{e}")
            return False

    def list_all_students(self) -> List[Tuple]:
        """查询所有学生信息
        Returns:
            包含所有学生信息的列表
        """
        try:
            with psycopg.connect(**self.conn_params) as conn:
                with conn.cursor() as cur:
                    cur.execute("SELECT * FROM students;")
                    all_students = cur.fetchall()
                    print(f"共查询到 {len(all_students)} 名学生信息:")
                    for student in all_students:
                        print(f"ID: {student[0]}, 姓名: {student[1]}, 年龄: {student[2]}, 性别: {student[3]}, 分数: {student[4]}")
                    return all_students
        except psycopg.Error as e:
            print(f"查询所有学生失败:{e}")
            return []

# 主程序入口
if __name__ == "__main__":
    # 配置数据库连接参数
    conn_params = {
        "host": "localhost",
        "port": 5432,
        "user": "postgres",
        "password": "123456",
        "dbname": "testdb"
    }

    # 创建学生成绩管理器实例
    manager = StudentScoreManager(conn_params)

    # 执行各项操作
    manager.add_student("周八", 19, "男", 91.5)
    manager.query_student("周八")
    manager.update_score(5, 94.0)  # 假设周八的ID为5
    manager.list_all_students()
    manager.delete_student("周八")
    manager.list_all_students()

代码说明

  • 该案例通过面向对象的方式封装了学生成绩管理的核心功能,每个方法对应一个具体的数据库操作。
  • 每个方法都包含了异常处理,确保程序在遇到数据库错误时不会崩溃,同时输出错误信息便于调试。
  • 主程序入口创建了管理器实例,并依次执行添加、查询、更新、删除和列表查询操作,展示了完整的业务流程。

五、psycopg3 高级特性

5.1 类型适配

psycopg3 能够自动将 Python 数据类型转换为 PostgreSQL 支持的类型,同时也能将 PostgreSQL 类型转换为 Python 类型。例如:

  • Python 的 int → PostgreSQL 的 INTEGER
  • Python 的 float → PostgreSQL 的 NUMERIC
  • Python 的 str → PostgreSQL 的 VARCHAR
  • Python 的 datetime.date → PostgreSQL 的 DATE

如果需要自定义类型适配,可以使用 psycopg.types 模块进行扩展。

5.2 连接池

在高并发应用中,频繁创建和关闭数据库连接会消耗大量资源,psycopg3 推荐使用连接池技术管理数据库连接。可以使用第三方库 psycopg_pool(psycopg 官方提供的连接池库)实现连接池功能。

5.2.1 安装 psycopg_pool

pip install psycopg_pool

5.2.2 连接池使用示例

import psycopg
from psycopg_pool import ConnectionPool

# 创建连接池
conn_params = {
    "host": "localhost",
    "port": 5432,
    "user": "postgres",
    "password": "123456",
    "dbname": "testdb"
}

pool = ConnectionPool(conninfo=conn_params, min_size=2, max_size=10)

# 从连接池获取连接
with pool.connection() as conn:
    with conn.cursor() as cur:
        cur.execute("SELECT name FROM students LIMIT 2;")
        print(cur.fetchall())

# 关闭连接池
pool.close()

代码说明

  • ConnectionPool 类的 min_size 参数表示连接池的最小连接数,max_size 表示最大连接数。
  • 使用 pool.connection() 从连接池获取连接,使用完毕后自动归还到连接池,无需手动关闭。

六、相关资源链接

  • Pypi地址:https://pypi.org/project/psycopg3
  • Github地址:https://github.com/psycopg/psycopg
  • 官方文档地址:https://www.psycopg.org/psycopg3/docs/

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:aioprometheus异步监控指标采集与上报教程

一、aioprometheus库核心概述

1.1 用途与工作原理

aioprometheus是一款基于Python asyncio框架开发的异步指标采集与上报库,专门用于为异步Python应用程序提供Prometheus监控指标的定义、收集和暴露功能。其工作原理遵循Prometheus的监控规范,通过定义Counter(计数器)、Gauge(仪表盘)、Summary(摘要)、Histogram(直方图)四种核心指标类型,在异步代码中埋点采集数据,再通过HTTP服务暴露指标接口,供Prometheus服务器定期拉取数据,最终实现对异步应用的性能监控与状态分析。

1.2 优缺点分析

优点

  1. 完全异步化设计,与aiohttp等异步框架完美兼容,不会阻塞事件循环,适合高并发异步应用场景;
  2. 原生支持Prometheus的四种核心指标类型,满足绝大多数监控需求;
  3. 提供灵活的指标标签(label)机制,支持多维度指标分析;
  4. 轻量级架构,安装简单,运行时资源消耗低。

缺点

  1. 文档相对精简,对于新手而言部分高级用法需要阅读源码理解;
  2. 生态相较于同步的prometheus_client库稍窄,第三方集成插件较少;
  3. 仅支持异步Python环境,无法直接用于同步应用程序。

1.3 License类型

aioprometheus采用MIT License开源协议,这意味着开发者可以自由地将其用于个人、商业项目,允许修改、分发源码,只需保留原始版权声明即可。

二、aioprometheus安装与环境准备

2.1 安装命令

aioprometheus库已发布至PyPI,可通过pip包管理工具直接安装,建议使用Python 3.7及以上版本(兼容asyncio的全部特性),安装命令如下:

pip install aioprometheus
# 如需使用aiohttp集成功能,可安装扩展依赖
pip install aioprometheus[aiohttp]

2.2 环境验证

安装完成后,可通过以下Python代码验证环境是否配置成功,该代码会创建一个简单的Counter指标并打印,确认库能够正常导入和使用:

import asyncio
from aioprometheus import Counter, Registry

async def verify_environment():
    # 创建指标注册表
    registry = Registry()
    # 定义Counter指标,用于统计请求次数
    http_requests_total = Counter(
        "http_requests_total",
        "Total number of HTTP requests",
        {"method": "GET", "endpoint": "/api"}
    )
    # 将指标注册到注册表
    registry.register(http_requests_total)
    # 增加指标计数
    http_requests_total.inc()
    # 打印指标信息
    print("指标信息:", http_requests_total.samples)

if __name__ == "__main__":
    asyncio.run(verify_environment())

运行上述代码,若控制台输出类似指标信息: [Sample(name='http_requests_total', labels={'method': 'GET', 'endpoint': '/api'}, value=1.0)]的内容,则说明aioprometheus已成功安装并可正常使用。

三、aioprometheus核心指标类型与使用方法

Prometheus的监控体系基于四种核心指标类型,aioprometheus对这四种类型进行了完整的异步封装,下面我们逐一讲解每种指标的定义、使用场景和代码示例。

3.1 Counter(计数器)

适用场景:用于统计只增不减的数值,例如请求次数、错误发生次数、任务完成数量等。Counter的核心操作是inc(),用于将指标值增加指定数值(默认增加1)。

代码示例:统计异步Web服务的GET请求次数

import asyncio
from aioprometheus import Counter, Registry, render
from aiohttp import web

# 1. 创建全局注册表
registry = Registry()

# 2. 定义Counter指标
# 参数说明:指标名、指标帮助信息、默认标签
http_requests_total = Counter(
    "http_requests_total",
    "Total count of HTTP requests by method and endpoint",
    labelnames=["method", "endpoint"]
)

# 3. 将指标注册到注册表
registry.register(http_requests_total)

# 4. 定义异步请求处理函数
async def handle_api_request(request):
    # 获取请求方法和路径
    method = request.method
    endpoint = request.path
    # 增加Counter计数,传入标签值
    http_requests_total.inc({"method": method, "endpoint": endpoint})
    return web.json_response({"status": "success", "data": "Hello, aioprometheus!"})

# 5. 定义指标暴露接口,供Prometheus拉取
async def metrics_handler(request):
    # 生成Prometheus格式的指标数据
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 6. 创建aiohttp应用并配置路由
async def create_app():
    app = web.Application()
    # 业务接口
    app.add_routes([web.get("/api/hello", handle_api_request)])
    # 指标暴露接口
    app.add_routes([web.get("/metrics", metrics_handler)])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8080)

代码说明

  • 首先创建Registry对象,用于管理所有指标,这是aioprometheus的核心管理组件;
  • 定义Counter指标时,通过labelnames参数指定标签维度,支持后续按请求方法和接口路径统计;
  • 在请求处理函数handle_api_request中,调用inc()方法增加计数,并传入当前请求的标签值;
  • 配置/metrics接口,通过render函数将注册表中的指标数据转换为Prometheus可识别的格式;
  • 运行程序后,访问http://localhost:8080/api/hello触发请求计数,再访问http://localhost:8080/metrics即可查看指标数据。

3.2 Gauge(仪表盘)

适用场景:用于统计可增可减的数值,例如当前内存使用量、活跃连接数、队列长度等。Gauge支持inc()(增加)、dec()(减少)、set()(直接设置值)、set_to_current_time()(设置为当前时间戳)等操作。

代码示例:监控异步任务队列的长度

import asyncio
import random
from aioprometheus import Gauge, Registry, render
from aiohttp import web

# 创建注册表和Gauge指标
registry = Registry()
task_queue_length = Gauge(
    "task_queue_length",
    "Current number of tasks in the async queue",
    labelnames=["queue_name"]
)
registry.register(task_queue_length)

# 模拟异步任务队列
task_queue = asyncio.Queue()

# 生产任务:向队列中添加随机数量的任务
async def task_producer():
    while True:
        # 随机生成1-5个任务
        task_count = random.randint(1, 5)
        for _ in range(task_count):
            await task_queue.put(f"task_{asyncio.get_event_loop().time()}")
        # 更新Gauge指标:设置为当前队列长度
        queue_len = task_queue.qsize()
        task_queue_length.set({"queue_name": "user_task_queue"}, queue_len)
        print(f"Added {task_count} tasks, current queue length: {queue_len}")
        await asyncio.sleep(3)

# 消费任务:从队列中取出任务并处理
async def task_consumer():
    while True:
        if not task_queue.empty():
            task = await task_queue.get()
            # 模拟任务处理耗时
            await asyncio.sleep(1)
            print(f"Processed task: {task}")
            task_queue.task_done()
            # 更新Gauge指标:处理完任务后更新队列长度
            queue_len = task_queue.qsize()
            task_queue_length.set({"queue_name": "user_task_queue"}, queue_len)
        else:
            await asyncio.sleep(1)

# 指标暴露接口
async def metrics_handler(request):
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 创建aiohttp应用
async def create_app():
    app = web.Application()
    app.add_routes([web.get("/metrics", metrics_handler)])
    # 启动生产者和消费者协程
    asyncio.create_task(task_producer())
    asyncio.create_task(task_consumer())
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8081)

代码说明

  • Gauge指标task_queue_length用于监控任务队列的实时长度,通过labelnames区分不同队列;
  • task_producer协程负责生产任务,每次添加任务后调用set()方法更新指标值为当前队列长度;
  • task_consumer协程负责消费任务,处理完任务后同样更新指标值;
  • 运行程序后,访问http://localhost:8081/metrics可查看队列长度的实时变化,该指标会随着任务的生产和消费动态增减。

3.3 Summary(摘要)

适用场景:用于统计数值的分布情况,例如请求响应时间的平均值、中位数、95分位数等。Summary通过observe()方法记录数值,自动计算并存储指定分位数的统计结果。

代码示例:统计异步接口的响应时间

import asyncio
import time
from aioprometheus import Summary, Registry, render
from aiohttp import web

# 创建注册表和Summary指标
registry = Registry()
request_duration_seconds = Summary(
    "request_duration_seconds",
    "Summary of request processing duration in seconds",
    labelnames=["method", "endpoint"],
    # 指定需要统计的分位数和误差范围
    quantiles={0.5: 0.05, 0.95: 0.01, 0.99: 0.001}
)
registry.register(request_duration_seconds)

# 定义响应时间统计中间件
async def timing_middleware(app, handler):
    async def middleware_handler(request):
        # 记录请求开始时间
        start_time = time.perf_counter()
        try:
            # 执行原始请求处理函数
            response = await handler(request)
            return response
        finally:
            # 计算请求耗时
            duration = time.perf_counter() - start_time
            # 记录耗时到Summary指标
            request_duration_seconds.observe(
                {"method": request.method, "endpoint": request.path},
                duration
            )
    return middleware_handler

# 模拟耗时的业务接口
async def slow_api_handler(request):
    # 模拟接口处理耗时:0.1-0.5秒
    delay = asyncio.sleep(random.uniform(0.1, 0.5))
    await delay
    return web.json_response({"status": "success", "message": "Slow API response"})

# 普通业务接口
async def fast_api_handler(request):
    return web.json_response({"status": "success", "message": "Fast API response"})

# 指标暴露接口
async def metrics_handler(request):
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 创建aiohttp应用
async def create_app():
    app = web.Application(middlewares=[timing_middleware])
    app.add_routes([
        web.get("/api/slow", slow_api_handler),
        web.get("/api/fast", fast_api_handler),
        web.get("/metrics", metrics_handler)
    ])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8082)

代码说明

  • 定义Summary指标时,通过quantiles参数指定需要统计的分位数,例如0.5代表中位数,0.95代表95分位数,后面的数值为允许的误差范围;
  • 实现timing_middleware中间件,在请求处理前后记录时间,计算耗时并通过observe()方法传入Summary指标;
  • 分别定义慢接口和快接口,模拟不同的响应耗时;
  • 运行程序后,多次访问/api/slow/api/fast接口,再访问/metrics即可查看响应时间的统计数据,包括平均值(_sum_count计算得出)、各分位数的数值。

3.4 Histogram(直方图)

适用场景:与Summary类似,用于统计数值的分布情况,但Histogram会将数值划分到不同的区间(bucket)中,统计每个区间的数值数量,适合用于绘制分布直方图。Histogram通过observe()方法记录数值,自动统计各区间的计数。

代码示例:统计异步任务的执行耗时分布

import asyncio
import random
import time
from aioprometheus import Histogram, Registry, render
from aiohttp import web

# 创建注册表和Histogram指标
registry = Registry()
task_execution_duration_seconds = Histogram(
    "task_execution_duration_seconds",
    "Histogram of task execution duration in seconds",
    labelnames=["task_type"],
    # 定义bucket区间:0.1, 0.2, 0.5, 1.0, +inf
    buckets=[0.1, 0.2, 0.5, 1.0]
)
registry.register(task_execution_duration_seconds)

# 模拟不同类型的异步任务
async def execute_task(task_type):
    # 根据任务类型设置不同的耗时范围
    if task_type == "light":
        duration = random.uniform(0.05, 0.15)
    elif task_type == "medium":
        duration = random.uniform(0.15, 0.3)
    else: # heavy
        duration = random.uniform(0.3, 0.8)
    # 模拟任务执行
    await asyncio.sleep(duration)
    # 记录任务耗时到Histogram指标
    task_execution_duration_seconds.observe({"task_type": task_type}, duration)
    return f"{task_type}_task_completed"

# 任务调度接口:接收任务类型参数并执行任务
async def task_scheduler_handler(request):
    task_type = request.query.get("task_type", "light")
    if task_type not in ["light", "medium", "heavy"]:
        return web.json_response({"error": "Invalid task type"}, status=400)
    result = await execute_task(task_type)
    return web.json_response({"status": "success", "result": result})

# 指标暴露接口
async def metrics_handler(request):
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 创建aiohttp应用
async def create_app():
    app = web.Application()
    app.add_routes([
        web.get("/api/task", task_scheduler_handler),
        web.get("/metrics", metrics_handler)
    ])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8083)

代码说明

  • 定义Histogram指标时,通过buckets参数指定区间边界,指标会自动统计落入每个区间的数值数量,最后一个区间默认为+inf
  • execute_task函数模拟不同类型任务的执行耗时,根据任务类型设置不同的耗时范围,并通过observe()方法记录耗时;
  • 访问/api/task?task_type=light等接口触发任务执行,多次调用后访问/metrics,可查看各任务类型的耗时分布情况,包括每个bucket的计数、总和(_sum)和总次数(_count)。

四、aioprometheus高级用法:自定义指标标签与多注册表管理

4.1 动态标签与标签值替换

在实际应用中,指标标签的值往往需要动态获取,例如根据用户ID、请求IP等信息区分指标维度。aioprometheus支持在运行时动态传入标签值,下面以用户登录次数统计为例展示动态标签的用法:

import asyncio
from aioprometheus import Counter, Registry, render
from aiohttp import web

registry = Registry()
user_login_total = Counter(
    "user_login_total",
    "Total number of user logins by user type and platform",
    labelnames=["user_type", "platform"]
)
registry.register(user_login_total)

# 模拟用户登录接口,接收用户类型和平台参数
async def login_handler(request):
    data = await request.json()
    user_type = data.get("user_type", "guest") # 可选值:guest, member, admin
    platform = data.get("platform", "web") # 可选值:web, ios, android
    # 动态传入标签值
    user_login_total.inc({"user_type": user_type, "platform": platform})
    return web.json_response({"status": "success", "message": "Login successful"})

async def metrics_handler(request):
    content, http_headers = render(registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

async def create_app():
    app = web.Application()
    app.add_routes([
        web.post("/api/login", login_handler),
        web.get("/metrics", metrics_handler)
    ])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8084)

代码说明

  • 指标user_login_total定义了user_typeplatform两个标签,用于区分不同类型用户和不同登录平台的登录次数;
  • login_handler接口中,从请求参数中动态获取user_typeplatform的值,并传入inc()方法;
  • 通过Postman等工具发送POST请求到/api/login,请求体携带{"user_type": "member", "platform": "ios"}等参数,即可按标签维度统计登录次数。

4.2 多注册表管理

在大型应用中,不同模块可能需要独立的指标管理,此时可以使用多注册表机制,将不同模块的指标注册到不同的Registry对象中,再统一暴露或分别暴露指标接口。

import asyncio
from aioprometheus import Counter, Registry, render
from aiohttp import web

# 为用户模块创建注册表
user_registry = Registry()
user_register_total = Counter(
    "user_register_total",
    "Total number of user registrations",
    labelnames=["channel"]
)
user_registry.register(user_register_total)

# 为订单模块创建注册表
order_registry = Registry()
order_create_total = Counter(
    "order_create_total",
    "Total number of order creations",
    labelnames=["order_type"]
)
order_registry.register(order_create_total)

# 用户注册接口
async def user_register_handler(request):
    data = await request.json()
    channel = data.get("channel", "direct")
    user_register_total.inc({"channel": channel})
    return web.json_response({"status": "success", "message": "User registered"})

# 订单创建接口
async def order_create_handler(request):
    data = await request.json()
    order_type = data.get("order_type", "normal")
    order_create_total.inc({"order_type": order_type})
    return web.json_response({"status": "success", "message": "Order created"})

# 暴露用户模块指标
async def user_metrics_handler(request):
    content, http_headers = render(user_registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 暴露订单模块指标
async def order_metrics_handler(request):
    content, http_headers = render(order_registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

# 暴露所有模块指标(合并注册表)
async def all_metrics_handler(request):
    # 合并多个注册表的指标
    combined_registry = Registry()
    for metric in user_registry.collect():
        combined_registry.register(metric)
    for metric in order_registry.collect():
        combined_registry.register(metric)
    content, http_headers = render(combined_registry, request.headers.get('accept'))
    return web.Response(body=content, headers=http_headers)

async def create_app():
    app = web.Application()
    app.add_routes([
        web.post("/api/user/register", user_register_handler),
        web.post("/api/order/create", order_create_handler),
        web.get("/metrics/user", user_metrics_handler),
        web.get("/metrics/order", order_metrics_handler),
        web.get("/metrics/all", all_metrics_handler)
    ])
    return app

if __name__ == "__main__":
    app = asyncio.run(create_app())
    web.run_app(app, host="0.0.0.0", port=8085)

代码说明

  • 分别为用户模块和订单模块创建独立的Registry对象,注册各自的指标;
  • 提供/metrics/user/metrics/order接口,分别暴露对应模块的指标;
  • 实现/metrics/all接口,通过遍历两个注册表的指标并注册到新的Registry对象中,实现指标合并暴露,满足Prometheus一次性拉取所有指标的需求。

五、aioprometheus与Prometheus服务器集成实战

5.1 Prometheus服务器配置

要实现对aioprometheus暴露指标的监控,需要配置Prometheus服务器定期拉取指标数据。首先下载并安装Prometheus(下载地址:https://prometheus.io/download/),然后修改prometheus.yml配置文件:

global:
  scrape_interval: 15s # 每15秒拉取一次指标

scrape_configs:
  - job_name: 'aioprometheus_demo'
    static_configs:
      - targets: ['localhost:8080', 'localhost:8081', 'localhost:8082', 'localhost:8083', 'localhost:8084', 'localhost:8085']
    metrics_path: '/metrics'

配置说明

  • scrape_interval设置为15秒,表示Prometheus每15秒拉取一次指标;
  • job_name为任务名称,可自定义;
  • targets指定需要拉取指标的服务地址列表,即我们之前运行的各个aioprometheus示例服务;
  • metrics_path指定指标接口路径,默认为/metrics

5.2 启动Prometheus并查看指标

启动Prometheus服务器:

./prometheus --config.file=prometheus.yml

访问http://localhost:9090进入Prometheus Web界面,在查询框中输入指标名(如http_requests_totaltask_queue_length等),即可查看指标的实时数据和变化趋势。

5.3 可视化监控面板(Grafana集成)

为了更直观地展示指标数据,可将Prometheus作为数据源集成到Grafana中:

  1. 安装并启动Grafana(下载地址:https://grafana.com/grafana/download);
  2. 访问http://localhost:3000,使用默认账号密码(admin/admin)登录;
  3. 进入Configuration > Data Sources,添加Prometheus数据源,设置URL为http://localhost:9090
  4. 进入Dashboards > Import,导入官方或自定义的Dashboard模板,即可实现指标的可视化监控。

六、相关资源链接

  • Pypi地址:https://pypi.org/project/aioprometheus
  • Github地址:https://github.com/claws/aioprometheus
  • 官方文档地址:https://aioprometheus.readthedocs.io/en/latest/

关注我,每天分享一个实用的Python自动化工具。

Python文件管理神器:filedepot库从入门到精通

一、filedepot库核心概述

filedepot是一款轻量级的Python文件存储抽象层库,核心用途是为开发者提供统一的文件操作接口,支持本地文件系统、AWS S3、Google Cloud Storage等多种存储后端,无需修改业务代码即可切换存储方式。其工作原理是通过定义FileStorage抽象基类,不同存储后端实现该类的方法,实现接口与实现的解耦。

该库的优点是轻量无冗余依赖、接口简洁易用、扩展性强;缺点是高级功能(如断点续传)需自行实现,部分云存储后端的支持需额外安装依赖。License类型为MIT License,允许自由使用、修改和分发,适合商业和开源项目。

二、filedepot库安装与环境配置

2.1 基础安装命令

filedepot库已发布至PyPI,可直接使用pip包管理器安装,命令如下:

pip install filedepot

该命令会自动安装filedepot的核心依赖,满足本地文件存储的基本使用需求。

2.2 云存储后端依赖安装

如果需要使用AWS S3、Google Cloud Storage等云存储后端,需要安装对应的依赖包,具体命令如下:

  • AWS S3后端依赖安装
pip install filedepot[s3]
  • Google Cloud Storage后端依赖安装
pip install filedepot[gcs]

安装完成后,可通过pip show filedepot命令验证安装是否成功,查看库的版本和安装路径。

三、filedepot核心API与基础使用

3.1 核心类与接口介绍

filedepot的核心是FileStorage抽象基类,该类定义了文件存储的通用操作接口,所有存储后端都需要实现以下核心方法:
| 方法名 | 功能描述 |
|–|-|
| save(file_obj, filename=None, **kwargs) | 保存文件对象到存储后端,返回文件唯一标识 |
| open(file_id, mode='rb') | 根据文件ID打开文件,返回文件对象 |
| delete(file_id) | 根据文件ID删除文件 |
| exists(file_id) | 判断指定ID的文件是否存在 |

在实际使用中,我们不需要直接实例化FileStorage类,而是使用其具体的实现类,例如本地存储的LocalFileStorage

3.2 本地文件存储基础操作

3.2.1 初始化本地存储

首先导入LocalFileStorage类,指定存储目录并初始化存储对象。代码示例如下:

from depot.io.local import LocalFileStorage

# 初始化本地文件存储,指定存储目录为./my_files
storage = LocalFileStorage('./my_files')

执行上述代码后,filedepot会自动创建./my_files目录(如果不存在),后续所有文件都会保存在该目录下。

3.2.2 保存文件到本地存储

保存文件的方式有两种:一种是保存已有的文件对象,另一种是保存字节数据。下面分别展示两种方式的代码示例。

方式一:保存文件对象

# 打开本地的test.txt文件,以二进制模式读取
with open('test.txt', 'rb') as f:
    # 保存文件到存储后端,filename指定保存的文件名
    file_id = storage.save(f, filename='test_save.txt')

# 打印文件ID,该ID是文件的唯一标识
print(f"文件保存成功,文件ID: {file_id}")

代码说明:storage.save()方法接收文件对象作为参数,filename参数可选,如果不指定,filedepot会自动生成一个随机的文件名。执行后会在./my_files目录下生成test_save.txt文件,并返回一个唯一的文件ID。

方式二:保存字节数据

from io import BytesIO

# 创建字节流对象,写入测试数据
data = b"Hello, filedepot! This is a test data."
file_obj = BytesIO(data)

# 保存字节流数据到存储后端
file_id = storage.save(file_obj, filename='byte_data.txt')
print(f"字节数据保存成功,文件ID: {file_id}")

代码说明:通过BytesIO将字节数据转换为文件对象,再传递给save()方法,实现无需创建本地临时文件即可保存数据的需求。

3.2.3 读取存储的文件

通过文件ID可以读取已保存的文件内容,代码示例如下:

# 根据文件ID打开文件
with storage.open(file_id) as f:
    content = f.read()
    print(f"文件内容: {content.decode('utf-8')}")

代码说明:storage.open()方法返回一个文件对象,使用read()方法可以读取文件内容,由于读取的是二进制数据,需要使用decode('utf-8')转换为字符串。

3.2.4 判断文件是否存在

使用exists()方法可以判断指定ID的文件是否存在,代码示例如下:

# 检查文件是否存在
if storage.exists(file_id):
    print(f"文件ID {file_id} 对应的文件存在")
else:
    print(f"文件ID {file_id} 对应的文件不存在")

3.2.5 删除存储的文件

使用delete()方法可以删除指定ID的文件,代码示例如下:

# 删除文件
storage.delete(file_id)
print(f"文件ID {file_id} 对应的文件已删除")

# 再次检查文件是否存在
if not storage.exists(file_id):
    print(f"文件删除成功")

3.3 AWS S3云存储操作

3.3.1 初始化S3存储

要使用AWS S3后端,需要先配置AWS的访问密钥和存储桶信息。代码示例如下:

from depot.io.awss3 import S3Storage

# 初始化S3存储
s3_storage = S3Storage(
    access_key_id='YOUR_AWS_ACCESS_KEY',
    secret_access_key='YOUR_AWS_SECRET_KEY',
    bucket='your-bucket-name',
    region_name='us-east-1'
)

代码说明:access_key_idsecret_access_key是AWS账号的访问密钥,bucket是S3存储桶名称,region_name是存储桶所在的区域。需要注意的是,必须确保AWS账号拥有该存储桶的读写权限。

3.3.2 S3存储的文件操作

S3存储的文件操作接口与本地存储完全一致,无需修改业务逻辑即可实现存储后端的切换。代码示例如下:

# 保存文件到S3
with open('test_s3.txt', 'rb') as f:
    s3_file_id = s3_storage.save(f, filename='s3_test_file.txt')

print(f"S3文件保存成功,文件ID: {s3_file_id}")

# 读取S3中的文件
with s3_storage.open(s3_file_id) as f:
    s3_content = f.read()
    print(f"S3文件内容: {s3_content.decode('utf-8')}")

# 删除S3中的文件
s3_storage.delete(s3_file_id)
print(f"S3文件ID {s3_file_id} 对应的文件已删除")

代码说明:无论是本地存储还是S3存储,都使用相同的save()open()delete()方法,体现了filedepot接口统一的优势。

四、filedepot高级功能与实战案例

4.1 文件元数据管理

filedepot在保存文件时,会自动记录文件的元数据信息,例如文件名、文件大小、上传时间等。可以通过get_file_metadata()方法获取元数据,代码示例如下:

# 保存文件并获取元数据
with open('test_metadata.txt', 'rb') as f:
    meta_file_id = storage.save(f, filename='test_metadata.txt')

# 获取文件元数据
metadata = storage.get_file_metadata(meta_file_id)
print("文件元数据信息:")
for key, value in metadata.items():
    print(f"{key}: {value}")

执行上述代码后,会输出类似以下的元数据信息:

文件元数据信息:
filename: test_metadata.txt
content_type: text/plain
content_length: 24
last_modified: 2026-01-08T12:00:00Z

代码说明:元数据中包含了文件的名称、类型、大小和最后修改时间等信息,方便开发者进行文件管理和统计。

4.2 文件存储异常处理

在实际开发中,文件操作可能会出现各种异常,例如权限不足、存储目录不存在、网络故障等。filedepot提供了统一的异常类,方便开发者进行异常捕获和处理。代码示例如下:

from depot.exceptions import FileStorageError, FileNotFoundError

try:
    # 尝试打开不存在的文件ID
    with storage.open('invalid_file_id') as f:
        content = f.read()
except FileNotFoundError as e:
    print(f"错误: 指定的文件不存在 - {e}")
except FileStorageError as e:
    print(f"文件存储错误: {e}")
except Exception as e:
    print(f"其他错误: {e}")

代码说明:FileNotFoundError用于捕获文件不存在的异常,FileStorageError是filedepot的基础异常类,用于捕获所有存储相关的异常。合理的异常处理可以提高程序的健壮性。

4.3 实战案例:多后端文件存储切换工具

本案例将实现一个简单的文件存储工具,支持在本地存储和S3存储之间无缝切换,无需修改核心业务代码。

4.3.1 工具类实现

from depot.io.local import LocalFileStorage
from depot.io.awss3 import S3Storage
from depot.exceptions import FileStorageError

class MultiBackendFileManager:
    def __init__(self, backend_type='local', **kwargs):
        """
        初始化多后端文件管理器
        :param backend_type: 存储后端类型,可选 'local' 或 's3'
        :param kwargs: 存储后端的配置参数
        """
        self.backend_type = backend_type
        self.storage = self._init_storage(** kwargs)

    def _init_storage(self, **kwargs):
        """初始化存储后端"""
        if self.backend_type == 'local':
            # 本地存储需要传入存储目录参数
            return LocalFileStorage(kwargs.get('storage_dir', './default_files'))
        elif self.backend_type == 's3':
            # S3存储需要传入AWS配置参数
            return S3Storage(
                access_key_id=kwargs.get('access_key_id'),
                secret_access_key=kwargs.get('secret_access_key'),
                bucket=kwargs.get('bucket'),
                region_name=kwargs.get('region_name', 'us-east-1')
            )
        else:
            raise ValueError(f"不支持的存储后端类型: {self.backend_type}")

    def save_file(self, file_obj, filename=None):
        """保存文件"""
        try:
            file_id = self.storage.save(file_obj, filename=filename)
            return file_id
        except FileStorageError as e:
            print(f"保存文件失败: {e}")
            return None

    def read_file(self, file_id):
        """读取文件"""
        try:
            with self.storage.open(file_id) as f:
                return f.read()
        except FileNotFoundError:
            print(f"文件ID {file_id} 不存在")
            return None
        except FileStorageError as e:
            print(f"读取文件失败: {e}")
            return None

    def delete_file(self, file_id):
        """删除文件"""
        try:
            self.storage.delete(file_id)
            return True
        except FileNotFoundError:
            print(f"文件ID {file_id} 不存在")
            return False
        except FileStorageError as e:
            print(f"删除文件失败: {e}")
            return False

4.3.2 工具类使用示例

示例1:使用本地存储后端

# 初始化本地存储管理器
local_manager = MultiBackendFileManager(
    backend_type='local',
    storage_dir='./my_local_files'
)

# 保存文件
with open('example.txt', 'w', encoding='utf-8') as f:
    f.write("这是一个多后端文件存储工具的测试文件")

with open('example.txt', 'rb') as f:
    file_id = local_manager.save_file(f, filename='local_example.txt')

if file_id:
    print(f"本地存储文件ID: {file_id}")
    # 读取文件
    content = local_manager.read_file(file_id)
    if content:
        print(f"本地文件内容: {content.decode('utf-8')}")
    # 删除文件
    if local_manager.delete_file(file_id):
        print("本地文件删除成功")

示例2:使用S3存储后端

# 初始化S3存储管理器
s3_manager = MultiBackendFileManager(
    backend_type='s3',
    access_key_id='YOUR_AWS_ACCESS_KEY',
    secret_access_key='YOUR_AWS_SECRET_KEY',
    bucket='your-bucket-name',
    region_name='us-east-1'
)

# 保存字节数据到S3
from io import BytesIO
data = b"Hello, S3 Storage! From MultiBackendFileManager."
file_obj = BytesIO(data)
s3_file_id = s3_manager.save_file(file_obj, filename='s3_example.txt')

if s3_file_id:
    print(f"S3存储文件ID: {s3_file_id}")
    # 读取文件
    s3_content = s3_manager.read_file(s3_file_id)
    if s3_content:
        print(f"S3文件内容: {s3_content.decode('utf-8')}")
    # 删除文件
    if s3_manager.delete_file(s3_file_id):
        print("S3文件删除成功")

代码说明:该工具类通过封装filedepot的不同存储后端,实现了统一的文件操作接口。开发者只需修改backend_type和对应的配置参数,即可在本地存储和S3存储之间切换,无需修改保存、读取、删除等核心业务逻辑。

4.4 自定义存储后端扩展

filedepot的扩展性很强,如果需要支持其他存储后端(如FTP、阿里云OSS等),可以通过继承FileStorage抽象基类并实现其方法来完成。下面以自定义一个简单的内存存储后端为例,展示扩展方法。

4.4.1 自定义内存存储后端实现

from depot.io.interfaces import FileStorage, FileStorageError
from depot.io.utils import FileIntent
from io import BytesIO
import uuid

class MemoryFileStorage(FileStorage):
    def __init__(self):
        # 使用字典存储文件,key为文件ID,value为文件内容和元数据
        self._files = {}

    def save(self, file_obj, filename=None, content_type=None, **kwargs):
        # 生成唯一的文件ID
        file_id = str(uuid.uuid4())
        # 读取文件内容
        file_content = file_obj.read()
        # 处理文件名
        if filename is None:
            filename = f"memory_file_{file_id}"
        # 处理内容类型
        if content_type is None:
            content_type = "application/octet-stream"
        # 存储文件内容和元数据
        self._files[file_id] = {
            'content': file_content,
            'filename': filename,
            'content_type': content_type,
            'content_length': len(file_content)
        }
        return file_id

    def open(self, file_id, mode='rb'):
        if mode != 'rb':
            raise FileStorageError("内存存储仅支持二进制读取模式")
        if file_id not in self._files:
            raise FileNotFoundError(f"文件ID {file_id} 不存在")
        # 返回字节流对象
        return BytesIO(self._files[file_id]['content'])

    def delete(self, file_id):
        if file_id not in self._files:
            raise FileNotFoundError(f"文件ID {file_id} 不存在")
        del self._files[file_id]

    def exists(self, file_id):
        return file_id in self._files

    def get_file_metadata(self, file_id):
        if file_id not in self._files:
            raise FileNotFoundError(f"文件ID {file_id} 不存在")
        return self._files[file_id].copy()

4.4.2 自定义内存存储后端使用示例

# 初始化内存存储
memory_storage = MemoryFileStorage()

# 保存文件
data = b"这是自定义内存存储的测试数据"
file_obj = BytesIO(data)
file_id = memory_storage.save(file_obj, filename='memory_test.txt')

print(f"内存存储文件ID: {file_id}")

# 读取文件
with memory_storage.open(file_id) as f:
    content = f.read()
    print(f"内存文件内容: {content.decode('utf-8')}")

# 获取元数据
metadata = memory_storage.get_file_metadata(file_id)
print(f"文件元数据: {metadata}")

# 删除文件
memory_storage.delete(file_id)
if not memory_storage.exists(file_id):
    print("内存文件删除成功")

代码说明:通过继承FileStorage抽象基类,实现save()open()delete()exists()等核心方法,即可自定义存储后端。该内存存储后端将文件内容保存在内存字典中,适用于临时文件存储的场景。

五、filedepot库常见问题与解决方案

5.1 问题1:保存大文件时内存占用过高

问题描述:当使用save()方法保存大文件时,会一次性读取文件内容到内存,导致内存占用过高。
解决方案:使用文件流分块读取的方式,避免一次性加载整个文件到内存。代码示例如下:

def save_large_file(storage, file_path, chunk_size=1024*1024):
    """
    分块保存大文件
    :param storage: 文件存储对象
    :param file_path: 大文件路径
    :param chunk_size: 分块大小,默认1MB
    :return: 文件ID
    """
    with open(file_path, 'rb') as f:
        # 创建临时文件对象
        temp_file = BytesIO()
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            temp_file.write(chunk)
        # 将文件指针移到开头
        temp_file.seek(0)
        return storage.save(temp_file, filename=file_path.split('/')[-1])

# 使用示例
large_file_id = save_large_file(storage, 'large_file.zip')
print(f"大文件保存成功,文件ID: {large_file_id}")

5.2 问题2:云存储后端连接超时

问题描述:使用S3等云存储后端时,经常出现连接超时的情况。
解决方案:增加连接超时参数,设置重试机制。以S3存储为例,代码示例如下:

from botocore.config import Config

# 配置S3连接参数,设置超时和重试次数
config = Config(
    connect_timeout=30,
    read_timeout=30,
    retries={'max_attempts': 3}
)

# 初始化S3存储,传入配置参数
s3_storage = S3Storage(
    access_key_id='YOUR_AWS_ACCESS_KEY',
    secret_access_key='YOUR_AWS_SECRET_KEY',
    bucket='your-bucket-name',
    region_name='us-east-1',
    config=config
)

代码说明:通过botocore.config.Config设置连接超时、读取超时和重试次数,提高云存储连接的稳定性。

5.3 问题3:文件ID管理困难

问题描述:filedepot自动生成的文件ID是随机字符串,不利于业务系统中的文件管理。
解决方案:自定义文件ID生成规则,例如使用业务标识+时间戳的方式。代码示例如下:

import time

def custom_file_id_generator(prefix='file_'):
    """自定义文件ID生成器"""
    timestamp = int(time.time() * 1000)
    return f"{prefix}{timestamp}"

# 使用自定义文件ID保存文件
with open('test.txt', 'rb') as f:
    # 生成自定义文件ID
    custom_file_id = custom_file_id_generator(prefix='myapp_')
    # 使用_file_id参数指定自定义文件ID
    storage.save(f, filename='test.txt', _file_id=custom_file_id)

print(f"自定义文件ID: {custom_file_id}")

代码说明:通过save()方法的_file_id参数,可以指定自定义的文件ID,方便业务系统根据文件ID关联业务数据。

六、filedepot库相关资源

  • PyPI地址:https://pypi.org/project/filedepot
  • Github地址:https://github.com/xxxxx/xxxxxx
  • 官方文档地址:https://www.xxxxx.com/xxxxxx

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:ODMantic入门到精通——异步MongoDB数据建模与操作指南

一、ODMantic库核心概述

ODMantic是一款专为Python异步生态设计的对象文档映射器(ODM),主要用于简化MongoDB数据库的操作流程。其核心工作原理是将Python类与MongoDB集合进行映射,类的实例对应集合中的文档,开发者无需编写原生MongoDB查询语句,通过操作Python对象即可完成数据的增删改查。

该库的优点十分突出:完全支持异步操作,可无缝对接asyncioFastAPI等异步框架;API设计简洁直观,与Python开发者熟悉的SQLAlchemy等ORM工具风格相近,学习成本低;内置数据校验功能,基于pydantic实现字段类型和约束的校验,保障数据一致性。缺点则是生态相较于老牌ODM工具mongoengine更小众,部分高级查询功能的支持度有待提升;仅适用于MongoDB,通用性较弱。

ODMantic的开源协议为MIT License,这意味着开发者可以自由地用于商业和非商业项目,修改和分发源码也不受过多限制。

二、ODMantic环境安装与配置

2.1 安装前提条件

在安装ODMantic之前,需要确保本地环境满足以下要求:

  1. Python版本≥3.7(推荐3.9及以上版本,兼容性更好);
  2. 已安装并运行MongoDB数据库(本地或远程实例均可,推荐版本≥4.0);
  3. 网络环境正常,能够通过pip下载相关依赖包。

2.2 安装命令

ODMantic的安装非常简单,直接使用pip工具在命令行中执行以下命令即可:

pip install odmantic

该命令会自动安装ODMantic及其核心依赖,包括pydantic(数据校验)、motor(异步MongoDB驱动)等,无需手动单独安装。

2.3 基础配置:连接MongoDB

使用ODMantic的第一步是建立与MongoDB数据库的异步连接。我们需要借助odmantic提供的AIOEngine类,该类是与数据库交互的核心入口,负责管理连接和执行操作。

以下是基础的连接代码示例:

from odmantic import AIOEngine
import asyncio
from motor.motor_asyncio import AsyncIOMotorClient

# 定义异步连接函数
async def connect_to_mongodb():
    # 1. 创建MongoDB异步客户端
    # 本地数据库地址:mongodb://localhost:27017/
    # 若为远程数据库,替换为对应的连接字符串,例如包含用户名和密码的地址:
    # mongodb://username:password@remote_host:27017/
    client = AsyncIOMotorClient("mongodb://localhost:27017/")
    # 2. 初始化AIOEngine,指定要操作的数据库名称
    # 这里指定数据库名为"odmantic_demo",若不存在会自动创建
    engine = AIOEngine(client=client, database="odmantic_demo")
    print("成功连接到MongoDB数据库!")
    return engine

# 运行异步函数
if __name__ == "__main__":
    engine = asyncio.run(connect_to_mongodb())

代码说明

  • AsyncIOMotorClientmotor库提供的异步MongoDB客户端,负责与数据库建立TCP连接;
  • AIOEngine是ODMantic的核心引擎,接收客户端实例和数据库名称作为参数,后续所有的数据操作都需要通过该引擎对象完成;
  • asyncio.run()用于运行异步函数,在实际的异步项目(如FastAPI)中,可直接通过await调用connect_to_mongodb()

三、ODMantic核心功能与代码示例

3.1 定义数据模型(Model)

ODMantic的数据模型基于Python类实现,继承自odmantic.Model,类中的字段对应MongoDB文档的键。字段类型通过pydantic的类型注解指定,同时支持设置默认值、必填约束等属性。

3.1.1 基础模型定义

我们以一个“用户(User)”模型为例,演示如何定义基础的数据模型:

from odmantic import Model
from typing import Optional
from datetime import datetime

class User(Model):
    # 字段1:用户名,字符串类型,必填
    username: str
    # 字段2:年龄,整数类型,必填
    age: int
    # 字段3:邮箱,可选字符串类型,默认值为None
    email: Optional[str] = None
    # 字段4:注册时间,datetime类型,默认值为当前时间
    register_time: datetime = datetime.now()

    # 可选配置:指定对应的MongoDB集合名称
    # 若不指定,默认集合名为类名的小写复数形式(此处为"users")
    class Config:
        collection = "user_collection"

代码说明

  • 模型类必须继承odmantic.Model,这是ODMantic识别模型的标志;
  • 字段的类型注解支持Python原生类型(strintdatetime等)和typing模块中的类型(Optional表示可选字段);
  • 通过Config类可以自定义模型的配置,例如collection参数指定模型对应的MongoDB集合名称,若不指定,ODMantic会自动将类名转为小写复数形式作为集合名;
  • 字段可以设置默认值,如register_time默认值为当前时间,email默认值为None

3.1.2 模型字段的常用约束

基于pydantic的特性,ODMantic支持为字段添加各种约束条件,例如字符串长度、数值范围等,确保存入数据库的数据符合预期。以下是带约束的模型示例:

from odmantic import Model, Field
from typing import Optional
from datetime import datetime

class UserWithConstraints(Model):
    # 用户名:字符串类型,长度在3-20之间,必填
    username: str = Field(min_length=3, max_length=20)
    # 年龄:整数类型,范围在0-120之间,必填
    age: int = Field(ge=0, le=120)
    # 邮箱:可选字符串类型,必须符合邮箱格式
    email: Optional[str] = Field(None, regex=r"^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$")
    # 注册时间:默认当前时间
    register_time: datetime = datetime.now()

    class Config:
        collection = "constraint_users"

代码说明

  • Field类用于为字段添加额外约束,参数min_lengthmax_length限制字符串长度,ge(大于等于)和le(小于等于)限制数值范围;
  • regex参数用于指定字符串的正则表达式验证规则,这里用于校验邮箱格式的合法性;
  • 当创建模型实例时,如果字段值不符合约束条件,会直接抛出ValidationError异常,避免非法数据存入数据库。

3.2 数据的增删改查(CRUD)操作

CRUD是数据库操作的核心,ODMantic通过简洁的API实现异步的增删改查功能,所有操作都需要通过之前初始化的AIOEngine对象完成。

在进行后续操作前,我们先统一初始化引擎,并定义一个通用的异步运行函数,方便执行异步代码:

from odmantic import AIOEngine, Model, Field
from typing import Optional, List
from datetime import datetime
import asyncio
from motor.motor_asyncio import AsyncIOMotorClient

# 初始化引擎
async def get_engine():
    client = AsyncIOMotorClient("mongodb://localhost:27017/")
    return AIOEngine(client=client, database="odmantic_demo")

# 定义User模型
class User(Model):
    username: str = Field(min_length=3, max_length=20)
    age: int = Field(ge=0, le=120)
    email: Optional[str] = None
    register_time: datetime = datetime.now()

    class Config:
        collection = "users"

3.2.1 新增数据(Create)

新增数据即向MongoDB集合中插入文档,使用engine.save()方法,该方法接收一个模型实例作为参数,异步将数据插入数据库,并返回插入后的实例(包含自动生成的id字段)。

代码示例

# 新增单个用户
async def create_single_user(engine: AIOEngine):
    # 创建User模型实例
    user = User(username="zhangsan", age=25, email="[email protected]")
    # 保存到数据库
    saved_user = await engine.save(user)
    print("新增用户成功:")
    print(f"用户ID: {saved_user.id}")
    print(f"用户名: {saved_user.username}")
    print(f"年龄: {saved_user.age}")
    print(f"邮箱: {saved_user.email}")
    return saved_user

# 新增多个用户
async def create_multiple_users(engine: AIOEngine):
    # 创建多个用户实例
    user1 = User(username="lisi", age=30)
    user2 = User(username="wangwu", age=22, email="[email protected]")
    # 批量保存
    saved_users = await engine.save_all([user1, user2])
    print("\n批量新增用户成功,共新增{}个用户:".format(len(saved_users)))
    for u in saved_users:
        print(f"ID: {u.id}, 用户名: {u.username}")

# 执行新增操作
async def run_create_demo():
    engine = await get_engine()
    await create_single_user(engine)
    await create_multiple_users(engine)

if __name__ == "__main__":
    asyncio.run(run_create_demo())

代码说明

  • engine.save()用于插入单个文档,插入后会自动为实例添加id属性(对应MongoDB文档的_id字段,类型为ObjectId);
  • engine.save_all()用于批量插入多个文档,接收一个模型实例列表作为参数,返回插入后的实例列表;
  • 插入操作是异步的,必须使用await关键字调用。

3.2.2 查询数据(Read)

ODMantic提供了多种查询方式,包括查询单个文档、查询多个文档、条件查询、排序、分页等,满足不同的查询需求。

(1)查询单个文档

使用engine.find_one()方法,可根据条件查询单个文档,若未找到则返回None

代码示例

# 根据用户名查询单个用户
async def find_user_by_username(engine: AIOEngine, username: str):
    # 构造查询条件:User.username == username
    user = await engine.find_one(User, User.username == username)
    if user:
        print(f"\n查询到用户:")
        print(f"ID: {user.id}, 用户名: {user.username}, 年龄: {user.age}, 邮箱: {user.email}")
    else:
        print(f"\n未找到用户名为{username}的用户")
    return user

# 执行查询操作
async def run_find_one_demo():
    engine = await get_engine()
    await find_user_by_username(engine, "zhangsan")
    await find_user_by_username(engine, "zhaoliu")

if __name__ == "__main__":
    asyncio.run(run_find_one_demo())

代码说明

  • engine.find_one()的第一个参数是模型类,第二个参数是查询条件,条件的写法为模型类.字段名 == 目标值
  • 若存在多个符合条件的文档,该方法只会返回第一个匹配的文档。
(2)查询多个文档

使用engine.find()方法,可查询符合条件的所有文档,返回一个异步迭代器,可通过async for遍历,或通过list()转换为列表。

代码示例

# 查询所有用户
async def find_all_users(engine: AIOEngine):
    users = await engine.find(User)
    print("\n所有用户列表:")
    for user in users:
        print(f"ID: {user.id}, 用户名: {user.username}, 年龄: {user.age}")

# 条件查询:年龄大于23的用户
async def find_users_by_age(engine: AIOEngine, min_age: int):
    users = await engine.find(User, User.age > min_age)
    print(f"\n年龄大于{min_age}的用户列表:")
    for user in users:
        print(f"ID: {user.id}, 用户名: {user.username}, 年龄: {user.age}")

# 执行多文档查询
async def run_find_demo():
    engine = await get_engine()
    await find_all_users(engine)
    await find_users_by_age(engine, 23)

if __name__ == "__main__":
    asyncio.run(run_find_demo())

代码说明

  • engine.find()的参数与find_one()一致,第一个参数为模型类,第二个参数为可选的查询条件;
  • 支持的查询运算符包括:==(等于)、!=(不等于)、>(大于)、<(小于)、>=(大于等于)、<=(小于等于)、in_(包含在列表中)等,例如User.username.in_(["zhangsan", "lisi"])
(3)排序与分页

在查询时,可通过sort()方法对结果进行排序,通过skip()limit()方法实现分页功能。

代码示例

# 排序查询:按年龄降序排列
async def find_users_sorted(engine: AIOEngine):
    users = await engine.find(User).sort(User.age, -1)
    print("\n按年龄降序排列的用户列表:")
    for user in users:
        print(f"用户名: {user.username}, 年龄: {user.age}")

# 分页查询:每页2条,查询第2页
async def find_users_paginated(engine: AIOEngine, page: int, page_size: int):
    skip_count = (page - 1) * page_size
    users = await engine.find(User).skip(skip_count).limit(page_size)
    print(f"\n第{page}页用户列表(每页{page_size}条):")
    for user in users:
        print(f"用户名: {user.username}, 年龄: {user.age}")

# 执行排序和分页查询
async def run_sort_paginate_demo():
    engine = await get_engine()
    await find_users_sorted(engine)
    await find_users_paginated(engine, page=2, page_size=2)

if __name__ == "__main__":
    asyncio.run(run_sort_paginate_demo())

代码说明

  • sort()方法的第一个参数是排序字段,第二个参数为1(升序)或-1(降序);
  • skip(n)表示跳过前n条数据,limit(m)表示最多返回m条数据,两者结合即可实现分页。

3.2.3 更新数据(Update)

更新数据有两种方式:一种是先查询出模型实例,修改实例的字段值后调用engine.save()方法;另一种是使用engine.update()方法直接执行更新操作。

(1)基于实例的更新

代码示例

# 更新用户信息:修改邮箱和年龄
async def update_user_by_instance(engine: AIOEngine, username: str):
    # 1. 查询用户
    user = await engine.find_one(User, User.username == username)
    if not user:
        print(f"未找到用户{username},更新失败")
        return
    # 2. 修改字段值
    user.age = 26
    user.email = "[email protected]"
    # 3. 保存更新
    updated_user = await engine.save(user)
    print(f"\n用户{username}更新成功:")
    print(f"年龄: {updated_user.age}, 邮箱: {updated_user.email}")

# 执行更新操作
async def run_update_instance_demo():
    engine = await get_engine()
    await update_user_by_instance(engine, "zhangsan")

if __name__ == "__main__":
    asyncio.run(run_update_instance_demo())

代码说明

  • 基于实例的更新步骤为“查询-修改-保存”,适用于需要先获取当前数据再进行修改的场景;
  • engine.save()方法会自动识别实例是否已存在(通过id字段),若存在则执行更新操作,若不存在则执行插入操作。
(2)基于查询条件的批量更新

代码示例

from odmantic import UpdateQuery

# 批量更新:将年龄小于25的用户年龄加1
async def batch_update_users(engine: AIOEngine):
    # 1. 构造更新查询
    update_query = UpdateQuery({User.age: User.age + 1})
    # 2. 执行批量更新
    update_result = await engine.update(
        User,
        User.age < 25,
        update_query
    )
    print(f"\n批量更新成功,共更新{update_result.modified_count}条记录")

# 执行批量更新
async def run_batch_update_demo():
    engine = await get_engine()
    await batch_update_users(engine)

if __name__ == "__main__":
    asyncio.run(run_batch_update_demo())

代码说明

  • 批量更新需要使用UpdateQuery类构造更新内容,支持字段的自增、自减等操作;
  • engine.update()的参数依次为:模型类、查询条件、更新查询对象,返回的结果对象包含modified_count属性,表示实际更新的记录数。

3.2.4 删除数据(Delete)

删除数据同样有两种方式:删除单个实例和批量删除符合条件的文档。

(1)删除单个实例

代码示例

# 删除指定用户
async def delete_user_by_instance(engine: AIOEngine, username: str):
    # 1. 查询用户
    user = await engine.find_one(User, User.username == username)
    if not user:
        print(f"未找到用户{username},删除失败")
        return
    # 2. 删除用户
    await engine.delete(user)
    print(f"\n用户{username}删除成功")

# 执行删除操作
async def run_delete_instance_demo():
    engine = await get_engine()
    await delete_user_by_instance(engine, "lisi")

if __name__ == "__main__":
    asyncio.run(run_delete_instance_demo())

代码说明

  • engine.delete()方法接收一个模型实例作为参数,根据实例的id字段删除对应的文档。
(2)批量删除文档

代码示例

# 批量删除:删除邮箱为None的用户
async def batch_delete_users(engine: AIOEngine):
    delete_result = await engine.delete(User, User.email == None)
    print(f"\n批量删除成功,共删除{delete_result.deleted_count}条记录")

# 执行批量删除
async def run_batch_delete_demo():
    engine = await get_engine()
    await batch_delete_users(engine)

if __name__ == "__main__":
    asyncio.run(run_batch_delete_demo())

代码说明

  • engine.delete()方法若传入模型类和查询条件,则会批量删除符合条件的所有文档;
  • 返回的结果对象包含deleted_count属性,表示实际删除的记录数。

3.3 模型关联(一对一、一对多)

在实际应用中,数据之间往往存在关联关系,例如“用户”和“文章”的一对多关系(一个用户可以发布多篇文章)。ODMantic支持通过Reference字段实现模型之间的关联。

3.3.1 定义关联模型

我们以“用户(User)”和“文章(Article)”为例,演示一对多关联的实现:

from odmantic import Model, Field, Reference
from typing import Optional, List
from datetime import datetime

# 定义User模型
class User(Model):
    username: str = Field(min_length=3, max_length=20)
    age: int = Field(ge=0, le=120)

    class Config:
        collection = "users"

# 定义Article模型,与User模型关联
class Article(Model):
    title: str = Field(min_length=1, max_length=100)
    content: str
    publish_time: datetime = datetime.now()
    # 关联到User模型,表示文章的作者
    author: Reference[User]

    class Config:
        collection = "articles"

代码说明

  • Reference[User]表示author字段是一个指向User模型的引用,存储的是User实例的id
  • 这种定义方式实现了从ArticleUser的单向关联,若需要双向关联,可在User模型中添加articles: List[Article] = []字段,并结合反向查询实现。

3.3.2 创建关联数据

代码示例

# 创建用户并发布文章
async def create_user_and_articles(engine: AIOEngine):
    # 1. 创建用户
    user = User(username="zhangsan", age=25)
    saved_user = await engine.save(user)
    print(f"创建用户成功:{saved_user.username}")

    # 2. 创建两篇文章,关联到该用户
    article1 = Article(title="Python异步编程入门", content="异步编程是Python的重要特性...", author=saved_user)
    article2 = Article(title="ODMantic使用指南", content="ODMantic是一款优秀的异步ODM工具...", author=saved_user)
    saved_articles = await engine.save_all([article1, article2])
    print(f"创建文章成功,共发布{len(saved_articles)}篇文章")

# 执行关联数据创建
async def run_relation_create_demo():
    engine = await get_engine()
    await create_user_and_articles(engine)

if __name__ == "__main__":
    asyncio.run(run_relation_create_demo())

代码说明

  • 创建关联数据时,直接将User实例赋值给Articleauthor字段即可,ODMantic会自动处理引用关系,存储Userid

3.3.3 查询关联数据

查询关联数据有两种方式:从子模型查询父模型(通过author字段查询用户信息),以及从父模型查询子模型(通过用户查询其发布的所有文章)。

代码示例

# 从文章查询作者信息
async def find_article_author(engine: AIOEngine, article_title: str):
    article = await engine.find_one(Article, Article.title == article_title)
    if not article:
        print(f"未找到标题为{article_title}的文章")
        return
    # 直接访问article.author即可获取关联的用户实例
    author = article.author
    print(f"\n文章《{article.title}》的作者信息:")
    print(f"用户名: {author.username}, 年龄: {author.age}")

# 从用户查询发布的所有文章
async def find_user_articles(engine: AIOEngine, username: str):
    user = await engine.find_one(User, User.username == username)
    if not user:
        print(f"未找到用户{username}")
        return
    # 查询该用户发布的所有文章
    articles = await engine.find(Article, Article.author == user)
    print(f"\n用户{username}发布的文章列表:")
    for article in articles:
        print(f"标题: {article.title}, 发布时间: {article.publish_time}")

# 执行关联数据查询
async def run_relation_find_demo():
    engine = await get_engine()
    await find_article_author(engine, "Python异步编程入门")
    await find_user_articles(engine, "zhangsan")

if __name__ == "__main__":
    asyncio.run(run_relation_find_demo())

代码说明

  • 从子模型查询父模型时,直接访问关联字段即可(如article.author),ODMantic会自动根据存储的id查询对应的父模型实例;
  • 从父模型查询子模型时,构造查询条件为Article.author == user,即可获取该用户发布的所有文章。

四、ODMantic与FastAPI框架集成实战

FastAPI是一款高性能的异步Web框架,与ODMantic的异步特性完美契合。本节将演示如何在FastAPI项目中集成ODMantic,实现一个简单的用户管理API。

4.1 项目目录结构

odmantic_fastapi_demo/
├── main.py               # 项目入口文件,包含API路由
└── models.py             # 数据模型定义

4.2 定义数据模型(models.py)

from odmantic import Model, Field
from datetime import datetime
from typing import Optional

class User(Model):
    username: str = Field(min_length=3, max_length=20)
    age: int = Field(ge=0, le=120)
    email: Optional[str] = None
    register_time: datetime = datetime.now()

    class Config:
        collection = "users"

4.3 实现FastAPI API(main.py)

from fastapi import FastAPI, HTTPException
from odmantic import AIOEngine, ObjectId
from motor.motor_asyncio import AsyncIOMotorClient
from models import User
from typing import List
import asyncio

# 初始化FastAPI应用
app = FastAPI(title="ODMantic + FastAPI 示例", version="1.0")

# 全局引擎对象
engine: AIOEngine = None

# 启动时初始化数据库连接
@app.on_event("startup")
async def startup_db():
    global engine
    client = AsyncIOMotorClient("mongodb://localhost:27017/")
    engine = AIOEngine(client=client, database="odmantic_fastapi_demo")
    print("数据库连接成功!")

# 关闭时断开数据库连接
@app.on_event("shutdown")
async def shutdown_db():
    global engine
    if engine:
        engine.client.close()
        print("数据库连接已关闭!")

# API路由:创建用户
@app.post("/users/", response_model=User, summary="创建新用户")
async def create_user(user: User):
    saved_user = await engine.save(user)
    return saved_user

# API路由:获取单个用户
@app.get("/users/{user_id}", response_model=User, summary="根据ID获取用户")
async def get_user(user_id: str):
    try:
        # 将字符串ID转换为ObjectId
        obj_id = ObjectId(user_id)
    except:
        raise HTTPException(status_code=400, detail="无效的用户ID格式")
    user = await engine.find_one(User, User.id == obj_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    return user

# API路由:获取所有用户
@app.get("/users/", response_model=List[User], summary="获取所有用户")
async def get_all_users():
    users = await engine.find(User)
    return users

# API路由:更新用户
@app.put("/users/{user_id}", response_model=User, summary="更新用户信息")
async def update_user(user_id: str, user_update: User):
    try:
        obj_id = ObjectId(user_id)
    except:
        raise HTTPException(status_code=400, detail="无效的用户ID格式")
    user = await engine.find_one(User, User.id == obj_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    # 更新字段
    user.username = user_update.username
    user.age = user_update.age
    user.email = user_update.email
    updated_user = await engine.save(user)
    return updated_user

# API路由:删除用户
@app.delete("/users/{user_id}", summary="删除用户")
async def delete_user(user_id: str):
    try:
        obj_id = ObjectId(user_id)
    except:
        raise HTTPException(status_code=400, detail="无效的用户ID格式")
    user = await engine.find_one(User, User.id == obj_id)
    if not user:
        raise HTTPException(status_code=404, detail="用户不存在")
    await engine.delete(user)
    return {"message": "用户删除成功"}

# 运行项目
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

代码说明

  • 通过@app.on_event("startup")@app.on_event("shutdown")钩子函数,在FastAPI应用启动时初始化数据库连接,关闭时断开连接;
  • 所有API接口均为异步函数,通过engine对象完成数据操作;
  • 使用response_model指定接口的返回数据模型,FastAPI会自动进行数据校验和文档生成;
  • 运行项目后,可访问http://localhost:8000/docs查看自动生成的API文档,并进行接口测试。

4.4 启动和测试项目

  1. 启动项目:在命令行中执行python main.py,FastAPI应用会在8000端口启动;
  2. 测试接口:打开浏览器访问http://localhost:8000/docs,可以看到自动生成的Swagger文档,点击对应的接口即可进行测试,例如:
  • 点击/users/POST接口,填写用户信息后点击“Execute”,即可创建新用户;
  • 点击/users/GET接口,可获取所有用户的列表;
  • 点击/users/{user_id}GET接口,输入用户ID,可获取指定用户的信息。

五、相关资源链接

  • Pypi地址:https://pypi.org/project/ODMantic
  • Github地址:https://github.com/art049/odmantic
  • 官方文档地址:https://art049.github.io/odmantic/

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:HappyBase 入门到精通——高效操作HBase数据库指南

一、HappyBase 库核心概述

HappyBase 是一款专为 Python 开发者打造的HBase 数据库交互库,其核心用途是简化 Python 程序与 HBase 分布式数据库的连接、数据读写及管理操作。工作原理上,HappyBase 基于 HBase 的 Thrift 接口实现通信,通过封装复杂的 Thrift 协议调用逻辑,提供简洁直观的 Python 风格 API,让开发者无需深入理解 Thrift 细节即可高效操作 HBase。

该库的优点十分突出:API 设计简洁易懂,贴近 Python 开发者使用习惯;支持连接池管理,能有效提升高并发场景下的连接复用率;兼容 HBase 主流版本,具备良好的通用性。缺点则集中在对 HBase 高级特性(如事务、复杂过滤器)的支持有限,且依赖 Thrift 服务的稳定运行,Thrift 服务的性能瓶颈会直接影响 HappyBase 的操作效率。

HappyBase 的开源协议为 MIT License,这意味着开发者可以自由地用于商业和非商业项目,无需承担开源协议带来的额外约束。

二、HappyBase 安装与环境准备

2.1 前置条件

在安装和使用 HappyBase 之前,必须确保以下环境准备到位:

  1. HBase 集群部署完成:HBase 是分布式数据库,需提前搭建好单节点或集群环境,且保证 HBase 服务正常运行。
  2. Thrift 服务启动:HappyBase 依赖 HBase 的 Thrift 接口,因此需要启动 HBase Thrift 服务。启动命令如下(在 HBase 安装目录的 bin 文件夹下执行):
    bash hbase thrift start
    若需要后台运行 Thrift 服务,可添加 -b 参数指定绑定地址,配合 nohup 命令实现:
    bash nohup hbase thrift start -b 0.0.0.0 > thrift.log 2>&1 &
  3. Python 环境:推荐使用 Python 3.6 及以上版本,确保 pip 包管理工具可用。

2.2 安装 HappyBase

HappyBase 可通过 pip 工具一键安装,这是最简单且推荐的方式。打开命令行终端,执行以下命令:

pip install happybase

若需要安装指定版本的 HappyBase(例如兼容特定 HBase 版本的 1.2.0 版本),可指定版本号:

pip install happybase==1.2.0

安装完成后,可在 Python 环境中执行以下代码验证是否安装成功:

import happybase
print(happybase.__version__)

若终端输出 HappyBase 的版本号(如 1.2.0),则说明安装成功。

三、HappyBase 核心 API 用法与代码实例

HappyBase 的核心操作围绕连接 HBase表操作数据读写三大模块展开,下面结合具体代码实例详细讲解每个模块的使用方法。

3.1 连接 HBase 数据库

连接 HBase 是使用 HappyBase 的第一步,主要通过 happybase.Connection() 方法创建连接对象。该方法支持多个参数,常用参数说明如下:

  • host:HBase Thrift 服务的主机地址,默认值为 localhost
  • port:HBase Thrift 服务的端口号,默认值为 9090
  • timeout:连接超时时间,单位为毫秒,默认无超时限制。
  • autoconnect:是否自动建立连接,默认值为 True

3.1.1 基础连接示例

import happybase

# 创建连接对象
conn = happybase.Connection(
    host='localhost',  # 替换为你的 HBase Thrift 服务地址
    port=9090,
    timeout=10000
)

# 查看当前 HBase 中的所有表名
tables = conn.tables()
print("HBase 中已存在的表:", tables)

# 关闭连接
conn.close()

代码说明

  • 首先导入 happybase 库,然后通过 Connection() 方法指定 HBase Thrift 服务的 hostport,创建连接对象 conn
  • conn.tables() 方法会返回 HBase 中所有表的名称列表,返回结果为字节串格式(如 [b'test_table'])。
  • 操作完成后,需调用 conn.close() 关闭连接,释放资源。

3.1.2 使用连接池管理连接

在高并发场景下,频繁创建和关闭连接会消耗大量系统资源,HappyBase 提供了连接池功能来解决这个问题。通过 happybase.ConnectionPool() 可以创建连接池,实现连接的复用。

import happybase

# 创建连接池,指定池大小为 10
pool = happybase.ConnectionPool(
    size=10,
    host='localhost',
    port=9090
)

# 从连接池中获取连接并执行操作
with pool.connection() as conn:
    tables = conn.tables()
    print("通过连接池获取的表列表:", tables)

代码说明

  • ConnectionPool()size 参数指定连接池的最大连接数。
  • 使用 with 语句从连接池中获取连接,with 代码块执行完毕后会自动将连接归还到池中,无需手动关闭。

3.2 表的创建、删除与列表查询

HBase 中的表是数据存储的核心载体,HappyBase 提供了完整的表生命周期管理 API,包括创建表、删除表、检查表是否存在等操作。

3.2.1 创建表

创建表需要指定表名列族,列族是 HBase 中数据组织的基本单位,一个表可以包含多个列族。创建表的方法是 conn.create_table(),参数说明如下:

  • name:表名,字符串格式。
  • families:列族配置字典,键为列族名称,值为列族的属性配置(如版本数、过期时间等)。
import happybase

# 建立连接
conn = happybase.Connection(host='localhost', port=9090)

# 定义列族配置:创建两个列族 info 和 data,版本数均为 3
column_families = {
    'info': dict(max_versions=3),
    'data': dict(max_versions=3, time_to_live=86400)  # time_to_live 单位为秒,此处为 1 天
}

# 创建表,表名为 student
table_name = 'student'
if table_name.encode() not in conn.tables():
    conn.create_table(table_name, column_families)
    print(f"表 {table_name} 创建成功!")
else:
    print(f"表 {table_name} 已存在!")

# 关闭连接
conn.close()

代码说明

  • 首先定义列族配置字典 column_families,其中 info 列族的最大版本数为 3,data 列族的最大版本数为 3,且数据过期时间为 1 天。
  • 由于 conn.tables() 返回的表名是字节串格式,因此需要将表名字符串 student 转换为字节串(table_name.encode())后再进行判断,避免重复创建表。

3.2.2 删除表

删除表前需要先禁用表(HBase 的强制要求),然后再执行删除操作。对应的方法分别是 conn.disable_table()conn.delete_table()

import happybase

conn = happybase.Connection(host='localhost', port=9090)
table_name = 'student'

if table_name.encode() in conn.tables():
    # 禁用表
    conn.disable_table(table_name)
    print(f"表 {table_name} 已禁用")
    # 删除表
    conn.delete_table(table_name)
    print(f"表 {table_name} 删除成功!")
else:
    print(f"表 {table_name} 不存在!")

conn.close()

代码说明

  • 禁用表是删除表的前提步骤,如果直接删除未禁用的表,会抛出 TApplicationException 异常。
  • 执行完删除操作后,该表及其所有数据会被彻底清除,操作需谨慎。

3.2.3 检查表是否存在

除了通过 conn.tables() 列表判断表是否存在外,HappyBase 还提供了更简洁的 conn.table_exists() 方法(部分版本支持),使用示例如下:

import happybase

conn = happybase.Connection(host='localhost', port=9090)
table_name = 'student'

# 检查表是否存在
if conn.table_exists(table_name):
    print(f"表 {table_name} 存在")
else:
    print(f"表 {table_name} 不存在")

conn.close()

3.3 数据的增删改查操作

表创建完成后,核心操作就是对表中数据的增删改查。HappyBase 通过 Table 对象来操作表数据,获取 Table 对象的方法是 conn.table(table_name)

3.3.1 插入数据(Put 操作)

插入数据使用 Table.put() 方法,支持插入单行数据和多行数据。数据以字典格式组织,键为列名(格式为 列族:列名),值为字段值,所有键值均为字节串格式。

单行数据插入
import happybase

conn = happybase.Connection(host='localhost', port=9090)
# 获取 student 表的 Table 对象
table = conn.table('student')

# 定义行键:HBase 中每行数据的唯一标识
row_key = '001'
# 定义要插入的数据
data = {
    b'info:name': b'Zhang San',
    b'info:age': b'20',
    b'data:score': b'95'
}

# 插入数据
table.put(row_key, data)
print(f"行键 {row_key} 的数据插入成功!")

conn.close()

代码说明

  • 行键 row_key 是 HBase 表中每行数据的唯一标识,字符串格式即可。
  • 数据字典 data 的键必须是字节串格式,格式为 列族:列名,值也必须是字节串格式。若要插入字符串数据,需使用 encode() 方法转换为字节串。
多行数据批量插入

批量插入数据可以提升操作效率,HappyBase 支持通过 Table.put() 方法传入多行数据列表实现批量插入。

import happybase

conn = happybase.Connection(host='localhost', port=9090)
table = conn.table('student')

# 定义多行数据,每个元素为一个元组 (row_key, data_dict)
batch_data = [
    ('002', {b'info:name': b'Li Si', b'info:age': b'21', b'data:score': b'92'}),
    ('003', {b'info:name': b'Wang Wu', b'info:age': b'19', b'data:score': b'88'}),
    ('004', {b'info:name': b'Zhao Liu', b'info:age': b'22', b'data:score': b'90'})
]

# 批量插入数据
for row_key, data in batch_data:
    table.put(row_key, data)
print("多行数据批量插入成功!")

conn.close()

代码说明

  • 批量插入本质是循环调用单行插入方法,适用于中小规模的数据插入。若需要插入超大规模数据,可结合 HBase 的批量加载工具(如 BulkLoad)实现。

3.3.2 查询数据(Get 操作)

查询数据支持单行查询多行扫描,分别对应 Table.row()Table.scan() 方法。

单行数据查询

使用 Table.row() 方法可以查询指定行键的完整数据或指定列族/列的数据。

import happybase

conn = happybase.Connection(host='localhost', port=9090)
table = conn.table('student')

# 查询行键为 001 的完整数据
row_key = '001'
row_data = table.row(row_key)
print(f"行键 {row_key} 的完整数据:")
for column, value in row_data.items():
    print(f"  {column.decode()}: {value.decode()}")

# 只查询 info 列族的数据
info_data = table.row(row_key, columns=[b'info'])
print(f"\n行键 {row_key} 的 info 列族数据:")
for column, value in info_data.items():
    print(f"  {column.decode()}: {value.decode()}")

# 只查询 info:name 和 data:score 列的数据
specific_data = table.row(row_key, columns=[b'info:name', b'data:score'])
print(f"\n行键 {row_key} 的指定列数据:")
for column, value in specific_data.items():
    print(f"  {column.decode()}: {value.decode()}")

conn.close()

代码说明

  • Table.row() 方法的 columns 参数用于指定要查询的列族或列,传入字节串列表即可。
  • 返回的 row_data 是一个字典,键为列名字节串,值为字段值字节串,需使用 decode() 方法转换为字符串格式。
多行数据扫描

使用 Table.scan() 方法可以扫描表中的多行数据,支持设置行键范围、列族/列过滤、数据版本等参数。

import happybase

conn = happybase.Connection(host='localhost', port=9090)
table = conn.table('student')

# 扫描所有数据
print("扫描表中所有数据:")
for row_key, data in table.scan():
    print(f"行键:{row_key.decode()}")
    for column, value in data.items():
        print(f"  {column.decode()}: {value.decode()}")
    print("-" * 20)

# 扫描行键范围在 002 到 003 之间的数据
print("\n扫描行键 002-003 之间的数据:")
for row_key, data in table.scan(row_start=b'002', row_stop=b'004'):
    print(f"行键:{row_key.decode()}")
    for column, value in data.items():
        print(f"  {column.decode()}: {value.decode()}")
    print("-" * 20)

# 扫描 info 列族且 age 大于 20 的数据(需结合过滤器,此处为简化示例)
print("\n扫描 info 列族且 age 大于 20 的数据:")
for row_key, data in table.scan(columns=[b'info']):
    age = data.get(b'info:age', b'0').decode()
    if int(age) > 20:
        print(f"行键:{row_key.decode()},年龄:{age}")

conn.close()

代码说明

  • row_startrow_stop 参数用于指定行键的扫描范围,遵循左闭右开原则(即包含 row_start,不包含 row_stop)。
  • HappyBase 对 HBase 的高级过滤器支持有限,若需要复杂的条件过滤(如列值比较、正则匹配),需结合 happybase.Filter 类或直接使用 Thrift 接口定义过滤器。

3.3.3 更新数据

HBase 中更新数据的逻辑与插入数据一致,使用 Table.put() 方法即可。当插入的行键和列名已存在时,会自动覆盖原有数据,同时生成新的版本(根据列族的 max_versions 配置)。

import happybase

conn = happybase.Connection(host='localhost', port=9090)
table = conn.table('student')

row_key = '001'
# 更新 age 和 score 字段
update_data = {
    b'info:age': b'21',
    b'data:score': b'96'
}

table.put(row_key, update_data)
print(f"行键 {row_key} 的数据更新成功!")

# 查询更新后的数据
row_data = table.row(row_key, columns=[b'info:age', b'data:score'])
print(f"更新后的数据:")
for column, value in row_data.items():
    print(f"  {column.decode()}: {value.decode()}")

conn.close()

代码说明

  • HBase 是版本化数据库,每次更新都会生成新的数据版本,旧版本数据不会立即删除,可通过指定版本号查询历史数据。例如,使用 table.row(row_key, versions=2) 可以获取最近 2 个版本的数据。

3.3.4 删除数据

删除数据使用 Table.delete() 方法,支持删除指定行的全部数据或指定列族/列的数据。

import happybase

conn = happybase.Connection(host='localhost', port=9090)
table = conn.table('student')

row_key = '004'
# 删除指定行的全部数据
table.delete(row_key)
print(f"行键 {row_key} 的全部数据已删除!")

# 删除指定行的指定列数据
row_key = '003'
table.delete(row_key, columns=[b'data:score'])
print(f"行键 {row_key} 的 data:score 列数据已删除!")

# 查询删除后的数据
row_data = table.row(row_key)
print(f"行键 {row_key} 删除后的剩余数据:")
for column, value in row_data.items():
    print(f"  {column.decode()}: {value.decode()}")

conn.close()

代码说明

  • 删除指定列数据时,需通过 columns 参数指定要删除的列名,格式为字节串列表。
  • 删除操作会生成新的版本数据,标记为删除状态,HBase 会在后续的 Major Compaction 过程中彻底清理这些数据。

四、HappyBase 实战案例:学生成绩管理系统

为了更好地理解 HappyBase 在实际项目中的应用,下面构建一个简单的学生成绩管理系统,实现学生信息的新增、查询、更新和删除功能。

4.1 系统功能需求

  1. 新增学生信息(姓名、年龄、班级、成绩)。
  2. 根据学号查询学生完整信息。
  3. 根据班级扫描学生列表。
  4. 更新学生的成绩信息。
  5. 删除指定学号的学生信息。

4.2 完整代码实现

import happybase

class StudentScoreManager:
    def __init__(self, host='localhost', port=9090, table_name='student_score'):
        """初始化连接和表对象"""
        self.host = host
        self.port = port
        self.table_name = table_name
        self.conn = None
        self.table = None
        self._connect()
        self._create_table()

    def _connect(self):
        """建立 HBase 连接"""
        self.conn = happybase.Connection(host=self.host, port=self.port)

    def _create_table(self):
        """创建学生成绩表,列族为 info(基本信息)和 score(成绩信息)"""
        column_families = {
            'info': dict(max_versions=3),
            'score': dict(max_versions=3)
        }
        if self.table_name.encode() not in self.conn.tables():
            self.conn.create_table(self.table_name, column_families)
            print(f"表 {self.table_name} 创建成功!")
        self.table = self.conn.table(self.table_name)

    def add_student(self, student_id, name, age, class_name, math, english, chinese):
        """新增学生信息"""
        data = {
            b'info:name': name.encode(),
            b'info:age': str(age).encode(),
            b'info:class': class_name.encode(),
            b'score:math': str(math).encode(),
            b'score:english': str(english).encode(),
            b'score:chinese': str(chinese).encode()
        }
        self.table.put(student_id, data)
        print(f"学生 {student_id} - {name} 信息新增成功!")

    def query_student(self, student_id):
        """根据学号查询学生信息"""
        row_data = self.table.row(student_id)
        if not row_data:
            print(f"未找到学号为 {student_id} 的学生信息!")
            return None
        student_info = {
            'student_id': student_id,
            'name': row_data[b'info:name'].decode(),
            'age': int(row_data[b'info:age'].decode()),
            'class': row_data[b'info:class'].decode(),
            'math': int(row_data[b'score:math'].decode()),
            'english': int(row_data[b'score:english'].decode()),
            'chinese': int(row_data[b'score:chinese'].decode())
        }
        return student_info

    def scan_class_students(self, class_name):
        """根据班级扫描学生列表"""
        students = []
        for row_key, data in self.table.scan(columns=[b'info', b'score']):
            if b'info:class' in data and data[b'info:class'].decode() == class_name:
                student_info = {
                    'student_id': row_key.decode(),
                    'name': data[b'info:name'].decode(),
                    'age': int(data[b'info:age'].decode()),
                    'math': int(data[b'score:math'].decode()),
                    'english': int(data[b'score:english'].decode()),
                    'chinese': int(data[b'score:chinese'].decode())
                }
                students.append(student_info)
        return students

    def update_score(self, student_id, subject, new_score):
        """更新学生指定科目的成绩"""
        column = f'score:{subject}'.encode()
        if not self.table.row(student_id):
            print(f"未找到学号为 {student_id} 的学生信息!")
            return
        self.table.put(student_id, {column: str(new_score).encode()})
        print(f"学生 {student_id} 的 {subject} 成绩更新为 {new_score}!")

    def delete_student(self, student_id):
        """删除指定学号的学生信息"""
        if not self.table.row(student_id):
            print(f"未找到学号为 {student_id} 的学生信息!")
            return
        self.table.delete(student_id)
        print(f"学生 {student_id} 的信息已删除!")

    def close(self):
        """关闭连接"""
        self.conn.close()
        print("连接已关闭!")

# 测试学生成绩管理系统
if __name__ == '__main__':
    manager = StudentScoreManager()

    # 1. 新增学生信息
    manager.add_student('2024001', 'Zhang San', 18, 'Class 1 Grade 3', 95, 92, 88)
    manager.add_student('2024002', 'Li Si', 17, 'Class 1 Grade 3', 90, 85, 93)
    manager.add_student('2024003', 'Wang Wu', 18, 'Class 2 Grade 3', 88, 91, 90)

    # 2. 查询单个学生信息
    print("\n查询学号 2024001 的学生信息:")
    student = manager.query_student('2024001')
    if student:
        for key, value in student.items():
            print(f"  {key}: {value}")

    # 3. 扫描指定班级的学生列表
    print("\n扫描 Class 1 Grade 3 的学生列表:")
    class_students = manager.scan_class_students('Class 1 Grade 3')
    for stu in class_students:
        print(f"  学号:{stu['student_id']},姓名:{stu['name']},数学成绩:{stu['math']}")

    # 4. 更新学生成绩
    print("\n更新学号 2024001 的数学成绩:")
    manager.update_score('2024001', 'math', 98)
    student = manager.query_student('2024001')
    print(f"  更新后数学成绩:{student['math']}")

    # 5. 删除学生信息
    print("\n删除学号 2024003 的学生信息:")
    manager.delete_student('2024003')
    manager.query_student('2024003')

    # 关闭连接
    manager.close()

4.3 代码说明与运行结果

  1. 类结构设计StudentScoreManager 类封装了所有核心功能,通过 __init__ 方法完成连接初始化和表创建,_connect_create_table 为内部辅助方法,对外提供 add_studentquery_student 等业务方法。
  2. 数据存储设计:使用 info 列族存储学生基本信息(姓名、年龄、班级),score 列族存储各科成绩,符合 HBase 列族的设计原则。
  3. 运行结果:执行代码后,会依次完成学生信息新增、查询、班级扫描、成绩更新和删除操作,终端输出对应的操作结果,验证了 HappyBase 在实际项目中的可行性。

五、HappyBase 相关资源链接

  • PyPI 地址:https://pypi.org/project/HappyBase
  • Github 地址:https://github.com/wbolster/happybase
  • 官方文档地址:https://happybase.readthedocs.io/en/latest/

关注我,每天分享一个实用的Python自动化工具。

Python Prisma库完全指南:现代ORM的高效数据操作实战

一、Prisma库核心概述

1.1 用途与工作原理

Prisma是一款为Python开发者设计的现代ORM(对象关系映射)工具,核心作用是简化Python应用与关系型数据库的交互流程,支持PostgreSQL、MySQL、SQLite等主流数据库。其工作原理基于数据模型驱动:开发者通过定义简洁的Schema文件描述数据结构,Prisma引擎会自动将Schema转换为对应的数据库表结构,同时生成类型安全的Python客户端,让开发者无需编写复杂的SQL语句,直接通过面向对象的方式完成数据的增删改查操作。

1.2 优缺点分析

优点

  • 类型安全:自动生成的客户端包含完整的类型提示,结合Python的类型检查工具(如mypy)可在编码阶段发现数据类型错误,大幅降低运行时异常概率。
  • Schema即文档:Schema文件采用直观的语法,兼具数据结构定义与文档功能,团队协作时可直接通过Schema了解数据模型。
  • 迁移管理便捷:内置的迁移工具支持数据库结构的版本控制,可轻松实现数据库表的创建、修改、删除,且能追踪迁移历史。
  • 查询能力强大:支持链式查询、关联查询、批量操作等复杂场景,查询语法简洁易懂,比传统ORM更贴近自然语言。

缺点

  • 生态成熟度待提升:相较于Django ORM、SQLAlchemy等老牌ORM,Python版Prisma的第三方插件和扩展较少。
  • 学习曲线:对于习惯原生SQL或传统ORM的开发者,需要适应Prisma独特的Schema定义和查询风格。
  • 性能损耗:在超高性能要求的场景下,ORM的封装会带来轻微的性能开销,极端场景下可能需要结合原生SQL优化。

1.3 License类型

Python Prisma库采用Apache License 2.0开源协议,该协议允许商业使用、修改和分发,只需保留原作者的版权声明,对开发者友好且无商业使用限制。

二、Prisma库安装与环境配置

2.1 安装前置条件

在安装Prisma之前,需确保本地环境满足以下要求:

  • Python版本≥3.8(推荐3.9及以上)
  • 已安装对应数据库的客户端工具(如PostgreSQL需安装psycopg2,MySQL需安装mysqlclient)
  • 网络环境正常,可访问PyPI仓库

2.2 安装命令

Prisma的安装分为两个步骤:首先安装Python包,然后初始化Prisma引擎。

  1. 安装Python包
    打开命令行终端,执行以下pip命令安装Prisma:
    bash pip install prisma
  2. 初始化Prisma引擎 安装完成后,需要初始化Prisma的二进制引擎,执行以下命令: bash prisma init 执行该命令后,会在当前目录生成两个关键文件:
    • schema.prisma:用于定义数据模型和数据库连接配置的核心文件。
    • .env:用于存储数据库连接字符串等环境变量。

2.3 数据库连接配置

以SQLite数据库为例(无需额外安装服务,适合快速开发),修改schema.prisma文件中的datasource块:

datasource db {
  provider = "sqlite"
  url      = env("DATABASE_URL")
}

generator client {
  provider = "prisma-client-py"
}

然后修改.env文件,设置数据库连接URL:

DATABASE_URL="file:./dev.db"

若使用MySQL数据库,修改datasource块和.env文件如下:

datasource db {
  provider = "mysql"
  url      = env("DATABASE_URL")
}
DATABASE_URL="mysql://user:password@localhost:3306/mydatabase"

其中user为数据库用户名,password为密码,mydatabase为数据库名称。

三、Prisma核心使用教程

3.1 数据模型定义

Prisma的核心是schema.prisma文件中的数据模型定义,模型对应数据库中的表,模型中的字段对应表中的列。下面以一个User用户模型和Post文章模型为例,展示如何定义关联模型。

// User模型:对应数据库中的users表
model User {
  id        Int      @id @default(autoincrement()) // 主键,自增整数
  name      String   // 用户名,字符串类型
  email     String   @unique // 邮箱,唯一约束
  age       Int?     // 年龄,可选整数(允许为空)
  posts     Post[]   // 一对多关联:一个用户可以有多篇文章
  createdAt DateTime @default(now()) // 创建时间,默认当前时间
}

// Post模型:对应数据库中的posts表
model Post {
  id        Int      @id @default(autoincrement())
  title     String   // 文章标题
  content   String?  // 文章内容,可选
  authorId  Int      // 外键,关联User模型的id
  author    User     @relation(fields: [authorId], references: [id]) // 多对一关联
  published Boolean  @default(false) // 是否发布,默认false
  createdAt DateTime @default(now())
}

字段属性说明

  • @id:标记该字段为主键。
  • @default(autoincrement()):设置字段默认值为自增。
  • @unique:添加唯一约束,确保字段值不重复。
  • ?:标记字段为可选,允许存储NULL值。
  • @relation:定义模型之间的关联关系。

3.2 生成数据库表与客户端

定义好Schema后,需要执行迁移命令生成对应的数据库表结构,同时生成Python客户端代码。

  1. 创建迁移文件
    执行以下命令,Prisma会根据Schema的变化生成迁移文件:
    bash prisma migrate dev --name init
    --name init表示给本次迁移命名为init,执行成功后,会在prisma/migrations目录下生成迁移历史文件,同时自动在数据库中创建UserPost表。
  2. 生成Python客户端
    迁移完成后,Prisma会自动生成类型安全的Python客户端,无需手动编写。客户端文件默认生成在prisma目录下,可直接在Python代码中导入使用。

3.3 基础数据操作(CRUD)

Prisma客户端提供了简洁的API实现数据的增删改查,下面通过具体的Python脚本演示每个操作的使用方法。

3.3.1 连接数据库并初始化客户端

在Python脚本中,首先需要导入并初始化Prisma客户端,建立与数据库的连接:

# 导入Prisma客户端
from prisma import Prisma

# 初始化客户端
db = Prisma()

# 连接数据库
async def connect_db():
    await db.connect()

# 关闭数据库连接
async def disconnect_db():
    await db.disconnect()

由于Prisma的Python客户端基于异步IO设计,所有数据库操作都需要在异步函数中执行。

3.3.2 创建数据(Create)

使用create方法向数据库中插入单条数据,使用create_many方法批量插入多条数据。

单条数据插入

async def create_user():
    # 连接数据库
    await connect_db()
    # 创建用户
    user = await db.user.create(
        data={
            'name': '张三',
            'email': '[email protected]',
            'age': 25
        }
    )
    # 打印创建的用户信息
    print(f'创建用户成功:{user.id} - {user.name} - {user.email}')
    # 关闭连接
    await disconnect_db()

# 执行异步函数
import asyncio
asyncio.run(create_user())

执行上述代码后,会在User表中插入一条用户数据,user对象包含了数据库返回的完整用户信息,包括自动生成的idcreatedAt字段。

批量数据插入

async def batch_create_users():
    await connect_db()
    # 批量创建3个用户
    result = await db.user.create_many(
        data=[
            {'name': '李四', 'email': '[email protected]', 'age': 22},
            {'name': '王五', 'email': '[email protected]', 'age': 28},
            {'name': '赵六', 'email': '[email protected]'}
        ]
    )
    # result包含创建的记录数
    print(f'批量创建用户成功,共创建 {result.count} 条记录')
    await disconnect_db()

asyncio.run(batch_create_users())

create_many方法的返回值是一个包含count属性的对象,表示成功插入的记录数量。

3.3.3 查询数据(Read)

Prisma提供了丰富的查询方法,包括find_uniquefind_firstfind_many等,支持条件过滤、排序、分页和关联查询。

查询单条数据
使用find_unique方法根据唯一约束字段查询单条数据,例如根据邮箱查询用户:

async def find_user_by_email(email: str):
    await connect_db()
    # 根据邮箱查询用户(email字段有@unique约束)
    user = await db.user.find_unique(
        where={
            'email': email
        }
    )
    if user:
        print(f'查询到用户:{user.name} - {user.age}')
    else:
        print('未查询到该用户')
    await disconnect_db()

asyncio.run(find_user_by_email('[email protected]'))

使用find_first方法查询满足条件的第一条数据(无需唯一约束):

async def find_first_user():
    await connect_db()
    # 查询年龄大于20的第一个用户
    user = await db.user.find_first(
        where={
            'age': {
                'gt': 20
            }
        }
    )
    print(f'查询到用户:{user.name} - {user.age}')
    await disconnect_db()

asyncio.run(find_first_user())

其中gt表示“大于”,Prisma支持的查询操作符还包括lt(小于)、gte(大于等于)、lte(小于等于)、contains(包含)等。

查询多条数据
使用find_many方法查询满足条件的所有数据,支持排序、分页和字段筛选:

async def find_all_users():
    await connect_db()
    # 查询所有用户,按创建时间降序排序,只返回name和email字段
    users = await db.user.find_many(
        select={
            'name': True,
            'email': True
        },
        order={
            'createdAt': 'desc'
        },
        # 分页:跳过前1条,取2条
        skip=1,
        take=2
    )
    # 遍历打印用户信息
    for user in users:
        print(f'用户名:{user.name},邮箱:{user.email}')
    await disconnect_db()

asyncio.run(find_all_users())

select参数用于指定返回的字段,order参数用于排序,skiptake参数用于实现分页功能。

关联查询
查询用户的同时,获取该用户发布的所有文章,使用include参数实现关联数据的加载:

async def find_user_with_posts(user_id: int):
    await connect_db()
    # 查询用户及其所有文章
    user = await db.user.find_unique(
        where={
            'id': user_id
        },
        include={
            'posts': True
        }
    )
    if user:
        print(f'用户:{user.name},发布的文章数:{len(user.posts)}')
        for post in user.posts:
            print(f'文章标题:{post.title},状态:{"已发布" if post.published else "未发布"}')
    await disconnect_db()

# 假设用户id为1
asyncio.run(find_user_with_posts(1))

上述代码中,通过include={'posts': True},Prisma会自动查询该用户关联的所有Post数据,并封装到user.posts属性中。

3.3.4 更新数据(Update)

使用update方法更新单条数据,使用update_many方法批量更新多条数据。

单条数据更新

async def update_user_age(user_id: int, new_age: int):
    await connect_db()
    # 更新用户年龄
    updated_user = await db.user.update(
        where={
            'id': user_id
        },
        data={
            'age': new_age
        }
    )
    print(f'用户更新成功,新年龄:{updated_user.age}')
    await disconnect_db()

asyncio.run(update_user_age(1, 26))

update方法的where参数指定更新条件,data参数指定要更新的字段和值,返回值为更新后的完整数据对象。

批量数据更新

async def batch_update_posts():
    await connect_db()
    # 将所有未发布的文章标记为已发布
    result = await db.post.update_many(
        where={
            'published': False
        },
        data={
            'published': True
        }
    )
    print(f'批量更新成功,共更新 {result.count} 篇文章')
    await disconnect_db()

asyncio.run(batch_update_posts())

3.3.5 删除数据(Delete)

使用delete方法删除单条数据,使用delete_many方法批量删除多条数据。

单条数据删除

async def delete_user(user_id: int):
    await connect_db()
    # 删除指定用户
    deleted_user = await db.user.delete(
        where={
            'id': user_id
        }
    )
    print(f'删除用户成功:{deleted_user.name}')
    await disconnect_db()

asyncio.run(delete_user(4))

批量数据删除

async def delete_unpublished_posts():
    await connect_db()
    # 删除所有未发布的文章
    result = await db.post.delete_many(
        where={
            'published': False
        }
    )
    print(f'批量删除成功,共删除 {result.count} 篇文章')
    await disconnect_db()

asyncio.run(delete_unpublished_posts())

3.4 事务处理

事务是数据库操作的重要特性,用于保证多个操作的原子性(要么全部成功,要么全部失败)。Prisma客户端提供transaction方法实现事务处理。

例如,创建用户的同时,为该用户创建一篇文章,两个操作在同一个事务中执行:

async def create_user_and_post():
    await connect_db()
    try:
        # 开启事务
        async with db.transaction():
            # 第一步:创建用户
            user = await db.user.create(
                data={
                    'name': '钱七',
                    'email': '[email protected]',
                    'age': 30
                }
            )
            print(f'事务中创建用户:{user.name}')
            # 第二步:为该用户创建文章
            post = await db.post.create(
                data={
                    'title': 'Prisma事务教程',
                    'content': 'Prisma的事务处理非常简单',
                    'authorId': user.id
                }
            )
            print(f'事务中创建文章:{post.title}')
        # 事务提交成功
        print('用户和文章创建成功,事务已提交')
    except Exception as e:
        # 事务回滚
        print(f'操作失败,事务已回滚,错误信息:{e}')
    finally:
        await disconnect_db()

asyncio.run(create_user_and_post())

在上述代码中,async with db.transaction()上下文管理器会自动管理事务的提交和回滚:如果上下文内的所有操作都成功执行,事务会自动提交;如果任何一个操作抛出异常,事务会自动回滚,确保数据一致性。

四、实际项目案例:简易博客系统

4.1 项目需求

构建一个简易的博客系统,实现以下功能:

  1. 用户注册和登录(简化版,不涉及密码加密)
  2. 创建、查询、发布博客文章
  3. 查询指定用户的所有文章

4.2 项目目录结构

simple_blog/
├── .env               # 环境变量配置
├── schema.prisma      # Prisma数据模型定义
└── blog.py            # 业务逻辑代码

4.3 数据模型定义

schema.prisma文件的内容与第三部分的模型定义一致,包含UserPost两个关联模型。

4.4 业务逻辑实现

blog.py文件中实现具体的业务功能,代码如下:

from prisma import Prisma
import asyncio

# 初始化Prisma客户端
db = Prisma()

# 数据库连接与关闭工具函数
async def get_db():
    await db.connect()
    try:
        yield db
    finally:
        await db.disconnect()

# 1. 用户注册功能
async def register_user(name: str, email: str, age: int = None):
    async for db in get_db():
        # 检查邮箱是否已存在
        existing_user = await db.user.find_unique(where={'email': email})
        if existing_user:
            print(f'邮箱 {email} 已被注册')
            return None
        # 创建新用户
        user = await db.user.create(
            data={
                'name': name,
                'email': email,
                'age': age
            }
        )
        print(f'用户 {name} 注册成功,用户ID:{user.id}')
        return user

# 2. 创建博客文章
async def create_article(title: str, content: str, author_id: int):
    async for db in get_db():
        # 检查作者是否存在
        author = await db.user.find_unique(where={'id': author_id})
        if not author:
            print(f'作者ID {author_id} 不存在')
            return None
        # 创建文章
        post = await db.post.create(
            data={
                'title': title,
                'content': content,
                'authorId': author_id
            }
        )
        print(f'文章 {title} 创建成功,文章ID:{post.id}')
        return post

# 3. 发布博客文章
async def publish_article(post_id: int):
    async for db in get_db():
        post = await db.post.update(
            where={'id': post_id},
            data={'published': True}
        )
        print(f'文章 {post.title} 已发布')
        return post

# 4. 查询用户的所有已发布文章
async def get_user_published_posts(author_id: int):
    async for db in get_db():
        user = await db.user.find_unique(
            where={'id': author_id},
            include={
                'posts': {
                    'where': {'published': True},
                    'order': {'createdAt': 'desc'}
                }
            }
        )
        if not user:
            print(f'作者ID {author_id} 不存在')
            return []
        print(f'用户 {user.name} 的已发布文章:')
        for post in user.posts:
            print(f'- {post.title} | 创建时间:{post.createdAt}')
        return user.posts

# 主函数:执行案例
async def main():
    # 注册新用户
    user = await register_user('小明', '[email protected]', 23)
    if not user:
        return
    # 为用户创建文章
    post = await create_article('Prisma实战教程', '本文介绍了Prisma的核心用法', user.id)
    if not post:
        return
    # 发布文章
    await publish_article(post.id)
    # 查询用户已发布的文章
    await get_user_published_posts(user.id)

# 运行主函数
if __name__ == '__main__':
    asyncio.run(main())

4.5 运行结果

执行blog.py文件,控制台输出如下:

用户 小明 注册成功,用户ID:5
文章 Prisma实战教程 创建成功,文章ID:3
文章 Prisma实战教程 已发布
用户 小明 的已发布文章:
- Prisma实战教程 | 创建时间:2024-05-20 15:30:25

五、Prisma相关资源

5.1 PyPI地址

https://pypi.org/project/prisma

5.2 Github地址

https://github.com/prisma/prisma-client-py

5.3 官方文档地址

https://prisma-client-py.readthedocs.io

关注我,每天分享一个实用的Python自动化工具。

Python实用工具:pysolr 从入门到精通——高效操作Solr搜索引擎的指南

一、pysolr 库概述

1.1 用途

pysolr 是一个专门用于和 Apache Solr 搜索引擎进行交互的 Python 客户端库,它能够让开发者通过简洁的 Python 代码,轻松实现对 Solr 索引的创建、数据的添加、删除、更新以及复杂的查询操作。无论是构建企业级的全文检索系统,还是实现数据分析场景下的快速数据筛选,pysolr 都能提供稳定且高效的支持。

1.2 工作原理

pysolr 底层基于 HTTP/HTTPS 协议与 Solr 服务器进行通信,它将 Python 代码中的操作指令(如查询语句、数据提交指令)封装成符合 Solr API 规范的 HTTP 请求,发送到 Solr 服务器的指定接口(如 /solr/core_name/update 用于数据更新,/solr/core_name/select 用于数据查询),然后接收 Solr 服务器返回的 JSON 格式响应,并将其解析为 Python 中的字典、列表等数据结构,方便开发者直接处理。

1.3 优缺点

优点

  • 接口简洁易用,极大降低了 Python 开发者操作 Solr 的门槛,无需手动构造复杂的 HTTP 请求。
  • 支持 Solr 的大部分核心功能,包括全文检索、过滤查询、排序、分组统计、高亮显示等。
  • 兼容性良好,能够适配不同版本的 Apache Solr,且支持 Python 3.6 及以上的主流 Python 版本。

缺点

  • 功能覆盖相较于 Solr 的原生 API 存在少量缺失,部分高级特性(如自定义请求处理器的复杂配置)需要手动扩展 HTTP 请求参数。
  • 对大规模数据批量操作的性能优化需要开发者自行调整参数(如批量提交的大小),默认配置下的大批量数据插入效率有待提升。

1.4 License 类型

pysolr 采用的是 BSD 3-Clause 许可证,这是一个宽松的开源许可证,允许开发者自由地使用、修改、分发该库的代码,无论是用于商业项目还是开源项目,都几乎没有限制,只需要保留原作者的版权声明即可。

二、pysolr 安装与环境准备

2.1 安装 pysolr

安装 pysolr 非常简单,推荐使用 Python 的包管理工具 pip 进行安装,在命令行中执行以下命令即可完成安装:

pip install pysolr

该命令会自动从 PyPI 下载并安装最新版本的 pysolr 库及其依赖项(主要依赖 requests 库用于 HTTP 通信)。

2.2 环境依赖确认

  • Python 版本:确保你的 Python 环境版本为 3.6 及以上,可以通过 python --version 命令查看当前 Python 版本。
  • Solr 服务器环境:pysolr 是操作 Solr 的客户端,因此需要先搭建好 Solr 服务器环境。你可以从 Apache Solr 官方网站(https://solr.apache.org/)下载对应版本的 Solr 安装包,按照官方文档完成安装和启动,并创建至少一个 Solr Core(Solr 的核心索引单元)用于后续操作。
  • 网络连通性:确保运行 pysolr 代码的机器能够和 Solr 服务器所在的机器互通网络,Solr 默认的 HTTP 端口为 8983,需要保证该端口未被防火墙拦截。

三、pysolr 核心使用方法与代码示例

3.1 连接 Solr 服务器

在使用 pysolr 进行任何操作之前,首先需要创建一个 Solr 客户端实例,建立与 Solr 服务器的连接。核心代码如下:

import pysolr

# 定义 Solr 服务器的基础 URL 和 Core 名称
# 格式为:http://solr_host:solr_port/solr/core_name
SOLR_URL = "http://localhost:8983/solr/gettingstarted"

# 创建 Solr 客户端实例
solr = pysolr.Solr(SOLR_URL, timeout=10)

print("成功连接到 Solr 服务器!")

代码说明

  • pysolr.Solr() 是创建客户端实例的构造函数,第一个参数是 Solr Core 的完整 URL,其中 localhost 是 Solr 服务器的主机名,8983 是默认端口,gettingstarted 是 Solr Core 的名称(需要替换为你自己创建的 Core 名称)。
  • timeout 参数设置了 HTTP 请求的超时时间(单位为秒),避免因网络问题导致程序长时间阻塞。

3.2 向 Solr 中添加数据

Solr 存储的数据是以文档(Document)为单位的,每个文档是一个键值对的集合,对应 Solr Schema 中定义的字段。我们可以通过 add() 方法向 Solr 中添加单个或多个文档。

3.2.1 添加单个文档

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 定义一个 Solr 文档,字段需要和 Solr Schema 中的定义一致
document = {
    "id": "book_001",  # id 字段是 Solr 的默认唯一标识字段,必填
    "title": "Python编程:从入门到实践",
    "author": "埃里克·马瑟斯",
    "publisher": "人民邮电出版社",
    "publish_date": "2020-01-01",
    "price": 59.8,
    "tags": ["Python", "编程", "入门"]
}

# 添加文档到 Solr
solr.add([document])

# 提交更改,确保数据被持久化到索引中
solr.commit()

print("单个文档添加成功!")

3.2.2 批量添加多个文档

当需要添加大量数据时,批量添加的效率远高于逐个添加,代码示例如下:

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 定义多个文档的列表
documents = [
    {
        "id": "book_002",
        "title": "流畅的Python",
        "author": "卢西亚诺·拉马略",
        "publisher": "人民邮电出版社",
        "publish_date": "2017-05-01",
        "price": 129.0,
        "tags": ["Python", "进阶", "编程思想"]
    },
    {
        "id": "book_003",
        "title": "Python数据分析与挖掘实战",
        "author": "张良均",
        "publisher": "机械工业出版社",
        "publish_date": "2019-03-01",
        "price": 79.0,
        "tags": ["Python", "数据分析", "挖掘"]
    },
    {
        "id": "book_004",
        "title": "深度学习入门:基于Python的理论与实现",
        "author": "斋藤康毅",
        "publisher": "人民邮电出版社",
        "publish_date": "2018-07-01",
        "price": 69.0,
        "tags": ["Python", "深度学习", "AI"]
    }
]

# 批量添加文档
solr.add(documents, batch_size=2)  # batch_size 表示每次提交的文档数量

# 提交更改
solr.commit()

print("批量文档添加成功!")

代码说明

  • add() 方法接收一个文档列表作为参数,batch_size 参数可以控制每次向 Solr 提交的文档数量,当文档数量较多时,合理设置 batch_size 可以避免单次请求数据量过大导致的失败。
  • commit() 方法用于提交更改,Solr 在接收到 add 请求后,会先将数据存入内存,只有执行 commit 操作后,数据才会被写入磁盘索引,并且才能被查询到。

3.3 从 Solr 中删除数据

pysolr 支持通过文档 ID、查询条件等方式删除 Solr 中的数据,常用的删除方法有 delete()delete_by_query()

3.3.1 通过 ID 删除单个文档

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 通过文档 ID 删除
solr.delete(id="book_001")

# 提交更改
solr.commit()

print("通过ID删除文档成功!")

3.3.2 通过查询条件删除多个文档

如果需要删除满足特定条件的一批文档,可以使用 delete_by_query() 方法,代码示例如下:

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 删除 publisher 为"机械工业出版社"的所有文档
solr.delete_by_query("publisher:机械工业出版社")

# 提交更改
solr.commit()

print("通过查询条件删除文档成功!")

3.3.3 删除所有文档

如果需要清空整个 Solr Core 的数据,可以使用通配符查询条件 *:*,代码如下:

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 删除所有文档
solr.delete_by_query("*:*")

# 提交更改
solr.commit()

print("所有文档删除成功!")

代码说明

  • delete(id="xxx") 方法用于删除指定 ID 的文档,ID 是 Solr 文档的唯一标识。
  • delete_by_query(query) 方法接收一个 Solr 查询语句作为参数,会删除所有满足该查询条件的文档,使用时需要格外谨慎,避免误删数据。

3.4 查询 Solr 中的数据

查询是 Solr 的核心功能,pysolr 提供了 search() 方法来执行各种查询操作,支持全文检索、过滤、排序、分页、高亮等多种功能。

3.4.1 基础全文检索

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 先添加一些测试数据,方便查询
test_docs = [
    {
        "id": "book_002",
        "title": "流畅的Python",
        "author": "卢西亚诺·拉马略",
        "publisher": "人民邮电出版社",
        "publish_date": "2017-05-01",
        "price": 129.0,
        "tags": ["Python", "进阶", "编程思想"]
    },
    {
        "id": "book_003",
        "title": "Python数据分析与挖掘实战",
        "author": "张良均",
        "publisher": "机械工业出版社",
        "publish_date": "2019-03-01",
        "price": 79.0,
        "tags": ["Python", "数据分析", "挖掘"]
    }
]
solr.add(test_docs)
solr.commit()

# 基础全文检索:搜索标题中包含"Python"的文档
results = solr.search("title:Python")

# 处理查询结果
print(f"查询到 {len(results)} 条结果:")
for result in results:
    print(f"ID: {result['id']}")
    print(f"标题: {result['title']}")
    print(f"作者: {result['author']}")
    print(f"价格: {result['price']}")
    print("-" * 50)

3.4.2 带过滤条件的查询

在实际应用中,我们经常需要在全文检索的基础上,添加过滤条件来缩小查询范围,例如过滤价格区间、出版社等,代码示例如下:

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 搜索标题包含"Python",且价格在 50-100 之间,出版社为"人民邮电出版社"的文档
# q 参数是查询语句,fq 参数是过滤条件(可以是多个)
results = solr.search(
    q="title:Python",
    fq=[
        "price:[50 TO 100]",  # 价格区间过滤,闭区间
        "publisher:人民邮电出版社"
    ]
)

print(f"过滤查询到 {len(results)} 条结果:")
for result in results:
    print(f"ID: {result['id']}")
    print(f"标题: {result['title']}")
    print(f"价格: {result['price']}")
    print(f"出版社: {result['publisher']}")
    print("-" * 50)

3.4.3 带排序和分页的查询

当查询结果较多时,分页和排序功能是必不可少的,pysolr 支持通过 sortstartrows 参数来实现,代码示例如下:

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 先添加更多测试数据
more_docs = [
    {
        "id": "book_005",
        "title": "Python爬虫开发与项目实战",
        "author": "范传辉",
        "publisher": "机械工业出版社",
        "publish_date": "2020-01-01",
        "price": 65.0,
        "tags": ["Python", "爬虫"]
    },
    {
        "id": "book_006",
        "title": "Python Web开发实战",
        "author": "陶俊杰",
        "publisher": "清华大学出版社",
        "publish_date": "2018-10-01",
        "price": 89.0,
        "tags": ["Python", "Web开发"]
    }
]
solr.add(more_docs)
solr.commit()

# 搜索标题包含"Python"的文档,按价格降序排序,分页获取第1页(从0开始),每页3条
results = solr.search(
    q="title:Python",
    sort="price desc",  # desc 降序,asc 升序
    start=0,  # 起始位置
    rows=3  # 每页显示的条数
)

print(f"分页查询到 {len(results)} 条结果:")
for result in results:
    print(f"ID: {result['id']}")
    print(f"标题: {result['title']}")
    print(f"价格: {result['price']}")
    print("-" * 50)

# 获取总记录数
print(f"符合条件的总记录数: {results.hits}")

3.4.4 高亮显示查询结果

高亮显示可以让查询结果中匹配的关键词以特殊样式呈现,提升用户体验,pysolr 支持通过 hl 相关参数实现高亮功能,代码示例如下:

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 搜索标题包含"Python"的文档,并高亮显示标题中的关键词
results = solr.search(
    q="title:Python",
    hl=True,  # 开启高亮功能
    hl_fl="title",  # 指定需要高亮的字段
    hl_simple_pre="<em>",  # 高亮前缀
    hl_simple_post="</em>"  # 高亮后缀
)

print(f"高亮查询到 {len(results)} 条结果:")
for result in results:
    print(f"ID: {result['id']}")
    # 获取高亮后的标题
    highlighted_title = result.highlighting.get(result['id'], {}).get('title', [result['title']])[0]
    print(f"高亮标题: {highlighted_title}")
    print(f"作者: {result['author']}")
    print("-" * 50)

代码说明

  • hl=True 表示开启高亮功能,hl_fl 指定需要进行高亮处理的字段。
  • hl_simple_prehl_simple_post 分别设置高亮的前缀和后缀,通常用于 HTML 页面展示,让关键词以斜体、加粗等样式显示。
  • result.highlighting 中存储了高亮后的字段内容,需要通过文档 ID 来获取对应字段的高亮结果。

3.5 更新 Solr 中的数据

Solr 的数据更新可以通过 add() 方法结合文档 ID 实现,因为 Solr 会根据 ID 进行覆盖更新,代码示例如下:

import pysolr

SOLR_URL = "http://localhost:8983/solr/gettingstarted"
solr = pysolr.Solr(SOLR_URL, timeout=10)

# 定义需要更新的文档,ID 为已存在的文档 ID
updated_document = {
    "id": "book_002",
    "title": "流畅的Python(第2版)",  # 更新标题
    "author": "卢西亚诺·拉马略",
    "publisher": "人民邮电出版社",
    "publish_date": "2022-01-01",  # 更新出版日期
    "price": 149.0,  # 更新价格
    "tags": ["Python", "进阶", "编程思想", "第2版"]  # 更新标签
}

# 通过 add 方法实现更新,Solr 会根据 ID 覆盖原有文档
solr.add([updated_document])
solr.commit()

print("文档更新成功!")

# 验证更新结果
result = solr.search(q="id:book_002")
for doc in result:
    print(f"更新后的标题: {doc['title']}")
    print(f"更新后的价格: {doc['price']}")
    print(f"更新后的标签: {doc['tags']}")

代码说明
Solr 没有专门的更新方法,而是通过“先删除后添加”的逻辑实现更新,当使用 add() 方法提交一个已存在 ID 的文档时,Solr 会自动删除原有 ID 的文档,然后添加新的文档内容,从而实现更新效果。

四、pysolr 实际应用案例:构建简单的图书检索系统

4.1 案例需求

我们需要构建一个简单的图书检索系统,实现以下功能:

  1. 批量导入图书数据到 Solr。
  2. 支持按书名、作者、出版社进行全文检索。
  3. 支持按价格区间过滤检索结果。
  4. 支持对检索结果按价格排序和分页。
  5. 支持高亮显示检索关键词。

4.2 案例代码实现

import pysolr
from typing import List, Dict, Optional

class BookSearchSystem:
    def __init__(self, solr_url: str, timeout: int = 10):
        """
        初始化图书检索系统
        :param solr_url: Solr Core 的 URL
        :param timeout: HTTP 请求超时时间
        """
        self.solr = pysolr.Solr(solr_url, timeout=timeout)

    def import_books(self, books: List[Dict]) -> None:
        """
        批量导入图书数据到 Solr
        :param books: 图书数据列表
        """
        if not books:
            print("没有需要导入的图书数据!")
            return
        try:
            self.solr.add(books, batch_size=5)
            self.solr.commit()
            print(f"成功导入 {len(books)} 本图书数据!")
        except Exception as e:
            print(f"导入图书数据失败:{e}")

    def search_books(
        self,
        keyword: str,
        field: str = "*",
        min_price: Optional[float] = None,
        max_price: Optional[float] = None,
        sort_by: str = "price asc",
        page: int = 1,
        page_size: int = 3,
        highlight: bool = True
    ) -> pysolr.Results:
        """
        检索图书数据
        :param keyword: 检索关键词
        :param field: 检索的字段,* 表示所有字段
        :param min_price: 最低价格过滤条件
        :param max_price: 最高价格过滤条件
        :param sort_by: 排序方式,如 price desc
        :param page: 页码,从 1 开始
        :param page_size: 每页显示的条数
        :param highlight: 是否开启高亮
        :return: 检索结果
        """
        # 构建查询语句
        if field == "*":
            query = f"{keyword}"
        else:
            query = f"{field}:{keyword}"

        # 构建过滤条件
        filter_queries = []
        if min_price is not None and max_price is not None:
            filter_queries.append(f"price:[{min_price} TO {max_price}]")
        elif min_price is not None:
            filter_queries.append(f"price:[{min_price} TO *]")
        elif max_price is not None:
            filter_queries.append(f"price:[* TO {max_price}]")

        # 计算分页参数
        start = (page - 1) * page_size

        # 构建高亮参数
        hl_params = {}
        if highlight:
            hl_params = {
                "hl": True,
                "hl_fl": field if field != "*" else "title,author,publisher",
                "hl_simple_pre": "<strong>",
                "hl_simple_post": "</strong>"
            }

        # 执行查询
        results = self.solr.search(
            q=query,
            fq=filter_queries,
            sort=sort_by,
            start=start,
            rows=page_size,
            **hl_params
        )

        return results

    def display_results(self, results: pysolr.Results) -> None:
        """
        展示检索结果
        :param results: 检索结果对象
        """
        if not results:
            print("没有查询到符合条件的图书!")
            return
        print(f"\n共查询到 {results.hits} 本符合条件的图书,当前显示第 {(results.start // results.rows) + 1} 页:")
        print("=" * 80)
        for idx, result in enumerate(results, start=1):
            book_id = result['id']
            # 获取高亮内容
            highlighting = result.highlighting.get(book_id, {})
            title = highlighting.get('title', [result.get('title', '未知标题')])[0]
            author = highlighting.get('author', [result.get('author', '未知作者')])[0]
            publisher = highlighting.get('publisher', [result.get('publisher', '未知出版社')])[0]
            price = result.get('price', 0.0)

            print(f"[{idx}] ID: {book_id}")
            print(f"标题: {title}")
            print(f"作者: {author}")
            print(f"出版社: {publisher}")
            print(f"价格: {price} 元")
            print("-" * 80)

# 测试图书检索系统
if __name__ == "__main__":
    # Solr Core URL
    SOLR_CORE_URL = "http://localhost:8983/solr/gettingstarted"

    # 初始化系统
    book_system = BookSearchSystem(SOLR_CORE_URL)

    # 准备测试图书数据
    test_books = [
        {"id": "b1001", "title": "Python编程:从入门到实践", "author": "埃里克·马瑟斯", "publisher": "人民邮电出版社", "price": 59.8, "tags": ["Python", "入门"]},
        {"id": "b1002", "title": "流畅的Python", "author": "卢西亚诺·拉马略", "publisher": "人民邮电出版社", "price": 129.0, "tags": ["Python", "进阶"]},
        {"id": "b1003", "title": "Python数据分析与挖掘实战", "author": "张良均", "publisher": "机械工业出版社", "price": 79.0, "tags": ["Python", "数据分析"]},
        {"id": "b1004", "title": "深度学习入门:基于Python的理论与实现", "author": "斋藤康毅", "publisher": "人民邮电出版社", "price": 69.0, "tags": ["Python", "AI"]},
        {"id": "b1005", "title": "Python爬虫开发与项目实战", "author": "范传辉", "publisher": "机械工业出版社", "price": 65.0, "tags": ["Python", "爬虫"]},
        {"id": "b1006", "title": "Java编程思想", "author": "布鲁斯·埃克尔", "publisher": "机械工业出版社", "price": 109.0, "tags": ["Java", "进阶"]},
        {"id": "b1007", "title": "Python Web开发实战", "author": "陶俊杰", "publisher": "清华大学出版社", "price": 89.0, "tags": ["Python", "Web"]},
        {"id": "b1008", "title": "数据结构与算法分析:Python语言描述", "author": "马克·艾伦·维斯", "publisher": "机械工业出版社", "price": 75.0, "tags": ["Python", "算法"]}
    ]

    # 批量导入图书数据
    book_system.import_books(test_books)

    # 测试检索功能:搜索标题包含"Python",价格在 50-100 之间的图书,按价格降序排序,第1页,每页3条
    search_results = book_system.search_books(
        keyword="Python",
        field="title",
        min_price=50.0,
        max_price=100.0,
        sort_by="price desc",
        page=1,
        page_size=3,
        highlight=True
    )

    # 展示检索结果
    book_system.display_results(search_results)

    # 测试检索功能:搜索作者包含"张良均"的图书
    print("\n\n===== 按作者检索 =====")
    author_results = book_system.search_books(keyword="张良均", field="author")
    book_system.display_results(author_results)

4.3 案例运行说明

  1. 运行该代码前,需要确保 Solr 服务器已启动,且对应的 Core 已创建。
  2. 代码中定义了 BookSearchSystem 类,封装了图书数据的导入和检索功能,便于复用和维护。
  3. 测试部分首先初始化系统,然后导入测试图书数据,接着执行两次检索操作,分别按标题和作者检索,并展示结果。
  4. 检索结果中,匹配的关键词会被 <strong> 标签包裹,在 HTML 页面中展示时会呈现为加粗样式。

五、pysolr 相关资源链接

  • PyPI 地址:https://pypi.org/project/pysolr
  • Github 地址:https://github.com/django-haystack/pysolr
  • 官方文档地址:https://pysolr.readthedocs.io/en/latest/

关注我,每天分享一个实用的Python自动化工具。

Python实用工具Piccolo详解:轻量级ORM的高效使用指南

一、Piccolo库核心概述

Piccolo是一款专为Python开发者设计的轻量级异步ORM(对象关系映射)框架,主要用于简化数据库的操作流程,支持PostgreSQL、SQLite等主流数据库,同时兼容同步与异步编程模式。其工作原理是将Python类映射为数据库表,通过面向对象的语法替代原生SQL语句,降低数据库操作的复杂度。

该库的优点突出:异步特性适配高并发场景,语法简洁易上手,支持自动生成迁移文件,且体积小巧、无过多依赖;缺点则是生态相较于Django ORM、SQLAlchemy更小众,部分高级功能有待完善。Piccolo采用MIT开源许可证,允许开发者自由使用、修改和分发,无商业使用限制。

二、Piccolo库安装与环境配置

2.1 安装命令

Piccolo支持通过pip包管理器一键安装,无论是同步环境还是异步环境,安装命令一致。打开终端,输入以下命令:

pip install piccolo

安装完成后,可通过以下命令验证安装是否成功:

python -m piccolo --version

若终端输出Piccolo的版本号,说明安装成功。

2.2 支持的Python与数据库版本

  • Python版本:推荐Python 3.8及以上版本,确保异步特性和语法兼容性。
  • 数据库版本:
  • SQLite:3.20.0及以上版本(无需额外配置,开箱即用);
  • PostgreSQL:10.0及以上版本(需提前安装并启动数据库服务)。

三、Piccolo核心功能与基础用法

3.1 定义数据库表模型

Piccolo的核心是通过Table类定义数据库表结构,每个类属性对应表中的一个字段。我们以创建一个“用户信息表”为例,演示模型定义的方法。

3.1.1 基础模型定义代码

from piccolo.table import Table
from piccolo.columns import Varchar, Int, Boolean, Timestamp

class User(Table):
    """
    用户信息表模型
    字段说明:
    - username: 用户名,字符串类型,长度50,非空且唯一
    - age: 年龄,整数类型
    - is_active: 是否激活,布尔类型,默认值为True
    - create_time: 创建时间,时间戳类型,默认自动填充当前时间
    """
    username = Varchar(length=50, null=False, unique=True)
    age = Int(null=True)
    is_active = Boolean(default=True)
    create_time = Timestamp(default=lambda: datetime.now())

代码说明

  1. 导入Table基类和需要的字段类型(VarcharInt等);
  2. 定义User类并继承Table,该类会被映射为数据库中的user表;
  3. 每个类属性对应表字段,通过参数指定字段约束(如null=False表示非空,unique=True表示唯一)。

3.1.2 字段类型与常用约束

Piccolo提供了丰富的字段类型,满足不同业务需求,常见字段及约束如下:
| 字段类型 | 作用 | 常用约束 |
|-||-|
| Varchar | 字符串类型 | length(长度)、null(是否允许空)、unique(是否唯一) |
| Int | 整数类型 | default(默认值)、choices(可选值列表) |
| Boolean | 布尔类型 | default(默认值) |
| Timestamp | 时间戳类型 | default(默认值,支持lambda函数) |
| ForeignKey | 外键类型 | references(关联的表模型) |

3.2 生成数据库迁移文件

在Piccolo中,模型定义完成后,需要生成迁移文件来创建对应的数据库表。迁移文件是数据库结构变更的记录,确保不同环境下的数据库结构一致。

3.2.1 初始化迁移环境

首先,在项目根目录下执行以下命令,初始化Piccolo的配置文件和迁移目录:

piccolo project new my_project

执行完成后,项目会生成piccolo_conf.py配置文件和migrations目录。

3.2.2 配置数据库连接

打开piccolo_conf.py文件,修改数据库连接配置。以SQLite为例:

from piccolo.conf.apps import AppConfig
from piccolo.engine.sqlite import SQLiteEngine

# SQLite数据库连接配置
DB = SQLiteEngine(path="my_database.db")

# 注册包含表模型的应用
APP_CONFIG = AppConfig(
    app_name="my_app",
    migrations_folder_path="my_app/migrations",
    table_classes=["my_app.tables.User"],
)

若使用PostgreSQL,配置如下:

from piccolo.engine.postgres import PostgresEngine

DB = PostgresEngine(
    config={
        "database": "my_db",
        "user": "postgres",
        "password": "123456",
        "host": "localhost",
        "port": 5432,
    }
)

3.2.3 创建并应用迁移文件

  1. 生成迁移文件:执行以下命令,Piccolo会自动检测模型变化并生成迁移文件。
piccolo migrations new my_app --auto
  1. 应用迁移文件:将迁移文件中的变更同步到数据库,创建user表。
piccolo migrations forwards my_app

执行成功后,数据库中会生成对应的user表结构。

3.3 数据的增删改查操作

Piccolo支持同步和异步两种数据操作方式,以下分别演示两种模式下的增删改查(CRUD)操作。

3.3.1 同步操作示例

from datetime import datetime
from my_app.tables import User

# 1. 新增数据(Create)
def add_user():
    # 方式一:通过类实例化并保存
    user1 = User(
        username="alice",
        age=25,
        is_active=True,
        create_time=datetime.now()
    )
    user1.save()  # 保存到数据库

    # 方式二:使用create方法直接创建
    User.create(username="bob", age=30, is_active=False)

# 2. 查询数据(Read)
def query_users():
    # 查询所有用户
    all_users = User.objects().all()
    for user in all_users:
        print(f"用户名:{user.username},年龄:{user.age}")

    # 条件查询:查询年龄大于25的激活用户
    active_users = User.objects().where(
        (User.age > 25) & (User.is_active == True)
    )
    print(f"年龄大于25的激活用户数量:{active_users.count()}")

    # 查询单个用户:根据用户名查询
    user = User.objects().get(User.username == "alice")
    print(f"Alice的年龄:{user.age}")

# 3. 更新数据(Update)
def update_user():
    # 修改单个用户的年龄
    user = User.objects().get(User.username == "bob")
    user.age = 31
    user.save()

    # 批量更新:将所有激活用户的年龄加1
    User.objects().where(User.is_active == True).update({User.age: User.age + 1})

# 4. 删除数据(Delete)
def delete_user():
    # 删除单个用户
    user = User.objects().get(User.username == "bob")
    user.remove()

    # 批量删除:删除年龄小于20的用户
    User.objects().where(User.age < 20).remove()

# 执行操作
if __name__ == "__main__":
    add_user()
    query_users()
    update_user()
    delete_user()

代码说明

  • 新增数据:支持实例化对象后save()和直接调用create()两种方式;
  • 查询数据:使用objects()获取查询集,通过where()添加条件,get()查询单条数据,count()统计数量;
  • 更新数据:支持单条数据修改后save()和批量update()
  • 删除数据:支持单条数据remove()和批量删除。

3.3.2 异步操作示例

Piccolo的异步特性基于asyncio实现,适合高并发场景,异步操作的语法与同步操作类似,只需使用async/await关键字。

import asyncio
from datetime import datetime
from my_app.tables import User

# 异步新增数据
async def async_add_user():
    user1 = User(username="charlie", age=28)
    await user1.save()  # 异步保存
    await User.create(username="david", age=22, is_active=False)

# 异步查询数据
async def async_query_users():
    all_users = await User.objects().all()
    for user in all_users:
        print(f"异步查询 - 用户名:{user.username},年龄:{user.age}")

    # 异步条件查询
    active_users = await User.objects().where(User.is_active == True).run()
    print(f"异步查询 - 激活用户数量:{len(active_users)}")

# 异步更新数据
async def async_update_user():
    await User.objects().where(User.username == "charlie").update({User.age: 29})

# 异步删除数据
async def async_delete_user():
    await User.objects().where(User.username == "david").remove()

# 执行异步操作
async def main():
    await async_add_user()
    await async_query_users()
    await async_update_user()
    await async_delete_user()

if __name__ == "__main__":
    asyncio.run(main())

代码说明

  • 异步操作需在async函数中执行,通过await调用Piccolo的异步方法;
  • run()方法用于执行异步查询集,获取结果列表;
  • 最后通过asyncio.run()启动异步事件循环。

四、Piccolo高级功能与应用

4.1 表关联(外键)操作

在实际项目中,表与表之间通常存在关联关系,如“用户表”和“订单表”的一对多关系。以下演示如何通过Piccolo定义外键关联并进行关联查询。

4.1.1 定义关联表模型

from piccolo.table import Table
from piccolo.columns import Varchar, Int, ForeignKey, Decimal
from my_app.tables import User

class Order(Table):
    """
    订单表模型
    外键关联User表,一个用户可以有多个订单
    """
    order_no = Varchar(length=30, unique=True, null=False)  # 订单编号
    amount = Decimal(precision=10, scale=2)  # 订单金额
    user = ForeignKey(references=User)  # 外键关联用户表

# 生成并应用迁移文件,创建order表
# 命令:piccolo migrations new my_app --auto && piccolo migrations forwards my_app

4.1.2 关联查询操作

from my_app.tables import User, Order

# 同步关联查询:查询某个用户的所有订单
def query_user_orders():
    user = User.objects().get(User.username == "alice")
    # 通过外键反向查询用户的订单
    orders = Order.objects().where(Order.user == user)
    for order in orders:
        print(f"用户{user.username}的订单:{order.order_no},金额:{order.amount}")

# 异步关联查询
async def async_query_user_orders():
    user = await User.objects().get(User.username == "alice")
    orders = await Order.objects().where(Order.user == user).run()
    for order in orders:
        print(f"异步查询 - 用户{user.username}的订单:{order.order_no}")

# 执行查询
query_user_orders()
asyncio.run(async_query_user_orders())

代码说明

  • 外键通过ForeignKey字段定义,references参数指定关联的表模型;
  • 关联查询时,可通过外键字段作为条件,查询关联表的数据。

4.2 数据筛选与排序

Piccolo提供了丰富的筛选和排序方法,满足复杂的查询需求。

from my_app.tables import User

# 数据筛选:多条件组合、模糊查询
def filter_users():
    # 模糊查询:用户名包含"li"的用户
    users = User.objects().where(User.username.like("%li%"))

    # 范围查询:年龄在20-30之间的用户
    users = User.objects().where((User.age >= 20) & (User.age <= 30))

    # 排序:按年龄降序排列
    sorted_users = User.objects().order_by(User.age, ascending=False)
    for user in sorted_users:
        print(f"用户名:{user.username},年龄:{user.age}")

filter_users()

代码说明

  • like()方法用于模糊查询,%表示通配符;
  • order_by()方法用于排序,ascending=False表示降序。

4.3 数据库事务操作

事务可以确保一系列数据库操作要么全部成功,要么全部失败,保证数据一致性。Piccolo支持同步和异步事务。

4.3.1 同步事务示例

from piccolo.utils.transaction import transaction
from my_app.tables import User

@transaction()
def transaction_demo():
    # 事务内的操作
    User.create(username="eva", age=24)
    User.create(username="frank", age=26)
    # 若执行过程中抛出异常,事务会回滚
    # 例如:raise Exception("模拟异常,事务回滚")

# 执行事务
transaction_demo()

4.3.2 异步事务示例

from piccolo.utils.transaction import async_transaction

@async_transaction()
async def async_transaction_demo():
    await User.create(username="grace", age=27)
    await User.create(username="henry", age=29)

asyncio.run(async_transaction_demo())

代码说明

  • 同步事务使用@transaction()装饰器,异步事务使用@async_transaction()装饰器;
  • 事务内的所有操作会被包裹,若出现异常则自动回滚。

五、实际项目案例:用户管理系统

我们以一个简单的用户管理系统为例,整合Piccolo的核心功能,实现用户的注册、查询、更新和删除功能。

5.1 项目目录结构

my_user_system/
├── my_app/
│   ├── __init__.py
│   ├── tables.py       # 表模型定义
│   └── operations.py   # 业务逻辑操作
├── piccolo_conf.py     # Piccolo配置文件
└── main.py             # 程序入口

5.2 代码实现

5.2.1 tables.py(表模型)

from piccolo.table import Table
from piccolo.columns import Varchar, Int, Boolean, Timestamp
from datetime import datetime

class User(Table):
    username = Varchar(length=50, null=False, unique=True)
    age = Int(null=True)
    is_active = Boolean(default=True)
    create_time = Timestamp(default=lambda: datetime.now())

5.2.2 operations.py(业务逻辑)

import asyncio
from my_app.tables import User

# 同步业务操作
class SyncUserOperations:
    @staticmethod
    def register_user(username, age):
        """用户注册"""
        if User.objects().where(User.username == username).exists():
            print(f"用户名{username}已存在")
            return False
        User.create(username=username, age=age)
        print(f"用户{username}注册成功")
        return True

    @staticmethod
    def get_user(username):
        """查询用户"""
        try:
            user = User.objects().get(User.username == username)
            return {
                "username": user.username,
                "age": user.age,
                "is_active": user.is_active,
                "create_time": user.create_time
            }
        except Exception:
            return None

    @staticmethod
    def update_user_age(username, new_age):
        """更新用户年龄"""
        user = User.objects().get(User.username == username)
        if not user:
            return False
        user.age = new_age
        user.save()
        return True

# 异步业务操作
class AsyncUserOperations:
    @staticmethod
    async def register_user(username, age):
        if await User.objects().where(User.username == username).exists():
            print(f"用户名{username}已存在")
            return False
        await User.create(username=username, age=age)
        print(f"用户{username}注册成功")
        return True

    @staticmethod
    async def get_user(username):
        try:
            user = await User.objects().get(User.username == username)
            return {
                "username": user.username,
                "age": user.age,
                "is_active": user.is_active,
                "create_time": user.create_time
            }
        except Exception:
            return None

5.2.3 main.py(程序入口)

import asyncio
from my_app.operations import SyncUserOperations, AsyncUserOperations

# 同步操作演示
def sync_demo():
    SyncUserOperations.register_user("user1", 22)
    SyncUserOperations.register_user("user1", 23)  # 重复注册
    user = SyncUserOperations.get_user("user1")
    print(f"查询用户:{user}")
    SyncUserOperations.update_user_age("user1", 24)
    user = SyncUserOperations.get_user("user1")
    print(f"更新后用户信息:{user}")

# 异步操作演示
async def async_demo():
    await AsyncUserOperations.register_user("user2", 25)
    user = await AsyncUserOperations.get_user("user2")
    print(f"异步查询用户:{user}")

if __name__ == "__main__":
    sync_demo()
    asyncio.run(async_demo())

5.3 运行项目

  1. 配置piccolo_conf.py文件,设置数据库连接;
  2. 生成并应用迁移文件:
piccolo migrations new my_app --auto
piccolo migrations forwards my_app
  1. 运行main.py
python main.py

终端会输出用户注册、查询和更新的结果,验证功能的正确性。

六、Piccolo相关资源链接

  • Pypi地址:https://pypi.org/project/piccolos
  • Github地址:https://github.com/xxxxx/xxxxxx
  • 官方文档地址:https://www.xxxxx.com/xxxxxx

关注我,每天分享一个实用的Python自动化工具。

Python Kafka 开发利器:confluent-kafka-python 从入门到实战

一、confluent-kafka-python 核心概述

1.1 库的用途

confluent-kafka-python 是 Confluent 公司推出的 Kafka Python 客户端,基于高性能的 librdkafka C 库封装而成,主要用于在 Python 程序中实现与 Apache Kafka 集群的高效交互,支持生产者(Producer)向 Kafka 发送消息、消费者(Consumer)从 Kafka 订阅并消费消息,同时兼容 Kafka 的各种高级特性,广泛应用于实时数据管道、日志收集、消息队列解耦等场景。

1.2 工作原理

该库的底层依赖 librdkafka,这是一个工业级的 Kafka 客户端库,提供了可靠的消息传输机制。在 Python 层面,confluent-kafka-python 对 librdkafka 的 API 进行了轻量级封装,实现了生产者的消息分区策略、批量发送、消息确认,以及消费者的群组协调、自动提交偏移量、消息回溯等核心功能。其工作流程遵循 Kafka 的标准模型:生产者将消息发送到指定 Topic,Kafka 集群存储消息,消费者订阅 Topic 并拉取消息进行处理。

1.3 优缺点分析

优点

  • 性能优异:基于 C 语言的 librdkafka,吞吐量和延迟表现远超纯 Python 实现的 Kafka 客户端(如 kafka-python)。
  • 功能全面:支持 Kafka 的所有核心特性,包括事务消息、压缩算法、SSL 加密、SASL 认证、自定义分区器等。
  • 稳定性高:经过大规模生产环境验证,适合高并发、高可用的场景。
  • 配置灵活:提供丰富的配置参数,可针对生产者和消费者进行精细化调优。

缺点

  • 安装依赖:需要系统中安装 librdkafka 库,Windows 平台安装相对复杂。
  • 学习曲线:部分高级配置参数(如分区策略、偏移量管理)需要对 Kafka 原理有一定理解。
  • 跨平台兼容:在一些小众操作系统上可能存在编译问题,需要手动调整编译参数。

1.4 License 类型

confluent-kafka-python 采用 Apache License 2.0 开源协议,允许用户自由使用、修改和分发代码,可用于商业项目,只需保留原作者的版权声明。

二、confluent-kafka-python 安装与环境准备

2.1 系统依赖安装

由于 confluent-kafka-python 依赖 librdkafka,在安装 Python 包之前需要先安装系统级的 librdkafka 库。

2.1.1 Linux 系统(Ubuntu/Debian)

sudo apt-get update
sudo apt-get install librdkafka-dev

2.1.2 Linux 系统(CentOS/RHEL)

sudo yum install librdkafka-devel

2.1.3 macOS 系统

使用 Homebrew 安装:

brew install librdkafka

2.1.4 Windows 系统

Windows 平台安装相对复杂,推荐两种方式:

  1. 使用预编译的二进制包:从 librdkafka 官网 下载预编译的 Windows 版本,解压后将库文件路径添加到系统环境变量 PATH 中。
  2. 使用 WSL(Windows Subsystem for Linux):在 WSL 中安装 Linux 版本的依赖,然后在 WSL 中运行 Python 程序。

2.2 Python 包安装

系统依赖安装完成后,使用 pip 安装 confluent-kafka-python:

pip install confluent-kafka

验证安装是否成功:

import confluent_kafka
print(confluent_kafka.__version__)

运行上述代码,如果输出库的版本号(如 2.2.0),则说明安装成功。

三、核心功能实战:生产者与消费者

3.1 Kafka 环境准备

在进行代码实战前,需要确保有一个可用的 Kafka 集群。如果是本地测试,可以使用 Docker 快速启动单节点 Kafka 和 ZooKeeper:

# 启动 ZooKeeper
docker run -d --name zookeeper -p 2181:2181 confluentinc/cp-zookeeper:7.4.0 \
  ZOOKEEPER_CLIENT_PORT=2181 \
  ZOOKEEPER_TICK_TIME=2000

# 启动 Kafka
docker run -d --name kafka -p 9092:9092 --link zookeeper:zookeeper confluentinc/cp-kafka:7.4.0 \
  KAFKA_BROKER_ID=1 \
  KAFKA_ZOOKEEPER_CONNECT=zookeeper:2181 \
  KAFKA_LISTENER_SECURITY_PROTOCOL_MAP=PLAINTEXT:PLAINTEXT \
  KAFKA_ADVERTISED_LISTENERS=PLAINTEXT://localhost:9092 \
  KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR=1

上述命令启动了一个单节点 Kafka 集群,监听本地 9092 端口。

3.2 生产者(Producer)实战

生产者的核心功能是向 Kafka Topic 发送消息。confluent-kafka-python 提供了 Producer 类,支持同步发送、异步发送、批量发送等多种模式。

3.2.1 基础异步生产者

异步发送是 Kafka 生产者的默认模式,特点是无需等待消息发送结果,通过回调函数处理发送成功或失败的通知,效率更高。

代码示例

from confluent_kafka import Producer
import json
import time

# 1. 配置生产者参数
producer_config = {
    "bootstrap.servers": "localhost:9092",  # Kafka 集群地址
    "client.id": "python-producer-demo",    # 客户端标识
    "acks": "1",                            # 消息确认级别:1 表示 leader 确认即可
    "retries": 3,                           # 发送失败重试次数
    "linger.ms": 5,                         # 批量发送延迟时间(毫秒)
    "compression.type": "gzip"              # 消息压缩算法
}

# 2. 初始化生产者
producer = Producer(producer_config)

# 3. 定义发送结果回调函数
def delivery_report(err, msg):
    """
    消息发送结果回调函数
    :param err: 发送失败时的错误信息,成功时为 None
    :param msg: 发送成功的消息元数据
    """
    if err is not None:
        print(f"消息发送失败: {err}")
    else:
        print(f"消息发送成功 -> Topic: {msg.topic()}, Partition: {msg.partition()}, Offset: {msg.offset()}")

# 4. 发送消息
topic = "test_topic"  # 目标 Topic

# 循环发送 10 条测试消息
for i in range(10):
    # 构造消息内容
    message_data = {
        "id": i,
        "content": f"Hello Kafka from Python - {i}",
        "timestamp": time.time()
    }
    # 将字典转换为 JSON 字符串
    message_value = json.dumps(message_data).encode("utf-8")

    # 发送消息:key 用于分区路由,value 为消息内容
    producer.produce(
        topic=topic,
        key=str(i).encode("utf-8"),
        value=message_value,
        on_delivery=delivery_report
    )

    # 触发消息发送(异步模式下需要定期调用 poll 处理事件)
    producer.poll(0)

    # 模拟业务延迟
    time.sleep(0.5)

# 5. 等待所有待发送消息完成
producer.flush()
print("所有消息发送完成!")

代码说明

  • 配置参数bootstrap.servers 指定 Kafka 集群地址,acks 设置消息确认级别(0=无确认,1=leader 确认,all=所有副本确认),retries 设置重试次数,linger.ms 控制批量发送的延迟时间,compression.type 启用 gzip 压缩以减少网络传输量。
  • 回调函数delivery_report 函数用于处理消息发送结果,当消息成功发送或失败时会被调用。
  • produce 方法:用于发送消息,key 会影响消息的分区策略(相同 key 的消息会被发送到同一个分区),value 为消息的二进制内容。
  • poll 方法:异步模式下必须定期调用 poll 方法,处理 Kafka 的事件(如回调函数执行),参数 0 表示非阻塞。
  • flush 方法:等待所有待发送的消息完成发送,确保程序退出前消息不会丢失。

3.2.2 同步生产者

同步发送模式下,程序会阻塞直到收到 Kafka 的确认响应,适合对消息发送结果有强依赖的场景。

代码示例

from confluent_kafka import Producer, KafkaError
import json

# 配置生产者参数
producer_config = {
    "bootstrap.servers": "localhost:9092",
    "acks": "all",  # 最高级别确认,确保消息可靠性
    "retries": 5
}

producer = Producer(producer_config)
topic = "test_topic"

def send_message_sync(topic, key, value):
    """
    同步发送消息
    """
    try:
        # 发送消息并等待结果
        producer.produce(topic, key=key, value=value)
        # 阻塞直到消息发送完成
        producer.flush()
        print("消息同步发送成功")
    except KafkaError as e:
        print(f"消息同步发送失败: {e}")

# 构造消息
message_value = json.dumps({"data": "Sync Message from Python"}).encode("utf-8")
send_message_sync(topic, b"sync_key", message_value)

代码说明

  • 同步发送的核心是调用 flush 方法,该方法会阻塞直到所有待发送消息处理完成。
  • 通过捕获 KafkaError 异常,可以处理发送过程中的错误。

3.3 消费者(Consumer)实战

消费者的核心功能是订阅 Kafka Topic 并拉取消息进行处理。confluent-kafka-python 提供了 Consumer 类,支持消费者群组、自动提交偏移量、手动提交偏移量等功能。

3.3.1 基础消费者(自动提交偏移量)

自动提交偏移量是消费者的默认模式,Kafka 会定期自动将消费者的偏移量提交到集群,简化开发流程。

代码示例

from confluent_kafka import Consumer, KafkaError
import json

# 1. 配置消费者参数
consumer_config = {
    "bootstrap.servers": "localhost:9092",
    "group.id": "python-consumer-group",  # 消费者群组 ID
    "auto.offset.reset": "earliest",      # 当没有初始偏移量时,从最早的消息开始消费
    "enable.auto.commit": True,           # 启用自动提交偏移量
    "auto.commit.interval.ms": 5000       # 自动提交间隔时间(毫秒)
}

# 2. 初始化消费者
consumer = Consumer(consumer_config)

# 3. 订阅 Topic
topic = "test_topic"
consumer.subscribe([topic])
print(f"消费者已订阅 Topic: {topic}")

# 4. 消费消息
try:
    while True:
        # 拉取消息,超时时间设置为 1 秒
        msg = consumer.poll(timeout=1.0)

        # 如果没有消息,继续循环
        if msg is None:
            continue

        # 处理错误
        if msg.error():
            # 处理分区 EOF 事件
            if msg.error().code() == KafkaError._PARTITION_EOF:
                print(f"已到达分区末尾 -> Topic: {msg.topic()}, Partition: {msg.partition()}, Offset: {msg.offset()}")
            else:
                print(f"消费消息出错: {msg.error()}")
            continue

        # 处理正常消息
        key = msg.key().decode("utf-8") if msg.key() else None
        value = json.loads(msg.value().decode("utf-8"))
        print(f"消费到消息 -> Key: {key}, Value: {value}, Topic: {msg.topic()}, Partition: {msg.partition()}, Offset: {msg.offset()}")

except KeyboardInterrupt:
    print("用户中断消费")
finally:
    # 关闭消费者,提交最后一次偏移量
    consumer.close()
    print("消费者已关闭")

代码说明

  • 配置参数group.id 指定消费者群组 ID,同一群组的消费者会负载均衡消费 Topic 的分区;auto.offset.reset 设置当消费者没有初始偏移量时的策略(earliest 从最早消息开始,latest 从最新消息开始);enable.auto.commit 启用自动提交,auto.commit.interval.ms 设置自动提交的间隔时间。
  • subscribe 方法:订阅一个或多个 Topic,支持正则表达式(如 subscribe(["test_*"]))。
  • poll 方法:拉取消息,timeout 参数设置超时时间(毫秒),超时后返回 None。
  • 消息处理:通过 msg.key()msg.value() 获取消息的键和值,需要进行解码;msg.error() 用于判断消息是否有错误,KafkaError._PARTITION_EOF 表示到达分区末尾。

3.3.2 高级消费者(手动提交偏移量)

手动提交偏移量可以更精确地控制消息的消费进度,确保消息被成功处理后再提交偏移量,避免消息丢失。适合对数据一致性要求高的场景(如金融交易、订单处理)。

代码示例

from confluent_kafka import Consumer, KafkaError, TopicPartition
import json

# 1. 配置消费者参数(关闭自动提交)
consumer_config = {
    "bootstrap.servers": "localhost:9092",
    "group.id": "python-manual-commit-group",
    "auto.offset.reset": "earliest",
    "enable.auto.commit": False  # 关闭自动提交
}

# 2. 初始化消费者
consumer = Consumer(consumer_config)

# 3. 订阅 Topic
topic = "test_topic"
consumer.subscribe([topic])
print(f"手动提交消费者已订阅 Topic: {topic}")

# 4. 消费消息
try:
    while True:
        msg = consumer.poll(timeout=1.0)
        if msg is None:
            continue

        if msg.error():
            if msg.error().code() == KafkaError._PARTITION_EOF:
                print(f"分区末尾 -> {msg.topic()}-{msg.partition()}:{msg.offset()}")
            else:
                print(f"消费错误: {msg.error()}")
            continue

        # 处理消息
        key = msg.key().decode("utf-8") if msg.key() else None
        value = json.loads(msg.value().decode("utf-8"))
        print(f"消费到消息 -> Key: {key}, Value: {value}")

        # 模拟业务处理(如写入数据库、调用 API)
        # 假设这里的业务逻辑执行成功
        print("业务逻辑处理成功,准备提交偏移量")

        # 5. 手动提交偏移量
        # 方式 1:提交当前消费的消息偏移量
        consumer.commit(msg)
        print(f"偏移量提交成功 -> Topic: {msg.topic()}, Partition: {msg.partition()}, Offset: {msg.offset() + 1}")

        # 方式 2:提交指定分区的偏移量(批量提交)
        # partitions = [TopicPartition(topic, msg.partition(), msg.offset() + 1)]
        # consumer.commit(partitions=partitions)

except KeyboardInterrupt:
    print("用户中断消费")
finally:
    consumer.close()
    print("消费者已关闭")

代码说明

  • 关闭自动提交:将 enable.auto.commit 设置为 False,禁用自动提交功能。
  • 手动提交方式
  1. consumer.commit(msg):提交当前消费的消息的偏移量,Kafka 会记录该消费者群组在对应分区的偏移量为 msg.offset() + 1(下一次从该偏移量开始消费)。
  2. consumer.commit(partitions=partitions):批量提交多个分区的偏移量,适合批量处理消息的场景。
  • 业务一致性:手动提交偏移量的核心优势是可以确保消息被成功处理后再提交,避免因程序崩溃导致的消息丢失。例如,在将消息写入数据库并确认写入成功后,再提交偏移量。

3.4 消费者群组与分区分配

Kafka 的消费者群组机制可以实现消息的负载均衡,当多个消费者属于同一个 group.id 时,Kafka 会将 Topic 的分区均匀分配给群组内的消费者。

示例场景
假设 test_topic 有 3 个分区,启动 2 个消费者属于同一个群组,则分区分配可能为:消费者 1 分配 2 个分区,消费者 2 分配 1 个分区。当新增一个消费者时,Kafka 会触发分区再平衡,将分区重新分配为每个消费者 1 个分区。

代码验证
启动多个上述的消费者实例(保持 group.id 相同),然后通过生产者发送消息,可以看到不同消费者消费不同分区的消息。

四、高级特性实战

4.1 事务消息

事务消息可以确保生产者发送的多条消息原子性地提交到 Kafka,同时确保消费者只消费已提交的事务消息,适合需要跨多个 Topic 或分区发送消息的场景(如分布式事务)。

代码示例

from confluent_kafka import Producer, Consumer, KafkaError, KafkaException
import json

# 生产者配置(启用事务)
producer_config = {
    "bootstrap.servers": "localhost:9092",
    "client.id": "transactional-producer",
    "acks": "all",
    "transactional.id": "test-transaction-id"  # 事务 ID,确保生产者故障恢复后的幂等性
}

# 初始化生产者并初始化事务
producer = Producer(producer_config)
producer.init_transactions()

try:
    # 开始事务
    producer.begin_transaction()

    # 发送多条消息到不同 Topic
    topic1 = "topic_tx_1"
    topic2 = "topic_tx_2"

    # 发送第一条消息
    producer.produce(topic1, value=b"Transaction Message 1", on_delivery=delivery_report)
    producer.poll(0)

    # 发送第二条消息
    producer.produce(topic2, value=b"Transaction Message 2", on_delivery=delivery_report)
    producer.poll(0)

    # 提交事务
    producer.commit_transaction()
    print("事务提交成功")

except KafkaException as e:
    print(f"事务执行失败,开始回滚: {e}")
    # 回滚事务
    producer.abort_transaction()

finally:
    producer.flush()

代码说明

  • 事务配置:通过 transactional.id 启用事务功能,同一个 transactional.id 的生产者可以确保故障恢复后的幂等性。
  • 事务流程init_transactions 初始化事务,begin_transaction 开始事务,commit_transaction 提交事务,abort_transaction 回滚事务。
  • 消费者事务隔离:消费者可以通过设置 isolation.level 参数控制是否消费未提交的事务消息,read_committed 表示只消费已提交的消息,read_uncommitted 表示消费所有消息。

4.2 SSL 加密与 SASL 认证

在生产环境中,Kafka 集群通常需要启用 SSL 加密和 SASL 认证,以确保数据传输的安全性和访问控制。

生产者配置示例(SASL/PLAIN 认证 + SSL 加密)

producer_config = {
    "bootstrap.servers": "kafka-cluster:9093",
    "security.protocol": "SASL_SSL",
    "sasl.mechanism": "PLAIN",
    "sasl.username": "kafka_user",
    "sasl.password": "kafka_password",
    "ssl.ca.location": "/path/to/ca.pem",  # CA 证书路径
    "ssl.certificate.location": "/path/to/client-cert.pem",  # 客户端证书路径
    "ssl.key.location": "/path/to/client-key.pem"  # 客户端私钥路径
}

producer = Producer(producer_config)

代码说明

  • security.protocol 设置为 SASL_SSL,表示启用 SASL 认证和 SSL 加密。
  • sasl.mechanism 指定 SASL 机制(如 PLAIN、SCRAM-SHA-256)。
  • ssl.ca.location 指定 CA 证书路径,用于验证 Kafka 服务端证书。
  • ssl.certificate.locationssl.key.location 指定客户端证书和私钥,用于双向认证。

五、实际业务案例:实时日志收集系统

5.1 案例背景

某电商平台需要构建一个实时日志收集系统,将用户行为日志(如浏览、点击、下单)从各个业务服务器收集到 Kafka,然后由下游的数据分析系统消费并处理这些日志。

5.2 系统架构

  1. 生产者端:业务服务器上的 Python 脚本收集用户行为日志,发送到 Kafka Topic user_behavior_topic
  2. Kafka 集群:存储用户行为日志,提供高吞吐量和高可用性。
  3. 消费者端:数据分析系统的 Python 脚本消费 user_behavior_topic 的日志,进行实时统计和存储。

5.3 生产者代码实现

from confluent_kafka import Producer
import json
import time
import random

# 生产者配置
producer_config = {
    "bootstrap.servers": "localhost:9092",
    "acks": "1",
    "retries": 3,
    "linger.ms": 10,
    "compression.type": "lz4"
}

producer = Producer(producer_config)

# 回调函数
def delivery_report(err, msg):
    if err:
        print(f"日志发送失败: {err}")
    else:
        print(f"日志发送成功 -> Topic: {msg.topic()}, Offset: {msg.offset()}")

# 模拟用户行为日志
def generate_user_behavior_log():
    user_ids = [f"user_{i}" for i in range(1000)]
    behaviors = ["view", "click", "add_cart", "purchase"]
    products = [f"product_{i}" for i in range(100)]

    return {
        "user_id": random.choice(user_ids),
        "behavior": random.choice(behaviors),
        "product_id": random.choice(products),
        "timestamp": int(time.time() * 1000),
        "ip": f"{random.randint(1, 255)}.{random.randint(1, 255)}.{random.randint(1, 255)}.{random.randint(1, 255)}"
    }

# 发送日志到 Kafka
topic = "user_behavior_topic"

try:
    while True:
        # 生成一条用户行为日志
        log_data = generate_user_behavior_log()
        log_value = json.dumps(log_data).encode("utf-8")

        # 发送日志
        producer.produce(topic, value=log_value, on_delivery=delivery_report)
        producer.poll(0)

        # 模拟每秒生成 10 条日志
        time.sleep(0.1)
except KeyboardInterrupt:
    producer.flush()
    print("日志生产者已停止")

5.4 消费者代码实现

from confluent_kafka import Consumer, KafkaError
import json
import pandas as pd
from collections import defaultdict

# 消费者配置
consumer_config = {
    "bootstrap.servers": "localhost:9092",
    "group.id": "user_behavior_consumer_group",
    "auto.offset.reset": "earliest",
    "enable.auto.commit": False
}

consumer = Consumer(consumer_config)
consumer.subscribe(["user_behavior_topic"])

# 统计用户行为次数
behavior_count = defaultdict(int)
# 批量处理消息的阈值
BATCH_SIZE = 100
batch_messages = []

try:
    while True:
        msg = consumer.poll(timeout=1.0)
        if msg is None:
            continue

        if msg.error():
            if msg.error().code() != KafkaError._PARTITION_EOF:
                print(f"消费错误: {msg.error()}")
            continue

        # 解析日志消息
        log_data = json.loads(msg.value().decode("utf-8"))
        batch_messages.append(log_data)

        # 当批量消息达到阈值时,进行统计处理
        if len(batch_messages) >= BATCH_SIZE:
            # 转换为 DataFrame 进行分析
            df = pd.DataFrame(batch_messages)
            # 统计每种行为的次数
            behavior_stats = df["behavior"].value_counts()
            # 更新全局统计结果
            for behavior, count in behavior_stats.items():
                behavior_count[behavior] += count

            print("=" * 50)
            print("用户行为统计结果:")
            for behavior, count in behavior_count.items():
                print(f"{behavior}: {count}")
            print("=" * 50)

            # 提交偏移量
            consumer.commit(msg)
            # 清空批量消息列表
            batch_messages = []

except KeyboardInterrupt:
    print("用户中断消费")
finally:
    # 处理剩余的消息
    if batch_messages:
        df = pd.DataFrame(batch_messages)
        behavior_stats = df["behavior"].value_counts()
        for behavior, count in behavior_stats.items():
            behavior_count[behavior] += count
        print("最终统计结果:")
        for behavior, count in behavior_count.items():
            print(f"{behavior}: {count}")
    consumer.close()

5.5 案例总结

该案例利用 confluent-kafka-python 的高性能特性,实现了大规模日志的实时收集和处理。生产者端通过批量发送和压缩提高了发送效率,消费者端通过批量处理和手动提交偏移量确保了数据处理的准确性和效率。同时,该系统具有良好的扩展性,新增业务服务器只需部署生产者脚本,新增数据分析任务只需新增消费者群组。

六、相关资源链接

  • Pypi地址:https://pypi.org/project/confluent-kafka
  • Github地址:https://github.com/confluentinc/confluent-kafka-python
  • 官方文档地址:https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html

关注我,每天分享一个实用的Python自动化工具。