python3Packages.warp-lang: fix CUDA build (#419750)
This commit is contained in:
		
						commit
						c0874d923d
					
				| @ -1,20 +1,24 @@ | ||||
| { | ||||
|   config, | ||||
|   lib, | ||||
|   stdenv, | ||||
|   autoAddDriverRunpath, | ||||
|   buildPythonPackage, | ||||
|   fetchurl, | ||||
|   fetchFromGitHub, | ||||
|   replaceVars, | ||||
|   build, | ||||
|   setuptools, | ||||
|   numpy, | ||||
|   llvmPackages, | ||||
|   config, | ||||
|   cudaPackages, | ||||
|   unittestCheckHook, | ||||
|   fetchFromGitHub, | ||||
|   fetchurl, | ||||
|   jax, | ||||
|   lib, | ||||
|   llvmPackages, | ||||
|   numpy, | ||||
|   pkgsBuildHost, | ||||
|   python, | ||||
|   replaceVars, | ||||
|   runCommand, | ||||
|   setuptools, | ||||
|   stdenv, | ||||
|   torch, | ||||
|   nix-update-script, | ||||
|   warp-lang, # Self-reference to this package for passthru.tests | ||||
|   writableTmpDirAsHomeHook, | ||||
|   writeShellApplication, | ||||
| 
 | ||||
|   # Use standalone LLVM-based JIT compiler and CPU device support | ||||
|   standaloneSupport ? true, | ||||
| @ -25,63 +29,69 @@ | ||||
|   # Build Warp with MathDx support (requires CUDA support) | ||||
|   # Most linear-algebra tile operations like tile_cholesky(), tile_fft(), | ||||
|   # and tile_matmul() require Warp to be built with the MathDx library. | ||||
|   libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux, | ||||
| }: | ||||
| 
 | ||||
|   # libmathdxSupport ? cudaSupport && stdenv.hostPlatform.isLinux, | ||||
|   libmathdxSupport ? cudaSupport, | ||||
| }@args: | ||||
| assert libmathdxSupport -> cudaSupport; | ||||
| let | ||||
|   effectiveStdenv = if cudaSupport then cudaPackages.backendStdenv else args.stdenv; | ||||
|   stdenv = builtins.throw "Use effectiveStdenv instead of stdenv directly, as it may be replaced by cudaPackages.backendStdenv"; | ||||
| 
 | ||||
|   version = "1.7.2.post1"; | ||||
| 
 | ||||
|   libmathdx = stdenv.mkDerivation (finalAttrs: { | ||||
|   libmathdx = effectiveStdenv.mkDerivation (finalAttrs: { | ||||
|     # NOTE: The version used should match the version Warp requires: | ||||
|     # https://github.com/NVIDIA/warp/blob/4ad209076ce09668b18dedc74dce0d5cf8b9e409/deps/libmathdx-deps.packman.xml | ||||
|     pname = "libmathdx"; | ||||
|     version = "0.2.0"; | ||||
|     version = "0.1.2"; | ||||
| 
 | ||||
|     outputs = [ | ||||
|       "out" | ||||
|       "static" | ||||
|     ]; | ||||
| 
 | ||||
|     src = | ||||
|       let | ||||
|         inherit (stdenv.hostPlatform) system; | ||||
|         selectSystem = attrs: attrs.${system} or (throw "Unsupported system: ${system}"); | ||||
| 
 | ||||
|         suffix = selectSystem { | ||||
|           x86_64-linux = "Linux-x86_64"; | ||||
|           aarch64-linux = "Linux-aarch64"; | ||||
|           x86_64-windows = "win32-x86_64"; | ||||
|         }; | ||||
| 
 | ||||
|         # nix-hash --type sha256 --to-sri $(nix-prefetch-url "https://...") | ||||
|         hash = selectSystem { | ||||
|           x86_64-linux = "sha256-Lk+PxWFvyQGRClFdmyuo4y7HBdR7pigOhMyEzajqbmg="; | ||||
|           aarch64-linux = "sha256-6tH9YH98kSvDiut9rQEU5potEpeKqma/QtrCHLxwRLo="; | ||||
|           x86_64-windows = "sha256-B8qwj7UzOXEDZh2oT3ip1qW0uqtygMsyfcbhh5Dgc8U="; | ||||
|         baseURL = "https://developer.download.nvidia.com/compute/cublasdx/redist/cublasdx"; | ||||
|         name = lib.concatStringsSep "-" [ | ||||
|           finalAttrs.pname | ||||
|           "Linux" | ||||
|           effectiveStdenv.hostPlatform.parsed.cpu.name | ||||
|           finalAttrs.version | ||||
|         ]; | ||||
|         hashes = { | ||||
|           aarch64-linux = "sha256-7HEXfzxPF62q/7pdZidj4eO09u588yxcpSu/bWot/9A="; | ||||
|           x86_64-linux = "sha256-MImBFv+ooRSUqdL/YEe/bJIcVBnHMCk7SLS5eSeh0cQ="; | ||||
|         }; | ||||
|       in | ||||
|       fetchurl { | ||||
|         url = "https://developer.nvidia.com/downloads/compute/cublasdx/redist/cublasdx/libmathdx-${suffix}-${finalAttrs.version}.tar.gz"; | ||||
|         inherit hash; | ||||
|       }; | ||||
| 
 | ||||
|     unpackPhase = '' | ||||
|       runHook preUnpack | ||||
| 
 | ||||
|       mkdir unpacked | ||||
|       cd unpacked | ||||
|       tar -xzf $src | ||||
|       export sourceRoot=$(pwd) | ||||
| 
 | ||||
|       runHook postUnpack | ||||
|     ''; | ||||
|       lib.mapNullable ( | ||||
|         hash: | ||||
|         fetchurl { | ||||
|           inherit hash name; | ||||
|           url = "${baseURL}/${name}.tar.gz"; | ||||
|         } | ||||
|       ) (hashes.${effectiveStdenv.hostPlatform.system} or null); | ||||
| 
 | ||||
|     dontUnpack = true; | ||||
|     dontConfigure = true; | ||||
|     dontBuild = true; | ||||
| 
 | ||||
|     # NOTE: The leading component is stripped because the 0.1.2 release is within the `libmathdx` directory. | ||||
|     installPhase = '' | ||||
|       runHook preInstall | ||||
| 
 | ||||
|       cp -rT "$sourceRoot" "$out" | ||||
|       mkdir -p "$out" | ||||
|       tar -xzf "$src" --strip-components=1 -C "$out" | ||||
| 
 | ||||
|       mkdir -p "$static" | ||||
|       moveToOutput "lib/libmathdx_static.a" "$static" | ||||
| 
 | ||||
|       runHook postInstall | ||||
|     ''; | ||||
| 
 | ||||
|     meta = { | ||||
|       description = "library used to integrate cuBLASDx and cuFFTDx into Warp"; | ||||
|       homepage = "https://developer.nvidia.com/cublasdx-downloads"; | ||||
|       sourceProvenance = with lib.sourceTypes; [ binaryNativeCode ]; | ||||
|       license = with lib.licenses; [ | ||||
|         # By downloading and using the software, you agree to fully | ||||
| @ -104,7 +114,10 @@ let | ||||
|         # license: | ||||
|         mit | ||||
|       ]; | ||||
|       platforms = with lib.platforms; linux ++ [ "x86_64-windows" ]; | ||||
|       platforms = [ | ||||
|         "aarch64-linux" | ||||
|         "x86_64-linux" | ||||
|       ]; | ||||
|       maintainers = with lib.maintainers; [ yzx9 ]; | ||||
|     }; | ||||
|   }); | ||||
| @ -114,6 +127,13 @@ buildPythonPackage { | ||||
|   inherit version; | ||||
|   pyproject = true; | ||||
| 
 | ||||
|   # TODO(@connorbaker): Some CUDA setup hook is failing when __structuredAttrs is false, | ||||
|   # causing a bunch of missing math symbols (like expf) when linking against the static library | ||||
|   # provided by NVCC. | ||||
|   __structuredAttrs = true; | ||||
| 
 | ||||
|   stdenv = effectiveStdenv; | ||||
| 
 | ||||
|   src = fetchFromGitHub { | ||||
|     owner = "NVIDIA"; | ||||
|     repo = "warp"; | ||||
| @ -122,7 +142,7 @@ buildPythonPackage { | ||||
|   }; | ||||
| 
 | ||||
|   patches = | ||||
|     lib.optionals stdenv.hostPlatform.isDarwin [ | ||||
|     lib.optionals effectiveStdenv.hostPlatform.isDarwin [ | ||||
|       (replaceVars ./darwin-libcxx.patch { | ||||
|         LIBCXX_DEV = llvmPackages.libcxx.dev; | ||||
|         LIBCXX_LIB = llvmPackages.libcxx; | ||||
| @ -140,22 +160,69 @@ buildPythonPackage { | ||||
|     ]; | ||||
| 
 | ||||
|   postPatch = | ||||
|     lib.optionalString (!stdenv.cc.isGNU) '' | ||||
|       substituteInPlace warp/build_dll.py \ | ||||
|         --replace-fail "g++" "${lib.getExe stdenv.cc}" | ||||
|     # Patch build_dll.py to use our gencode flags rather than NVIDIA's very broad defaults. | ||||
|     # NOTE: After 1.7.2, patching will need to be updated like this: | ||||
|     # https://github.com/ConnorBaker/cuda-packages/blob/2fc8ba8c37acee427a94cdd1def55c2ec701ad82/pkgs/development/python-modules/warp/default.nix#L56-L65 | ||||
|     lib.optionalString cudaSupport '' | ||||
|       nixLog "patching $PWD/warp/build_dll.py to use our gencode flags" | ||||
|       substituteInPlace "$PWD/warp/build_dll.py" \ | ||||
|         --replace-fail \ | ||||
|           'nvcc_opts = gencode_opts + [' \ | ||||
|           'nvcc_opts = [ ${ | ||||
|             lib.concatMapStringsSep ", " (gencodeString: ''"${gencodeString}"'') cudaPackages.flags.gencode | ||||
|           }, ' | ||||
|     '' | ||||
|     # Patch build_dll.py to use dynamic libraries rather than static ones. | ||||
|     # NOTE: We do not patch the `nvptxcompiler_static` path because it is not available as a dynamic library. | ||||
|     + lib.optionalString cudaSupport '' | ||||
|       nixLog "patching $PWD/warp/build_dll.py to use dynamic libraries" | ||||
|       substituteInPlace "$PWD/warp/build_dll.py" \ | ||||
|         --replace-fail \ | ||||
|           '-lcudart_static' \ | ||||
|           '-lcudart' \ | ||||
|         --replace-fail \ | ||||
|           '-lnvrtc_static' \ | ||||
|           '-lnvrtc' \ | ||||
|         --replace-fail \ | ||||
|           '-lnvrtc-builtins_static' \ | ||||
|           '-lnvrtc-builtins' \ | ||||
|         --replace-fail \ | ||||
|           '-lnvJitLink_static' \ | ||||
|           '-lnvJitLink' \ | ||||
|         --replace-fail \ | ||||
|           '-lmathdx_static' \ | ||||
|           '-lmathdx' | ||||
|     '' | ||||
|     + '' | ||||
|       nixLog "patching $PWD/warp/build_dll.py to use our C++ compiler" | ||||
|       substituteInPlace "$PWD/warp/build_dll.py" \ | ||||
|         --replace-fail "g++" "c++" | ||||
|     '' | ||||
|     # Broken tests on aarch64. Since unittest doesn't support disabling a | ||||
|     # single test, and pytest isn't compatible, we patch the test file directly | ||||
|     # instead. | ||||
|     # | ||||
|     # See: https://github.com/NVIDIA/warp/issues/552 | ||||
|     + lib.optionalString stdenv.hostPlatform.isAarch64 '' | ||||
|       substituteInPlace warp/tests/test_fem.py \ | ||||
|         --replace-fail "add_function_test(TestFem, \"test_integrate_gradient\", test_integrate_gradient, devices=devices)" "" | ||||
|     + lib.optionalString effectiveStdenv.hostPlatform.isAarch64 '' | ||||
|       nixLog "patching $PWD/warp/tests/test_fem.py to disable broken tests on aarch64" | ||||
|       substituteInPlace "$PWD/warp/tests/test_fem.py" \ | ||||
|         --replace-fail \ | ||||
|           'add_function_test(TestFem, "test_integrate_gradient", test_integrate_gradient, devices=devices)' \ | ||||
|           "" | ||||
|     '' | ||||
|     # These tests fail on CPU and CUDA. | ||||
|     + '' | ||||
|       nixLog "patching $PWD/warp/tests/test_reload.py to disable broken tests" | ||||
|       substituteInPlace "$PWD/warp/tests/test_reload.py" \ | ||||
|         --replace-fail \ | ||||
|           'add_function_test(TestReload, "test_reload", test_reload, devices=devices)' \ | ||||
|           "" \ | ||||
|         --replace-fail \ | ||||
|           'add_function_test(TestReload, "test_reload_references", test_reload_references, devices=get_test_devices("basic"))' \ | ||||
|           "" | ||||
|     ''; | ||||
| 
 | ||||
|   build-system = [ | ||||
|     build | ||||
|     setuptools | ||||
|   ]; | ||||
| 
 | ||||
| @ -163,11 +230,11 @@ buildPythonPackage { | ||||
|     numpy | ||||
|   ]; | ||||
| 
 | ||||
|   nativeBuildInputs = lib.optionals libmathdxSupport [ | ||||
|     libmathdx | ||||
|     cudaPackages.libcublas | ||||
|     cudaPackages.libcufft | ||||
|     cudaPackages.libnvjitlink | ||||
|   # NOTE: While normally we wouldn't include autoAddDriverRunpath for packages built from source, since Warp | ||||
|   # will be loading GPU drivers at runtime, we need to inject the path to our video drivers. | ||||
|   nativeBuildInputs = lib.optionals cudaSupport [ | ||||
|     autoAddDriverRunpath | ||||
|     cudaPackages.cuda_nvcc | ||||
|   ]; | ||||
| 
 | ||||
|   buildInputs = | ||||
| @ -177,10 +244,18 @@ buildPythonPackage { | ||||
|       llvmPackages.libcxx | ||||
|     ] | ||||
|     ++ lib.optionals cudaSupport [ | ||||
|       cudaPackages.cudatoolkit | ||||
|       (lib.getOutput "static" cudaPackages.cuda_nvcc) # dependency on nvptxcompiler_static; no dynamic version available | ||||
|       cudaPackages.cuda_cccl | ||||
|       cudaPackages.cuda_cudart | ||||
|       cudaPackages.cuda_nvcc | ||||
|       cudaPackages.cuda_nvrtc | ||||
|     ] | ||||
|     ++ lib.optionals libmathdxSupport [ | ||||
|       libmathdx | ||||
|       cudaPackages.libcublas | ||||
|       cudaPackages.libcufft | ||||
|       cudaPackages.libcusolver | ||||
|       cudaPackages.libnvjitlink | ||||
|     ]; | ||||
| 
 | ||||
|   preBuild = | ||||
| @ -190,7 +265,8 @@ buildPythonPackage { | ||||
|           "--no_standalone" | ||||
|         ] | ||||
|         ++ lib.optionals cudaSupport [ | ||||
|           "--cuda_path=${cudaPackages.cudatoolkit}" | ||||
|           # NOTE: The `cuda_path` argument is the directory which contains `bin/nvcc` (i.e., the bin output). | ||||
|           "--cuda_path=${lib.getBin pkgsBuildHost.cudaPackages.cuda_nvcc}" | ||||
|         ] | ||||
|         ++ lib.optionals libmathdxSupport [ | ||||
|           "--libmathdx" | ||||
| @ -203,34 +279,102 @@ buildPythonPackage { | ||||
|       buildOptionString = lib.concatStringsSep " " buildOptions; | ||||
|     in | ||||
|     '' | ||||
|       python build_lib.py ${buildOptionString} | ||||
|       nixLog "running $PWD/build_lib.py to create components necessary to build the wheel" | ||||
|       "${python.pythonOnBuildForHost.interpreter}" "$PWD/build_lib.py" ${buildOptionString} | ||||
|     ''; | ||||
| 
 | ||||
|   pythonImportsCheck = [ | ||||
|     "warp" | ||||
|   ]; | ||||
| 
 | ||||
|   # Many unit tests fail with segfaults on aarch64-linux, especially in the sim | ||||
|   # and grad modules. However, other functionality generally works, so we don't | ||||
|   # mark the package as broken. | ||||
|   # | ||||
|   # See: https://www.github.com/NVIDIA/warp/issues/{356,372,552} | ||||
|   doCheck = !(stdenv.hostPlatform.isAarch64 && stdenv.hostPlatform.isLinux); | ||||
|   # See passthru.tests. | ||||
|   doCheck = false; | ||||
| 
 | ||||
|   nativeCheckInputs = [ | ||||
|     unittestCheckHook | ||||
|     (jax.override { inherit cudaSupport; }) | ||||
|     (torch.override { inherit cudaSupport; }) | ||||
|   passthru = { | ||||
|     # Make libmathdx available for introspection. | ||||
|     inherit libmathdx; | ||||
| 
 | ||||
|     # # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected | ||||
|     #  (paddlepaddle.override { inherit cudaSupport; }) | ||||
|   ]; | ||||
|     # Scripts which provide test packages and implement test logic. | ||||
|     testers.unit-tests = writeShellApplication { | ||||
|       name = "warp-lang-unit-tests"; | ||||
|       runtimeInputs = [ | ||||
|         # Use the references from args | ||||
|         (python.withPackages (_: [ | ||||
|           warp-lang | ||||
|           jax | ||||
|           torch | ||||
|         ])) | ||||
|         # Disable paddlepaddle interop tests: malloc(): unaligned tcache chunk detected | ||||
|         #  (paddlepaddle.override { inherit cudaSupport; }) | ||||
|       ]; | ||||
|       text = '' | ||||
|         python3 -m warp.tests | ||||
|       ''; | ||||
|     }; | ||||
| 
 | ||||
|   preCheck = '' | ||||
|     export WARP_CACHE_PATH=$(mktemp -d) # warp.config.kernel_cache_dir | ||||
|   ''; | ||||
|     # Tests run within the Nix sandbox. | ||||
|     tests = | ||||
|       let | ||||
|         mkUnitTests = | ||||
|           { | ||||
|             cudaSupport, | ||||
|             libmathdxSupport, | ||||
|           }: | ||||
|           let | ||||
|             name = | ||||
|               "warp-lang-unit-tests-cpu" # CPU is baseline | ||||
|               + lib.optionalString cudaSupport "-cuda" | ||||
|               + lib.optionalString libmathdxSupport "-libmathdx"; | ||||
| 
 | ||||
|   passthru.updateScript = nix-update-script { }; | ||||
|             warp-lang' = warp-lang.override { | ||||
|               inherit cudaSupport libmathdxSupport; | ||||
|               # Make sure the warp-lang provided through callPackage is replaced with the override we're making. | ||||
|               warp-lang = warp-lang'; | ||||
|             }; | ||||
|           in | ||||
|           runCommand name | ||||
|             { | ||||
|               nativeBuildInputs = [ | ||||
|                 warp-lang'.passthru.testers.unit-tests | ||||
|                 writableTmpDirAsHomeHook | ||||
|               ]; | ||||
|               requiredSystemFeatures = lib.optionals cudaSupport [ "cuda" ]; | ||||
|               # Many unit tests fail with segfaults on aarch64-linux, especially in the sim | ||||
|               # and grad modules. However, other functionality generally works, so we don't | ||||
|               # mark the package as broken. | ||||
|               # | ||||
|               # See: https://www.github.com/NVIDIA/warp/issues/{356,372,552} | ||||
|               meta.broken = effectiveStdenv.hostPlatform.isAarch64 && effectiveStdenv.hostPlatform.isLinux; | ||||
|             } | ||||
|             '' | ||||
|               nixLog "running ${name}" | ||||
| 
 | ||||
|               if warp-lang-unit-tests; then | ||||
|                 nixLog "${name} passed" | ||||
|                 touch "$out" | ||||
|               else | ||||
|                 nixErrorLog "${name} failed" | ||||
|                 exit 1 | ||||
|               fi | ||||
|             ''; | ||||
|       in | ||||
|       { | ||||
|         cpu = mkUnitTests { | ||||
|           cudaSupport = false; | ||||
|           libmathdxSupport = false; | ||||
|         }; | ||||
|         cuda = { | ||||
|           cudaOnly = mkUnitTests { | ||||
|             cudaSupport = true; | ||||
|             libmathdxSupport = false; | ||||
|           }; | ||||
|           cudaWithLibmathDx = mkUnitTests { | ||||
|             cudaSupport = true; | ||||
|             libmathdxSupport = true; | ||||
|           }; | ||||
|         }; | ||||
|       }; | ||||
|   }; | ||||
| 
 | ||||
|   meta = { | ||||
|     description = "Python framework for high performance GPU simulation and graphics"; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Connor Baker
						Connor Baker