torch.utils.rename_privateuse1_backend
-
torch.utils.rename_privateuse1_backend(backend_name)[source] -
Rename the privateuse1 backend device to make it more convenient to use as a device name within PyTorch APIs.
The steps are:
- (In C++) implement kernels for various torch operations, and register them to the PrivateUse1 dispatch key.
- (In python) call torch.utils.rename_privateuse1_backend(“foo”)
You can now use “foo” as an ordinary device string in python.
Note: this API can only be called once per process. Attempting to change the external backend after it’s already been set will result in an error.
Note(AMP): If you want to support AMP on your device, you can register a custom backend module. The backend must register a custom backend module with
torch._register_device_module("foo", BackendModule). BackendModule needs to have the following API’s:-
get_amp_supported_dtype() -> List[torch.dtype]get the supported dtypes on your “foo” device in AMP, maybe the “foo” device supports one more dtype.
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API’s:
-
_is_in_bad_fork() -> boolReturnTrueif now it is in bad_fork, else returnFalse. -
manual_seed_all(seed int) -> NoneSets the seed for generating random numbers for your devices. -
device_count() -> intReturns the number of “foo”s available. -
get_rng_state(device: Union[int, str, torch.device] = 'foo') -> TensorReturns a list of ByteTensor representing the random number states of all devices. -
set_rng_state(new_state: Tensor, device: Union[int, str, torch.device] = 'foo') -> NoneSets the random number generator state of the specified “foo” device.
And there are some common funcs:
-
is_available() -> boolReturns a bool indicating if “foo” is currently available. -
current_device() -> intReturns the index of a currently selected device.
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend For an existing example, see bdhirsh/pytorch_open_registration_example
Example:
>>> torch.utils.rename_privateuse1_backend("foo") # This will work, assuming that you've implemented the right C++ kernels # to implement torch.ones. >>> a = torch.ones(2, device="foo")