Shortcuts

Source code for bundled_program.config

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from dataclasses import dataclass
from typing import get_args, List, Optional, Sequence, Union

import torch
from torch.utils._pytree import tree_flatten

from typing_extensions import TypeAlias

"""
The data types currently supported for element to be bundled. It should be
consistent with the types in bundled_program.schema.Value.
"""
ConfigValue: TypeAlias = Union[
    torch.Tensor,
    int,
    bool,
    float,
]

"""
The data type of the input for method single execution.
"""
MethodInputType: TypeAlias = Sequence[ConfigValue]

"""
The data type of the output for method single execution.
"""
MethodOutputType: TypeAlias = Sequence[torch.Tensor]

"""
All supported types for input/expected output of MethodTestCase.

Namedtuple is also supported and listed implicity since it is a subclass of tuple.
"""

# pyre-ignore
DataContainer: TypeAlias = Union[list, tuple, dict]


class MethodTestCase:
    """Test case with inputs and expected outputs
    The expected_outputs are optional and only required if the user wants to verify model outputs after execution."""

    def __init__(
        self,
        inputs: MethodInputType,
        expected_outputs: Optional[MethodOutputType] = None,
    ) -> None:
        """Single test case for verifying specific method

        Args:
            input: All inputs required by eager_model with specific inference method for one-time execution.

                    It is worth mentioning that, although both bundled program and ET runtime apis support setting input
                    other than torch.tensor type, only the input in torch.tensor type will be actually updated in
                    the method, and the rest of the inputs will just do a sanity check if they match the default value in method.

            expected_output: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling.

        Returns:
            self
        """
        # TODO(gasoonjia): Update type check logic.
        # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
        self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs)
        self.expected_outputs: List[ConfigValue] = []
        if expected_outputs is not None:
            # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sannity check.
            self.expected_outputs = self._flatten_and_sanity_check(expected_outputs)

    def _flatten_and_sanity_check(
        self, unflatten_data: DataContainer
    ) -> List[ConfigValue]:
        """Flat the given data and check its legality

        Args:
            unflatten_data: Data needs to be flatten.

        Returns:
            flatten_data: Flatten data with legal type.
        """

        flatten_data, _ = tree_flatten(unflatten_data)

        for data in flatten_data:
            assert isinstance(
                data,
                get_args(ConfigValue),
            ), "The type of input {} with type {} is not supported.\n".format(
                data, type(data)
            )
            assert not isinstance(
                data,
                type(None),
            ), "The input {} should not be in null type.\n".format(data)

        return flatten_data


[docs]@dataclass class MethodTestSuite: """All test info related to verify method Attributes: method_name: Name of the method to be verified. test_cases: All test cases for verifying the method. """ method_name: str test_cases: Sequence[MethodTestCase]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources