python3Packages.warp-lang: fix CUDA build (#419750)

This commit is contained in:
Connor Baker 2025-06-27 15:44:24 -07:00 committed by GitHub
commit c0874d923d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,20 +1,24 @@
{ {
config, autoAddDriverRunpath,
lib,
stdenv,
buildPythonPackage, buildPythonPackage,
fetchurl, config,
fetchFromGitHub,
replaceVars,
build,
setuptools,
numpy,
llvmPackages,
cudaPackages, cudaPackages,
unittestCheckHook, fetchFromGitHub,
fetchurl,
jax, jax,
lib,
llvmPackages,
numpy,
pkgsBuildHost,
python,
replaceVars,
runCommand,
setuptools,
stdenv,
torch, torch,
nix-update-script, warp-lang, # Self-reference to this package for passthru.tests
writableTmpDirAsHomeHook,
writeShellApplication,
# Use standalone LLVM-based JIT compiler and CPU device support # Use standalone LLVM-based JIT compiler and CPU device support
standaloneSupport ? true, standaloneSupport ? true,
@ -25,63 +29,69 @@
# Build Warp with MathDx support (requires CUDA support) # Build Warp with MathDx support (requires CUDA support)
# Most linear-algebra tile operations like tile_cholesky(), tile_fft(), # Most linear-algebra tile operations like tile_cholesky(), tile_fft(),
# and tile_matmul() require Warp to be built with the MathDx library. # and tile_matmul() require Warp to be built with the MathDx library.
libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux, # libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux,
}: libmathdxSupport ? cudaSupport,
}@args:
assert libmathdxSupport -> cudaSupport;
let let
effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else args.stdenv;
stdenv = builtins.throw "Use effectiveStdenv instead of stdenv directly, as it may be replaced by cudaPackages.backendStdenv";
version = "1.7.2.post1"; version = "1.7.2.post1";
libmathdx = stdenv.mkDerivation (finalAttrs: { libmathdx = effectiveStdenv.mkDerivation (finalAttrs: {
# NOTE: The version used should match the version Warp requires:
# https://github.com/NVIDIA/warp/blob/4ad209076ce09668b18dedc74dce0d5cf8b9e409/deps/libmathdx-deps.packman.xml
pname = "libmathdx"; pname = "libmathdx";
version = "0.2.0"; version = "0.1.2";
outputs = [
"out"
"static"
];
src = src =
let let
inherit (stdenv.hostPlatform) system; baseURL = "https://developer.download.nvidia.com/compute/cublasdx/redist/cublasdx";
selectSystem = attrs: attrs.${system} or (throw "Unsupported system: ${system}"); name = lib.concatStringsSep "-" [
finalAttrs.pname
suffix = selectSystem { "Linux"
x86_64-linux = "Linux-x86_64"; effectiveStdenv.hostPlatform.parsed.cpu.name
aarch64-linux = "Linux-aarch64"; finalAttrs.version
x86_64-windows = "win32-x86_64"; ];
}; hashes = {
aarch64-linux = "sha256-7HEXfzxPF62q/7pdZidj4eO09u588yxcpSu/bWot/9A=";
# nix-hash --type sha256 --to-sri $(nix-prefetch-url "https://...") x86_64-linux = "sha256-MImBFv+ooRSUqdL/YEe/bJIcVBnHMCk7SLS5eSeh0cQ=";
hash = selectSystem {
x86_64-linux = "sha256-Lk+PxWFvyQGRClFdmyuo4y7HBdR7pigOhMyEzajqbmg=";
aarch64-linux = "sha256-6tH9YH98kSvDiut9rQEU5potEpeKqma/QtrCHLxwRLo=";
x86_64-windows = "sha256-B8qwj7UzOXEDZh2oT3ip1qW0uqtygMsyfcbhh5Dgc8U=";
}; };
in in
fetchurl { lib.mapNullable (
url = "https://developer.nvidia.com/downloads/compute/cublasdx/redist/cublasdx/libmathdx-${suffix}-${finalAttrs.version}.tar.gz"; hash:
inherit hash; fetchurl {
}; inherit hash name;
url = "${baseURL}/${name}.tar.gz";
unpackPhase = '' }
runHook preUnpack ) (hashes.${effectiveStdenv.hostPlatform.system} or null);
mkdir unpacked
cd unpacked
tar -xzf $src
export sourceRoot=$(pwd)
runHook postUnpack
'';
dontUnpack = true;
dontConfigure = true; dontConfigure = true;
dontBuild = true; dontBuild = true;
# NOTE: The leading component is stripped because the 0.1.2 release is within the `libmathdx` directory.
installPhase = '' installPhase = ''
runHook preInstall runHook preInstall
cp -rT "$sourceRoot" "$out" mkdir -p "$out"
tar -xzf "$src" --strip-components=1 -C "$out"
mkdir -p "$static"
moveToOutput "lib/libmathdx_static.a" "$static"
runHook postInstall runHook postInstall
''; '';
meta = { meta = {
description = "library used to integrate cuBLASDx and cuFFTDx into Warp"; description = "library used to integrate cuBLASDx and cuFFTDx into Warp";
homepage = "https://developer.nvidia.com/cublasdx-downloads";
sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ]; sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ];
license = with lib.licenses; [ license = with lib.licenses; [
# By downloading and using the software, you agree to fully # By downloading and using the software, you agree to fully
@ -104,7 +114,10 @@ let
# license: # license:
mit mit
]; ];
platforms = with lib.platforms; linux ++ [ "x86_64-windows" ]; platforms = [
"aarch64-linux"
"x86_64-linux"
];
maintainers = with lib.maintainers; [ yzx9 ]; maintainers = with lib.maintainers; [ yzx9 ];
}; };
}); });
@ -114,6 +127,13 @@ buildPythonPackage {
inherit version; inherit version;
pyproject = true; pyproject = true;
# TODO(@connorbaker): Some CUDA setup hook is failing when __structuredAttrs is false,
# causing a bunch of missing math symbols (like expf) when linking against the static library
# provided by NVCC.
__structuredAttrs = true;
stdenv = effectiveStdenv;
src = fetchFromGitHub { src = fetchFromGitHub {
owner = "NVIDIA"; owner = "NVIDIA";
repo = "warp"; repo = "warp";
@ -122,7 +142,7 @@ buildPythonPackage {
}; };
patches = patches =
lib.optionals stdenv.hostPlatform.isDarwin [ lib.optionals effectiveStdenv.hostPlatform.isDarwin [
(replaceVars ./darwin-libcxx.patch { (replaceVars ./darwin-libcxx.patch {
LIBCXX_DEV = llvmPackages.libcxx.dev; LIBCXX_DEV = llvmPackages.libcxx.dev;
LIBCXX_LIB = llvmPackages.libcxx; LIBCXX_LIB = llvmPackages.libcxx;
@ -140,22 +160,69 @@ buildPythonPackage {
]; ];
postPatch = postPatch =
lib.optionalString (!stdenv.cc.isGNU) '' # Patch build_dll.py to use our gencode flags rather than NVIDIA's very broad defaults.
substituteInPlace warp/build_dll.py \ # NOTE: After 1.7.2, patching will need to be updated like this:
--replace-fail "g++" "${lib.getExe stdenv.cc}" # https://github.com/ConnorBaker/cuda-packages/blob/2fc8ba8c37acee427a94cdd1def55c2ec701ad82/pkgs/development/python-modules/warp/default.nix#L56-L65
lib.optionalString cudaSupport ''
nixLog "patching $PWD/warp/build_dll.py to use our gencode flags"
substituteInPlace "$PWD/warp/build_dll.py" \
--replace-fail \
'nvcc_opts = gencode_opts + [' \
'nvcc_opts = [ ${
lib.concatMapStringsSep ", " (gencodeString: ''"${gencodeString}"'') cudaPackages.flags.gencode
}, '
''
# Patch build_dll.py to use dynamic libraries rather than static ones.
# NOTE: We do not patch the `nvptxcompiler_static` path because it is not available as a dynamic library.
+ lib.optionalString cudaSupport ''
nixLog "patching $PWD/warp/build_dll.py to use dynamic libraries"
substituteInPlace "$PWD/warp/build_dll.py" \
--replace-fail \
'-lcudart_static' \
'-lcudart' \
--replace-fail \
'-lnvrtc_static' \
'-lnvrtc' \
--replace-fail \
'-lnvrtc-builtins_static' \
'-lnvrtc-builtins' \
--replace-fail \
'-lnvJitLink_static' \
'-lnvJitLink' \
--replace-fail \
'-lmathdx_static' \
'-lmathdx'
''
+ ''
nixLog "patching $PWD/warp/build_dll.py to use our C++ compiler"
substituteInPlace "$PWD/warp/build_dll.py" \
--replace-fail "g++" "c++"
'' ''
# Broken tests on aarch64. Since unittest doesn't support disabling a # Broken tests on aarch64. Since unittest doesn't support disabling a
# single test, and pytest isn't compatible, we patch the test file directly # single test, and pytest isn't compatible, we patch the test file directly
# instead. # instead.
# #
# See: https://github.com/NVIDIA/warp/issues/552 # See: https://github.com/NVIDIA/warp/issues/552
+ lib.optionalString stdenv.hostPlatform.isAarch64 '' + lib.optionalString effectiveStdenv.hostPlatform.isAarch64 ''
substituteInPlace warp/tests/test_fem.py \ nixLog "patching $PWD/warp/tests/test_fem.py to disable broken tests on aarch64"
--replace-fail "add_function_test(TestFem, \"test_integrate_gradient\", test_integrate_gradient, devices=devices)" "" substituteInPlace "$PWD/warp/tests/test_fem.py" \
--replace-fail \
'add_function_test(TestFem, "test_integrate_gradient", test_integrate_gradient, devices=devices)' \
""
''
# These tests fail on CPU and CUDA.
+ ''
nixLog "patching $PWD/warp/tests/test_reload.py to disable broken tests"
substituteInPlace "$PWD/warp/tests/test_reload.py" \
--replace-fail \
'add_function_test(TestReload, "test_reload", test_reload, devices=devices)' \
"" \
--replace-fail \
'add_function_test(TestReload, "test_reload_references", test_reload_references, devices=get_test_devices("basic"))' \
""
''; '';
build-system = [ build-system = [
build
setuptools setuptools
]; ];
@ -163,11 +230,11 @@ buildPythonPackage {
numpy numpy
]; ];
nativeBuildInputs = lib.optionals libmathdxSupport [ # NOTE: While normally we wouldn't include autoAddDriverRunpath for packages built from source, since Warp
libmathdx # will be loading GPU drivers at runtime, we need to inject the path to our video drivers.
cudaPackages.libcublas nativeBuildInputs = lib.optionals cudaSupport [
cudaPackages.libcufft autoAddDriverRunpath
cudaPackages.libnvjitlink cudaPackages.cuda_nvcc
]; ];
buildInputs = buildInputs =
@ -177,10 +244,18 @@ buildPythonPackage {
llvmPackages.libcxx llvmPackages.libcxx
] ]
++ lib.optionals cudaSupport [ ++ lib.optionals cudaSupport [
cudaPackages.cudatoolkit (lib.getOutput "static" cudaPackages.cuda_nvcc) # dependency on nvptxcompiler_static; no dynamic version available
cudaPackages.cuda_cccl
cudaPackages.cuda_cudart cudaPackages.cuda_cudart
cudaPackages.cuda_nvcc cudaPackages.cuda_nvcc
cudaPackages.cuda_nvrtc cudaPackages.cuda_nvrtc
]
++ lib.optionals libmathdxSupport [
libmathdx
cudaPackages.libcublas
cudaPackages.libcufft
cudaPackages.libcusolver
cudaPackages.libnvjitlink
]; ];
preBuild = preBuild =
@ -190,7 +265,8 @@ buildPythonPackage {
"--no_standalone" "--no_standalone"
] ]
++ lib.optionals cudaSupport [ ++ lib.optionals cudaSupport [
"--cuda_path=${cudaPackages.cudatoolkit}" # NOTE: The `cuda_path` argument is the directory which contains `bin/nvcc` (i.e., the bin output).
"--cuda_path=${lib.getBin pkgsBuildHost.cudaPackages.cuda_nvcc}"
] ]
++ lib.optionals libmathdxSupport [ ++ lib.optionals libmathdxSupport [
"--libmathdx" "--libmathdx"
@ -203,34 +279,102 @@ buildPythonPackage {
buildOptionString = lib.concatStringsSep " " buildOptions; buildOptionString = lib.concatStringsSep " " buildOptions;
in in
'' ''
python build_lib.py ${buildOptionString} nixLog "running $PWD/build_lib.py to create components necessary to build the wheel"
"${python.pythonOnBuildForHost.interpreter}" "$PWD/build_lib.py" ${buildOptionString}
''; '';
pythonImportsCheck = [ pythonImportsCheck = [
"warp" "warp"
]; ];
# Many unit tests fail with segfaults on aarch64-linux, especially in the sim # See passthru.tests.
# and grad modules. However, other functionality generally works, so we don't doCheck = false;
# mark the package as broken.
#
# See: https://www.github.com/NVIDIA/warp/issues/{356,372,552}
doCheck = !(stdenv.hostPlatform.isAarch64 && stdenv.hostPlatform.isLinux);
nativeCheckInputs = [ passthru = {
unittestCheckHook # Make libmathdx available for introspection.
(jax.override { inherit cudaSupport; }) inherit libmathdx;
(torch.override { inherit cudaSupport; })
# # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected # Scripts which provide test packages and implement test logic.
# (paddlepaddle.override { inherit cudaSupport; }) testers.unit-tests = writeShellApplication {
]; name = "warp-lang-unit-tests";
runtimeInputs = [
# Use the references from args
(python.withPackages (_: [
warp-lang
jax
torch
]))
# Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected
# (paddlepaddle.override { inherit cudaSupport; })
];
text = ''
python3 -m warp.tests
'';
};
preCheck = '' # Tests run within the Nix sandbox.
export WARP_CACHE_PATH=$(mktemp -d) # warp.config.kernel_cache_dir tests =
''; let
mkUnitTests =
{
cudaSupport,
libmathdxSupport,
}:
let
name =
"warp-lang-unit-tests-cpu" # CPU is baseline
+ lib.optionalString cudaSupport "-cuda"
+ lib.optionalString libmathdxSupport "-libmathdx";
passthru.updateScript = nix-update-script { }; warp-lang' = warp-lang.override {
inherit cudaSupport libmathdxSupport;
# Make sure the warp-lang provided through callPackage is replaced with the override we're making.
warp-lang = warp-lang';
};
in
runCommand name
{
nativeBuildInputs = [
warp-lang'.passthru.testers.unit-tests
writableTmpDirAsHomeHook
];
requiredSystemFeatures = lib.optionals cudaSupport [ "cuda" ];
# Many unit tests fail with segfaults on aarch64-linux, especially in the sim
# and grad modules. However, other functionality generally works, so we don't
# mark the package as broken.
#
# See: https://www.github.com/NVIDIA/warp/issues/{356,372,552}
meta.broken = effectiveStdenv.hostPlatform.isAarch64 && effectiveStdenv.hostPlatform.isLinux;
}
''
nixLog "running ${name}"
if warp-lang-unit-tests; then
nixLog "${name} passed"
touch "$out"
else
nixErrorLog "${name} failed"
exit 1
fi
'';
in
{
cpu = mkUnitTests {
cudaSupport = false;
libmathdxSupport = false;
};
cuda = {
cudaOnly = mkUnitTests {
cudaSupport = true;
libmathdxSupport = false;
};
cudaWithLibmathDx = mkUnitTests {
cudaSupport = true;
libmathdxSupport = true;
};
};
};
};
meta = { meta = {
description = "Python framework for high performance GPU simulation and graphics"; description = "Python framework for high performance GPU simulation and graphics";