Publication: JAX-Fluids 2.0: Towards HPC for differentiable CFD of compressible two-phase flows
Introduction
Applications
Tools
Research Groups
Workshops
Publications
   List Publications
   Advanced Search
   Info
   Add Publications
My Account
About
Impress

JAX-Fluids 2.0: Towards HPC for differentiable CFD of compressible two-phase flows

- Article in a journal -
 

Area
Computational Fluid Dynamics

Author(s)
Deniz A. Bezgin , Aaron B. Buhendwa , Nikolaus A. Adams

Published in
Computer Physics Communications

Year
2025

Abstract
In our effort to facilitate machine learning-assisted computational fluid dynamics (CFD), we introduce the second iteration of JAX-Fluids. JAX-Fluids is a Python-based fully-differentiable CFD solver designed for compressible single- and two-phase flows. In this work, the first version is extended to incorporate high-performance computing (HPC) capabilities. We introduce a parallelization strategy utilizing JAX primitive operations that scales efficiently on GPU (up to 512 NVIDIA A100 graphics cards) and TPU (up to 1024 TPU v3 cores) HPC systems. We further demonstrate stable parallel computation of automatic differentiation gradients across extended integration trajectories. The new code version offers enhanced two-phase flow modeling capabilities. In particular, a five-equation diffuse-interface model is incorporated which complements the level-set sharp-interface model. Additional algorithmic improvements include positivity-preserving limiters for increased robustness, support for stretched Cartesian meshes, refactored I/O handling, comprehensive post-processing routines, and an updated list of state-of-the-art high-order numerical discretization schemes. We verify newly added numerical models by showcasing simulation results for single- and two-phase flows, including turbulent boundary layer and channel flows, air-helium shock bubble interactions, and air-water shock drop interactions. PROGRAM SUMMARY Program Title: JAX-Fluids CPC Library link to program files: https://doi.org/10.17632/pzvkwn5s6p.2 Developer's repository link: https://github.com/tumaer/JAXFLUIDS Licensing provisions: GPLv3 Programming language: Python, JAX Supplementary material: Source code, example scripts, videos Journal reference of previous version: D.A. Bezgin, A.B. Buhendwa, N.A. Adams, JAX-Fluids: A fully-differentiable high-order computational fluid dynamics solver for compressible two-phase flows, Computer Physics Communications 282 (2022) 108527. Does the new version supersede the previous version?: Yes Reasons for the new version: New features and updates of the CFD solver Summary of revisions:•JAX primitives-based parallelization for GPU and TPU clusters•Automatic differentiation through distributed simulations•Diffuse-interface model for two-phase flows•Positivity-preserving interpolation and flux limiters•Support for stretched Cartesian meshes•Extended list of numerical discretization schemes•Performance improvements•Revised I/O handling Nature of problem: The compressible Navier-Stokes equations describe continuum-scale fluid flows which may exhibit complex phenomena such as shock waves, material interfaces, and turbulence. The accurate numerical solution of fluid flows is computationally expensive and, therefore, requires high-performance computing (HPC) architectures. To this end, machine learning (ML), in particular differentiable programming, is continuously being explored as a tool to accelerate conventional computational fluid dynamics (CFD). With the second iteration of JAX-Fluids, we provide a comprehensive differentiable CFD code that scales efficiently on HPC systems, seamlessly integrates ML models, and accurately simulates complex flow physics with high-order low-dissipative numerical methods. Solution method: JAX-Fluids is a finite-volume solver which uses high-order low-dissipative shock capturing schemes in combination with approximate Riemann solvers. Two-phase flows can be simulated using the sharp-interface level-set method or the diffuse-interface five-equation model. The code is written in Python and builds on the JAX library. The JAX backend allows the computation of automatic differentiation gradients. We use a homogenous domain decomposition ansatz to implement the parallelization. An object-oriented programming style and a modular design philosophy allow exchanging numerical schemes and integrating custom subroutines. Additional comments including restrictions and unusual features: JAX-Fluids runs on CPUs, GPUs, and TPUs in single- and multi-device settings. JAX-Fluids requires open-source third-party Python libraries which are automatically installed. The solver has been tested on Linux and macOS operating systems.

AD Tools
JAX

BibTeX
@ARTICLE{
         Bezgin2025JFT,
       title = "{JAX-Fluids} 2.0: Towards {HPC} for differentiable {CFD} of compressible two-phase
         flows",
       journal = "Computer Physics Communications",
       volume = "308",
       pages = "109433",
       year = "2025",
       issn = "0010-4655",
       doi = "10.1016/j.cpc.2024.109433",
       author = "Deniz A. Bezgin and Aaron B. Buhendwa and Nikolaus A. Adams",
       keywords = "Computational fluid dynamics, Machine learning, Differential programming,
         High-performance computing, JAX, Navier-Stokes equations, Turbulence, Level-set, Diffuse-interface,
         Two-phase flows",
       abstract = "In our effort to facilitate machine learning-assisted computational fluid dynamics
         (CFD), we introduce the second iteration of JAX-Fluids. JAX-Fluids is a Python-based
         fully-differentiable CFD solver designed for compressible single- and two-phase flows. In this work,
         the first version is extended to incorporate high-performance computing (HPC) capabilities. We
         introduce a parallelization strategy utilizing JAX primitive operations that scales efficiently on
         GPU (up to 512 NVIDIA A100 graphics cards) and TPU (up to 1024 TPU v3 cores) HPC systems. We further
         demonstrate stable parallel computation of automatic differentiation gradients across extended
         integration trajectories. The new code version offers enhanced two-phase flow modeling capabilities.
         In particular, a five-equation diffuse-interface model is incorporated which complements the
         level-set sharp-interface model. Additional algorithmic improvements include positivity-preserving
         limiters for increased robustness, support for stretched Cartesian meshes, refactored I/O handling,
         comprehensive post-processing routines, and an updated list of state-of-the-art high-order numerical
         discretization schemes. We verify newly added numerical models by showcasing simulation results for
         single- and two-phase flows, including turbulent boundary layer and channel flows, air-helium shock
         bubble interactions, and air-water shock drop interactions. PROGRAM SUMMARY Program Title:
         JAX-Fluids CPC Library link to program files: https://doi.org/10.17632/pzvkwn5s6p.2 Developer's
         repository link: https://github.com/tumaer/JAXFLUIDS Licensing provisions: GPLv3 Programming
         language: Python, JAX Supplementary material: Source code, example scripts, videos Journal reference
         of previous version: D.A. Bezgin, A.B. Buhendwa, N.A. Adams, JAX-Fluids: A fully-differentiable
         high-order computational fluid dynamics solver for compressible two-phase flows, Computer Physics
         Communications 282 (2022) 108527. Does the new version supersede the previous version?: Yes Reasons
         for the new version: New features and updates of the CFD solver Summary of revisions:•JAX
         primitives-based parallelization for GPU and TPU clusters•Automatic differentiation through
         distributed simulations•Diffuse-interface model for two-phase flows•Positivity-preserving
         interpolation and flux limiters•Support for stretched Cartesian meshes•Extended list of
         numerical discretization schemes•Performance improvements•Revised I/O handling Nature of
         problem: The compressible Navier-Stokes equations describe continuum-scale fluid flows which may
         exhibit complex phenomena such as shock waves, material interfaces, and turbulence. The accurate
         numerical solution of fluid flows is computationally expensive and, therefore, requires
         high-performance computing (HPC) architectures. To this end, machine learning (ML), in particular
         differentiable programming, is continuously being explored as a tool to accelerate conventional
         computational fluid dynamics (CFD). With the second iteration of JAX-Fluids, we provide a
         comprehensive differentiable CFD code that scales efficiently on HPC systems, seamlessly integrates
         ML models, and accurately simulates complex flow physics with high-order low-dissipative numerical
         methods. Solution method: JAX-Fluids is a finite-volume solver which uses high-order low-dissipative
         shock capturing schemes in combination with approximate Riemann solvers. Two-phase flows can be
         simulated using the sharp-interface level-set method or the diffuse-interface five-equation model.
         The code is written in Python and builds on the JAX library. The JAX backend allows the computation
         of automatic differentiation gradients. We use a homogenous domain decomposition ansatz to implement
         the parallelization. An object-oriented programming style and a modular design philosophy allow
         exchanging numerical schemes and integrating custom subroutines. Additional comments including
         restrictions and unusual features: JAX-Fluids runs on CPUs, GPUs, and TPUs in single- and
         multi-device settings. JAX-Fluids requires open-source third-party Python libraries which are
         automatically installed. The solver has been tested on Linux and macOS operating systems.",
       ad_area = "Computational Fluid Dynamics",
       ad_tools = "JAX"
}


back
  

Contact:
autodiff.org
Username:
Password:
(lost password)