diff options
Diffstat (limited to 'gnu/packages/machine-learning.scm')
-rw-r--r-- | gnu/packages/machine-learning.scm | 113 |
1 files changed, 90 insertions, 23 deletions
diff --git a/gnu/packages/machine-learning.scm b/gnu/packages/machine-learning.scm index fd0be8d500..abba41626d 100644 --- a/gnu/packages/machine-learning.scm +++ b/gnu/packages/machine-learning.scm @@ -2060,6 +2060,31 @@ physics-informed learning. It includes implementations for the PINN MFNN (multifidelity neural network) algorithms.") (license license:lgpl2.1+))) +(define-public python-jaxtyping + (package + (name "python-jaxtyping") + (version "0.2.21") + (source (origin + (method url-fetch) + (uri (pypi-uri "jaxtyping" version)) + (sha256 + (base32 + "19qmsnbn4wv2jl99lpn622qs49mrfxmx8s9pr5y8izzgdjq1fvii")))) + (build-system pyproject-build-system) + ;; Tests require JAX, but JAX can't be packaged because it uses the Bazel + ;; build system. + (arguments (list #:tests? #f)) + (native-inputs (list python-hatchling)) + (propagated-inputs (list python-numpy python-typeguard + python-typing-extensions)) + (home-page "https://github.com/google/jaxtyping") + (synopsis + "Type annotations and runtime checking for JAX arrays and others") + (description "@code{jaxtyping} provides type annotations and runtime +checking for shape and dtype of JAX arrays, PyTorch, NumPy, TensorFlow, and +PyTrees.") + (license license:expat))) + ;; There have been no proper releases yet. (define-public kaldi (let ((commit "be22248e3a166d9ec52c78dac945f471e7c3a8aa") @@ -2972,7 +2997,7 @@ advanced research.") (define-public tensorflow-lite (package (name "tensorflow-lite") - (version "2.12.1") + (version "2.13.0") (source (origin (method git-fetch) @@ -2982,7 +3007,8 @@ advanced research.") (file-name (git-file-name name version)) (sha256 (base32 - "0jkgljdagdqllnxygl35r5bh3f9qmbczymfj357gm9krh59g2kmd")))) + "07g6vlrs0aayrg2mfdl15gxg5dy103wx2xlqkran15dib40nkbj6")) + (patches (search-patches "tensorflow-lite-unbundle.patch")))) (build-system cmake-build-system) (arguments (list @@ -3025,6 +3051,7 @@ advanced research.") "-DTFLITE_ENABLE_XNNPACK=OFF" ;; Don't fetch the sources. We have these already + "-Dgemmlowp_POPULATED=TRUE" "-Degl_headers_POPULATED=TRUE" "-Dfp16_headers_POPULATED=TRUE" "-Dopencl_headers_POPULATED=TRUE" @@ -3037,7 +3064,7 @@ advanced research.") "-DFFT2D_SOURCE_DIR=/tmp/fft2d" "-DFARMHASH_SOURCE_DIR=/tmp/farmhash" - "-Dgemmlowp_SOURCE_DIR=/tmp/gemmlowp") + (string-append "-Dgemmlowp_ROOT=" #$(this-package-input "gemmlowp"))) #:phases #~(modify-phases %standard-phases (add-after 'unpack 'chdir @@ -3067,11 +3094,7 @@ advanced research.") (mkdir-p "/tmp/fft2d") (with-directory-excursion "/tmp/fft2d" (invoke "tar" "--strip-components=1" - "-xf" (assoc-ref inputs "fft2d-src"))) - - (copy-recursively (assoc-ref inputs "gemmlowp-src") - "/tmp/gemmlowp/"))) - + "-xf" (assoc-ref inputs "fft2d-src"))))) (add-after 'build 'build-shared-library (lambda* (#:key configure-flags #:allow-other-keys) (mkdir-p "c") @@ -3101,7 +3124,7 @@ advanced research.") ("eigen" ,eigen) ("fp16" ,fp16) ("flatbuffers-shared" ,flatbuffers-next-shared) - ;;("gemmlowp" ,gemmlowp) ; TODO + ("gemmlowp" ,gemmlowp) ("mesa-headers" ,mesa-headers) ("neon2sse" ,neon2sse) ("nsync" ,nsync) @@ -3117,19 +3140,6 @@ advanced research.") (native-inputs `(("pkg-config" ,pkg-config) ("googletest" ,googletest) - ("gemmlowp-src" - ;; The commit hash is taken from - ;; "tensorflow/lite/tools/cmake/modules/gemmlowp.cmake". - ,(let ((commit "fda83bdc38b118cc6b56753bd540caa49e570745")) - (origin - (method git-fetch) - (uri (git-reference - (url "https://github.com/google/gemmlowp") - (commit commit))) - (file-name (git-file-name "gemmlowp" (string-take commit 8))) - (sha256 - (base32 - "1sbp8kmr2azwlvfbzryy1frxi99jhsh1nc93bdbxdf8zdgpv0kxl"))))) ("farmhash-src" ,(let ((commit "816a4ae622e964763ca0862d9dbd19324a1eaf45")) (origin @@ -3151,7 +3161,7 @@ advanced research.") (sha256 (base32 "1jfflzi74fag9z4qmgwvp90aif4dpbr1657izmxlgvf4hy8fk9xd")))))) - (home-page "https://tensorflow.org") + (home-page "https://www.tensorflow.org") (synopsis "Machine learning framework") (description "TensorFlow is a flexible platform for building and training machine @@ -4546,6 +4556,63 @@ and Numpy.") inference.") (license license:asl2.0))) +(define-public python-linear-operator + (package + (name "python-linear-operator") + (version "0.5.2") + (source (origin + (method url-fetch) + (uri (pypi-uri "linear_operator" version)) + (sha256 + (base32 + "03drb4hn9nn8jrqd9vbalihhahgpdm956hbs05bix7svradhknaw")))) + (build-system pyproject-build-system) + (propagated-inputs (list python-jaxtyping + python-pytorch + python-scipy + python-typeguard)) + (native-inputs (list python-flake8 + python-flake8-print + python-pytest + python-setuptools-scm + python-twine)) + (home-page "https://github.com/cornellius-gp/linear_operator/") + (synopsis "Linear operator implementation") + (description "LinearOperator is a PyTorch package for abstracting away the +linear algebra routines needed for structured matrices (or operators).") + (license license:expat))) + +(define-public python-gpytorch + (package + (name "python-gpytorch") + (version "1.11") + (source (origin + (method url-fetch) + (uri (pypi-uri "gpytorch" version)) + (sha256 + (base32 + "0q17bml53vixk3cwj3p893809927hz81fprwsmxpxqv5i4mvgyvj")))) + (build-system pyproject-build-system) + (arguments + (list #:test-flags + ;; The error message in test_t_matmul_matrix suggests the error may + ;; be due to a bug in gpytorch. test_deprecated_methods fails with + ;; an AssertionError. + #~(list "-k" (string-append "not test_deprecated_methods" + " and not test_t_matmul_matrix")))) + (propagated-inputs (list python-linear-operator python-scikit-learn)) + (native-inputs (list python-coverage + python-flake8 + python-flake8-print + python-nbval + python-pytest + python-twine)) + (home-page "https://gpytorch.ai") + (synopsis "Implementation of Gaussian Processes in PyTorch") + (description + "GPyTorch is a Gaussian process library implemented using PyTorch.") + (license license:expat))) + (define-public vosk-api (let* ((openfst openfst-for-vosk) (kaldi kaldi-for-vosk)) |