python3Packages.tensordict: 0.8.3 -> 0.9.0 (#423963)
This commit is contained in:
commit
0bdc43c96d
@ -15,6 +15,10 @@
|
||||
pythonOlder,
|
||||
importlib-metadata,
|
||||
|
||||
# optional-dependencies
|
||||
# atari
|
||||
ale-py,
|
||||
|
||||
# tests
|
||||
array-api-compat,
|
||||
dill,
|
||||
@ -54,6 +58,12 @@ buildPythonPackage rec {
|
||||
typing-extensions
|
||||
] ++ lib.optionals (pythonOlder "3.10") [ importlib-metadata ];
|
||||
|
||||
optional-dependencies = {
|
||||
atari = [
|
||||
ale-py
|
||||
];
|
||||
};
|
||||
|
||||
pythonImportsCheck = [ "gymnasium" ];
|
||||
|
||||
nativeCheckInputs = [
|
||||
|
||||
@ -28,14 +28,14 @@
|
||||
|
||||
buildPythonPackage rec {
|
||||
pname = "tensordict";
|
||||
version = "0.8.3";
|
||||
version = "0.9.0";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "pytorch";
|
||||
repo = "tensordict";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-d/6JKGFcFLXY9pxsnP27uwnAnIQ9EKvfTS30DCwQrCM=";
|
||||
hash = "sha256-actBFzWb2JBPsLhRZiD6zRpk7eyX2OHUPMU9JpJ90Wc=";
|
||||
};
|
||||
|
||||
build-system = [
|
||||
|
||||
@ -17,12 +17,26 @@
|
||||
torch,
|
||||
|
||||
# optional-dependencies
|
||||
ale-py,
|
||||
gym,
|
||||
pygame,
|
||||
torchsnapshot,
|
||||
# atari
|
||||
gymnasium,
|
||||
# checkpointing
|
||||
torchsnapshot,
|
||||
# gym-continuous
|
||||
mujoco,
|
||||
# llm
|
||||
accelerate,
|
||||
datasets,
|
||||
einops,
|
||||
immutabledict,
|
||||
langdetect,
|
||||
nltk,
|
||||
playwright,
|
||||
protobuf,
|
||||
safetensors,
|
||||
sentencepiece,
|
||||
transformers,
|
||||
vllm,
|
||||
# offline-data
|
||||
h5py,
|
||||
huggingface-hub,
|
||||
minari,
|
||||
@ -32,7 +46,9 @@
|
||||
scikit-learn,
|
||||
torchvision,
|
||||
tqdm,
|
||||
# rendering
|
||||
moviepy,
|
||||
# utils
|
||||
git,
|
||||
hydra-core,
|
||||
tensorboard,
|
||||
@ -48,14 +64,14 @@
|
||||
|
||||
buildPythonPackage rec {
|
||||
pname = "torchrl";
|
||||
version = "0.8.1";
|
||||
version = "0.9.1";
|
||||
pyproject = true;
|
||||
|
||||
src = fetchFromGitHub {
|
||||
owner = "pytorch";
|
||||
repo = "rl";
|
||||
tag = "v${version}";
|
||||
hash = "sha256-ANoqIAVKSq023hG83Q71t8oLzud1LeVN5WVPYL3nOks=";
|
||||
hash = "sha256-afaWDX5lIAoGTfrBSqrktYoA1S4hv6ogBaKYHc8dQ6E=";
|
||||
};
|
||||
|
||||
build-system = [
|
||||
@ -73,16 +89,26 @@ buildPythonPackage rec {
|
||||
];
|
||||
|
||||
optional-dependencies = {
|
||||
atari = [
|
||||
ale-py
|
||||
gym
|
||||
pygame
|
||||
];
|
||||
atari = gymnasium.optional-dependencies.atari;
|
||||
checkpointing = [ torchsnapshot ];
|
||||
gym-continuous = [
|
||||
gymnasium
|
||||
mujoco
|
||||
];
|
||||
llm = [
|
||||
accelerate
|
||||
datasets
|
||||
einops
|
||||
immutabledict
|
||||
langdetect
|
||||
nltk
|
||||
playwright
|
||||
protobuf
|
||||
safetensors
|
||||
sentencepiece
|
||||
transformers
|
||||
vllm
|
||||
];
|
||||
offline-data = [
|
||||
h5py
|
||||
huggingface-hub
|
||||
@ -131,10 +157,31 @@ buildPythonPackage rec {
|
||||
]
|
||||
++ optional-dependencies.atari
|
||||
++ optional-dependencies.gym-continuous
|
||||
++ optional-dependencies.llm
|
||||
++ optional-dependencies.rendering;
|
||||
|
||||
disabledTests =
|
||||
[
|
||||
# Require network
|
||||
"test_create_or_load_dataset"
|
||||
"test_from_text_env_tokenizer"
|
||||
"test_from_text_env_tokenizer_catframes"
|
||||
"test_from_text_rb_slicesampler"
|
||||
"test_generate"
|
||||
"test_get_dataloader"
|
||||
"test_get_scores"
|
||||
"test_preproc_data"
|
||||
"test_prompt_tensordict_tokenizer"
|
||||
"test_reward_model"
|
||||
"test_tensordict_tokenizer"
|
||||
"test_transform_compose"
|
||||
"test_transform_model"
|
||||
"test_transform_no_env"
|
||||
"test_transform_rb"
|
||||
|
||||
# ray.exceptions.RuntimeEnvSetupError: Failed to set up runtime environment
|
||||
"TestRayCollector"
|
||||
|
||||
# torchrl is incompatible with gymnasium>=1.0
|
||||
# https://github.com/pytorch/rl/discussions/2483
|
||||
"test_resetting_strategies"
|
||||
@ -194,6 +241,16 @@ buildPythonPackage rec {
|
||||
"test_vecnorm_parallel_auto"
|
||||
];
|
||||
|
||||
disabledTestPaths = [
|
||||
# ERROR collecting test/smoke_test.py
|
||||
# import file mismatch:
|
||||
# imported module 'smoke_test' has this __file__ attribute:
|
||||
# /build/source/test/llm/smoke_test.py
|
||||
# which is not the same as the test file we want to collect:
|
||||
# /build/source/test/smoke_test.py
|
||||
"test/llm"
|
||||
];
|
||||
|
||||
meta = {
|
||||
description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning";
|
||||
homepage = "https://github.com/pytorch/rl";
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user