Tackling Skewed Data Challenges in Decentralized Machine Learning

Slide Note
Embed
Share

Addressing the critical issue of skewed data in decentralized machine learning, this work explores solutions to effectively handle non-iid data distribution, focusing on communication bottlenecks, data skewness, and proposing innovative approaches for decentralized learning over skewed datasets.


Uploaded on Jul 22, 2024 | 0 Views


Download Presentation

Please find below an Image/Link to download the presentation.

The content on the website is provided AS IS for your information and personal use only. It may not be sold, licensed, or shared on other websites without obtaining consent from the author. Download presentation by click this link. If you encounter any issues during the download, it is possible that the publisher has removed the file from their server.

E N D

Presentation Transcript


  1. The Non-IID Data Quagmire of Decentralized Machine Learning ICML 2020 Kevin Hsieh, Amar Phanishayee, Onur Mutlu, Phillip Gibbons

  2. ML Training with Decentralized Data Geo-Distributed Learning Federated Learning Data Sovereignty and Privacy 2

  3. Major Challenges in Decentralized ML Geo-Distributed Learning Federated Learning Challenge 1: Communication Bottlenecks Solutions: Federated Averaging, Gaia, Deep Gradient Compression 3

  4. Major Challenges in Decentralized ML Geo-Distributed Learning Federated Learning Challenge 2: Data are often highly skewed (non-iid data) Solutions: Understudied! Is it a real problem? 4

  5. Our Work in a Nutshell Real-World Dataset Experimental Study Proposed Solution 5

  6. Geographical mammal images from Flickr 736K pictures in 42 mammal classes Highly skewed labels among geographic regions Real-World Dataset 6

  7. Skewed data labels are a fundamental and pervasive problem The problem is even worse for DNNs with batch normalization Experimental Study The degree of skew determines the difficulty of the problem 7

  8. Replace batch normalization with group normalization SkewScout: communication-efficient decentralized learning over arbitrarily skewed data Proposed Solution 8

  9. Real-World Dataset 9

  10. Flickr-Mammal Dataset 42 mammal classes from Open Images and ImageNet 40,000 images per class Clean images with PNAS [Liu et al., 18] Reverse geocoding to country, subcontinent, and continent https://doi.org/10.5281/zenodo.3676081 736K Pictures with Labels and Geographic Information

  11. Top-3 Mammals in Each Continent Each top-3 mammal takes 44-92% share of global images 11

  12. Label Distribution Across Continents 100% 90% 80% 70% 60% 50% 40% 30% 20% 10% 0% harbor seal bull alpaca lion pig sheep zebra deer monkey otter camel dolphin whale armadillo cheetah goat hamster brown bear fox leopard lynx squirrel hippopotamus kangaroo koala polar bear cat hedgehog rabbit mule teddy bear cattle porcupine skunk antelope panda red panda sea lion tiger elephant jaguar Africa Americas Asia Europe Oceania Vast majority of mammals are dominated by 2-3 continents The labels are even more skewed among subcontinents 12

  13. Experimental Study 13

  14. Scope of Experimental Study Decentralized Learning Algorithms Skewness of Data Label Partitions ML Application Gaia [NSDI 17] FederatedAveraging [AISTATS 17] DeepGradientCompression [ICLR 18] 2-5 Partitions -- more partitions are worse Image Classification (with various DNNs and datasets) Face recognition

  15. Results: GoogLeNet over CIFAR-10 BSP (Bulk Synchronous Parallel) FederatedAveraging (20X faster than BSP) Gaia (20X faster than BSP) DeepGradientCompression (30X faster than BSP) 80% Top-1 Validation -12% -15% 60% Accuarcy 40% 20% -69% 0% Shuffled Data Skewed Data All decentralized learning algorithms lose significant accuracy Tight synchronization (BSP) is accurate but too slow 15

  16. Skewed data is a pervasive and fundamental problem Even BSP loses accuracy for DNNs with Batch Normalization layers Similar Results across the Board BSP Gaia FederatedAveraging DeepGradientCompression 90% Top-1 Validation 45% Accuracy 0% Shuffled Data Skewed Data Shuffled Data Skewed Data Shuffled Data Skewed Data AlexNet LeNet ResNet20 80% Image Classification (CIFAR-10) Top-1 Validation BSP Gaia FedAvg BSP Gaia FedAvg 40% 100% 100% Accuracy 80% 50% 0% Shuffled Data Skewed Data Shuffled Data Skewed Data 60% 0% Shuffled Data Image Classification (Mammal-Flickr) Skewed Data Shuffled Data Face Recognition (CASIA and test with LFW) Skewed Data GoogLeNet Image Classification (ImageNet) ResNet10

  17. Degree of Skew is a Key Factor 20% Skewed Data 40% Skewed Data 60% Skewed Data 80% Skewed Data -1.5% -3.0% -0.5% -5.3% -1.3% -1.1% -3.5% -2.6% 80% -4.8% Top-1 Validation -5.1% -6.5% 75% -8.5% Accuracy 70% 65% 60% BSP Gaia Federated Averaging Deep Gradient Compression CIFAR-10 with GN-LeNet Degree of skew can determine the difficulty of the problem 17

  18. Batch Normalization Problem and Solution 18

  19. Background: Batch Normalization [Ioffe & Szegedy, 2015] Prev Layer Next Layer W BN Normalize with estimated global and at test time Standard normal distribution ( = 0, = 1) in each minibatch at training time Batch normalization enables larger learning rates and avoid sharp local minimum (generalize better)

  20. Batch Normalization with Skewed Data Shuffled Data Skewed DataMinibatch Mean Divergence: ||Mean1 Mean2|| / AVG(Mean1, Mean2) 70% Minibatch Mean Divergence 35% 0% 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 Channel CIFAR-10 with BN-LeNet (2 Partitions) Minibatch and vary significantly among partitions Global and do not work for all partitions 20

  21. Solution: Use Group Normalization [Wu and He, ECCV 18] Batch Normalization Group Normalization H, W H, W C C N N Designed for small minibatches We apply as a solution for skewed data 21

  22. Results with Group Normalization Shuffled Data Skewed Data 0% 80% -10% -9% Validation Accuracy -12% -15% -26% 60% -29% 40% 20% -70% 0% BSP Gaia Federated Averaging Deep Gradient Compression BSP Gaia Federated Averaging Deep Gradient Compression BatchNorm GroupNorm GroupNorm recovers the accuracy loss for BSP and reduces accuracy losses for decentralized algorithms 22

  23. SkewScout SkewScout: Decentralized learning over arbitrarily skewed data 23

  24. Overview of SkewScout Recall that degree of data skew determines difficulty SkewScout SkewScout: : Adapts communication to the skew-induced accuracy loss SkewScout Model Travelling Accuracy Loss Estimation Communication Control Minimize commutation when accuracy loss is acceptable Work with different decentralized learning algorithms

  25. Evaluation of SkewScout All data points achieves the same validation accuracy SkewScout SkewScout Oracle SkewScout Oracle 60 50 Communication Saving 51.8 42.1 50 40 over BSP (times) 29.6 40 34.1 30 23.6 24.9 30 19.1 19.9 20 20 11.0 9.9 10.6 9.6 10 10 0 0 20% Skewed 60% Skewed 100% Skewed 20% Skewed 60% Skewed 100% Skewed CIFAR-10 with AlexNet CIFAR-10 with GoogLeNet Significant saving over BSP Only within 1.5X more than Oracle 25

  26. Key Takeaways Flickr-Mammal dataset: Highly skewed label distribution in the real world Skewed data is a pervasive problem Batch normalization is particularly problematic SkewScout: adapts decentralized learning over arbitrarily skewed data Group normalization is a good alternative to batch normalization 26

Related