torch.escape-hatch ====================== assume_constant_result ^^^^^^^^^^^^^^^^^^^^^^ .. note:: Tags: :doc:`torch.escape-hatch ` Support Level: SUPPORTED Original source code: .. code-block:: python import torch import torch._dynamo as torchdynamo class AssumeConstantResult(torch.nn.Module): """ Applying `assume_constant_result` decorator to burn make non-tracable code as constant. """ def __init__(self): super().__init__() @torchdynamo.assume_constant_result def get_item(self, y): return y.int().item() def forward(self, x, y): return x[: self.get_item(y)] Result: .. code-block:: ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, arg0_1: f32[3, 2], arg1_1: i64[]): # sym_size_int - torch.ops.aten.sym_size.int(arg0_1, 0) sym_size_int_1 - torch.ops.aten.sym_size.int(arg0_1, 1) eq - sym_size_int_1 -- 2; sym_size_int_1 - None scalar_tensor_default: f32[] - torch.ops.aten.scalar_tensor.default(eq); eq - None _assert_async_msg - torch.ops.aten._assert_async.msg(scalar_tensor_default, 'Input arg0_1.shape[1] is specialized at 2'); scalar_tensor_default - None eq_1 - sym_size_int -- 3; sym_size_int - None scalar_tensor_default_1: f32[] - torch.ops.aten.scalar_tensor.default(eq_1); eq_1 - None _assert_async_msg_1 - torch.ops.aten._assert_async.msg(scalar_tensor_default_1, 'Input arg0_1.shape[0] is specialized at 3'); scalar_tensor_default_1 - None slice_tensor: f32[3, 2] - torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 4); arg0_1 - None return (slice_tensor,) Graph Signature: ExportGraphSignature(parameters-[], buffers-[], user_inputs-['arg0_1', 'arg1_1'], user_outputs-['slice_tensor'], inputs_to_parameters-{}, inputs_to_buffers-{}, buffers_to_mutate-{}, backward_signature-None, assertion_dep_token-None) Symbol to range: {}