# py-meanflow
**Repository Path**: zxmh/py-meanflow
## Basic Information
- **Project Name**: py-meanflow
- **Description**: No description available
- **Primary Language**: Unknown
- **License**: Not specified
- **Default Branch**: main
- **Homepage**: None
- **GVP Project**: No
## Statistics
- **Stars**: 0
- **Forks**: 0
- **Created**: 2025-09-15
- **Last Updated**: 2025-09-15
## Categories & Tags
**Categories**: Uncategorized
**Tags**: None
## README
# Mean Flows: PyTorch + GPU Implementation
This is a PyTorch+GPU re-implementation for the CIFAR-10 experiments in [Mean Flows for One-step Generative Modeling](https://arxiv.org/abs/2505.13447). The original experiments were done in JAX+TPU.
## Installation
This repo was tested in PyTorch 2.7.1 and uses `torch.compile`. Compilation may depend on PyTorch versions.
```
conda env create -f environment.yml
conda activate meanflow
```
## Demo
Run `demo.ipynb` for a demo of 1-step generation and FID evaluation. This demo should produce <2.9 FID.
## Training
Run the script `cifar10_v1.sh` to train from scratch with 8 GPUs.
It is an improved configuration that can approach ~2.9 FID at 16000 epochs (800k iterations with batch 128x8). It takes 0.21s/iter in 8x H200 GPUs. The checkpoint in `demo.ipynb` (~2.80 FID) is from this script.
The original configuration used in the paper is in `cifar10_v0.sh`.
## Note on JVP
Users may be unfamiliar with the JVP (Jacobian-vector product) operation, which MeanFlow is based on. While JVP is straightforward to implement in JAX, its correct implementation in PyTorch is worth a closer look.
#### DDP
The op `torch.func.jvp` does not support a DDP (`DistributedDataParallel`) object. In your code, you may need to replace `model` with `model.module` to allow `torch.func.jvp` to run. However, doing so may bypass the gradient synchronization normally handled by DDP, **with no error reported**.
In our code, we handle this by `synchronize_gradients(model)`, with a sanity check `gradient_sanity_check`.
#### Compilation
The memory and speed of JVP can greatly benefit from compilation, in both JAX and PyTorch. In our code, this is done by:
```
compiled_train_step = torch.compile(
train_step,
disable=not args.compile,
)
```
where `train_step` is:
```
def train_step(model_without_ddp, *args, **kwargs):
loss = model_without_ddp.forward_with_loss(*args, **kwargs)
loss.backward(create_graph=False)
return loss
```
Optionally, we also put `update_ema()` into `train_step` for compilation.
#### Alternative to Compilation
If you don't want to compile (for example, some of your ops are not supported), we recommend to compute `dudt` by `torch.func.jvp` under `torch.no_grad()`:
```
u_pred = u_func(z, t, r)
with torch.no_grad():
_, dudt = torch.func.jvp(u_func, (z, t, r), (v, dtdt, drdt))
```
The function prediction `u_pred` is computed separately. In this way, computing `dudt` does not introduce substantial additional memory usage, and its time cost is roughly equivalent to a forward and backward pass. If you want `u_func` to share the dropout masks, consider backing up rng states by `cpu_rng_state = torch.get_rng_state(); cuda_rng_state = torch.cuda.get_rng_state()` and restoring by `torch.set_rng_state(cpu_rng_state); torch.cuda.set_rng_state(cuda_rng_state)` before and after the call of `u_func`.
## References
This repo is based on the following repos:
* [Flow Matching repo](https://github.com/facebookresearch/flow_matching)
* [EDM repo](https://github.com/NVlabs/edm)
See also:
* [Our MeanFlow JAX repo](https://github.com/Gsunshine/meanflow) with ImageNet experiments.
* [A third-party MeanFlow PyTorch repo](https://github.com/zhuyu-cs/MeanFlow) with reproduced ImageNet results.