Source code for pennylane.workflow.get_compile_pipeline

# Copyright 2018-2024 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains a function for getting the compile pipeline of a given QNode."""

from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING, ParamSpec

from pennylane.transforms.core import CompilePipeline
from pennylane.workflow import construct_execution_config, marker
from pennylane.workflow._setup_transform_program import _setup_transform_program

if TYPE_CHECKING:
    from collections.abc import Callable

    from pennylane.devices.execution_config import ExecutionConfig
    from pennylane.workflow import QNode

P = ParamSpec("P")


def _find_level(program: CompilePipeline, level: str) -> int:
    """Retrieve the numerical level associated to a marker."""
    found_levels = []
    for idx, t in enumerate(program):
        if t.tape_transform == marker.tape_transform:
            found_level = t.args[0] if t.args else t.kwargs["level"]
            found_levels.append(found_level)

            if found_level == level:
                return idx
    raise ValueError(
        f"level {level} not found in compile pipeline. "
        "Builtin options are 'top', 'user', 'device', and 'gradient'."
        f" Custom levels are {found_levels}."
    )


def _resolve_level(
    level: str | int | slice,
    full_pipeline: CompilePipeline,
    num_user: int,
    config: ExecutionConfig,
) -> slice:
    """Resolve level to a slice."""

    if level == "top":
        level = slice(0, 0)
    elif level == "user":
        level = slice(0, num_user)
    elif level == "gradient":
        level = slice(0, num_user + int(hasattr(config.gradient_method, "expand_transform")))
    elif level == "device":
        # Captures everything: user + gradient + device + final
        level = slice(0, None)
    elif isinstance(level, str):
        level = slice(0, _find_level(full_pipeline, level))
    elif isinstance(level, int):
        level = slice(0, level)

    return level


[docs] def get_compile_pipeline( qnode: QNode, level: str | int | slice = "device", ) -> Callable[P, CompilePipeline]: """Extract a compile pipeline at a designated level. Args: qnode (QNode): The QNode to get the compile pipeline for. level (str, int, slice): An indication of what transforms to use from the full compile pipeline. - ``"top"``: Returns an empty compile pipeline. - ``"user"``: Retrieves a compile pipeline containing manually applied user transformations. - ``"gradient"``: Retrieves a compile pipeline that includes user transformations and any relevant gradient transformations. - ``"device"``: Retrieves the entire compile pipeline (user + gradient + device) that is used for execution. - ``str``: Can also accept a string corresponding to the name of a marker that was manually added to the compile pipeline. - ``int``: Can also accept an integer, corresponding to a number of transforms in the program. ``level=0`` corresponds to the start of the program. - ``slice``: Can also accept a ``slice`` object to select an arbitrary subset of the compile pipeline. Returns: CompilePipeline: the compile pipeline corresponding to the requested level. Raises: ValueError: If a final transform is applied to the qnode with a level that goes deeper than the gradient level of the compile pipeline. **Example:** Consider this simple circuit, .. code-block:: python dev = qml.device("default.qubit") @qml.transforms.merge_rotations @qml.transforms.cancel_inverses @qml.qnode(dev) def circuit(): qml.RX(1, wires=0) qml.H(0) qml.H(0) qml.RX(1, wires=0) return qml.expval(qml.Z(0)) We can retrieve the compile pipeline used during execution with, >>> get_compile_pipeline(circuit)() # or level="device" CompilePipeline(cancel_inverses, merge_rotations, defer_measurements, decompose, device_resolve_dynamic_wires, validate_device_wires, validate_measurements, _conditional_broadcast_expand, no_sampling) or use the ``level`` argument to inspect specific stages of the pipeline. >>> get_compile_pipeline(circuit, level="user")() CompilePipeline(cancel_inverses, merge_rotations) .. details:: :title: Usage Details Consider the circuit below which is loaded with user applied transforms, a checkpoint marker and uses the parameter-shift gradient method, .. code-block:: python dev = qml.device("default.qubit") @qml.metric_tensor @qml.transforms.merge_rotations @qml.marker("checkpoint") @qml.transforms.cancel_inverses @qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"shifts": np.pi / 4}) def circuit(x): qml.RX(x, wires=0) qml.H(0) qml.H(0) qml.RX(x, wires=0) return qml.expval(qml.Z(0)) By default, without specifying a ``level`` we will get the full compile pipeline that is used during execution on this device. Note that this can also be retrieved by manually specifying ``level="device"``, >>> get_compile_pipeline(circuit)(3.14) CompilePipeline(cancel_inverses, marker, merge_rotations, _expand_metric_tensor, metric_tensor, _expand_transform_param_shift, defer_measurements, decompose, device_resolve_dynamic_wires, validate_device_wires, validate_measurements, _conditional_broadcast_expand) As can be seen above, this not only includes the two transforms we manually applied, but also a set of transforms used by the device in order to execute the circuit. The ``"user"`` level will retrieve the portion of the compile pipeline that was manually applied by the user to the qnode, >>> get_compile_pipeline(circuit, level="user")(3.14) CompilePipeline(cancel_inverses, marker, merge_rotations, _expand_metric_tensor, metric_tensor) The ``"gradient"`` level builds on top of this to then add any relevant gradient transforms, >>> get_compile_pipeline(circuit, level="gradient")(3.14) CompilePipeline(cancel_inverses, marker, merge_rotations, _expand_metric_tensor, metric_tensor, _expand_transform_param_shift) which in this case is ``_expand_transform_param_shift``, a transform that expands all trainable operations to a state where the parameter shift transform can operate on them. We can use ``qml.marker`` to further subdivide our compile pipeline into stages, >>> get_compile_pipeline(circuit, level="checkpoint")(3.14) CompilePipeline(cancel_inverses) If ``"top"`` or ``0`` are specified, an empty compile pipeline will be returned, >>> get_compile_pipeline(circuit, level=0)(3.14) CompilePipeline() >>> get_compile_pipeline(circuit, level="top")(3.14) CompilePipeline() Integer levels correspond to the number of transforms to retrieve from the compile pipeline, >>> get_compile_pipeline(circuit, level=3)(3.14) CompilePipeline(cancel_inverses, marker, merge_rotations) Slice levels enable you to extract a specific range of transformations in the compile pipeline. For example, we can retrieve the second to fourth transform by using a slice, >>> get_compile_pipeline(circuit, level=slice(1,4))(3.14) CompilePipeline(marker, merge_rotations, _expand_metric_tensor) """ if not isinstance(level, (int, slice, str)): raise ValueError( f"'level={level}' of type '{type(level)}' is not supported. Please provide an integer, slice or a string as input." ) @wraps(qnode) def wrapper(*args: P.args, **kwargs: P.kwargs) -> CompilePipeline: resolved_config = construct_execution_config(qnode, resolve=True)(*args, **kwargs) full_compile_pipeline = CompilePipeline() full_compile_pipeline += qnode.compile_pipeline # NOTE: User transforms that contain an informative transform by pass gradient + device transforms if not qnode.compile_pipeline.is_informative: outer_pipeline, inner_pipeline = _setup_transform_program(qnode.device, resolved_config) full_compile_pipeline += outer_pipeline + inner_pipeline num_user = len(qnode.compile_pipeline) level_slice: slice = _resolve_level(level, full_compile_pipeline, num_user, resolved_config) resolved_pipeline = full_compile_pipeline[level_slice] return resolved_pipeline return wrapper