Brief Review — Population Based Training of Neural Networks
PBT, Adaptive Hyperparameter Tuning Using Explore & Exploit
4 min readOct 10, 2022
Population Based Training of Neural Networks
PBT, by Deepmind,
2017 arXiv v2, Over 500 Citations (Sik-Ho Tsang @ Medium)
Hyperparameter Tuning, Deep Reinforcement Learning, Neural Machine Translation, GAN
- Population Based Training (PBT) is proposed to adaptively tune the hyperparameters using explore and exploit.
Outline
- Conventional Sequential & Parallel Random/Grid Search Approaches
- Population Based Training (PBT)
- Results
1. Conventional Sequential & Parallel Random/Grid Search Approaches
1.1. Sequential Optimization
- Sequential optimization requires multiple training runs to be completed (potentially with early stopping), after which new hyperparameters are selected and the model is retrained from scratch with the new hyperparameters.
- The hyperparameters that included in3 major tasks to evaluate are:
- For reinforcement learning tasks, hyperparameters include learning rate, entropy cost, and unroll length for UNREAL on DeepMind Lab, learning rate, entropy cost, and intrinsic reward cost for FuN on Atari, and learning rate only for A3C on StarCraft II.
- For machine translation tasks, hyperparameters include learning rate, attention Dropout, layer Dropout, and ReLU Dropout rates.
- For GAN tasks, hyperparameters include discriminator’s learning rate and the generator’s learning rate.
1.2. Parallel Random/Grid Search Approaches
- Parallel random/grid search of hyperparameters trains multiple models in parallel with different weight initializations and hyperparameters, with the view that one of the models will be optimized the best.
2. Population Based Training (PBT)
- Population Based Training (PBT) starts like parallel search, randomly sampling hyperparameters and weight initializations.
- However, each training run asynchronously evaluates its performance periodically. If a model in the population is under-performing, it will exploit the rest of the population by replacing itself with a better performing model, and it will explore new hyperparameters by modifying the better model’s hyperparameters, before training is continued.
- This process allows hyperparameters to be optimized online, and the computational resources to be focused on the hyperparameter and weight space that has most chance of producing good results.
- The result is a hyperparameter tuning method that while very simple, results in faster learning, lower computational resources, and often better solutions.
2.1. Exploit
- t-test selection: where we another agent is uniformly sampled in the population, and compared the last 10 episodic rewards using Welch’s t-test
- Truncation selection: where all agents in the population are ranked by episodic reward. If the current agent is in the bottom 20% of the population, we sample another agent uniformly from the top 20% of the population, and copy its weights and hyperparameters.
2.2. Explore
- Perturb: where each hyperparameter independently is randomly perturbed by a factor of 1.2 or 0.8.
- Resample: where each hyperparameter is resampled from the original prior distribution defined with some probability.
- Different tasks use different Exploit & Explore strategies. (For details, please feel free to read the paper directly.)
3. Results
- Reinforcement Learning: On all three domains — DeepMind Lab, Atari, and StarCraft II – PBT increases the final performance of the agents when trained for the same number of steps, compared to the very strong baseline of performing random search with the same number of workers.
- Machine Translation: Transformer is used. PBT is able to automatically discover adaptations of various Dropout rates throughout training, and a learning rate schedule that remarkably resembles that of the hand tuned baseline.
- GAN: PBT improves DCGAN with higher Inception Score.
- PBTs outperform baselines by large margin on different tasks.
PBT helps to optimize the hyperparameters during parallel multiple-model training.
Reference
[2017 arXiv v2] [PBT]
Population Based Training of Neural Networks
2.1. Generative Adversarial Network (GAN)
Image Synthesis: 2014 … 2017 [PBT] … 2019 [SAGAN]
4.2. Machine Translation
2014 … 2017 [PBT] … 2021 [ResMLP] [GPKD]