Explore the Read Parquet Best Practice Using Polars#
Keywords: Read, Parquet, Python, Polars, AWS, S3, Lambda
目标#
Parquet 作为主流大数据存储格式, 它的 Columnar data format 和 Row Group 的两个特性可以大大提升读性能. 这里对这两个性能做一个简单的介绍:
Columnar data format: 在磁盘中一列中的数据都是顺序存储的, 所以你如果只需要数据集中的指定列, 你可以只读取这一列的数据, 而不需要读取其他列的数据. 这样可以大大减少磁盘 IO. 此外还有为数据类型相同的一列数据进行压缩等特点这里我们不展开说了, 因为它们主要是为了节省空间, 而不是节约读性能.
Row group: 在磁盘中数据被按照行分为了 Row Group (RG), 每一个 RG 在操作系统的磁盘上就是一个 Page. 而一个 RG 中包含了每个 Column 的统计信息, 例如最大和最小值. 这样如果你读数据的时候需要根据值进行 filter, 那么也可以跳过很多的 Row Group 从而大大减少了磁盘 IO.
本文档对 Parquet 的读操作进行了详细的测试, 希望能探索出最佳的读取 Parquet 文件的方法. 本文的探索主要针对位于本地磁盘上的 Parquet 文件, 和位于 AWS S3 上的 Parquet.
实验设计#
我们创建了一个测试数据集. 由于在大数据领域, 一个 Parquet 文件通常推荐保持在 100 MB ~ 1GB 之间, 一个 Row Group 通常在 64MB 左右. 所以我们创建了如下数据集:
有 10M 行, 10 列.
每一列都是 1 - 1M 的随机整数.
数据分为 10 个 Row Group, 每个行组有 1M 行数据.
数据经过 snappy 压缩后约 500MB.
我们希望读取 col 1 和 col2 (col 3 - 10 不要), 并只返回 col 1 的值在 1 ~ 1000 之间的结果.
这个文件我们一份放在了本地磁盘上, 一份放在了 AWS S3 上, 用于模拟生产环境中通过网络 IO 的应用场景.
本地电脑室 MacBook M1 Pro, 32 G 内存, 10 核 CPU. 对本地文件的读取用 Mac 来进行.
为了减少网络的影响, 我们使用了 Lambda Function 来进行 S3 中的数据的读取. 因为 Lambda Function 会被部署在和 S3 同一个 Region 的网络中, 并且通信走的是内网, 所以网络速度非常稳定, 减少了测量误差, 也更接近生产环境的情况.
我们使用了 polars, 和 pandas + pyarrow 两套常见的数据分析工具来进行测试.
学到的经验#
对于本地文件, polars.scan_parquet > polars.read_parquet > pandas.read_parquet > pandas.read_parquet + use row group filter. 速度级都在 0.1 秒以内. 其中 scan_parquet 达到了 0.02 秒. 并且 scan_parquet 跟其他几个相比存在数量级的差距. 这是因为 scan_parquet 使用了 lazy load 的策略.
s3_client.get_object 是单线程, 在个人电脑上大约是 5-10MB / S, 而 s3_client.download_file 是多线程多个 CPU 一起下载, 速度要快很多.
s3path.open, smart_open, s3fs.open 这些都是用的单线程
先将文件用 s3_client.download_file 下载到本地, 再进行读取的速度会比直接从 S3 读取要快很多, 因为目前没有多线程创建多个 buffer 来读取 S3 的方法.
AWS Lambda 读取 S3 的速度大约是 75MB - 100MB / S, 比个人电脑快多了.
如果用 AWS Lambda 来写 ETL 程序, 可以先将文件用 download_file 下载到 /tmp, 然后再进行读取. Lambda 默认有 500 MB 磁盘大小的限制, 你可以将这个限制提高到最多 10G, 完全足够用了.
代码#
用于探索的 Python 脚本.
1# -*- coding: utf-8 -*-
2
3"""
4此脚本用于探索读取 Parquet 文件的最佳策略. 其中包含了创建测试数据集, 以及用不同的方式读取,
5测量时间和内存消耗的代码.
6
7Conclusion:
8
9- For local parquet file:
10- 对于本地文件 polars.scan_parquet > polars.read_parquet > pandas.read_parquet
11 > pandas.read_parquet + use row group filter. 并且 scan_parquet 跟其他几个相比
12 存在数量级的差距. 这是因为 scan_parquet
13 使用了 lazy load 的策略.
14- s3_client.get_object 是单线程, 在个人电脑上大约是 5-10MB / S, 而 s3_client.download_file
15 是多线程多个 CPU 一起下载, 大约几个核心速度就是原来的几倍. 但是多线程 download 时每个线程
16 下载的最小 chunk 是 8MB, 也就是说你的文件要大于 8MB * CPU 核心数才能跑满带宽.
17- s3path.open, smart_open, s3fs.open 这些都是用的单线程, 和 s3_client.get_object 一样.
18- 如果先将文件用 s3_client.download_file 下载到本地, 再进行读取的速度会比直接从 S3 读取要快很多,
19 因为目前没有多线程创建多个 buffer 来读取 S3 的方法. 如果你用 s3_client.download_fileobj
20 方法, 你还能直接将数据写入 buffer 中从而避免了将文件写入磁盘的过程, 速度会更快.
21- AWS Lambda 读取 S3 的速度大约是 75MB - 100MB / S, 比个人电脑快多了.
22- AWS Lambda (10GB 内存) 从 S3 上下载 500MB 文件 (多线程) 的速度大约是 6.5 秒.
23- 用上面的技巧用 AWS Lambda 读 500MB parquet file 并用 filter 的速度大约是 6.5 秒,
24 也就是说 parse 数据以及 filter 的速度跟 IO 相比可以忽略不计.
25"""
26
27import typing as T
28import os
29import io
30import math
31import dataclasses
32from functools import cached_property
33
34# 注意, 一般 pl 和 pandas 在 Lambda 上只能 2 选 1. 两个都安装很容易超过 250MB 依赖的限制.
35import numpy as np
36import pandas as pd
37import polars as pl
38from pathlib_mate import Path
39from s3pathlib import S3Path, context
40from boto_session_manager import BotoSesManager
41from fixa.timer import DateTimeTimer
42
43dir_here = Path.dir_here(__file__)
44IS_LAMBDA = "AWS_LAMBDA_FUNCTION_NAME" in os.environ
45
46
47@dataclasses.dataclass
48class Config:
49 """
50 :param aws_profile: the aws profile you want to use
51 :param n_col: number of columns in the test dataframe
52 :param n_row: number of rows in the test dataframe
53 :param row_group_size: number of rows in each parquet row group
54 :param n_row_group: number of row groups in the parquet file
55 """
56
57 n_col: int = dataclasses.field()
58 n_row: int = dataclasses.field()
59 row_group_size: int = dataclasses.field()
60 fname: str = dataclasses.field()
61 aws_profile: T.Optional[str] = dataclasses.field(default=None)
62
63 @property
64 def n_row_group(self):
65 return int(math.ceil(self.n_row / self.row_group_size))
66
67 @cached_property
68 def bsm(self):
69 return BotoSesManager(profile_name=self.aws_profile)
70
71 @property
72 def s3dir_root(self) -> S3Path:
73 return S3Path(
74 f"s3://{self.bsm.aws_account_id}-{self.bsm.aws_region}-data"
75 "/projects/explore_the_read_parquet_best_practice_using_polars/"
76 ).to_dir()
77
78 @property
79 def s3path(self) -> S3Path:
80 return self.s3dir_root.joinpath(f"{self.fname}.snappy.parquet")
81
82 @property
83 def path(self) -> Path:
84 if IS_LAMBDA:
85 return Path("/tmp").joinpath(f"{self.fname}.snappy.parquet")
86 else:
87 return dir_here.joinpath(f"{self.fname}.snappy.parquet")
88
89 def show(self):
90 print("--- Project settings")
91 print(f"aws_profile = {self.aws_profile}")
92 print(f"n_col = {self.n_col}")
93 print(f"n_row = {self.n_row}")
94 print(f"row_group_size = {self.row_group_size}")
95 print(f"n_row_group = {self.n_row_group}")
96 print(f"file size = {self.fname}")
97 print(f"preview s3 file at: {self.s3path.console_url}")
98 print(f"preview local file at: file://{self.path}")
99
100
101def timeit(n: int, func):
102 """
103 Measure a callable function's average execution time. The function must
104 return a number (elapsed seconds).
105 """
106 lst = list()
107 for _ in range(n):
108 lst.append(func())
109 lst.sort()
110 if len(lst) >= 3:
111 lst = lst[1:-1] # ignore the highest and the lowest value
112 elapse = "%.6f" % (sum(lst) / len(lst),)
113 print(f"{n} times average elapse = {elapse}")
114
115
116def create_test_data():
117 df = pl.from_numpy(
118 np.random.randint(
119 1,
120 1000000,
121 (config.n_row, config.n_col),
122 ),
123 schema=[f"col_{i}" for i in range(1, 1 + config.n_col)],
124 )
125 with config.path.open("wb") as f:
126 df.write_parquet(f, compression="snappy", row_group_size=config.row_group_size)
127 print(f"file size = {config.path.size_in_text}")
128 config.s3path.upload_file(config.path, overwrite=True)
129
130
131def _polars_read_parquet(f):
132 df = pl.read_parquet(
133 f,
134 columns=["col_1", "col_2"],
135 ).filter(pl.col("col_1") <= 1000)
136 _ = df.shape
137 # print(df.shape)
138
139
140def _polars_scan_parquet(uri: str):
141 df = (
142 pl.scan_parquet(
143 uri,
144 )
145 .select(
146 "col_1",
147 "col_2",
148 )
149 .filter(pl.col("col_1") <= 1000)
150 .collect()
151 )
152 _ = df.shape
153 # print(df.shape)
154
155
156def _pandas_read_parquet_then_filter(f):
157 df = pd.read_parquet(f, columns=["col_1", "col_2"])
158 df = df[df["col_1"] <= 1000]
159 # print(df.shape)
160
161
162def _pandas_read_parquet_use_filter_while_reading(f):
163 df = pd.read_parquet(
164 f,
165 columns=["col_1", "col_2"],
166 filters=[
167 ("col_1", "<=", 1000),
168 ],
169 )
170 # print(df.shape)
171
172
173def download_s3_file():
174 with DateTimeTimer(
175 # display=True, # show info
176 display=False, # mute
177 ) as timer:
178 config.path.remove_if_exists()
179 config.bsm.s3_client.download_file(
180 config.s3path.bucket, config.s3path.key, str(config.path)
181 )
182 return timer.elapsed
183
184
185def polars_read_parquet():
186 with DateTimeTimer(
187 # display=True, # show info
188 display=False, # mute
189 ) as timer:
190 _polars_read_parquet(str(config.path))
191 return timer.elapsed
192
193
194def polars_scan_parquet():
195 with DateTimeTimer(
196 # display=True, # show info
197 display=False, # mute
198 ) as timer:
199 _polars_scan_parquet(str(config.path))
200 return timer.elapsed
201
202
203def pandas_read_parquet_then_filter():
204 with DateTimeTimer(
205 # display=True, # show info
206 display=False, # mute
207 ) as timer:
208 _pandas_read_parquet_then_filter(str(config.path))
209 return timer.elapsed
210
211
212def pandas_read_parquet_use_filter_while_reading():
213 with DateTimeTimer(
214 # display=True, # show info
215 display=False, # mute
216 ) as timer:
217 _pandas_read_parquet_use_filter_while_reading(str(config.path))
218 return timer.elapsed
219
220
221def polars_scan_parquet_from_s3_download_fileobj():
222 with DateTimeTimer(
223 # display=True, # show info
224 display=False, # mute
225 ) as timer:
226 buffer = io.BytesIO()
227 config.bsm.s3_client.download_fileobj(
228 config.s3path.bucket,
229 config.s3path.key,
230 buffer,
231 )
232 f = io.BytesIO(buffer.getvalue())
233 df = (
234 pl.read_parquet(
235 f,
236 )
237 .select(
238 "col_1",
239 "col_2",
240 )
241 .filter(pl.col("col_1") <= 1000)
242 )
243 _ = df.shape
244 # print(df.shape)
245 return timer.elapsed
246
247
248def pandas_scan_parquet_from_s3_download_fileobj():
249 with DateTimeTimer(
250 # display=True, # show info
251 display=False, # mute
252 ) as timer:
253 buffer = io.BytesIO()
254 config.bsm.s3_client.download_fileobj(
255 config.s3path.bucket,
256 config.s3path.key,
257 buffer,
258 )
259 f = io.BytesIO(buffer.getvalue())
260 df = pd.read_parquet(
261 f,
262 columns=["col_1", "col_2"],
263 filters=[
264 ("col_1", "<=", 1000),
265 ],
266 )
267 _ = df.shape
268 # print(df.shape)
269 return timer.elapsed
270
271
272# --- measure
273def measure_download_s3_file():
274 print("--- measure_download_s3_file")
275 timeit(3, download_s3_file)
276
277
278def measure_polars_read_parquet_from_local():
279 print("--- measure_polars_read_parquet_from_local")
280 timeit(10, polars_read_parquet)
281
282
283def measure_polars_scan_parquet_from_local():
284 print("--- measure_polars_scan_parquet_from_local")
285 timeit(10, polars_scan_parquet)
286
287
288def measure_pandas_read_parquet_then_filter():
289 print("--- measure_pandas_read_parquet_then_filter")
290 timeit(10, pandas_read_parquet_then_filter)
291
292
293def measure_pandas_read_parquet_use_filter_while_reading():
294 print("--- measure_pandas_read_parquet_use_filter_while_reading")
295 timeit(10, pandas_read_parquet_use_filter_while_reading)
296
297
298def measure_polars_scan_parquet_from_s3_download_fileobj():
299 print("--- measure_polars_scan_parquet_from_s3_download_fileobj")
300 timeit(3, polars_scan_parquet_from_s3_download_fileobj)
301
302
303def measure_pandas_scan_parquet_from_s3_download_fileobj():
304 print("--- measure_pandas_scan_parquet_from_s3_download_fileobj")
305 timeit(3, pandas_scan_parquet_from_s3_download_fileobj)
306
307
308# ------------------------------------------------------------------------------
309# measure benchmark
310# ------------------------------------------------------------------------------
311if IS_LAMBDA:
312 aws_profile = None
313else:
314 aws_profile = "awshsh_app_dev_us_east_1"
315config_5mb = Config(
316 aws_profile=aws_profile,
317 n_col=10,
318 n_row=100000,
319 row_group_size=10000,
320 fname="5MB",
321)
322
323config_50mb = Config(
324 aws_profile=aws_profile,
325 n_col=10,
326 n_row=1000000,
327 row_group_size=100000,
328 fname="50MB",
329)
330
331config_500mb = Config(
332 aws_profile=aws_profile,
333 n_col=10,
334 n_row=10000000,
335 row_group_size=1000000,
336 fname="500MB",
337)
338
339# config = config_5mb
340# config = config_50mb
341config = config_500mb
342
343context.attach_boto_session(config.bsm.boto_ses)
344
345config.show()
346print("--- Benchmark result")
347
348
349def lambda_handler(event=None, context=None):
350 # create_test_data()
351 # measure_download_s3_file()
352 measure_polars_read_parquet_from_local()
353 measure_polars_scan_parquet_from_local()
354 # measure_pandas_read_parquet_then_filter()
355 # measure_pandas_read_parquet_use_filter_while_reading()
356 # measure_polars_scan_parquet_from_s3_download_fileobj()
357 # measure_pandas_scan_parquet_from_s3_download_fileobj()
358
359lambda_handler()
依赖列表.