PALM: Pushing Adaptive Learning Rate Mechanisms for Continual Test-Time Adaptation

The University of Texas at Dallas, Richardson, TX
AAAI 2025 (Oral)

Our Framework for Continual Test-Time Adaptation

Description of the image

PALM framework: At time t, input batch xk is processed by the model, parameterized by θt. The KL-divergence between the softmax predictions and a uniform distribution is backpropagated to select the layers with the gradient norm ≤ η, to quantify the uncertainty. The parameter sensitivities of these layers, as an indicator of domain shift meters, are computed to update their learning rates. Finally, with the optimization objective, we update the model with the adjusted learning rates of the parameters.

Abstract

Real-world vision models in dynamic environments face rapid shifts in domain distributions, leading to decreased recognition performance. Using unlabeled test data, continuous test-time adaptation (CTTA) directly adjusts a pre-trained source discriminative model to these changing domains. A highly effective CTTA method involves applying layer-wise adaptive learning rates for selectively adapting pre-trained layers. However, it suffers from the poor estimation of domain shift and the inaccuracies arising from the pseudo-labels. This work aims to overcome these limitations by identifying layers for adaptation via quantifying model prediction uncertainty without relying on pseudo-labels. We utilize the magnitude of gradients as a metric, calculated by backpropagating the KL divergence between the softmax output and a uniform distribution, to select layers for further adaptation. Subsequently, for the parameters exclusively belonging to these selected layers, with the remaining ones frozen, we evaluate their sensitivity to approximate the domain shift and adjust their learning rates accordingly. We conduct extensive image classification experiments on CIFAR-10C, CIFAR-100C, and ImageNet-C, demonstrating the superior efficacy of our method compared to prior approaches.

CTTA Experimental Results

Each task with K batches, of a certain data distribution, arrives at time step t. At time step t=0, the model is initialized to θs. It is then gradually adapted to each incoming batch xk of the current task in an online manner, where the model parameters are updated to θt. The source dataset is unavailable in this adaptation process due to privacy and storage constraints. For the experiments, each dataset contains a set of 15 corruption styles as tasks (e.g. gaussian noise, shot noise, . . .) with 5 severity levels - indicating corruption strength. The model is evaluated on the CIFAR-10C, CIFAR-100C, and ImageNet-C datasets. The mean errors (%) are shown below.

Ablation Results

BibTeX

If our work is of interest to you, consider citing it.
@inproceedings{maharana2024palm,
        title={PALM: Pushing Adaptive Learning Rate Mechanisms for Continual Test-Time Adaptation},
        author={Maharana, Sarthak Kumar and Zhang, Baoming and Guo, Yunhui},
        booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
        year={2025}
      }