diff --git a/colossalai/kernel/cpu_adam_loader.py b/colossalai/kernel/cpu_adam_loader.py index 0df6bd49b..4763f40ab 100644 --- a/colossalai/kernel/cpu_adam_loader.py +++ b/colossalai/kernel/cpu_adam_loader.py @@ -12,37 +12,6 @@ class CPUAdamLoader(BaseKernelLoader): Usage: # init cpu_adam = CPUAdamLoader().load() - cpu_adam_op = cpu_adam.CPUAdamOptimizer( - alpha, beta1, beta2, epsilon, weight_decay, adamw_mode, - ) - ... - # optim step - cpu_adam_op.step( - step, lr, beta1, beta2, epsilon, weight_decay, bias_correction, - params, grads, exp_avg, exp_avg_sq, loss_scale, - ) - - Args: - func CPUAdamOptimizer: - alpha (float): learning rate. Default to 1e-3. - beta1 (float): coefficients used for computing running averages of gradient. Default to 0.9. - beta2 (float): coefficients used for computing running averages of its square. Default to 0.99. - epsilon (float): term added to the denominator to improve numerical stability. Default to 1e-8. - weight_decay (float): weight decay (L2 penalty). Default to 0. - adamw_mode (bool): whether to use the adamw. Default to True. - func step: - step (int): current step. - lr (float): learning rate. - beta1 (float): coefficients used for computing running averages of gradient. - beta2 (float): coefficients used for computing running averages of its square. - epsilon (float): term added to the denominator to improve numerical stability. - weight_decay (float): weight decay (L2 penalty). - bias_correction (bool): whether to use bias correction. - params (torch.Tensor): parameter. - grads (torch.Tensor): gradient. - exp_avg (torch.Tensor): exp average. - exp_avg_sq (torch.Tensor): exp average square. - loss_scale (float): loss scale value. """ def __init__(self): @@ -57,7 +26,7 @@ class CPUAdamLoader(BaseKernelLoader): def fetch_kernel(self): if platform.machine() == "x86_64": kernel = self._extension_map["x86"]().fetch() - elif platform.machine() in ["aarch64", "aarch64_be", "armv8b", "armv8l"]: + elif platform.machine() == "aarch64": kernel = self._extension_map["arm"]().fetch() else: raise Exception("not supported")