Keras/TF implementation of AdamW, SGDW, NadamW, Warm Restarts, and Learning Rate multipliers
Keras/TF implementation of AdamW, SGDW, NadamW, and Warm Restarts, based on paper Decoupled Weight Decay Regularization - plus Learning Rate Multipliers
lr, complicating search
lris prone to overfitting
pip install keras-adamwor clone repository
If using tensorflow.keras imports, set
import os; os.environ["TF_KERAS"]='1'.
AdamW(model=model)
weight_decays = {:,}:
# 1. Automatically Just pass in `model` (`AdamW(model=model)`), and decays will be automatically extracted. Loss-based penalties (l1, l2, l1_l2) will be zeroed by default, but can be kept via `zero_penalties=False` (NOT recommended, see Use guidelines).
# 2. Use keras_adamw.utils.py Dense(.., kernel_regularizer=l2(0)) # set weight decays in layers as usual, but to ZERO wd_dict = get_weight_decays(model) # print(wd_dict) to see returned matrix names, note their order # specify values as (l1, l2) tuples, both for l1_l2 decay ordered_values = [(0, 1e-3), (1e-4, 2e-4), ..] weight_decays = fill_dict_in_order(wd_dict, ordered_values)
# 3. Fill manually model.layers[1].kernel.name # get name of kernel weight matrix of layer indexed 1 weight_decays.update({'conv1d_0/kernel:0': (1e-4, 0)}) # example
AdamW(.., use_cosine_annealing=True, total_iterations=200)- refer to Use guidelines below
AdamW(.., lr_multipliers=lr_multipliers)- to get,
{:,}:
Dense(.., name='dense_1')- OR
[print(idx,layer.name) for idx,layer in enumerate(model.layers)]
lr_multipliers = {'conv1d_0':0.1} # target layer by full name- OR
lr_multipliers = {'conv1d':0.1} # target all layers w/ name substring 'conv1d'
## Example ```python import numpy as np from keras.layers import Input, Dense, LSTM from keras.models import Model from keras.regularizers import l1, l2, l1l2 from kerasadamw import AdamW
ipt = Input(shape=(120, 4)) x = LSTM(60, activation='relu', name='lstm1', kernelregularizer=l1(1e-4), recurrentregularizer=l2(2e-4))(ipt) out = Dense(1, activation='sigmoid', kernelregularizer=l1l2(1e-4, 2e-4))(x) model = Model(ipt, out)
python lrmultipliers = {'lstm_1': 0.5}optimizer = AdamW(lr=1e-4, model=model, lrmultipliers=lrmultipliers, usecosineannealing=True, totaliterations=24) model.compile(optimizer, loss='binarycrossentropy')
python for epoch in range(3): for iteration in range(24): x = np.random.rand(10, 120, 4) # dummy data y = np.random.randint(0, 2, (10, 1)) # dummy labels loss = model.trainonbatch(x, y) print("Iter {} loss: {}".format(iteration + 1, "%.3f" % loss)) print("EPOCH {} COMPLETED\n".format(epoch + 1)) ```(Full example + plot code, and explanation of
lr_tvs.
lr: example.py)
weight_decays- else the purpose of the 'fix' is largely defeated, and weights will be over-decayed --My recommendation
lambda = lambda_norm * sqrt(1/total_iterations)--> can be changed; the intent is to scale λ to decouple it from other hyperparams - including (but not limited to), # of epochs & batch size. --Authors (Appendix, pg.1) (A-1)
total_iterations_wd--> set to normalize over all epochs (or other interval
!= total_iterations) instead of per-WR when using WR; may sometimes yield better results --My note
autorestart=True, which is the default if
use_cosine_annealing=True; internally sets
t_cur=0after
total_iterationsiterations.
t_cur = -1to restart schedule multiplier (see Example). Can be done at compilation or during training. Non-
-1is also valid, and will start
eta_tat another point on the cosine curve. Details in A-2,3
t_curshould be set at
iter == total_iterations - 2; explanation here
total_iterationsto the # of expected weight updates for the given restart --Authors (A-1,2)
eta_min=0, eta_max=1are tunable hyperparameters; e.g., an exponential schedule can be used for
eta_max. If unsure, the defaults were shown to work well in the paper. --Authors
python # 'total_iterations' general purpose example def get_total_iterations(restart_idx, num_epochs, iterations_per_epoch): return num_epochs[restart_idx] * iterations_per_epoch[restart_idx] get_total_iterations(0, num_epochs=[1,3,5,8], iterations_per_epoch=[240,120,60,30])### Learning rate multipliers
lrshould be. --My recommendation