python3Packages.tensordict: 0.8.3 -> 0.9.0 (#423963)

This commit is contained in:
Gaétan Lepage 2025-07-14 00:31:38 +02:00 committed by GitHub
commit 0bdc43c96d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 80 additions and 13 deletions

View File

@ -15,6 +15,10 @@
pythonOlder, pythonOlder,
importlib-metadata, importlib-metadata,
# optional-dependencies
# atari
ale-py,
# tests # tests
array-api-compat, array-api-compat,
dill, dill,
@ -54,6 +58,12 @@ buildPythonPackage rec {
typing-extensions typing-extensions
] ++ lib.optionals (pythonOlder "3.10") [ importlib-metadata ]; ] ++ lib.optionals (pythonOlder "3.10") [ importlib-metadata ];
optional-dependencies = {
atari = [
ale-py
];
};
pythonImportsCheck = [ "gymnasium" ]; pythonImportsCheck = [ "gymnasium" ];
nativeCheckInputs = [ nativeCheckInputs = [

View File

@ -28,14 +28,14 @@
buildPythonPackage rec { buildPythonPackage rec {
pname = "tensordict"; pname = "tensordict";
version = "0.8.3"; version = "0.9.0";
pyproject = true; pyproject = true;
src = fetchFromGitHub { src = fetchFromGitHub {
owner = "pytorch"; owner = "pytorch";
repo = "tensordict"; repo = "tensordict";
tag = "v${version}"; tag = "v${version}";
hash = "sha256-d/6JKGFcFLXY9pxsnP27uwnAnIQ9EKvfTS30DCwQrCM="; hash = "sha256-actBFzWb2JBPsLhRZiD6zRpk7eyX2OHUPMU9JpJ90Wc=";
}; };
build-system = [ build-system = [

View File

@ -17,12 +17,26 @@
torch, torch,
# optional-dependencies # optional-dependencies
ale-py, # atari
gym,
pygame,
torchsnapshot,
gymnasium, gymnasium,
# checkpointing
torchsnapshot,
# gym-continuous
mujoco, mujoco,
# llm
accelerate,
datasets,
einops,
immutabledict,
langdetect,
nltk,
playwright,
protobuf,
safetensors,
sentencepiece,
transformers,
vllm,
# offline-data
h5py, h5py,
huggingface-hub, huggingface-hub,
minari, minari,
@ -32,7 +46,9 @@
scikit-learn, scikit-learn,
torchvision, torchvision,
tqdm, tqdm,
# rendering
moviepy, moviepy,
# utils
git, git,
hydra-core, hydra-core,
tensorboard, tensorboard,
@ -48,14 +64,14 @@
buildPythonPackage rec { buildPythonPackage rec {
pname = "torchrl"; pname = "torchrl";
version = "0.8.1"; version = "0.9.1";
pyproject = true; pyproject = true;
src = fetchFromGitHub { src = fetchFromGitHub {
owner = "pytorch"; owner = "pytorch";
repo = "rl"; repo = "rl";
tag = "v${version}"; tag = "v${version}";
hash = "sha256-ANoqIAVKSq023hG83Q71t8oLzud1LeVN5WVPYL3nOks="; hash = "sha256-afaWDX5lIAoGTfrBSqrktYoA1S4hv6ogBaKYHc8dQ6E=";
}; };
build-system = [ build-system = [
@ -73,16 +89,26 @@ buildPythonPackage rec {
]; ];
optional-dependencies = { optional-dependencies = {
atari = [ atari = gymnasium.optional-dependencies.atari;
ale-py
gym
pygame
];
checkpointing = [ torchsnapshot ]; checkpointing = [ torchsnapshot ];
gym-continuous = [ gym-continuous = [
gymnasium gymnasium
mujoco mujoco
]; ];
llm = [
accelerate
datasets
einops
immutabledict
langdetect
nltk
playwright
protobuf
safetensors
sentencepiece
transformers
vllm
];
offline-data = [ offline-data = [
h5py h5py
huggingface-hub huggingface-hub
@ -131,10 +157,31 @@ buildPythonPackage rec {
] ]
++ optional-dependencies.atari ++ optional-dependencies.atari
++ optional-dependencies.gym-continuous ++ optional-dependencies.gym-continuous
++ optional-dependencies.llm
++ optional-dependencies.rendering; ++ optional-dependencies.rendering;
disabledTests = disabledTests =
[ [
# Require network
"test_create_or_load_dataset"
"test_from_text_env_tokenizer"
"test_from_text_env_tokenizer_catframes"
"test_from_text_rb_slicesampler"
"test_generate"
"test_get_dataloader"
"test_get_scores"
"test_preproc_data"
"test_prompt_tensordict_tokenizer"
"test_reward_model"
"test_tensordict_tokenizer"
"test_transform_compose"
"test_transform_model"
"test_transform_no_env"
"test_transform_rb"
# ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment
"TestRayCollector"
# torchrl is incompatible with gymnasium>=1.0 # torchrl is incompatible with gymnasium>=1.0
# https://github.com/pytorch/rl/discussions/2483 # https://github.com/pytorch/rl/discussions/2483
"test_resetting_strategies" "test_resetting_strategies"
@ -194,6 +241,16 @@ buildPythonPackage rec {
"test_vecnorm_parallel_auto" "test_vecnorm_parallel_auto"
]; ];
disabledTestPaths = [
# ERROR collecting test/smoke_test.py
# import file mismatch:
# imported module 'smoke_test' has this __file__ attribute:
# /build/source/test/llm/smoke_test.py
# which is not the same as the test file we want to collect:
# /build/source/test/smoke_test.py
"test/llm"
];
meta = { meta = {
description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning"; description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
homepage = "https://github.com/pytorch/rl"; homepage = "https://github.com/pytorch/rl";