diff --git a/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix b/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix index 14c59787f06a..160158d0ed39 100644 --- a/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix +++ b/pkgs/development/python-modules/jax-cuda12-pjrt/default.nix @@ -2,7 +2,7 @@ lib, stdenv, buildPythonPackage, - fetchurl, + fetchPypi, addDriverRunpath, autoPatchelfHook, pypaInstallHook, @@ -31,30 +31,31 @@ let ] ); - # Find new releases at https://storage.googleapis.com/jax-releases - # When upgrading, you can get these hashes from jaxlib/prefetch.sh. See - # https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index. - - # upstream does not distribute jax-cuda12-pjrt binaries for aarch64-linux - srcs = { - "x86_64-linux" = fetchurl { - url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_x86_64.whl"; - hash = "sha256-xTeDBlaLoMgbIwp3ndMZTJ3RAzmrY2CugJKBCNN+f3U="; - }; - # "aarch64-linux" = fetchurl { - # url = "https://storage.googleapis.com/jax-releases/cuda12_plugin/jax_cuda12_pjrt-${version}-py3-none-manylinux2014_aarch64.whl"; - # hash = "sha256-AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="; - # }; - }; in -buildPythonPackage { +buildPythonPackage rec { pname = "jax-cuda12-pjrt"; inherit version; pyproject = false; - src = - srcs.${stdenv.hostPlatform.system} - or (throw "jax-cuda12-pjrt: No src for ${stdenv.hostPlatform.system}"); + src = fetchPypi { + pname = "jax_cuda12_pjrt"; + inherit version; + format = "wheel"; + python = "py3"; + dist = "py3"; + platform = + { + x86_64-linux = "manylinux2014_x86_64"; + aarch64-linux = "manylinux2014_aarch64"; + } + .${stdenv.hostPlatform.system}; + hash = + { + x86_64-linux = "sha256-aDcb2cE1JEuJZjA5viCCVWmKdb7JhU1BnqPD+VfKRkY= "; + aarch64-linux = "sha256-m/67BqOWFMtomfdzDqhWHxEVasgcuz7GiEpir7OxX/M="; + } + .${stdenv.hostPlatform.system}; + }; nativeBuildInputs = [ autoPatchelfHook @@ -97,7 +98,7 @@ buildPythonPackage { sourceProvenance = [ lib.sourceTypes.binaryNativeCode ]; license = lib.licenses.asl20; maintainers = with lib.maintainers; [ natsukium ]; - platforms = lib.attrNames srcs; + platforms = lib.platforms.linux; # see CUDA compatibility matrix # https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder broken = diff --git a/pkgs/development/python-modules/jax-cuda12-plugin/default.nix b/pkgs/development/python-modules/jax-cuda12-plugin/default.nix index a1bcf01218da..3103feb271e4 100644 --- a/pkgs/development/python-modules/jax-cuda12-plugin/default.nix +++ b/pkgs/development/python-modules/jax-cuda12-plugin/default.nix @@ -40,42 +40,42 @@ let "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; - hash = "sha256-uiVVln+bbDgci075+wPQW8Vewl7P7lz+RcWs4099QVI="; + hash = "sha256-pwDhcYI84lUQIALkDJR4j6ho8hYle30/BWjQn+dcEHs="; }; "3.10-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp310"; - hash = "sha256-YXGu0vSzvdX8E3gt4QcsamNPzhNzG3XQywpquPTm5lA="; + hash = "sha256-UwrYUcpGKZHOgtsmrUfwKwjOvkg8nI0MADfp4np7Up8="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; - hash = "sha256-qqcEpe9UdZXQItscHkh4oGdxFkEqk2DBFdZ/9LZOFZY="; + hash = "sha256-DZ7O3mbEAlhwKkImHoaM21ahA1UafDyISzX1Mcms1I4="; }; "3.11-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp311"; - hash = "sha256-KY0tdo8QKbdKCx0BJw5Uk0nSw33Adlh5ZULNqWfre9M="; + hash = "sha256-fNG0iKVKMInolYjMr2dwiZUsglKefQQD4LBQGZ5SVBg="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; - hash = "sha256-IDDPEgjOTqcO5WysYd3SOfl5hpX8Obt3OcUKJdbp2kQ="; + hash = "sha256-5w608IRpbD474StekJ7xIFyfVu/j3OzyYhvZtatZVNU="; }; "3.12-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp312"; - hash = "sha256-wlF6fCGG+HCIlGluJs+W69YLeHnOyjmLLEarso0slsg="; + hash = "sha256-oqOvX5iIDYb40kartGpVLlou9J12e/xKdMjDV3UgB8Y="; }; "3.13-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp313"; - hash = "sha256-GGJZWyttgVZ50R4OiJ5SMYXuVKRtRuAiaJ9w/EVU3ZE="; + hash = "sha256-6W891KlCUWroeMn2l+au/teOFI8JAYynPuKLI0JqfYo="; }; "3.13-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp313"; - hash = "sha256-If7BtWyYeD6gVpt0elZ1Hx+f8hh7SKzBHHANO/xeGjE="; + hash = "sha256-o0LyznxLH1nUA/Zlo1qGuGUCU7sl3jRkf7IlxFzrCgQ="; }; }; in diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index 30b6789692a9..b8168916df8d 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -40,7 +40,7 @@ let in buildPythonPackage rec { pname = "jax"; - version = "0.5.3"; + version = "0.6.0"; pyproject = true; src = fetchFromGitHub { @@ -48,7 +48,7 @@ buildPythonPackage rec { repo = "jax"; # google/jax contains tags for jax and jaxlib. Only use jax tags! tag = "jax-v${version}"; - hash = "sha256-t4LHwpCz08zrQGWBehyPs2JnxsOvtV3L14MCdTqMeEI="; + hash = "sha256-leadxLK21JF/cF/hMveUPgrwyumjR+zMm9Wsn7bWNLQ="; }; build-system = [ setuptools ]; diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index 4a8bfb8d31af..8c934c4af977 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -18,7 +18,7 @@ }: let - version = "0.5.3"; + version = "0.6.0"; inherit (python) pythonVersion; # As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the @@ -49,65 +49,65 @@ let "3.10-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp310"; - hash = "sha256-Ur5sl3Wv9zimEXDYwEdQXHW7eZpFUY5mp6CQgSexF4U="; + hash = "sha256-pNQlTHEziIh6MhN508WxogITqNzckD+vFRObqB4+zWE="; }; "3.10-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp310"; - hash = "sha256-lyQA20r26FJw2B215uYg0xOV8EcuUQxQ381Ms/crciA="; + hash = "sha256-VBpBi5iyjfW9Oh6TxistP2TUSwxwt7YI9/4rSqRSsq8="; }; "3.10-aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; dist = "cp310"; - hash = "sha256-SP9cifuKD+BNR16d3AdLSHmpHXq2ilHOxc0eh/gebEc="; + hash = "sha256-ZKgvjrQP23uh1G75BzANQuT5jL2pYCou2OcNsamsSmA="; }; "3.11-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp311"; - hash = "sha256-KeFTD8gYMyFvHii1eNDFlpdlT3LuMcekTtd1O69axGY="; + hash = "sha256-vtRVJeO7XsCGML/SB8Ca+dYun/E/XwfC7iz9jthBG6E="; }; "3.11-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp311"; - hash = "sha256-pGZvgdcsBg7T5YHe0RapyqmwpwoUilTLEqHTr8o2JLU="; + hash = "sha256-wK6VmJmALhMpzI7ForTUvpoHa1vrIFLrSbo3UU5iPrw="; }; "3.11-aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; dist = "cp311"; - hash = "sha256-tivYsp5aT5v6pXyNr24EggssmU9Ejz3sYC1kJVVF6fI="; + hash = "sha256-7xY88H3gC8VpAWnpf6+q3DePHDgfAofopHPnirW6sbU="; }; "3.12-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp312"; - hash = "sha256-Wl6IqxzW/feNaavjVE6PCcziAN0zm7hfvjwupn8qXmg="; + hash = "sha256-tthbjR/XkkiwRQNRcgHnL8vNOYDPeR036BRwnqUKPII="; }; "3.12-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp312"; - hash = "sha256-vd9jYDd6oceS5H/YfzB8NC4zHl/zWC+UCxvKAPa0vHM="; + hash = "sha256-JTb6k+wUjVAW2osgd7pmMlsNhqriKJphwSaHfwQrPRw="; }; "3.12-aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; dist = "cp312"; - hash = "sha256-05Tb3koca9Z1Ac+ynTgZoQuQDLU0zA/GAzGfcJLyTPo="; + hash = "sha256-fjzi7w7cm0izbicEw2GB8eznoSrBFN91PbQobqLG6Lg="; }; "3.13-x86_64-linux" = getSrcFromPypi { platform = "manylinux2014_x86_64"; dist = "cp313"; - hash = "sha256-u3WTy3//yxOWPyL6UintlguPtK5ew7CCAEjL1n8ejjE="; + hash = "sha256-0PsSLceDDKKlyjyHSghzY6AFMrZEUJwhnDv9HVRRXo0="; }; "3.13-aarch64-linux" = getSrcFromPypi { platform = "manylinux2014_aarch64"; dist = "cp313"; - hash = "sha256-6QS5Le37x+VFclqNdnaYcDCunAaQAdlHAbwQnG2rQQA="; + hash = "sha256-GJcpY5diBQwXgLBQ6Y/2IEgLHqMr8WdTPgAKXPTFc44="; }; "3.13-aarch64-darwin" = getSrcFromPypi { platform = "macosx_11_0_arm64"; dist = "cp313"; - hash = "sha256-MTIcJSgqBqbfyUBQe8FNCgrIONjO1sB6oAp/rjTOez8="; + hash = "sha256-xOl5NMuvUXI0OqWujvDFhGLOJhVN/adUICswNBYMrHs="; }; }; in