|
|
JAX-Fluids 2.0: Towards HPC for differentiable CFD of compressible two-phase flows-
Article in a journal
- | |
|
Area Computational Fluid Dynamics |
Author(s)
Deniz A. Bezgin
, Aaron B. Buhendwa
, Nikolaus A. Adams
|
Published in
Computer Physics Communications |
Year 2025 |
Abstract In our effort to facilitate machine learning-assisted computational fluid dynamics (CFD), we introduce the second iteration of JAX-Fluids. JAX-Fluids is a Python-based fully-differentiable CFD solver designed for compressible single- and two-phase flows. In this work, the first version is extended to incorporate high-performance computing (HPC) capabilities. We introduce a parallelization strategy utilizing JAX primitive operations that scales efficiently on GPU (up to 512 NVIDIA A100 graphics cards) and TPU (up to 1024 TPU v3 cores) HPC systems. We further demonstrate stable parallel computation of automatic differentiation gradients across extended integration trajectories. The new code version offers enhanced two-phase flow modeling capabilities. In particular, a five-equation diffuse-interface model is incorporated which complements the level-set sharp-interface model. Additional algorithmic improvements include positivity-preserving limiters for increased robustness, support for stretched Cartesian meshes, refactored I/O handling, comprehensive post-processing routines, and an updated list of state-of-the-art high-order numerical discretization schemes. We verify newly added numerical models by showcasing simulation results for single- and two-phase flows, including turbulent boundary layer and channel flows, air-helium shock bubble interactions, and air-water shock drop interactions. PROGRAM SUMMARY Program Title: JAX-Fluids CPC Library link to program files: https://doi.org/10.17632/pzvkwn5s6p.2 Developer's repository link: https://github.com/tumaer/JAXFLUIDS Licensing provisions: GPLv3 Programming language: Python, JAX Supplementary material: Source code, example scripts, videos Journal reference of previous version: D.A. Bezgin, A.B. Buhendwa, N.A. Adams, JAX-Fluids: A fully-differentiable high-order computational fluid dynamics solver for compressible two-phase flows, Computer Physics Communications 282 (2022) 108527. Does the new version supersede the previous version?: Yes Reasons for the new version: New features and updates of the CFD solver Summary of revisions:•JAX primitives-based parallelization for GPU and TPU clusters•Automatic differentiation through distributed simulations•Diffuse-interface model for two-phase flows•Positivity-preserving interpolation and flux limiters•Support for stretched Cartesian meshes•Extended list of numerical discretization schemes•Performance improvements•Revised I/O handling Nature of problem: The compressible Navier-Stokes equations describe continuum-scale fluid flows which may exhibit complex phenomena such as shock waves, material interfaces, and turbulence. The accurate numerical solution of fluid flows is computationally expensive and, therefore, requires high-performance computing (HPC) architectures. To this end, machine learning (ML), in particular differentiable programming, is continuously being explored as a tool to accelerate conventional computational fluid dynamics (CFD). With the second iteration of JAX-Fluids, we provide a comprehensive differentiable CFD code that scales efficiently on HPC systems, seamlessly integrates ML models, and accurately simulates complex flow physics with high-order low-dissipative numerical methods. Solution method: JAX-Fluids is a finite-volume solver which uses high-order low-dissipative shock capturing schemes in combination with approximate Riemann solvers. Two-phase flows can be simulated using the sharp-interface level-set method or the diffuse-interface five-equation model. The code is written in Python and builds on the JAX library. The JAX backend allows the computation of automatic differentiation gradients. We use a homogenous domain decomposition ansatz to implement the parallelization. An object-oriented programming style and a modular design philosophy allow exchanging numerical schemes and integrating custom subroutines. Additional comments including restrictions and unusual features: JAX-Fluids runs on CPUs, GPUs, and TPUs in single- and multi-device settings. JAX-Fluids requires open-source third-party Python libraries which are automatically installed. The solver has been tested on Linux and macOS operating systems. |
AD Tools JAX |
BibTeX
@ARTICLE{
Bezgin2025JFT,
title = "{JAX-Fluids} 2.0: Towards {HPC} for differentiable {CFD} of compressible two-phase
flows",
journal = "Computer Physics Communications",
volume = "308",
pages = "109433",
year = "2025",
issn = "0010-4655",
doi = "10.1016/j.cpc.2024.109433",
author = "Deniz A. Bezgin and Aaron B. Buhendwa and Nikolaus A. Adams",
keywords = "Computational fluid dynamics, Machine learning, Differential programming,
High-performance computing, JAX, Navier-Stokes equations, Turbulence, Level-set, Diffuse-interface,
Two-phase flows",
abstract = "In our effort to facilitate machine learning-assisted computational fluid dynamics
(CFD), we introduce the second iteration of JAX-Fluids. JAX-Fluids is a Python-based
fully-differentiable CFD solver designed for compressible single- and two-phase flows. In this work,
the first version is extended to incorporate high-performance computing (HPC) capabilities. We
introduce a parallelization strategy utilizing JAX primitive operations that scales efficiently on
GPU (up to 512 NVIDIA A100 graphics cards) and TPU (up to 1024 TPU v3 cores) HPC systems. We further
demonstrate stable parallel computation of automatic differentiation gradients across extended
integration trajectories. The new code version offers enhanced two-phase flow modeling capabilities.
In particular, a five-equation diffuse-interface model is incorporated which complements the
level-set sharp-interface model. Additional algorithmic improvements include positivity-preserving
limiters for increased robustness, support for stretched Cartesian meshes, refactored I/O handling,
comprehensive post-processing routines, and an updated list of state-of-the-art high-order numerical
discretization schemes. We verify newly added numerical models by showcasing simulation results for
single- and two-phase flows, including turbulent boundary layer and channel flows, air-helium shock
bubble interactions, and air-water shock drop interactions. PROGRAM SUMMARY Program Title:
JAX-Fluids CPC Library link to program files: https://doi.org/10.17632/pzvkwn5s6p.2 Developer's
repository link: https://github.com/tumaer/JAXFLUIDS Licensing provisions: GPLv3 Programming
language: Python, JAX Supplementary material: Source code, example scripts, videos Journal reference
of previous version: D.A. Bezgin, A.B. Buhendwa, N.A. Adams, JAX-Fluids: A fully-differentiable
high-order computational fluid dynamics solver for compressible two-phase flows, Computer Physics
Communications 282 (2022) 108527. Does the new version supersede the previous version?: Yes Reasons
for the new version: New features and updates of the CFD solver Summary of revisions:•JAX
primitives-based parallelization for GPU and TPU clusters•Automatic differentiation through
distributed simulations•Diffuse-interface model for two-phase flows•Positivity-preserving
interpolation and flux limiters•Support for stretched Cartesian meshes•Extended list of
numerical discretization schemes•Performance improvements•Revised I/O handling Nature of
problem: The compressible Navier-Stokes equations describe continuum-scale fluid flows which may
exhibit complex phenomena such as shock waves, material interfaces, and turbulence. The accurate
numerical solution of fluid flows is computationally expensive and, therefore, requires
high-performance computing (HPC) architectures. To this end, machine learning (ML), in particular
differentiable programming, is continuously being explored as a tool to accelerate conventional
computational fluid dynamics (CFD). With the second iteration of JAX-Fluids, we provide a
comprehensive differentiable CFD code that scales efficiently on HPC systems, seamlessly integrates
ML models, and accurately simulates complex flow physics with high-order low-dissipative numerical
methods. Solution method: JAX-Fluids is a finite-volume solver which uses high-order low-dissipative
shock capturing schemes in combination with approximate Riemann solvers. Two-phase flows can be
simulated using the sharp-interface level-set method or the diffuse-interface five-equation model.
The code is written in Python and builds on the JAX library. The JAX backend allows the computation
of automatic differentiation gradients. We use a homogenous domain decomposition ansatz to implement
the parallelization. An object-oriented programming style and a modular design philosophy allow
exchanging numerical schemes and integrating custom subroutines. Additional comments including
restrictions and unusual features: JAX-Fluids runs on CPUs, GPUs, and TPUs in single- and
multi-device settings. JAX-Fluids requires open-source third-party Python libraries which are
automatically installed. The solver has been tested on Linux and macOS operating systems.",
ad_area = "Computational Fluid Dynamics",
ad_tools = "JAX"
}
| |
back
|
|