2025-07-24 13:55:40 +02:00

260 lines
5.6 KiB
Nix

{
lib,
stdenv,
buildPythonPackage,
fetchFromGitHub,
# build-system
ninja,
setuptools,
which,
# dependencies
cloudpickle,
numpy,
packaging,
tensordict,
torch,
# optional-dependencies
# atari
gymnasium,
# checkpointing
torchsnapshot,
# gym-continuous
mujoco,
# llm
accelerate,
datasets,
einops,
immutabledict,
langdetect,
nltk,
playwright,
protobuf,
safetensors,
sentencepiece,
transformers,
vllm,
# offline-data
h5py,
huggingface-hub,
minari,
pandas,
pillow,
requests,
scikit-learn,
torchvision,
tqdm,
# rendering
moviepy,
# utils
git,
hydra-core,
tensorboard,
wandb,
# tests
imageio,
pytest-rerunfailures,
pytestCheckHook,
pyyaml,
scipy,
}:
buildPythonPackage rec {
pname = "torchrl";
version = "0.9.2";
pyproject = true;
src = fetchFromGitHub {
owner = "pytorch";
repo = "rl";
tag = "v${version}";
hash = "sha256-6rU5+J70T0E7+60jihsjwlLls8jJlxKi3nmrL0xm2c0=";
};
build-system = [
ninja
setuptools
which
];
dependencies = [
cloudpickle
numpy
packaging
tensordict
torch
];
optional-dependencies = {
atari = gymnasium.optional-dependencies.atari;
checkpointing = [ torchsnapshot ];
gym-continuous = [
gymnasium
mujoco
];
llm = [
accelerate
datasets
einops
immutabledict
langdetect
nltk
playwright
protobuf
safetensors
sentencepiece
transformers
vllm
];
offline-data = [
h5py
huggingface-hub
minari
pandas
pillow
requests
scikit-learn
torchvision
tqdm
];
rendering = [ moviepy ];
utils = [
git
hydra-core
tensorboard
tqdm
wandb
];
};
# torchrl needs to create a folder to store datasets
preBuild = ''
export D4RL_DATASET_DIR=$(mktemp -d)
'';
pythonImportsCheck = [ "torchrl" ];
# We have to delete the source because otherwise it is used instead of the installed package.
preCheck = ''
rm -rf torchrl
export XDG_RUNTIME_DIR=$(mktemp -d)
'';
nativeCheckInputs = [
h5py
gymnasium
imageio
pytest-rerunfailures
pytestCheckHook
pyyaml
scipy
torchvision
]
++ optional-dependencies.atari
++ optional-dependencies.gym-continuous
++ optional-dependencies.llm
++ optional-dependencies.rendering;
disabledTests = [
# Require network
"test_create_or_load_dataset"
"test_from_text_env_tokenizer"
"test_from_text_env_tokenizer_catframes"
"test_from_text_rb_slicesampler"
"test_generate"
"test_get_dataloader"
"test_get_scores"
"test_preproc_data"
"test_prompt_tensordict_tokenizer"
"test_reward_model"
"test_tensordict_tokenizer"
"test_transform_compose"
"test_transform_model"
"test_transform_no_env"
"test_transform_rb"
# ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment
"TestRayCollector"
# torchrl is incompatible with gymnasium>=1.0
# https://github.com/pytorch/rl/discussions/2483
"test_resetting_strategies"
"test_torchrl_to_gym"
"test_vecenvs_nan"
# gym.error.VersionNotFound: Environment version `v5` for environment `HalfCheetah` doesn't exist.
"test_collector_run"
"test_transform_inverse"
# OSError: Unable to synchronously create file (unable to truncate a file which is already open)
"test_multi_env"
"test_simple_env"
# ImportWarning: Ignoring non-library in plugin directory:
# /nix/store/cy8vwf1dacp3xfwnp9v6a1sz8bic8ylx-python3.12-mujoco-3.3.2/lib/python3.12/site-packages/mujoco/plugin/libmujoco.so.3.3.2
"test_auto_register"
"test_info_dict_reader"
# mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
"test_vecenvs_env"
# ValueError: Can't write images with one color channel.
"test_log_video"
# Those tests require the ALE environments (provided by unpackaged shimmy)
"test_collector_env_reset"
"test_gym"
"test_gym_fake_td"
"test_recorder"
"test_recorder_load"
"test_rollout"
"test_parallel_trans_env_check"
"test_serial_trans_env_check"
"test_single_trans_env_check"
"test_td_creation_from_spec"
"test_trans_parallel_env_check"
"test_trans_serial_env_check"
"test_transform_env"
# undeterministic
"test_distributed_collector_updatepolicy"
"test_timeit"
# On a 24 threads system
# assert torch.get_num_threads() == max(1, init_threads - 3)
# AssertionError: assert 23 == 21
"test_auto_num_threads"
# Flaky (hangs indefinitely on some CPUs)
"test_gae_multidim"
"test_gae_param_as_tensor"
]
++ lib.optionals (stdenv.hostPlatform.isLinux && stdenv.hostPlatform.isAarch64) [
# Flaky
# AssertionError: assert tensor([51.]) == ((5 * 11) + 2)
"test_vecnorm_parallel_auto"
];
disabledTestPaths = [
# ERROR collecting test/smoke_test.py
# import file mismatch:
# imported module 'smoke_test' has this __file__ attribute:
# /build/source/test/llm/smoke_test.py
# which is not the same as the test file we want to collect:
# /build/source/test/smoke_test.py
"test/llm"
];
meta = {
description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
homepage = "https://github.com/pytorch/rl";
changelog = "https://github.com/pytorch/rl/releases/tag/v${version}";
license = lib.licenses.mit;
maintainers = with lib.maintainers; [ GaetanLepage ];
};
}