Explore the Read Parquet Best Practice Using Polars#

Keywords: Read, Parquet, Python, Polars, AWS, S3, Lambda

目标#

Parquet 作为主流大数据存储格式, 它的 Columnar data format 和 Row Group 的两个特性可以大大提升读性能. 这里对这两个性能做一个简单的介绍:

  • Columnar data format: 在磁盘中一列中的数据都是顺序存储的, 所以你如果只需要数据集中的指定列, 你可以只读取这一列的数据, 而不需要读取其他列的数据. 这样可以大大减少磁盘 IO. 此外还有为数据类型相同的一列数据进行压缩等特点这里我们不展开说了, 因为它们主要是为了节省空间, 而不是节约读性能.

  • Row group: 在磁盘中数据被按照行分为了 Row Group (RG), 每一个 RG 在操作系统的磁盘上就是一个 Page. 而一个 RG 中包含了每个 Column 的统计信息, 例如最大和最小值. 这样如果你读数据的时候需要根据值进行 filter, 那么也可以跳过很多的 Row Group 从而大大减少了磁盘 IO.

本文档对 Parquet 的读操作进行了详细的测试, 希望能探索出最佳的读取 Parquet 文件的方法. 本文的探索主要针对位于本地磁盘上的 Parquet 文件, 和位于 AWS S3 上的 Parquet.

实验设计#

我们创建了一个测试数据集. 由于在大数据领域, 一个 Parquet 文件通常推荐保持在 100 MB ~ 1GB 之间, 一个 Row Group 通常在 64MB 左右. 所以我们创建了如下数据集:

  • 有 10M 行, 10 列.

  • 每一列都是 1 - 1M 的随机整数.

  • 数据分为 10 个 Row Group, 每个行组有 1M 行数据.

  • 数据经过 snappy 压缩后约 500MB.

我们希望读取 col 1 和 col2 (col 3 - 10 不要), 并只返回 col 1 的值在 1 ~ 1000 之间的结果.

这个文件我们一份放在了本地磁盘上, 一份放在了 AWS S3 上, 用于模拟生产环境中通过网络 IO 的应用场景.

  • 本地电脑室 MacBook M1 Pro, 32 G 内存, 10 核 CPU. 对本地文件的读取用 Mac 来进行.

  • 为了减少网络的影响, 我们使用了 Lambda Function 来进行 S3 中的数据的读取. 因为 Lambda Function 会被部署在和 S3 同一个 Region 的网络中, 并且通信走的是内网, 所以网络速度非常稳定, 减少了测量误差, 也更接近生产环境的情况.

我们使用了 polars, 和 pandas + pyarrow 两套常见的数据分析工具来进行测试.

学到的经验#

  • 对于本地文件, polars.scan_parquet > polars.read_parquet > pandas.read_parquet > pandas.read_parquet + use row group filter. 速度级都在 0.1 秒以内. 其中 scan_parquet 达到了 0.02 秒. 并且 scan_parquet 跟其他几个相比存在数量级的差距. 这是因为 scan_parquet 使用了 lazy load 的策略.

  • s3_client.get_object 是单线程, 在个人电脑上大约是 5-10MB / S, 而 s3_client.download_file 是多线程多个 CPU 一起下载, 速度要快很多.

  • s3path.open, smart_open, s3fs.open 这些都是用的单线程

  • 先将文件用 s3_client.download_file 下载到本地, 再进行读取的速度会比直接从 S3 读取要快很多, 因为目前没有多线程创建多个 buffer 来读取 S3 的方法.

  • AWS Lambda 读取 S3 的速度大约是 75MB - 100MB / S, 比个人电脑快多了.

  • 如果用 AWS Lambda 来写 ETL 程序, 可以先将文件用 download_file 下载到 /tmp, 然后再进行读取. Lambda 默认有 500 MB 磁盘大小的限制, 你可以将这个限制提高到最多 10G, 完全足够用了.

代码#

用于探索的 Python 脚本.

  1# -*- coding: utf-8 -*-
  2
  3"""
  4此脚本用于探索读取 Parquet 文件的最佳策略. 其中包含了创建测试数据集, 以及用不同的方式读取,
  5测量时间和内存消耗的代码.
  6
  7Conclusion:
  8
  9- For local parquet file:
 10- 对于本地文件 polars.scan_parquet > polars.read_parquet > pandas.read_parquet
 11    > pandas.read_parquet + use row group filter. 并且 scan_parquet 跟其他几个相比
 12    存在数量级的差距. 这是因为 scan_parquet
 13        使用了 lazy load 的策略.
 14- s3_client.get_object 是单线程, 在个人电脑上大约是 5-10MB / S, 而 s3_client.download_file
 15    是多线程多个 CPU 一起下载, 大约几个核心速度就是原来的几倍. 但是多线程 download 时每个线程
 16    下载的最小 chunk 是 8MB, 也就是说你的文件要大于 8MB * CPU 核心数才能跑满带宽.
 17- s3path.open, smart_open, s3fs.open 这些都是用的单线程, 和 s3_client.get_object 一样.
 18- 如果先将文件用 s3_client.download_file 下载到本地, 再进行读取的速度会比直接从 S3 读取要快很多,
 19    因为目前没有多线程创建多个 buffer 来读取 S3 的方法. 如果你用 s3_client.download_fileobj
 20    方法, 你还能直接将数据写入 buffer 中从而避免了将文件写入磁盘的过程, 速度会更快.
 21- AWS Lambda 读取 S3 的速度大约是 75MB - 100MB / S, 比个人电脑快多了.
 22- AWS Lambda (10GB 内存) 从 S3 上下载 500MB 文件 (多线程) 的速度大约是 6.5 秒.
 23- 用上面的技巧用 AWS Lambda 读 500MB parquet file 并用 filter 的速度大约是 6.5 秒,
 24    也就是说 parse 数据以及 filter 的速度跟 IO 相比可以忽略不计.
 25"""
 26
 27import typing as T
 28import os
 29import io
 30import math
 31import dataclasses
 32from functools import cached_property
 33
 34# 注意, 一般 pl 和 pandas 在 Lambda 上只能 2 选 1. 两个都安装很容易超过 250MB 依赖的限制.
 35import numpy as np
 36import pandas as pd
 37import polars as pl
 38from pathlib_mate import Path
 39from s3pathlib import S3Path, context
 40from boto_session_manager import BotoSesManager
 41from fixa.timer import DateTimeTimer
 42
 43dir_here = Path.dir_here(__file__)
 44IS_LAMBDA = "AWS_LAMBDA_FUNCTION_NAME" in os.environ
 45
 46
 47@dataclasses.dataclass
 48class Config:
 49    """
 50    :param aws_profile: the aws profile you want to use
 51    :param n_col: number of columns in the test dataframe
 52    :param n_row: number of rows in the test dataframe
 53    :param row_group_size: number of rows in each parquet row group
 54    :param n_row_group: number of row groups in the parquet file
 55    """
 56
 57    n_col: int = dataclasses.field()
 58    n_row: int = dataclasses.field()
 59    row_group_size: int = dataclasses.field()
 60    fname: str = dataclasses.field()
 61    aws_profile: T.Optional[str] = dataclasses.field(default=None)
 62
 63    @property
 64    def n_row_group(self):
 65        return int(math.ceil(self.n_row / self.row_group_size))
 66
 67    @cached_property
 68    def bsm(self):
 69        return BotoSesManager(profile_name=self.aws_profile)
 70
 71    @property
 72    def s3dir_root(self) -> S3Path:
 73        return S3Path(
 74            f"s3://{self.bsm.aws_account_id}-{self.bsm.aws_region}-data"
 75            "/projects/explore_the_read_parquet_best_practice_using_polars/"
 76        ).to_dir()
 77
 78    @property
 79    def s3path(self) -> S3Path:
 80        return self.s3dir_root.joinpath(f"{self.fname}.snappy.parquet")
 81
 82    @property
 83    def path(self) -> Path:
 84        if IS_LAMBDA:
 85            return Path("/tmp").joinpath(f"{self.fname}.snappy.parquet")
 86        else:
 87            return dir_here.joinpath(f"{self.fname}.snappy.parquet")
 88
 89    def show(self):
 90        print("--- Project settings")
 91        print(f"aws_profile = {self.aws_profile}")
 92        print(f"n_col = {self.n_col}")
 93        print(f"n_row = {self.n_row}")
 94        print(f"row_group_size = {self.row_group_size}")
 95        print(f"n_row_group = {self.n_row_group}")
 96        print(f"file size = {self.fname}")
 97        print(f"preview s3 file at: {self.s3path.console_url}")
 98        print(f"preview local file at: file://{self.path}")
 99
100
101def timeit(n: int, func):
102    """
103    Measure a callable function's average execution time. The function must
104    return a number (elapsed seconds).
105    """
106    lst = list()
107    for _ in range(n):
108        lst.append(func())
109    lst.sort()
110    if len(lst) >= 3:
111        lst = lst[1:-1]  # ignore the highest and the lowest value
112    elapse = "%.6f" % (sum(lst) / len(lst),)
113    print(f"{n} times average elapse = {elapse}")
114
115
116def create_test_data():
117    df = pl.from_numpy(
118        np.random.randint(
119            1,
120            1000000,
121            (config.n_row, config.n_col),
122        ),
123        schema=[f"col_{i}" for i in range(1, 1 + config.n_col)],
124    )
125    with config.path.open("wb") as f:
126        df.write_parquet(f, compression="snappy", row_group_size=config.row_group_size)
127    print(f"file size = {config.path.size_in_text}")
128    config.s3path.upload_file(config.path, overwrite=True)
129
130
131def _polars_read_parquet(f):
132    df = pl.read_parquet(
133        f,
134        columns=["col_1", "col_2"],
135    ).filter(pl.col("col_1") <= 1000)
136    _ = df.shape
137    # print(df.shape)
138
139
140def _polars_scan_parquet(uri: str):
141    df = (
142        pl.scan_parquet(
143            uri,
144        )
145        .select(
146            "col_1",
147            "col_2",
148        )
149        .filter(pl.col("col_1") <= 1000)
150        .collect()
151    )
152    _ = df.shape
153    # print(df.shape)
154
155
156def _pandas_read_parquet_then_filter(f):
157    df = pd.read_parquet(f, columns=["col_1", "col_2"])
158    df = df[df["col_1"] <= 1000]
159    # print(df.shape)
160
161
162def _pandas_read_parquet_use_filter_while_reading(f):
163    df = pd.read_parquet(
164        f,
165        columns=["col_1", "col_2"],
166        filters=[
167            ("col_1", "<=", 1000),
168        ],
169    )
170    # print(df.shape)
171
172
173def download_s3_file():
174    with DateTimeTimer(
175        # display=True, # show info
176        display=False,  # mute
177    ) as timer:
178        config.path.remove_if_exists()
179        config.bsm.s3_client.download_file(
180            config.s3path.bucket, config.s3path.key, str(config.path)
181        )
182    return timer.elapsed
183
184
185def polars_read_parquet():
186    with DateTimeTimer(
187        # display=True, # show info
188        display=False,  # mute
189    ) as timer:
190        _polars_read_parquet(str(config.path))
191    return timer.elapsed
192
193
194def polars_scan_parquet():
195    with DateTimeTimer(
196        # display=True, # show info
197        display=False,  # mute
198    ) as timer:
199        _polars_scan_parquet(str(config.path))
200    return timer.elapsed
201
202
203def pandas_read_parquet_then_filter():
204    with DateTimeTimer(
205        # display=True, # show info
206        display=False,  # mute
207    ) as timer:
208        _pandas_read_parquet_then_filter(str(config.path))
209    return timer.elapsed
210
211
212def pandas_read_parquet_use_filter_while_reading():
213    with DateTimeTimer(
214        # display=True, # show info
215        display=False,  # mute
216    ) as timer:
217        _pandas_read_parquet_use_filter_while_reading(str(config.path))
218    return timer.elapsed
219
220
221def polars_scan_parquet_from_s3_download_fileobj():
222    with DateTimeTimer(
223        # display=True, # show info
224        display=False,  # mute
225    ) as timer:
226        buffer = io.BytesIO()
227        config.bsm.s3_client.download_fileobj(
228            config.s3path.bucket,
229            config.s3path.key,
230            buffer,
231        )
232        f = io.BytesIO(buffer.getvalue())
233        df = (
234            pl.read_parquet(
235                f,
236            )
237            .select(
238                "col_1",
239                "col_2",
240            )
241            .filter(pl.col("col_1") <= 1000)
242        )
243        _ = df.shape
244        # print(df.shape)
245    return timer.elapsed
246
247
248def pandas_scan_parquet_from_s3_download_fileobj():
249    with DateTimeTimer(
250        # display=True, # show info
251        display=False,  # mute
252    ) as timer:
253        buffer = io.BytesIO()
254        config.bsm.s3_client.download_fileobj(
255            config.s3path.bucket,
256            config.s3path.key,
257            buffer,
258        )
259        f = io.BytesIO(buffer.getvalue())
260        df = pd.read_parquet(
261            f,
262            columns=["col_1", "col_2"],
263            filters=[
264                ("col_1", "<=", 1000),
265            ],
266        )
267        _ = df.shape
268        # print(df.shape)
269    return timer.elapsed
270
271
272# --- measure
273def measure_download_s3_file():
274    print("--- measure_download_s3_file")
275    timeit(3, download_s3_file)
276
277
278def measure_polars_read_parquet_from_local():
279    print("--- measure_polars_read_parquet_from_local")
280    timeit(10, polars_read_parquet)
281
282
283def measure_polars_scan_parquet_from_local():
284    print("--- measure_polars_scan_parquet_from_local")
285    timeit(10, polars_scan_parquet)
286
287
288def measure_pandas_read_parquet_then_filter():
289    print("--- measure_pandas_read_parquet_then_filter")
290    timeit(10, pandas_read_parquet_then_filter)
291
292
293def measure_pandas_read_parquet_use_filter_while_reading():
294    print("--- measure_pandas_read_parquet_use_filter_while_reading")
295    timeit(10, pandas_read_parquet_use_filter_while_reading)
296
297
298def measure_polars_scan_parquet_from_s3_download_fileobj():
299    print("--- measure_polars_scan_parquet_from_s3_download_fileobj")
300    timeit(3, polars_scan_parquet_from_s3_download_fileobj)
301
302
303def measure_pandas_scan_parquet_from_s3_download_fileobj():
304    print("--- measure_pandas_scan_parquet_from_s3_download_fileobj")
305    timeit(3, pandas_scan_parquet_from_s3_download_fileobj)
306
307
308# ------------------------------------------------------------------------------
309# measure benchmark
310# ------------------------------------------------------------------------------
311if IS_LAMBDA:
312    aws_profile = None
313else:
314    aws_profile = "awshsh_app_dev_us_east_1"
315config_5mb = Config(
316    aws_profile=aws_profile,
317    n_col=10,
318    n_row=100000,
319    row_group_size=10000,
320    fname="5MB",
321)
322
323config_50mb = Config(
324    aws_profile=aws_profile,
325    n_col=10,
326    n_row=1000000,
327    row_group_size=100000,
328    fname="50MB",
329)
330
331config_500mb = Config(
332    aws_profile=aws_profile,
333    n_col=10,
334    n_row=10000000,
335    row_group_size=1000000,
336    fname="500MB",
337)
338
339# config = config_5mb
340# config = config_50mb
341config = config_500mb
342
343context.attach_boto_session(config.bsm.boto_ses)
344
345config.show()
346print("--- Benchmark result")
347
348
349def lambda_handler(event=None, context=None):
350    # create_test_data()
351    # measure_download_s3_file()
352    measure_polars_read_parquet_from_local()
353    measure_polars_scan_parquet_from_local()
354    # measure_pandas_read_parquet_then_filter()
355    # measure_pandas_read_parquet_use_filter_while_reading()
356    # measure_polars_scan_parquet_from_s3_download_fileobj()
357    # measure_pandas_scan_parquet_from_s3_download_fileobj()
358
359lambda_handler()

依赖列表.