mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix implement error in diffusers
commit
8f72b6f8fb
|
@ -141,7 +141,25 @@ def _is_grad_tensor(obj) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _has_grad_tensor(obj) -> bool:
|
||||
if isinstance(obj, tuple) or isinstance(obj, list):
|
||||
for x in obj:
|
||||
if _has_grad_tensor(x):
|
||||
return True
|
||||
return False
|
||||
elif isinstance(obj, dict):
|
||||
for x in obj.values():
|
||||
if _has_grad_tensor(x):
|
||||
return True
|
||||
return False
|
||||
else:
|
||||
return _is_grad_tensor(obj)
|
||||
|
||||
|
||||
def _get_grad_args(*args):
|
||||
# if there is no grad tensors, do nothing
|
||||
if not _has_grad_tensor(args):
|
||||
return args, None
|
||||
# returns the identical args if there is a grad tensor
|
||||
for obj in args:
|
||||
if _is_grad_tensor(obj):
|
||||
|
|
|
@ -7,27 +7,22 @@
|
|||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
betas = (torch.linspace(linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64)**2)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
|
@ -37,7 +32,7 @@ def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2,
|
|||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
@ -48,7 +43,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
|||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps))**2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
|
||||
|
@ -110,21 +105,26 @@ def checkpoint(func, inputs, params, flag):
|
|||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
return torch_checkpoint(func, *inputs)
|
||||
# args = tuple(inputs) + tuple(params)
|
||||
# return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()}
|
||||
ctx.gpu_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled()
|
||||
}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
@ -162,9 +162,8 @@ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
|||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) /
|
||||
half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
|
@ -211,14 +210,17 @@ def normalization(channels):
|
|||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
|
@ -268,4 +270,4 @@ class HybridConditioner(nn.Module):
|
|||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
|
Loading…
Reference in New Issue