Spaces:
Running
Running
| # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. | |
| import torch | |
| import numbers | |
| from peft import LoraConfig | |
| def get_loraconfig(transformer, rank=128, alpha=128, init_lora_weights="gaussian"): | |
| target_modules = [] | |
| for name, module in transformer.named_modules(): | |
| if "blocks" in name and "face" not in name and "modulation" not in name and isinstance(module, torch.nn.Linear): | |
| target_modules.append(name) | |
| transformer_lora_config = LoraConfig( | |
| r=rank, | |
| lora_alpha=alpha, | |
| init_lora_weights=init_lora_weights, | |
| target_modules=target_modules, | |
| ) | |
| return transformer_lora_config | |
| class TensorList(object): | |
| def __init__(self, tensors): | |
| """ | |
| tensors: a list of torch.Tensor objects. No need to have uniform shape. | |
| """ | |
| assert isinstance(tensors, (list, tuple)) | |
| assert all(isinstance(u, torch.Tensor) for u in tensors) | |
| assert len(set([u.ndim for u in tensors])) == 1 | |
| assert len(set([u.dtype for u in tensors])) == 1 | |
| assert len(set([u.device for u in tensors])) == 1 | |
| self.tensors = tensors | |
| def to(self, *args, **kwargs): | |
| return TensorList([u.to(*args, **kwargs) for u in self.tensors]) | |
| def size(self, dim): | |
| assert dim == 0, 'only support get the 0th size' | |
| return len(self.tensors) | |
| def pow(self, *args, **kwargs): | |
| return TensorList([u.pow(*args, **kwargs) for u in self.tensors]) | |
| def squeeze(self, dim): | |
| assert dim != 0 | |
| if dim > 0: | |
| dim -= 1 | |
| return TensorList([u.squeeze(dim) for u in self.tensors]) | |
| def type(self, *args, **kwargs): | |
| return TensorList([u.type(*args, **kwargs) for u in self.tensors]) | |
| def type_as(self, other): | |
| assert isinstance(other, (torch.Tensor, TensorList)) | |
| if isinstance(other, torch.Tensor): | |
| return TensorList([u.type_as(other) for u in self.tensors]) | |
| else: | |
| return TensorList([u.type(other.dtype) for u in self.tensors]) | |
| def dtype(self): | |
| return self.tensors[0].dtype | |
| def device(self): | |
| return self.tensors[0].device | |
| def ndim(self): | |
| return 1 + self.tensors[0].ndim | |
| def __getitem__(self, index): | |
| return self.tensors[index] | |
| def __len__(self): | |
| return len(self.tensors) | |
| def __add__(self, other): | |
| return self._apply(other, lambda u, v: u + v) | |
| def __radd__(self, other): | |
| return self._apply(other, lambda u, v: v + u) | |
| def __sub__(self, other): | |
| return self._apply(other, lambda u, v: u - v) | |
| def __rsub__(self, other): | |
| return self._apply(other, lambda u, v: v - u) | |
| def __mul__(self, other): | |
| return self._apply(other, lambda u, v: u * v) | |
| def __rmul__(self, other): | |
| return self._apply(other, lambda u, v: v * u) | |
| def __floordiv__(self, other): | |
| return self._apply(other, lambda u, v: u // v) | |
| def __truediv__(self, other): | |
| return self._apply(other, lambda u, v: u / v) | |
| def __rfloordiv__(self, other): | |
| return self._apply(other, lambda u, v: v // u) | |
| def __rtruediv__(self, other): | |
| return self._apply(other, lambda u, v: v / u) | |
| def __pow__(self, other): | |
| return self._apply(other, lambda u, v: u ** v) | |
| def __rpow__(self, other): | |
| return self._apply(other, lambda u, v: v ** u) | |
| def __neg__(self): | |
| return TensorList([-u for u in self.tensors]) | |
| def __iter__(self): | |
| for tensor in self.tensors: | |
| yield tensor | |
| def __repr__(self): | |
| return 'TensorList: \n' + repr(self.tensors) | |
| def _apply(self, other, op): | |
| if isinstance(other, (list, tuple, TensorList)) or ( | |
| isinstance(other, torch.Tensor) and ( | |
| other.numel() > 1 or other.ndim > 1 | |
| ) | |
| ): | |
| assert len(other) == len(self.tensors) | |
| return TensorList([op(u, v) for u, v in zip(self.tensors, other)]) | |
| elif isinstance(other, numbers.Number) or ( | |
| isinstance(other, torch.Tensor) and ( | |
| other.numel() == 1 and other.ndim <= 1 | |
| ) | |
| ): | |
| return TensorList([op(u, other) for u in self.tensors]) | |
| else: | |
| raise TypeError( | |
| f'unsupported operand for *: "TensorList" and "{type(other)}"' | |
| ) |