| def batch_broadcast(a, x): | |
| """Broadcasts a over all dimensions of x, except the batch dimension, which must match.""" | |
| if len(a.shape) != 1: | |
| a = a.squeeze() | |
| if len(a.shape) != 1: | |
| raise ValueError( | |
| f"Don't know how to batch-broadcast tensor `a` with more than one effective dimension (shape {a.shape})" | |
| ) | |
| if a.shape[0] != x.shape[0] and a.shape[0] != 1: | |
| raise ValueError( | |
| f"Don't know how to batch-broadcast shape {a.shape} over {x.shape} as the batch dimension is not matching") | |
| out = a.view((x.shape[0], *(1 for _ in range(len(x.shape)-1)))) | |
| return out | |