TinyImageNet#
- class fl_sim.data_processing.TinyImageNet(datadir: Path | str | None = None, num_clients: int = 100, alpha: float = 0.5, transform: str | Callable | None = 'none', seed: int = 0)[source]#
Bases:
FedVisionDataset
Tiny ImageNet dataset.
The Tiny ImageNet dataset is a subset of the ImageNet dataset. It consists of 200 classes, each with 500 training images and 50 validation images and 50 test images. The images are downsampled to 64x64 pixels.
The original dataset [1] contains the test images while the hugingface dataset [3] does not contain the test images. We use the hugingface dataset [3] for simplicity, and treat the validation set as the test set.
- Parameters:
datadir (Union[pathlib.Path, str], optional) – Directory to store data. If
None
, use default directory.num_clients (int, default 100) – Number of clients.
alpha (float, default 0.5) – Concentration parameter for the Dirichlet distribution.
transform (Union[str, Callable], default "none") – Transform to apply to data. Conventions:
"none"
means no transform, using TensorDataset.seed (int, default 0) – Random seed for data partitioning.
**extra_config (dict, optional) – Extra configurations.
References
- evaluate(probs: Tensor, truths: Tensor) Dict[str, float] [source]#
Evaluation using predictions and ground truth.
- Parameters:
probs (torch.Tensor) – Predicted probabilities.
truths (torch.Tensor) – Ground truth labels.
- Returns:
Evaluation results.
- Return type:
- get_dataloader(train_bs: int | None = None, test_bs: int | None = None, client_idx: int | None = None) Tuple[DataLoader, DataLoader] [source]#
Get local dataloader at client client_idx or get the global dataloader.
- Parameters:
train_bs (int, optional) – Batch size for training dataloader. If
None
, use default batch size.test_bs (int, optional) – Batch size for testing dataloader. If
None
, use default batch size.client_idx (int, optional) – Index of the client to get dataloader. If
None
, get the dataloader containing all data. Usually used for centralized training.
- Returns:
train_dl (
torch.utils.data.DataLoader
) – Training dataloader.test_dl (
torch.utils.data.DataLoader
) – Testing dataloader.