python.control-flow ======================= dynamic_shape_if_guard ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.control-flow `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class DynamicShapeIfGuard(torch.nn.Module): """ `if` statement with backed dynamic shape predicate will be specialized into one particular branch and generate a guard. However, export will fail if the the dimension is marked as dynamic shape from higher level API. """ def forward(self, x): if x.shape[0] -- 3: return x.cos() return x.sin() Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2, 2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) sym_size_int_2 - torch.ops.aten.sym_size.int(arg0_1, 2) eq - sym_size_int_2 -- 2; sym_size_int_2 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[2] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 - None eq_2 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_2: f32[] - torch.ops.aten.scalar_tensor.default(eq_2); eq_2 - None _assert_async_msg_2 - torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 - None cos_default: f32[3, 2, 2] - torch.ops.aten.cos.default(arg0_1); arg0_1 - None return (cos_default,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cos_default'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} list_unpack ^^^^^^^^^^^ .. note:: Tags: :doc:`python.control-flow `, :doc:`python.data-structure ` Support Level: SUPPORTED Original source code: .. code-block:: python from typing import List import torch def list_unpack(args: List[torch.Tensor]): """ Lists are treated as static construct, therefore unpacking should be erased after tracing. """ x, *y - args return x + y[0] Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2], arg1_1: i64[], arg2_1: i64[]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) eq - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None add_tensor: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 - arg1_1 - None return (add_tensor,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1', 'arg2_1'], user_outputs-['add_tensor'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} static_for_loop ^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.control-flow ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class StaticForLoop(torch.nn.Module): """ A for loop with constant number of iterations should be unrolled in the exported graph. """ def __init__(self): super().__init__() def forward(self, x): ret - [] for i in range(10): # constant ret.append(i + x) return ret Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) eq - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None add_tensor: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 0) add_tensor_1: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 1) add_tensor_2: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 2) add_tensor_3: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 3) add_tensor_4: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 4) add_tensor_5: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 5) add_tensor_6: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 6) add_tensor_7: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 7) add_tensor_8: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 8) add_tensor_9: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, 9); arg0_1 - None return (add_tensor, add_tensor_1, add_tensor_2, add_tensor_3, add_tensor_4, add_tensor_5, add_tensor_6, add_tensor_7, add_tensor_8, add_tensor_9) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['add_tensor', 'add_tensor_1', 'add_tensor_2', 'add_tensor_3', 'add_tensor_4', 'add_tensor_5', 'add_tensor_6', 'add_tensor_7', 'add_tensor_8', 'add_tensor_9'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} static_if ^^^^^^^^^ .. note:: Tags: :doc:`python.control-flow ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class StaticIf(torch.nn.Module): """ `if` statement with static predicate value should be traced through with the taken branch. """ def __init__(self): super().__init__() def forward(self, x): if len(x.shape) -- 3: return x + torch.ones(1, 1, 1) return x Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2, 2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) sym_size_int_2 - torch.ops.aten.sym_size.int(arg0_1, 2) eq - sym_size_int_2 -- 2; sym_size_int_2 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[2] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 - None eq_2 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_2: f32[] - torch.ops.aten.scalar_tensor.default(eq_2); eq_2 - None _assert_async_msg_2 - torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 - None full_default: f32[1, 1, 1] - torch.ops.aten.full.default([1, 1, 1], 1, dtype - torch.float32, layout - torch.strided, device - device(type-'cpu'), pin_memory - False) add_tensor: f32[3, 2, 2] - torch.ops.aten.add.Tensor(arg0_1, full_default); arg0_1 - full_default - None return (add_tensor,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['add_tensor'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {}