# FedSWA-main **Repository Path**: yiming1606/FedSWA-main ## Basic Information - **Project Name**: FedSWA-main - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-02-17 - **Last Updated**: 2026-02-25 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # README ## Improving Generalization in Federated Learning with Highly Heterogeneous Data This repository contains the implementation of **FedSWA** and **FedMoSWA**, federated learning algorithms designed to improve generalization in the presence of highly heterogeneous client data. The work is based on the paper: > *Improving Generalization in Federated Learning with Highly Heterogeneous Data via Momentum-Based Stochastic Controlled Weight Averaging (FedSWA & FedMoSWA)*. Both algorithms extend upon **SCAFFOLD** and **SAM-type methods**, aiming to find **global flat minima** that lead to better test performance compared to FedAvg, FedSAM, and related approaches. --- 一张4090或者一张2080ti即可训练!! ## 🛠 Environment Setup 创建环境,要python=3.8 ``` conda create -n fedswa python=3.8 -y conda activate fedswa ``` ### Requirements * Python 3.8 * torch==2.4.1 * torchvision==0.19.1 * numpy * ray==1.0.0 * tensorboardX==2.6.2.2 * peft==0.13.2 * transformers==4.46.3 You can install the dependencies with: ```bash pip install -r requirements.txt ``` 下载速度慢改一下镜像源 ``` mkdir -p ~/.config/pip cat > ~/.config/pip/pip.conf <<'EOF' [global] index-url = https://pypi.tuna.tsinghua.edu.cn/simple timeout = 120 EOF ``` --- 下载模型权重网址: vit-base: https://huggingface.co/Junkang2/vit/tree/main swin_transformer https://huggingface.co/Junkang2/swin_transformer/tree/main ## Dataset 数据集下载网址 Tiny-ImageNet: https://huggingface.co/datasets/Junkang2/Tiny-ImageNet/upload/main The code supports multiple datasets: * **CIFAR-10 / CIFAR-100** * **Tiny-ImageNet** * **EMNIST** * **MNIST** Data will be automatically downloaded to the `./data` directory unless specified via `--datapath`. --- ## Usage To run the training with **SCAFFOLD+** on **CIFAR100** using a **ResNet-18** backbone and Group Normalization: ```bash python main_FedSWA.py --alg FedSWA --lr 1e-1 --data_name CIFAR100 --alpha_value 0.1 --alpha 10 --epoch 301 --extname FedMuon --lr_decay 2 --gamma 0.85 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --rho 0.01 --pix 32 --lora 0 --K 50 python main_FedSWA.py --alg MoFedSWA --lr 1e-1 --data_name CIFAR100 --alpha_value 0.1 --alpha 10 --epoch 301 --extname FedMuon --lr_decay 2 --gamma 0.5 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --rho 0.01 --pix 32 --lora 0 --K 50 python main_FedSWA.py --alg FedAvg --lr 1e-1 --data_name CIFAR100 --alpha_value 0.1 --alpha 10 --epoch 301 --extname FedMuon --lr_decay 2 --gamma 0.5 --CNN resnet18 --E 5 --batch_size 50 --gpu 0 --p 1 --num_gpus_per 0.1 --normalization BN --selection 0.1 --print 0 --pre 1 --num_workers 100 --preprint 10 --rho 0.01 --pix 32 --lora 0 --K 50 ``` 另外注意!!FedSWA的学习率一般是FedAvg的两倍!lr=0.1*2 --- ## Key Arguments * `--alg` : Algorithm to use (e.g., `FedAvg`, `FedSWA`, `FedMoSWA`, `SCAFFOLD+`). * `--lr` : Initial learning rate. * `--lr_decay` : Decay rate for learning rate scheduling. * `--epoch` : Total number of training epochs. * `--E` : Local epochs per communication round. * `--batch_size` : Training batch size. * `--alpha_value` : Dirichlet distribution parameter (controls data heterogeneity). * `--alpha` : Momentum/variance reduction hyperparameter. * `--gamma` : Server momentum coefficient. * `--selection` : Fraction of clients selected each round. * `--CNN` : Model architecture (`resnet18`, `vgg11`, `lenet5`, etc.). * `--normalization` : Normalization layer (`BN` for BatchNorm, `GN` for GroupNorm). * `--gpu` : GPU index(es) to use. * `--num_gpus_per` : Fraction of GPU resources allocated per client. --- ## Algorithm Overview * **FedSWA**: Incorporates **stochastic weight averaging** and cyclical learning rate schedules to find flatter global minima, outperforming FedAvg and FedSAM on heterogeneous data. * **FedMoSWA**: Builds on FedSWA with **momentum-based variance reduction**, aligning local and global updates more effectively than SCAFFOLD. --- ## Results Experiments on **CIFAR-10/100** and **Tiny-ImageNet** show that: * **FedSWA** achieves better generalization than FedAvg and FedSAM. * **FedMoSWA** further reduces client drift and improves optimization stability, especially under high heterogeneity. --- ## Citation If you use this code, please cite the paper: ``` @inproceedings{liu2025fedswa, title={Improving Generalization in Federated Learning with Highly Heterogeneous Data via Momentum-Based Stochastic Controlled Weight Averaging}, author={Liu, Junkang and Liu, Yuanyuan and Shang, Fanhua and Liu, Hongying and Liu, Jin and Feng, Wei}, booktitle={Proceedings of the 42nd International Conference on Machine Learning}, year={2025} } ``` ---