leonardlin's picture
Add ROCm build artifacts and HIP backend
1e407f0
raw
history blame
10.8 kB
import numpy as np
import torch
# 1. Add heavyweight (data) validation helper.
# 2. Add construction helpers
# 3. Make indentation consistent
# 4. Replace asserts with descriptive errors.
##
### Validation helpers.
##
def _validate_matrix(shape, data, row_indices, column_indices, offsets):
# Data should be [nnz, block_size, block_size]
if data.dim() == 1:
data = torch.reshape(data, [data.numel(), 1, 1])
# Blocks should be square.
if data.shape[-2] != data.shape[-1]:
raise ValueError(
"Expected square blocking in data. "
f"Got block shape {[data.shape[-2], data.shape[-1]]}")
# Flatten batch dimensions on data - original shape preserved
# in shape argument.
block_size = data.shape[-1]
data = data.view([-1, block_size, block_size])
if data.dim() != 3:
raise ValueError(
"Expected 3D shape for data (nnz, block, block). "
f"Got shape {data.dim()}D shape.")
block_size = data.shape[1]
if shape[-2] % block_size != 0 or shape[-1] % block_size != 0:
raise ValueError(
"Matrix shape must be dividible by blocking. "
f"Got shape {shape} with "
f"{[block_size, block_size]} blocking.")
if np.prod(shape) < data.numel():
raise ValueError(
"Invalid matrix. Number of nonzeros exceeds matrix capacity "
f"({data.numel()} v. {np.prod(shape)})")
if row_indices.dim() != 1:
raise ValueError(
f"Expected 1D row_indices. Got {row_indices.dim()}D row_indices.")
if column_indices.dim() != 1:
raise ValueError(
f"Expected 1D column_indices. Got {column_indices.dim()}D column_indices.")
if offsets.dim() != 1:
raise ValueError(
f"Expected 1D offsets. Got {offsets.dim()}D offsets.")
if row_indices.numel() != data.shape[0]:
raise ValueError(
"Expected 1 index per nonzero block. "
f"Got {row_indices.numel()} row_indices for {data.shape[0]} blocks")
if column_indices.numel() != data.shape[0]:
raise ValueError(
"Expected 1 index per nonzero block. "
f"Got {column_indices.numel()} column_indices for {data.shape[0]} blocks")
block_rows = np.prod(shape[:-1]) / block_size
if offsets.numel() != block_rows + 1:
raise ValueError(
"Expected one offset per block row plus one. "
f"Got {offsets.numel()} offsets with {block_rows} block rows.")
is_cuda = (data.is_cuda and
row_indices.is_cuda and
column_indices.is_cuda and
offsets.is_cuda)
is_cpu = (not data.is_cuda and
not row_indices.is_cuda and
not column_indices.is_cuda and
not offsets.is_cuda)
if not (is_cuda or is_cpu):
raise ValueError(
"Expected data & meta-data on common device. "
f"Got data on {data.device}, row_indices on {row_indices.device} "
f"column_indices on {column_indices.device} and "
f"offsets on {offsets.device}.")
if data.dtype != torch.float16:
raise ValueError(
f"Expected float16 data. Got {data.dtype} data.")
if row_indices.dtype != torch.int16:
raise ValueError(
f"Expected int16 row_indices. Got {row_indices.dtype} row_indices.")
if column_indices.dtype != torch.int16:
raise ValueError(
f"Expected int16 column_indices. Got {column_indices.dtype} column_indices.")
if offsets.dtype != torch.int32:
raise ValueError(
f"Expected int32 offsets. Got {offsets.dtype} offsets.")
return data
def _transpose(size, data, row_indices, column_indices, offsets):
block_columns = size[1] // data.shape[1]
# Sort row indices by column indices to get the transposed matrix's
# column indices.
gather_indices = column_indices.argsort()
column_indices_t = row_indices.gather(0, gather_indices)
block_offsets_t = gather_indices.int()
# NOTE: Histogram is not implemented for any integer type on CPU. Do
# the histogram in 32-bit float, which can exactly represent 16-bit
# integers.
column_indices_float = column_indices.float()
zero = torch.zeros((1,), dtype=torch.int32, device=data.device)
nnz_per_column = column_indices_float.histc(block_columns, 0, block_columns)
nnz_per_column = nnz_per_column.int()
offsets_t = torch.cat([zero, nnz_per_column.cumsum(0, dtype=torch.int32)])
return column_indices_t, offsets_t, block_offsets_t
class Matrix(torch.nn.Module):
"""A matrix stored in sparse format.
Underlying format is block compressed sparse row (BCSR).
TODO(tgale): Make this mirror torch.Tensor API as much as possible.
"""
def __init__(self,
size,
data,
row_indices,
column_indices,
offsets,
column_indices_t=None,
offsets_t=None,
block_offsets_t=None):
super().__init__()
self._size = size
self._data = data
self._row_indices = row_indices
self._column_indices = column_indices
self._offsets = offsets
# Produce the transpose meta-data if it is not passed in.
if ((column_indices_t is None) or (offsets_t is None) or
(block_offsets_t is None)):
column_indices_t, offsets_t, block_offsets_t = _transpose(
size, data, row_indices, column_indices, offsets)
self._column_indices_t = column_indices_t
self._offsets_t = offsets_t
self._block_offsets_t = block_offsets_t
self._transposed = False
# Validate that our metadata will not overflow.
max_dim = np.iinfo(np.int16).max * self.blocking
if column_indices.dtype == torch.int16:
if size[0] > max_dim or size[1] > max_dim:
raise ValueError(
"Sparse matrix with shape {size} exceeds representable "
"size with 16-bit indices.")
def validate(self):
_validate_matrix(self._size,
self._data,
self._row_indices,
self._column_indices,
self._offsets)
# TODO(tgale): Add heavyweight data validation.
def to(self, device):
# TODO(tgale): Handle type conversions here. We
# need to set the appropriate meta-data type for
# the given floating-point type.
self._data = self._data.to(device)
self._row_indices = self._row_indices.to(device)
self._column_indices = self._column_indices.to(device)
self._offsets = self._offsets.to(device)
self._column_indices_t = self._column_indices_t.to(device)
self._offsets_t = self._offsets_t.to(device)
self._block_offsets_t = self._block_offsets_t.to(device)
return self
def cuda(self):
return self.to(torch.cuda.current_device())
def clone(self):
return Matrix(
self.size(),
self.data.clone(),
self.row_indices.clone(),
self.column_indices.clone(),
self.offsets.clone(),
self.column_indices_t.clone(),
self.offsets_t.clone(),
self.block_offsets_t.clone())
def t(self):
if self.dim() != 2:
raise ValueError(
"t() expects a tensor with <= 2 dimensions, "
f"but self is {self.dim()}D.")
out = Matrix(self.size(),
self.data,
self.row_indices,
self.column_indices,
self.offsets,
self.column_indices_t,
self.offsets_t,
self.block_offsets_t)
out._transposed = not self._transposed
out._size = torch.Size((self._size[1], self._size[0]))
return out
def contiguous(self):
raise ValueError("Not yet implemented.")
def is_contiguous(self):
return not self._transposed
@property
def is_cuda(self):
return self._data.is_cuda
@property
def device(self):
return self._data.device
def size(self):
return self._size
@property
def shape(self):
return self.size()
def dim(self):
return len(self._size)
@property
def data(self):
return self._data
@property
def row_indices(self):
return self._row_indices
@property
def column_indices(self):
return self._column_indices
@property
def offsets(self):
return self._offsets
@property
def offsets_t(self):
return self._offsets_t
@property
def column_indices_t(self):
return self._column_indices_t
@property
def block_offsets_t(self):
return self._block_offsets_t
@property
def dtype(self):
return self.data.dtype
@property
def nnz(self):
return self.data.numel()
@property
def blocking(self):
return self.data.shape[1]
@property
def requires_grad(self):
return self.data.requires_grad
def requires_grad_(self, x):
self.data.requires_grad_(x)
return self
def view(self, *shape):
assert self.is_contiguous()
if shape[-1] != self.size()[-1]:
raise ValueError(
"Can't change view on compressed dimension. "
f"{self.size()[-1]} v. {shape[-1]}.")
if np.prod(shape) != np.prod(self.size()):
raise ValueError(
"Mismatch in numel of Matrix and new shape. "
f"{np.prod(self.size())} v. {np.prod(shape)}")
return Matrix(shape,
self.data,
self.row_indices,
self.column_indices,
self.offsets,
self.column_indices_t,
self.offsets_t,
self.block_offsets_t)
@property
def grad(self):
# TODO(tgale): Make sure this mirrors torch.Tensor
# behavior in the case where we ask for the gradient
# of a non-contiguous tensor.
size = self.size()
if not self.is_contiguous():
size = torch.Size((size[1], size[0]))
out = Matrix(size,
self.data.grad,
self.row_indices,
self.column_indices,
self.offsets,
self.column_indices_t,
self.offsets_t,
self.block_offsets_t)
return out if self.is_contiguous() else out.t()