cast_if_src_dtype¶
Module Name: cast_if_src_dtype
¶
Description¶
cast_if_src_dtype
is a utility function that checks the data type (dtype
) of a given tensor. If the tensor's dtype
matches the provided source dtype
(src_dtype
), the function will cast the tensor to the target dtype
(tgt_dtype
). After the casting operation, the function returns the updated tensor and a boolean
flag indicating whether the tensor data type was updated.
This function provides a convenient way to enforce specific data types for torch tensors.
Class/Function Signature in Pytorch¶
def cast_if_src_dtype(
tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
):
updated = False
if tensor.dtype == src_dtype:
tensor = tensor.to(dtype=tgt_dtype)
updated = True
return tensor, updated
Parameters¶
Parameter | Type | Description |
---|---|---|
tensor |
torch.Tensor |
The tensor whose data type is to be checked and potentially updated. |
src_dtype |
torch.dtype |
The source data type that should trigger the casting operation. |
tgt_dtype |
torch.dtype |
The target data type that the tensor will be cast into if the source data type matches its data type. |
Functionality and Use¶
Functionality: cast_if_src_dtype
takes in three parameters: a tensor, a source data type, and a target data type. If the data type of the tensor equals the source data type, the function casts this tensor to the target data type. The function then returns both the potentially modified tensor and a flag indicating whether the cast was performed.
Usage: This utility function is used when certain operations or functions require inputs of a specific data type. A common scenario is when tensors with floating-point data types need to be converted to integers or vice versa.
Usage Examples¶
Below are some examples of how the function could be used:
Example 1¶
import torch
from zeta.utils import cast_if_src_dtype
# Given: a float tensor
tensor = torch.tensor([1.0, 2.0, 3.0])
# We want to convert it to integer type tensor if its data type is float32
tensor, updated = cast_if_src_dtype(tensor, torch.float32, torch.int32)
print(tensor) # tensor([1, 2, 3], dtype=torch.int32)
print(updated) # True
Example 2¶
import torch
from zeta.utils import cast_if_src_dtype
# Given: an integer tensor
tensor = torch.tensor([1, 2, 3])
# We want to convert it to float type tensor if its data type is int32
tensor, updated = cast_if_src_dtype(tensor, torch.int32, torch.float32)
print(tensor) # tensor([1.0, 2.0, 3.0])
print(updated) # True
Example 3¶
import torch
from zeta.utils import cast_if_src_dtype
# Given: an integer tensor
tensor = torch.tensor([1, 2, 3])
# If the data type is not equal to the source data type, the tensor will remain the same
tensor, updated = cast_if_src_dtype(tensor, torch.float32, torch.int32)
print(tensor) # tensor([1, 2, 3])
print(updated) # False
Resources and References¶
For more information on tensor operations and data types in PyTorch, refer to the official PyTorch documentation:
Note¶
The cast_if_src_dtype
function doesn't modify the original tensor in-place. Instead, it creates a new tensor with the updated data type. Keep that in mind during function calls, and be sure to substitute the original tensor with the returned tensor to reflect the change in the rest of your code.