torch.dynamic-shape ======================= cond_branch_class_method ^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.cond `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond class MySubModule(torch.nn.Module): def foo(self, x): return x.cos() def forward(self, x): return self.foo(x) class CondBranchClassMethod(torch.nn.Module): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using class method in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def __init__(self): super().__init__() self.subm - MySubModule() def bar(self, x): return x.sin() def forward(self, x): return cond(x.shape[0] <- 2, self.subm.forward, self.bar, [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) eq - sym_size_int -- 3; sym_size_int - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default - None submodule_0 - self.submodule_0 submodule_1 - self.submodule_1 cond: f32[3] - torch.ops.higher_order.cond(False, submodule_0, submodule_1, [arg0_1]); submodule_0 - submodule_1 - arg0_1 - None return (cond,) class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3]): cos_default: f32[3] - torch.ops.aten.cos.default(arg0_1); arg0_1 - None return cos_default class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3]): sin_default: f32[3] - torch.ops.aten.sin.default(arg0_1); arg0_1 - None return sin_default Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} cond_branch_nested_function ^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.cond `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_branch_nested_function(x): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates using nested function in cond(). NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def true_fn(x): def inner_true_fn(y): return x + y return inner_true_fn(x) def false_fn(x): def inner_false_fn(y): return x - y return inner_false_fn(x) return cond(x.shape[0] < 10, true_fn, false_fn, [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) eq - sym_size_int -- 3; sym_size_int - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default - None submodule_0 - self.submodule_0 submodule_1 - self.submodule_1 cond: f32[3] - torch.ops.higher_order.cond(True, submodule_0, submodule_1, [arg0_1]); submodule_0 - submodule_1 - arg0_1 - None return (cond,) class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3]): add_tensor: f32[3] - torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 - None return add_tensor class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3]): sub_tensor: f32[3] - torch.ops.aten.sub.Tensor(arg0_1, arg0_1); arg0_1 - None return sub_tensor Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} cond_branch_nonlocal_variables ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.cond `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_branch_nonlocal_variables(x): """ The branch functions (`true_fn` and `false_fn`) passed to cond() must follow these rules: - both branches must take the same args, which must also match the branch args passed to cond. - both branches must return a single tensor - returned tensor must have the same tensor metadata, e.g. shape and dtype - branch function can be free function, nested function, lambda, class methods - branch function can not have closure variables - no inplace mutations on inputs or global variables This example demonstrates how to rewrite code to avoid capturing closure variables in branch functions. The code below will not work because capturing closure variables is not supported. ``` my_tensor_var - x + 100 my_primitive_var - 3.14 def true_fn(y): nonlocal my_tensor_var, my_primitive_var return y + my_tensor_var + my_primitive_var def false_fn(y): nonlocal my_tensor_var, my_primitive_var return y - my_tensor_var - my_primitive_var return cond(x.shape[0] > 5, true_fn, false_fn, [x]) ``` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ my_tensor_var - x + 100 my_primitive_var - 3.14 def true_fn(x, y, z): return x + y + z def false_fn(x, y, z): return x - y - z return cond( x.shape[0] > 5, true_fn, false_fn, [x, my_tensor_var, torch.tensor(my_primitive_var)], ) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[6]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) eq - sym_size_int -- 6; sym_size_int - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 6'); scalar_tensor_default - None add_tensor: f32[6] - torch.ops.aten.add.Tensor(arg0_1, 100) _tensor_constant0: f32[] - self._tensor_constant0 lift_fresh_copy_default: f32[] - torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 - None submodule_0 - self.submodule_0 submodule_1 - self.submodule_1 cond: f32[6] - torch.ops.higher_order.cond(True, submodule_0, submodule_1, [arg0_1, add_tensor, lift_fresh_copy_default]); submodule_0 - submodule_1 - arg0_1 - add_tensor - lift_fresh_copy_default - None return (cond,) class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[6], arg1_1: f32[6], arg2_1: f32[]): add_tensor: f32[6] - torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 - arg1_1 - None add_tensor_1: f32[6] - torch.ops.aten.add.Tensor(add_tensor, arg2_1); add_tensor - arg2_1 - None return add_tensor_1 class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[6], arg1_1: f32[6], arg2_1: f32[]): sub_tensor: f32[6] - torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 - arg1_1 - None sub_tensor_1: f32[6] - torch.ops.aten.sub.Tensor(sub_tensor, arg2_1); sub_tensor - arg2_1 - None return sub_tensor_1 Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} cond_operands ^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.cond `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from torch._export import dynamic_dim from functorch.experimental.control_flow import cond x - torch.randn(3, 2) y - torch.ones(2) dynamic_constraint - dynamic_dim(x, 0) def cond_operands(x, y): """ The operands passed to cond() must be: - a list of tensors - match arguments of `true_fn` and `false_fn` NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ def true_fn(x, y): return x + y def false_fn(x, y): return x - y return cond(x.shape[0] > 2, true_fn, false_fn, [x, y]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]): # sym_size_int: Sym(s0) - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) sym_size_int_2 - torch.ops.aten.sym_size.int(arg1_1, 0) eq - sym_size_int_2 -- 2; sym_size_int_2 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[0] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 - None sym_size: Sym(s0) - torch.ops.aten.sym_size.int(arg0_1, 0) gt: Sym(s0 > 2) - sym_size > 2; sym_size - None submodule_0 - self.submodule_0 submodule_1 - self.submodule_1 cond: f32[s0, 2] - torch.ops.higher_order.cond(gt, submodule_0, submodule_1, [arg0_1, arg1_1]); gt - submodule_0 - submodule_1 - arg0_1 - arg1_1 - None return (cond,) class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]): add_tensor: f32[s0, 2] - torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 - arg1_1 - None return add_tensor class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[s0, 2], arg1_1: f32[2]): sub_tensor: f32[s0, 2] - torch.ops.aten.sub.Tensor(arg0_1, arg1_1); arg0_1 - arg1_1 - None return sub_tensor Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {s0: RangeConstraint(min_val-2, max_val-oo)} cond_predicate ^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.cond `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import cond def cond_predicate(x): """ The conditional statement (aka predicate) passed to cond() must be one of the following: - torch.Tensor with a single element - boolean expression NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ pred - x.dim() > 2 and x.shape[2] > 10 return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[6, 4, 3]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) sym_size_int_2 - torch.ops.aten.sym_size.int(arg0_1, 2) eq - sym_size_int_2 -- 3; sym_size_int_2 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[2] is specialized at 3'); scalar_tensor_default - None eq_1 - sym_size_int_1 -- 4; sym_size_int_1 - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 4'); scalar_tensor_default_1 - None eq_2 - sym_size_int -- 6; sym_size_int - None scalar_tensor_default_2: f32[] - torch.ops.aten.scalar_tensor.default(eq_2); eq_2 - None _assert_async_msg_2 - torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 6'); scalar_tensor_default_2 - None submodule_0 - self.submodule_0 submodule_1 - self.submodule_1 cond: f32[6, 4, 3] - torch.ops.higher_order.cond(False, submodule_0, submodule_1, [arg0_1]); submodule_0 - submodule_1 - arg0_1 - None return (cond,) class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[6, 4, 3]): cos_default: f32[6, 4, 3] - torch.ops.aten.cos.default(arg0_1); arg0_1 - None return cos_default class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[6, 4, 3]): sin_default: f32[6, 4, 3] - torch.ops.aten.sin.default(arg0_1); arg0_1 - None return sin_default Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cond'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} dynamic_shape_constructor ^^^^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def dynamic_shape_constructor(x): """ Tensor constructors should be captured with dynamic shape inputs rather than being baked in with static shape. """ return torch.ones(x.shape[0] * 2) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1); arg0_1 - None eq - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None full_default: f32[6] - torch.ops.aten.full.default([6], 1, dtype - torch.float32, layout - torch.strided, device - device(type-'cpu'), pin_memory - False) return (full_default,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['full_default'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} dynamic_shape_if_guard ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.control-flow `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch class DynamicShapeIfGuard(torch.nn.Module): """ `if` statement with backed dynamic shape predicate will be specialized into one particular branch and generate a guard. However, export will fail if the the dimension is marked as dynamic shape from higher level API. """ def forward(self, x): if x.shape[0] -- 3: return x.cos() return x.sin() Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2, 2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) sym_size_int_2 - torch.ops.aten.sym_size.int(arg0_1, 2) eq - sym_size_int_2 -- 2; sym_size_int_2 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[2] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 - None eq_2 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_2: f32[] - torch.ops.aten.scalar_tensor.default(eq_2); eq_2 - None _assert_async_msg_2 - torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 - None cos_default: f32[3, 2, 2] - torch.ops.aten.cos.default(arg0_1); arg0_1 - None return (cos_default,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['cos_default'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} dynamic_shape_map ^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.map `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from functorch.experimental.control_flow import map def dynamic_shape_map(xs, y): """ functorch map() maps a function over the first tensor dimension. """ def body(x, y): return x + y return map(body, xs, y) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) sym_size_int_2 - torch.ops.aten.sym_size.int(arg1_1, 0) eq - sym_size_int_2 -- 2; sym_size_int_2 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg1_1.shape[0] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default_1 - None eq_2 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_2: f32[] - torch.ops.aten.scalar_tensor.default(eq_2); eq_2 - None _assert_async_msg_2 - torch.ops.aten._assert_async.msg(scalar_tensor_default_2, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_2 - None submodule_0 - self.submodule_0 map_impl - torch.ops.map_impl(submodule_0, 1, arg0_1, arg1_1); submodule_0 - arg0_1 - arg1_1 - None getitem: f32[3, 2] - map_impl[0]; map_impl - None return (getitem,) class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2], arg1_1: f32[2]): add_tensor: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 - arg1_1 - None return [add_tensor] Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1'], user_outputs-['getitem'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} dynamic_shape_round ^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.builtin `, :doc:`torch.dynamic-shape ` Support Level: NOT_SUPPORTED_YET Original source code: .. code-block:: python import torch from torch._export import dynamic_dim x - torch.ones(3, 2) dynamic_constraint - dynamic_dim(x, 0) def dynamic_shape_round(x): """ Calling round on dynamic shapes is not supported. """ return x[: round(x.shape[0] / 2)] Result: .. code-block:: Unsupported: Calling round() on symbolic value is not supported. You can use floor() to implement this functionality dynamic_shape_slicing ^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def dynamic_shape_slicing(x): """ Slices with dynamic shape arguments should be captured into the graph rather than being baked in. """ return x[: x.shape[0] - 2, x.shape[1] - 1 :: 2] Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) eq - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None slice_tensor: f32[1, 2] - torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 1); arg0_1 - None slice_tensor_1: f32[1, 1] - torch.ops.aten.slice.Tensor(slice_tensor, 1, 1, 9223372036854775807, 2); slice_tensor - None return (slice_tensor_1,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['slice_tensor_1'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} dynamic_shape_view ^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def dynamic_shape_view(x): """ Dynamic shapes should be propagated to view arguments instead of being baked into the exported graph. """ new_x_shape - x.size()[:-1] + (2, 5) x - x.view(*new_x_shape) return x.permute(0, 2, 1) Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[10, 10]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) eq - sym_size_int_1 -- 10; sym_size_int_1 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 10'); scalar_tensor_default - None eq_1 - sym_size_int -- 10; sym_size_int - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 10'); scalar_tensor_default_1 - None view_default: f32[10, 2, 5] - torch.ops.aten.view.default(arg0_1, [10, 2, 5]); arg0_1 - None permute_default: f32[10, 5, 2] - torch.ops.aten.permute.default(view_default, [0, 2, 1]); view_default - None return (permute_default,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['permute_default'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} list_contains ^^^^^^^^^^^^^ .. note:: Tags: :doc:`python.assert `, :doc:`python.data-structure `, :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch def list_contains(x): """ List containment relation can be checked on a dynamic shape or constants. """ assert x.size(-1) in [6, 2] assert x.size(0) not in [4, 5, 6] assert "monkey" not in ["cow", "pig"] return x + x Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) eq - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None add_tensor: f32[3, 2] - torch.ops.aten.add.Tensor(arg0_1, arg0_1); arg0_1 - None return (add_tensor,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['add_tensor'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {} scalar_output ^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.dynamic-shape ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch from torch._export import dynamic_dim x - torch.ones(3, 2) dynamic_constraint - dynamic_dim(x, 1) def scalar_output(x): """ Returning scalar values from the graph is supported, in addition to Tensor outputs. Symbolic shapes are captured and rank is specialized. """ return x.shape[1] + 1 Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, s0]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1: Sym(s0) - torch.ops.aten.sym_size.int(arg0_1, 1) eq - sym_size_int -- 3; sym_size_int - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default - None sym_size: Sym(s0) - torch.ops.aten.sym_size.int(arg0_1, 1); arg0_1 - None add: Sym(s0 + 1) - sym_size + 1; sym_size - None return (add,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1'], user_outputs-['add'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {s0: RangeConstraint(min_val-2, max_val-oo)}