P#linear-algebra

Definition

Tensor contraction

Tensor contraction is an operation that combines tensors by multiplying matching entries and summing over selected pairs of axes.

In implementation-oriented form, let

Choose axis lists

with matching dimensions

The contraction of and over these axes is the tensor obtained by summing over the matched coordinates while keeping all non-contracted axes.

This generalises familiar operations such as the dot product, matrix multiplication, and the trace.

A self-contraction can be visualised as summing over a diagonal. The diagram below shows the main diagonal of a tensor, which is the set of entries selected by the condition before summation.

Index form

If denotes the free indices of , denotes the free indices of , and are the contracted indices, then the result is given by

The precise position of the free indices depends on the chosen axis lists, but the idea is always the same:

  • free axes remain in the output;
  • contracted axes are summed away.

Shape

Output shape

A practical convention is:

  1. keep all non-contracted axes of the left tensor in their original order;
  2. then keep all non-contracted axes of the right tensor in their original order.

If lhs_axes and rhs_axes describe the contracted axes, then

This convention makes binary contraction easy to implement and covers the usual linear-algebra cases.

Implementation

Generic contraction

A simple implementation does not need a special matrix-multiplication kernel. It only needs:

  • the shapes of the left and right tensors;
  • the selected contraction axes;
  • strides or another way to map multi-indices to storage offsets.

A direct algorithm is:

Tensor contract(
    Tensor lhs,
    Tensor rhs,
    usize[] lhs_axes,
    usize[] rhs_axes
) {
    usize[] lhs_free_axes = free_axes(lhs.shape, lhs_axes);
    usize[] rhs_free_axes = free_axes(rhs.shape, rhs_axes);
    usize[] out_shape = lhs_free_axes.shape ++ rhs_free_axes.shape;

    Tensor out = allocate(out_shape);

    for each output index out_idx {
        usize[] out_coords = decode(out_idx, out_shape);

        Offset lhs_base = project_free_axes(out_coords, lhs_free_axes, lhs.strides);
        Offset rhs_base = project_free_axes(out_coords, rhs_free_axes, rhs.strides);

        Value acc = 0;
        for each contracted coordinate k {
            Offset lhs_offset = lhs_base + project_contracted_axes(k, lhs_axes, lhs.strides);
            Offset rhs_offset = rhs_base + project_contracted_axes(k, rhs_axes, rhs.strides);
            acc += lhs[lhs_offset] * rhs[rhs_offset];
        }

        out[out_idx] = acc;
    }

    return out;
}

This is not the fastest implementation, but it is general and correct. Faster backends can recognise special cases such as matrix multiplication and replace the generic loop nest with an optimised kernel.

Conditions

Well-formedness

A binary contraction is well formed only if:

  • lhs_axes.len = rhs_axes.len;
  • every listed axis is in range;
  • no axis is repeated on either side;
  • paired contracted dimensions match.

These checks are best performed before execution, ideally when the tensor expression is constructed.

Examples

Common special cases

The following operations are all contractions.

Dot product:

Matrix multiplication:

Matrix trace:

In each case, multiplication is followed by summation over matched indices.