Pytorch Hub is a pre-trained model repository designed to facilitate research reproducibility.
Pytorch Hub supports publishing pre-trained models(model definitions and pre-trained weights) to a github repository by adding a simple hubconf.py
file;
hubconf.py
can have multiple entrypoints. Each entrypoint is defined as a python function (example: a pre-trained model you want to publish).
def entrypoint_name(*args, **kwargs): # args & kwargs are optional, for models which take positional/keyword arguments. ...
Here is a code snippet specifies an entrypoint for resnet18
model if we expand the implementation in pytorch/vision/hubconf.py
. In most case importing the right function in hubconf.py
is sufficient. Here we just want to use the expanded version as an example to show how it works. You can see the full script in pytorch/vision repo
dependencies = ['torch'] from torchvision.models.resnet import resnet18 as _resnet18 # resnet18 is the name of entrypoint def resnet18(pretrained=False, **kwargs): """ # This docstring shows up in hub.help() Resnet18 model pretrained (bool): kwargs, load pretrained weights into the model """ # Call the model, load pretrained weights model = _resnet18(pretrained=pretrained, **kwargs) return model
dependencies
variable is a list of package names required to load the model. Note this might be slightly different from dependencies required for training a model.args
and kwargs
are passed along to the real callable function.torch.hub.list()
.torch.hub.load_state_dict_from_url()
. If less than 2GB, it’s recommended to attach it to a project release and use the url from the release. In the example above torchvision.models.resnet.resnet18
handles pretrained
, alternatively you can put the following logic in the entrypoint definition.if pretrained: # For checkpoint saved in local github repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth dirname = os.path.dirname(__file__) checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>) state_dict = torch.load(checkpoint) model.load_state_dict(state_dict) # For checkpoint saved elsewhere checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth' model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))
Pytorch Hub provides convenient APIs to explore all available models in hub through torch.hub.list()
, show docstring and examples through torch.hub.help()
and load the pre-trained models using torch.hub.load()
.
torch.hub.list(github, force_reload=False)
[source]
List all entrypoints available in github
hubconf.
master
if not specified. Example: ‘pytorch/vision[:hub]’False
.a list of available entrypoint names
entrypoints
>>> entrypoints = torch.hub.list('pytorch/vision', force_reload=True)
torch.hub.help(github, model, force_reload=False)
[source]
Show the docstring of entrypoint model
.
master
if not specified. Example: ‘pytorch/vision[:hub]’False
.>>> print(torch.hub.help('pytorch/vision', 'resnet18', force_reload=True))
torch.hub.load(repo_or_dir, model, *args, **kwargs)
[source]
Load a model from a github repo or a local directory.
Note: Loading a model is the typical use case, but this can also be used to for loading other objects such as tokenizers, loss functions, etc.
If source
is 'github'
, repo_or_dir
is expected to be of the form repo_owner/repo_name[:tag_name]
with an optional tag/branch.
If source
is 'local'
, repo_or_dir
is expected to be a path to a local directory.
repo_owner/repo_name[:tag_name]
), if source = 'github'
; or a path to a local directory, if source = 'local'
.hubconf.py
.model
.'github'
| 'local'
. Specifies how repo_or_dir
is to be interpreted. Default is 'github'
.source = 'local'
. Default is False
.False
, mute messages about hitting local caches. Note that the message about first download cannot be muted. Does not have any effect if source = 'local'
. Default is True
.model
.The output of the model
callable when called with the given *args
and **kwargs
.
>>> # from a github repo >>> repo = 'pytorch/vision' >>> model = torch.hub.load(repo, 'resnet50', pretrained=True) >>> # from a local directory >>> path = '/some/local/path/pytorch/vision' >>> model = torch.hub.load(path, 'resnet50', pretrained=True)
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)
[source]
Download object at the given URL to a local path.
/tmp/temporary_file
hash_prefix
. Default: None>>> torch.hub.download_url_to_file('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', '/tmp/temporary_file')
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None)
[source]
Loads the Torch serialized object at the given URL.
If downloaded file is a zip file, it will be automatically decompressed.
If the object is already present in model_dir
, it’s deserialized and returned. The default value of model_dir
is <hub_dir>/checkpoints
where hub_dir
is the directory returned by get_dir()
.
filename-<sha256>.ext
where <sha256>
is the first eight or more digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: Falseurl
will be used if not set.>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
Note that *args
and **kwargs
in torch.hub.load()
are used to instantiate a model. After you have loaded a model, how can you find out what you can do with the model? A suggested workflow is
dir(model)
to see all available methods of the model.help(model.foo)
to check what arguments model.foo
takes to runTo help users explore without referring to documentation back and forth, we strongly recommend repo owners make function help messages clear and succinct. It’s also helpful to include a minimal working example.
The locations are used in the order of
hub.set_dir(<PATH_TO_HUB_DIR>)
$TORCH_HOME/hub
, if environment variable TORCH_HOME
is set.$XDG_CACHE_HOME/torch/hub
, if environment variable XDG_CACHE_HOME
is set.~/.cache/torch/hub
torch.hub.get_dir()
[source]
Get the Torch Hub cache directory used for storing downloaded models & weights.
If set_dir()
is not called, default path is $TORCH_HOME/hub
where environment variable $TORCH_HOME
defaults to $XDG_CACHE_HOME/torch
. $XDG_CACHE_HOME
follows the X Design Group specification of the Linux filesystem layout, with a default value ~/.cache
if the environment variable is not set.
torch.hub.set_dir(d)
[source]
Optionally set the Torch Hub directory used to save downloaded models & weights.
d (string) – path to a local folder to save downloaded models & weights.
By default, we don’t clean up files after loading it. Hub uses the cache by default if it already exists in the directory returned by get_dir()
.
Users can force a reload by calling hub.load(..., force_reload=True)
. This will delete the existing github folder and downloaded weights, reinitialize a fresh download. This is useful when updates are published to the same branch, users can keep up with the latest release.
Torch hub works by importing the package as if it was installed. There’re some side effects introduced by importing in Python. For example, you can see new items in Python caches sys.modules
and sys.path_importer_cache
which is normal Python behavior.
A known limitation that worth mentioning here is user CANNOT load two different branches of the same repo in the same python process. It’s just like installing two packages with the same name in Python, which is not good. Cache might join the party and give you surprises if you actually try that. Of course it’s totally fine to load them in separate processes.
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.7.0/hub.html