diff --git a/cuda_core/cuda/core/utils.py b/cuda_core/cuda/core/utils.py deleted file mode 100644 index f15d924277..0000000000 --- a/cuda_core/cuda/core/utils.py +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from cuda.core._memoryview import ( - StridedMemoryView, # noqa: F401 - args_viewable_as_strided_memory, # noqa: F401 -) diff --git a/cuda_core/cuda/core/utils/__init__.py b/cuda_core/cuda/core/utils/__init__.py new file mode 100644 index 0000000000..69b6fdb67f --- /dev/null +++ b/cuda_core/cuda/core/utils/__init__.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.core._memoryview import ( + StridedMemoryView, + args_viewable_as_strided_memory, +) + +__all__ = [ + "FileStreamProgramCache", + "ProgramCacheResource", + "SQLiteProgramCache", + "StridedMemoryView", + "args_viewable_as_strided_memory", + "make_program_cache_key", +] + +# Lazily expose the program-cache APIs so ``from cuda.core.utils import +# StridedMemoryView`` stays lightweight -- the cache backends pull in driver, +# NVRTC, and module-load machinery that memoryview-only consumers do not need. +_LAZY_CACHE_ATTRS = frozenset( + { + "FileStreamProgramCache", + "ProgramCacheResource", + "SQLiteProgramCache", + "make_program_cache_key", + } +) + + +def __getattr__(name): + if name in _LAZY_CACHE_ATTRS: + from cuda.core.utils import _program_cache + + value = getattr(_program_cache, name) + globals()[name] = value # cache for subsequent accesses + return value + raise AttributeError(f"module 'cuda.core.utils' has no attribute {name!r}") + + +def __dir__(): + return sorted(__all__) diff --git a/cuda_core/cuda/core/utils/_program_cache.py b/cuda_core/cuda/core/utils/_program_cache.py new file mode 100644 index 0000000000..68cba1545d --- /dev/null +++ b/cuda_core/cuda/core/utils/_program_cache.py @@ -0,0 +1,1482 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Persistent program caches for cuda.core. + +Two concrete backends are provided: + +* :class:`SQLiteProgramCache` -- a single-file sqlite3 database, best for + single-process workflows, with LRU eviction and a hard size cap. +* :class:`FileStreamProgramCache` -- a directory of atomically-written entry + files, safe across concurrent processes via :func:`os.replace`. + +Both implement :class:`ProgramCacheResource`, so callers can swap backends +without changing the calling code. +""" + +from __future__ import annotations + +import abc +import collections.abc +import contextlib +import errno +import hashlib +import os +import pickle +import tempfile +import threading +import time +from pathlib import Path +from typing import Iterable, Sequence + +from cuda.core._module import ObjectCode +from cuda.core._program import ProgramOptions +from cuda.core._utils.cuda_utils import ( + driver as _driver, +) +from cuda.core._utils.cuda_utils import ( + handle_return as _handle_return, +) +from cuda.core._utils.cuda_utils import ( + nvrtc as _nvrtc, +) + +__all__ = [ + "FileStreamProgramCache", + "ProgramCacheResource", + "SQLiteProgramCache", + "make_program_cache_key", +] + + +_PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL + +# Exposed as a module-level flag so tests can toggle it without monkeypatching +# ``os.name`` itself (pathlib reads ``os.name`` at instantiation time). +_IS_WINDOWS = os.name == "nt" + + +def _require_object_code(value: object) -> ObjectCode: + if not isinstance(value, ObjectCode): + raise TypeError(f"cache values must be ObjectCode instances, got {type(value).__name__}") + # Reject path-backed ObjectCode: the ``code`` property exposes the + # underlying module, which is the raw ``str`` path for + # ``ObjectCode.from_cubin("/path")`` et al. Pickling stores only the path, + # so later reads would return stale (or missing) on-disk content. The + # caller must read the file into bytes first. + if isinstance(value.code, str): + raise TypeError( + "cache values must be bytes-backed ObjectCode instances; " + "path-backed ObjectCode (e.g. ObjectCode.from_cubin('/path')) " + "cannot be cached safely -- read the file into bytes first" + ) + return value + + +def _as_key_bytes(key: object) -> bytes: + if isinstance(key, (bytes, bytearray)): + return bytes(key) + if isinstance(key, str): + return key.encode("utf-8") + raise TypeError(f"cache keys must be bytes or str, got {type(key).__name__}") + + +# --------------------------------------------------------------------------- +# Abstract base class +# --------------------------------------------------------------------------- + + +class ProgramCacheResource(abc.ABC): + """Abstract base class for compiled-program caches. + + Concrete implementations store and retrieve :class:`~cuda.core.ObjectCode` + instances keyed by ``bytes`` or ``str``. A ``str`` key is encoded as UTF-8 + before being used, so ``"k"`` and ``b"k"`` refer to the same entry. A + typical key is produced by :func:`make_program_cache_key`, which returns + ``bytes``. + + .. warning:: + + Persistent backends use :mod:`pickle` for serialization. Only read + cache files that you trust — loading a cache from an untrusted source + can execute arbitrary code. Cache directories should have the same + access controls as any other sensitive build artifact. + + .. note:: + + Cache only bytes-backed ``ObjectCode`` instances. Path-backed objects + (created via ``ObjectCode.from_cubin("/path")``) store the path, not + the content — if the file is later modified or deleted, cache hits + will return stale or broken data. Normalize to bytes before caching. + """ + + @abc.abstractmethod + def __getitem__(self, key: bytes | str) -> ObjectCode: + """Retrieve the cached :class:`ObjectCode`. + + Raises + ------ + KeyError + If ``key`` is not in the cache. + """ + + @abc.abstractmethod + def __setitem__(self, key: bytes | str, value: ObjectCode) -> None: + """Store ``value`` under ``key``.""" + + @abc.abstractmethod + def __contains__(self, key: bytes | str) -> bool: + """Return ``True`` if ``key`` is in the cache.""" + + @abc.abstractmethod + def __delitem__(self, key: bytes | str) -> None: + """Remove the entry associated with ``key``. + + Raises + ------ + KeyError + If ``key`` is not in the cache. + """ + + @abc.abstractmethod + def __len__(self) -> int: + """Return the number of entries currently in the cache.""" + + @abc.abstractmethod + def clear(self) -> None: + """Remove every entry from the cache.""" + + def get(self, key: bytes | str, default: ObjectCode | None = None) -> ObjectCode | None: + """Return ``self[key]`` or ``default`` if absent.""" + try: + return self[key] + except KeyError: + return default + + def close(self) -> None: # noqa: B027 + """Release backend resources. No-op by default.""" + + def __enter__(self) -> ProgramCacheResource: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + + +# --------------------------------------------------------------------------- +# Key construction +# --------------------------------------------------------------------------- + + +# Bump when the key schema changes in a way that invalidates existing caches. +_KEY_SCHEMA_VERSION = 1 + +_VALID_CODE_TYPES = frozenset({"c++", "ptx", "nvvm"}) +_VALID_TARGET_TYPES = frozenset({"ptx", "cubin", "ltoir"}) + +# code_type -> allowed target_type set, mirroring Program.compile's +# SUPPORTED_TARGETS matrix in _program.pyx. +_SUPPORTED_TARGETS_BY_CODE_TYPE = { + "c++": frozenset({"ptx", "cubin", "ltoir"}), + "ptx": frozenset({"cubin", "ptx"}), + "nvvm": frozenset({"ptx", "ltoir"}), +} + + +def _backend_for_code_type(code_type: str) -> str: + if code_type == "nvvm": + return "nvvm" + if code_type == "ptx": + # Program routes PTX through Linker, not NVRTC. + return "linker" + return "nvrtc" + + +# ProgramOptions fields that reach the Linker via _translate_program_options +# (see cuda_core/cuda/core/_program.pyx). All other fields on ProgramOptions +# are NVRTC-only and must NOT perturb a PTX cache key: a PTX compile with a +# shared ProgramOptions that happens to set include_path/pch/frandom_seed +# would otherwise miss the cache unnecessarily. +_LINKER_RELEVANT_FIELDS = ( + "name", + "arch", + "max_register_count", + "time", + "link_time_optimization", + "debug", + "lineinfo", + "ftz", + "prec_div", + "prec_sqrt", + "fma", + "split_compile", + "ptxas_options", + "no_cache", +) + + +# Map each linker-relevant ProgramOptions field to the gate the Linker uses +# to turn it into a flag (see ``_prepare_nvjitlink_options`` and +# ``_prepare_driver_options`` in _linker.pyx). Collapsing inputs through +# these gates means semantically-equivalent configurations +# (``debug=False`` vs ``None``, ``time=True`` vs ``time="path"``) hash to +# the same cache key instead of forcing spurious misses. +def _gate_presence(v): + return v is not None + + +def _gate_truthy(v): + return bool(v) + + +def _gate_is_true(v): + return v is True + + +def _gate_tristate_bool(v): + return None if v is None else bool(v) + + +def _gate_identity(v): + return v + + +def _gate_ptxas_options(v): + # ``_prepare_nvjitlink_options`` emits one ``-Xptxas=`` per element, and + # treats ``str`` as a single-element sequence. Canonicalize to a tuple so + # ``"-v"`` / ``["-v"]`` / ``("-v",)`` all hash the same. An empty sequence + # emits no flags, so collapse it to ``None`` too. + if v is None: + return None + if isinstance(v, str): + return ("-Xptxas=" + v,) + if isinstance(v, collections.abc.Sequence): + if len(v) == 0: + return None + return tuple(f"-Xptxas={s}" for s in v) + return v + + +_LINKER_FIELD_GATES = { + "name": _gate_identity, + "arch": _gate_identity, + "max_register_count": _gate_identity, + "time": _gate_presence, # linker emits ``-time`` iff value is not None + "link_time_optimization": _gate_truthy, + "debug": _gate_truthy, + "lineinfo": _gate_truthy, + "ftz": _gate_tristate_bool, + "prec_div": _gate_tristate_bool, + "prec_sqrt": _gate_tristate_bool, + "fma": _gate_tristate_bool, + "split_compile": _gate_identity, + "ptxas_options": _gate_ptxas_options, + "no_cache": _gate_is_true, +} + + +# LinkerOptions fields the ``cuLink`` driver backend silently ignores +# (emits only a DeprecationWarning; no actual flag reaches the compiler). +# When the driver backend is active, collapse them to a single sentinel in +# the fingerprint so nvJitLink<->driver parity of ``ObjectCode`` doesn't +# cause cache misses from otherwise-equivalent configurations. +_DRIVER_IGNORED_LINKER_FIELDS = frozenset({"ftz", "prec_div", "prec_sqrt", "fma"}) + + +def _linker_option_fingerprint(options: ProgramOptions, *, use_driver_linker: bool | None) -> list[bytes]: + """Backend-aware fingerprint of ProgramOptions fields consumed by the Linker. + + Each field passes through the gate the Linker itself uses so equivalent + inputs (e.g. ``debug=False`` / ``None``) hash to the same bytes. When + the driver (cuLink) linker backend is in use, fields it silently + ignores collapse to one sentinel so those options don't perturb the + key on driver-backed hosts either. ``use_driver_linker=None`` means we + couldn't probe the backend; we don't collapse driver-ignored fields in + that case, to stay conservative. + """ + parts = [] + driver_ignored = use_driver_linker is True + for name in _LINKER_RELEVANT_FIELDS: + if driver_ignored and name in _DRIVER_IGNORED_LINKER_FIELDS: + parts.append(f"{name}=".encode()) + continue + gated = _LINKER_FIELD_GATES[name](getattr(options, name, None)) + parts.append(f"{name}={gated!r}".encode()) + return parts + + +# ProgramOptions fields that map to LinkerOptions fields the cuLink (driver) +# backend rejects outright (see _prepare_driver_options in _linker.pyx). +# ``split_compile_extended`` exists on LinkerOptions but is not exposed via +# ProgramOptions / _translate_program_options, so it cannot reach the driver +# linker from the cache path and is omitted here. +_DRIVER_LINKER_UNSUPPORTED_FIELDS = ("time", "ptxas_options", "split_compile") + + +def _driver_version() -> int: + return int(_handle_return(_driver.cuDriverGetVersion())) + + +def _nvrtc_version() -> tuple[int, int]: + major, minor = _handle_return(_nvrtc.nvrtcVersion()) + return int(major), int(minor) + + +def _linker_backend_and_version() -> tuple[str, str]: + """Return ``(backend, version)`` for the linker used on PTX inputs. + + Raises any underlying probe exception. ``make_program_cache_key`` catches + and mixes the exception's class name into the digest, so the same probe + failure produces the same key across processes -- the cache stays + persistent in broken environments, while never sharing a key with a + working probe (``_probe_failed`` label vs. ``driver``/``nvrtc``/...). + + nvJitLink version lookup goes through ``sys.modules`` first so we hit the + same module ``_decide_nvjitlink_or_driver()`` already loaded. That keeps + fingerprinting aligned with whichever ``cuda.bindings.nvjitlink`` import + path the linker actually uses. + """ + import sys + + from cuda.core._linker import _decide_nvjitlink_or_driver + + use_driver = _decide_nvjitlink_or_driver() + if use_driver: + return ("driver", str(_driver_version())) + nvjitlink = sys.modules.get("cuda.bindings.nvjitlink") + if nvjitlink is None: + from cuda.bindings import nvjitlink + + return ("nvJitLink", str(nvjitlink.version())) + + +def _nvvm_fingerprint() -> str: + """Stable identifier for the loaded NVVM toolchain. + + Combines the libNVVM library version (``module.version()``) with the IR + version reported by ``module.ir_version()``. The library version is the + primary invalidation lever: a libNVVM patch upgrade can change codegen + while keeping the same IR major/minor, so keying only on the IR pair + would silently reuse stale entries. Paired with cuda-core, the IR pair + adds defence in depth without making the key any less stable. + + Both calls go through ``_get_nvvm_module()`` so this fingerprint follows + the same availability / cuda-bindings-version gate that real NVVM + compilation does -- if NVVM is unusable at compile time, the probe + fails the same way and ``_probe`` mixes the failure label into the key. + """ + from cuda.core._program import _get_nvvm_module + + module = _get_nvvm_module() + lib_major, lib_minor = module.version() + major, minor, debug_major, debug_minor = module.ir_version() + return f"lib={lib_major}.{lib_minor};ir={major}.{minor}.{debug_major}.{debug_minor}" + + +# ProgramOptions fields that reference external files whose *contents* the +# cache key cannot observe without reading the filesystem. Callers that set +# any of these must supply an ``extra_digest`` covering the dependency surface +# (e.g. a hash over all reachable headers / PCH bytes). +_EXTERNAL_CONTENT_OPTIONS = ( + "include_path", + "pre_include", + "pch", + "use_pch", + "pch_dir", +) + +# ProgramOptions fields whose compilation effect is not captured in the +# returned ``ObjectCode`` -- they produce a filesystem artifact as a side +# effect. A cache hit skips compilation, so that artifact would never be +# written. Reject these outright: the persistent cache is for pure ObjectCode +# reuse, not for replaying compile-time side effects. +# * create_pch -- writes a PCH file (NVRTC). +# * time -- writes NVRTC timing info to a file. +# * fdevice_time_trace -- writes a device-compilation time trace file (NVRTC). +# These are all NVRTC-specific; the Linker's ``-time`` logs to the info log +# (not a file) and NVVM explicitly rejects all three at compile time. The +# side-effect guard is therefore gated on ``backend == "nvrtc"`` below. +_SIDE_EFFECT_OPTIONS = ("create_pch", "time", "fdevice_time_trace") + + +# ProgramOptions fields gated by plain truthiness in ``_program.pyx`` (the +# compiler writes the flag only when the value is truthy). +_BOOLEAN_OPTION_FIELDS = frozenset({"pch"}) + +# Fields whose compiler emission requires ``isinstance(value, str)`` or a +# non-empty sequence; anything else (``False``, ``int``, ``None``, ``[]``) +# is silently ignored at compile time. +_STR_OR_SEQUENCE_OPTION_FIELDS = frozenset({"include_path", "pre_include"}) + + +def _option_is_set(options: ProgramOptions, name: str) -> bool: + """Match how ``_program.pyx`` gates option emission, per field shape. + + - Boolean flags (``pch``): truthy only. + - str-or-sequence fields (``include_path``, ``pre_include``): ``str`` + (including empty) or a non-empty ``collections.abc.Sequence`` (list, + tuple, range, user subclass, ...); everything else (``False``, ``int``, + empty sequence, ``None``) is ignored by the compiler and must not + trigger a cache-time guard. + - Path/string-shaped fields (``create_pch``, ``time``, + ``fdevice_time_trace``, ``use_pch``, ``pch_dir``): ``is not None`` -- + the compiler emits ``--flag=`` for any non-None value, so + ``False`` / ``""`` / ``0`` must still count as set. + """ + value = getattr(options, name, None) + if value is None: + return False + if name in _BOOLEAN_OPTION_FIELDS: + return bool(value) + if name in _STR_OR_SEQUENCE_OPTION_FIELDS: + # Mirror ``_prepare_nvrtc_options_impl``: it checks ``isinstance(v, str)`` + # first, then ``is_sequence(v)`` (which is ``isinstance(v, Sequence)``). + # We therefore accept any ``collections.abc.Sequence`` (range, deque, + # user subclass, etc.), not just list/tuple. + if isinstance(value, str): + return True + if isinstance(value, collections.abc.Sequence): + return len(value) > 0 + return False + return True + + +def make_program_cache_key( + *, + code: str | bytes, + code_type: str, + options: ProgramOptions, + target_type: str, + name_expressions: Sequence[str | bytes | bytearray] = (), + extra_digest: bytes | None = None, +) -> bytes: + """Build a stable cache key from compile inputs. + + Parameters + ---------- + code: + Source text. ``str`` is encoded as UTF-8. + code_type: + One of ``"c++"``, ``"ptx"``, ``"nvvm"``. + options: + A :class:`cuda.core.ProgramOptions`. Its ``arch`` must be set (the + default ``ProgramOptions.__post_init__`` populates it from the current + device). + target_type: + One of ``"ptx"``, ``"cubin"``, ``"ltoir"``. + name_expressions: + Optional iterable of mangled-name lookups. Order is not significant. + Elements may be ``str``, ``bytes``, or ``bytearray``; ``"foo"`` and + ``b"foo"`` produce distinct keys because ``Program.compile`` records + the original Python object as the ``ObjectCode.symbol_mapping`` key, + and ``get_kernel`` lookups must use the same type the cache key + recorded. + extra_digest: + Caller-supplied bytes mixed into the key. Required whenever + :class:`cuda.core.ProgramOptions` sets any option that pulls in + external file content (``include_path``, ``pre_include``, ``pch``, + ``use_pch``, ``pch_dir``) -- the cache cannot read + those files on the caller's behalf, so the caller must fingerprint + the header / PCH surface and pass it here. Callers may pass this for + other inputs too (embedded kernels, generated sources, etc.). + + Returns + ------- + bytes + A 32-byte blake2b digest suitable for use as a cache key. + + Raises + ------ + ValueError + If ``options`` sets an option with compile-time side effects (such as + ``create_pch``) -- a cache hit skips compilation, so the side effect + would not occur. + ValueError + If ``extra_digest`` is ``None`` while ``options`` sets any option whose + compilation effect depends on external file content that the key + cannot otherwise observe. + """ + # Mirror Program.compile (_program.pyx Program_init lowercases code_type + # before dispatch); a caller that passes "PTX" or "C++" must get the + # same routing and the same cache key as the lowercase form. + code_type = code_type.lower() if isinstance(code_type, str) else code_type + if code_type not in _VALID_CODE_TYPES: + raise ValueError(f"code_type={code_type!r} is not supported (must be one of {sorted(_VALID_CODE_TYPES)})") + if target_type not in _VALID_TARGET_TYPES: + raise ValueError(f"target_type={target_type!r} is not supported (must be one of {sorted(_VALID_TARGET_TYPES)})") + supported_for_code = _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type] + if target_type not in supported_for_code: + raise ValueError( + f"target_type={target_type!r} is not valid for code_type={code_type!r}" + f" (supported: {sorted(supported_for_code)}). Program.compile() rejects" + f" this combination, so caching a key for it is meaningless." + ) + + backend = _backend_for_code_type(code_type) + + # Side-effect options are NVRTC-specific: ``time``/``fdevice_time_trace`` + # write artifacts via NVRTC, ``create_pch`` writes via NVRTC. The linker + # (PTX inputs) uses ``-time`` only to log to the info log (not a file), + # and NVVM explicitly rejects all three at compile time anyway, so the + # guard is only meaningful for the NVRTC path. + if backend == "nvrtc": + side_effects = [name for name in _SIDE_EFFECT_OPTIONS if _option_is_set(options, name)] + if side_effects: + raise ValueError( + f"make_program_cache_key() refuses to build a key for options that " + f"have compile-time side effects ({', '.join(side_effects)}); a " + f"cache hit skips compilation, so the side effect would not occur. " + f"Disable the option, or compile directly without the cache." + ) + + # NVVM with ``use_libdevice=True`` reads external libdevice bitcode at + # compile time (see Program_init in _program.pyx). The file is resolved + # from the active toolkit, so a changed CUDA_HOME / libdevice upgrade + # changes the linked output without touching any key input the cache can + # observe. Require the caller to supply an ``extra_digest`` that + # fingerprints the libdevice bytes (or simply disable use_libdevice for + # caching-sensitive workflows). + if backend == "nvvm" and extra_digest is None and getattr(options, "use_libdevice", None): + raise ValueError( + "make_program_cache_key() refuses to build an NVVM key with " + "use_libdevice=True and no extra_digest: the linked libdevice " + "bitcode can change out from under a cached ObjectCode. Pass an " + "extra_digest that fingerprints the libdevice file you intend " + "to link against, or disable use_libdevice." + ) + + # External-content options are NVRTC-only. ``Program.compile`` for PTX + # inputs runs ``_translate_program_options``, which drops + # ``include_path``/``pre_include``/``pch``/``use_pch``/``pch_dir`` + # entirely, and NVVM explicitly rejects them. Only NVRTC actually reads + # those external files, so gate the guard on the NVRTC backend. + if backend == "nvrtc" and extra_digest is None: + external = [name for name in _EXTERNAL_CONTENT_OPTIONS if _option_is_set(options, name)] + if external: + raise ValueError( + f"make_program_cache_key() refuses to build a key for options that " + f"pull in external file content ({', '.join(external)}) without an " + f"extra_digest; compute a digest over the header/PCH bytes the " + f"compile will read and pass it as extra_digest=..." + ) + + # PTX compiles go through Linker. When the driver (cuLink) backend is + # selected (nvJitLink unavailable), Program.compile rejects a subset of + # options that nvJitLink would accept; reject them here too so we never + # store a key for a compilation that can't succeed in this environment. + # If the probe fails we can't tell which backend will run, so skip -- the + # failed-probe branch below already taints the key. + use_driver_linker: bool | None = None + if backend == "linker": + try: + from cuda.core._linker import _decide_nvjitlink_or_driver + + use_driver_linker = _decide_nvjitlink_or_driver() + except Exception: + use_driver_linker = None + if use_driver_linker is True: + # Mirror ``_prepare_driver_options``'s exact gate: the driver + # linker checks ``is not None`` for these fields, so ``time=False`` + # or ``ptxas_options=[]`` is still a rejection. Do NOT use the + # truthiness-based ``_option_is_set`` helper here. + unsupported = [ + name for name in _DRIVER_LINKER_UNSUPPORTED_FIELDS if getattr(options, name, None) is not None + ] + if unsupported: + raise ValueError( + f"the cuLink driver linker does not support these options: " + f"{', '.join(unsupported)}; Program.compile() would reject this " + f"configuration before producing an ObjectCode." + ) + + if isinstance(code, str): + code_bytes = code.encode("utf-8") + elif isinstance(code, (bytes, bytearray)): + # Program() only accepts bytes-like ``code`` for the NVVM backend + # (_program.pyx Program_init); c++/ptx require ``str``. Mirror that + # so the cache helper doesn't mint keys for inputs the real compile + # would reject. + if backend != "nvvm": + raise TypeError( + f"code must be str for code_type={code_type!r}; bytes/bytearray are only accepted for code_type='nvvm'." + ) + code_bytes = bytes(code) + else: + raise TypeError(f"code must be str or bytes, got {type(code).__name__}") + + # For PTX inputs the Linker path reads only a subset of ProgramOptions + # (see _translate_program_options in _program.pyx); fingerprint just those + # fields so shared ProgramOptions carrying NVRTC-only flags + # (include_path, pch_*, frandom_seed, ...) don't force spurious cache + # misses on PTX. For nvrtc/nvvm backends, ProgramOptions.as_bytes gives + # the real compile-time flag surface. + if backend == "linker": + option_bytes = _linker_option_fingerprint(options, use_driver_linker=use_driver_linker) + else: + try: + option_bytes = options.as_bytes(backend, target_type) + except TypeError: + option_bytes = options.as_bytes(backend) + + # Preserve the original type of each name expression in the key: though + # ``name_expressions`` is only consumed (and only meaningful) on the + # NVRTC compile path; Program.compile silently ignores it for PTX/NVVM. + # Validation + tagging is therefore gated on the NVRTC backend so the + # cache helper doesn't reject inputs the real compile would accept. + # NVRTC tagging notes: ``"foo"`` and ``b"foo"`` get distinct tags + # because Program.compile records the original Python object as the + # ObjectCode.symbol_mapping key (_program.pyx:759), so a cached + # ObjectCode whose mapping-key type differs from what the caller's + # later ``get_kernel`` passes would silently miss. + if backend == "nvrtc": + + def _tag_name(n): + if isinstance(n, (bytes, bytearray)): + return b"b:" + bytes(n) + if isinstance(n, str): + return b"s:" + n.encode("utf-8") + raise TypeError(f"name_expressions elements must be str, bytes, or bytearray; got {type(n).__name__}") + + names = tuple(sorted(_tag_name(n) for n in name_expressions)) + else: + names = () + + hasher = hashlib.blake2b(digest_size=32) + + def _update(label: str, payload: bytes) -> None: + hasher.update(label.encode("ascii")) + hasher.update(len(payload).to_bytes(8, "big")) + hasher.update(payload) + + def _probe(label: str, fn): + """Run an environment probe; on failure, hash the exception's + CLASS NAME (not its message) under a ``*_probe_failed`` label. + + Using only the class name keeps the digest stable across repeated + calls within one process (e.g. NVVM's loader reports different + messages on first vs. cached-failure attempts) AND across processes + that hit the same failure mode. The ``_probe_failed`` label differs + from the success labels (``driver``/``nvrtc``/...), so a broken env + never collides with a working one -- the cache "fails closed" + between broken and working environments while staying persistent + within either. + """ + try: + return fn() + except Exception as exc: + _update(f"{label}_probe_failed", type(exc).__name__.encode()) + return None + + _update("schema", str(_KEY_SCHEMA_VERSION).encode("ascii")) + if backend == "nvrtc": + nvrtc_ver = _probe("nvrtc", _nvrtc_version) + if nvrtc_ver is not None: + nv_major, nv_minor = nvrtc_ver + _update("nvrtc", f"{nv_major}.{nv_minor}".encode("ascii")) + elif backend == "linker": + # Only cuLink (driver-backed linker) goes through the CUDA driver + # for codegen. nvJitLink is a separate library, so a driver upgrade + # under it does not change the compiled bytes -- skip the driver + # version there. ``_linker_backend_and_version`` already returns the + # driver version when the driver backend is active, so the bytes + # are still in the digest via ``linker_version``. + linker = _probe("linker", _linker_backend_and_version) + if linker is not None: + lb_name, lb_version = linker + _update("linker_backend", lb_name.encode("ascii")) + _update("linker_version", lb_version.encode("ascii")) + else: + nvvm_fp = _probe("nvvm", _nvvm_fingerprint) + if nvvm_fp is not None: + _update("nvvm", nvvm_fp.encode("ascii")) + _update("code_type", code_type.encode("ascii")) + _update("target_type", target_type.encode("ascii")) + _update("code", code_bytes) + _update("option_count", str(len(option_bytes)).encode("ascii")) + for opt in option_bytes: + _update("option", bytes(opt)) + # Only NVRTC consumes ``name_expressions``; Program.compile ignores them + # on the NVVM and PTX/linker paths, so folding them into the key there + # would force spurious cache misses. + if backend == "nvrtc": + _update("names_count", str(len(names)).encode("ascii")) + for n in names: + _update("name", n) + + # ``extra_sources`` is NVVM-only -- ``Program`` raises for non-NVVM + # backends (_program.pyx). Reject up front so callers get the same + # error from the cache key path as from a real compile, and only hash + # for backend == "nvvm". + extra_sources = getattr(options, "extra_sources", None) + if extra_sources is not None and backend != "nvvm": + raise ValueError( + f"extra_sources is only valid for code_type='nvvm'; Program() rejects it for code_type={code_type!r}." + ) + if extra_sources: + _update("extra_sources_count", str(len(extra_sources)).encode("ascii")) + for item in extra_sources: + # extra_sources is a sequence of (name, source) tuples. + if isinstance(item, (tuple, list)) and len(item) == 2: + name, src = item + _update("extra_source_name", str(name).encode("utf-8")) + if isinstance(src, str): + _update("extra_source_code", src.encode("utf-8")) + elif isinstance(src, (bytes, bytearray)): + _update("extra_source_code", bytes(src)) + else: + _update("extra_source_code", str(src).encode("utf-8")) + else: + # Fallback for unexpected format. + _update("extra_source", str(item).encode("utf-8")) + # ``use_libdevice`` is only consumed on the NVVM compile path + # (_program.pyx Program_init); NVRTC and PTX/linker ignore it, so + # folding it into the key there would force spurious misses. On NVVM, + # Program_init gates it on truthiness -- False and None match. + if backend == "nvvm" and getattr(options, "use_libdevice", None): + _update("use_libdevice", b"1") + + # Program.compile() propagates options.name onto the returned ObjectCode, + # so two compiles identical in everything but name produce ObjectCodes + # that differ in their public ``name`` attribute. The key must reflect + # that or a cache hit could hand back an entry with the wrong name. + options_name = getattr(options, "name", None) + if options_name is not None: + _update("options_name", str(options_name).encode("utf-8")) + + if extra_digest is not None: + _update("extra_digest", bytes(extra_digest)) + + return hasher.digest() + + +# --------------------------------------------------------------------------- +# SQLite backend +# --------------------------------------------------------------------------- + + +_SQLITE_SCHEMA_VERSION = "1" + + +class SQLiteProgramCache(ProgramCacheResource): + """Persistent program cache backed by a single sqlite3 database file. + + Suitable for single-process workflows. Multiple processes *can* share the + file (sqlite3 WAL mode serialises writes), but + :class:`FileStreamProgramCache` is the recommended choice for concurrent + workers. + + Parameters + ---------- + path: + Filesystem path to the sqlite3 database. The parent directory is + created if missing. + max_size_bytes: + Optional cap on the sum of stored payload sizes. When that total + exceeds the cap, the least-recently-used entries are evicted until + the logical total is at or below the cap; ``None`` means unbounded. + Real on-disk usage tracks the logical total *at quiescent points*: + WAL frames and freed pages are reclaimed opportunistically via + ``wal_checkpoint(TRUNCATE)`` + ``VACUUM`` after each eviction, but + ``sqlite3`` skips both under active readers or writers. With + concurrent access, the on-disk file can grow above the cap until + readers release; :class:`FileStreamProgramCache` is the right + backend for multi-process workloads with strict on-disk bounds. + """ + + def __init__( + self, + path: str | os.PathLike, + *, + max_size_bytes: int | None = None, + ) -> None: + if max_size_bytes is not None and max_size_bytes < 0: + raise ValueError("max_size_bytes must be non-negative or None") + self._path = Path(path) + self._path.parent.mkdir(parents=True, exist_ok=True) + self._max_size_bytes = max_size_bytes + import sqlite3 + + self._sqlite3 = sqlite3 + self._conn: sqlite3.Connection | None = None + # ``check_same_thread=False`` lets multiple threads reuse one cache + # object; this RLock serialises every connection-touching method so + # threads can't interleave a read/update or a write/VACUUM pair. + # Reentrant because ``clear`` calls ``_compact`` under the same lock. + self._lock = threading.RLock() + self._open() + + # -- lifecycle ----------------------------------------------------------- + + def _open(self) -> None: + # Opening a cache is "a cache is a cache" territory: a damaged file + # (non-SQLite bytes, truncated DB, unreadable header) must degrade + # to an empty cache rather than breaking the caller. On corruption- + # shaped errors we nuke the DB + its WAL/SHM companions and retry. + # OperationalError is excluded because it covers transient runtime + # states like "database is locked" / "busy" that are normal under + # multi-process sharing -- nuking those would destroy a healthy + # cache another process is mid-write into. Those propagate. + try: + self._connect_and_init() + except self._sqlite3.OperationalError: + # Lock/busy/transient: don't nuke the file, but also don't leak + # a partially-initialised connection if connect() succeeded and + # a later PRAGMA tripped the error. + if self._conn is not None: + with contextlib.suppress(Exception): + self._conn.close() + self._conn = None + raise + except self._sqlite3.DatabaseError: + if self._conn is not None: + with contextlib.suppress(Exception): + self._conn.close() + self._conn = None + with contextlib.suppress(FileNotFoundError): + self._path.unlink() + for suffix in ("-wal", "-shm"): + with contextlib.suppress(FileNotFoundError): + self._path.with_name(self._path.name + suffix).unlink() + self._connect_and_init() + + def _connect_and_init(self) -> None: + # ``isolation_level=None`` puts the connection in autocommit mode so + # each statement is its own transaction; ``check_same_thread=False`` + # lets a cache be created in one thread and used from another (writes + # are still serialised by sqlite's own lock). + self._conn = self._sqlite3.connect( + self._path, + isolation_level=None, + check_same_thread=False, + timeout=5.0, + ) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA foreign_keys=ON") + self._conn.execute("PRAGMA busy_timeout=5000") + + # Detect an existing schema version *before* creating any tables so + # a structural migration (added columns / renamed indexes) can be + # handled by dropping the old tables outright. + existing_version = None + try: + row = self._conn.execute( + "SELECT value FROM schema_meta WHERE key = ?", + ("schema_version",), + ).fetchone() + if row is not None: + existing_version = row[0] + except self._sqlite3.OperationalError: + # schema_meta doesn't exist yet -- fresh database. + existing_version = None + + if existing_version is not None and existing_version != _SQLITE_SCHEMA_VERSION: + # Drop all cache tables -- ensures a future schema change that + # alters columns/indexes can't leave legacy layout in place. + self._conn.execute("DROP TABLE IF EXISTS entries") + self._conn.execute("DROP TABLE IF EXISTS schema_meta") + + self._conn.executescript( + """ + CREATE TABLE IF NOT EXISTS schema_meta ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS entries ( + key BLOB PRIMARY KEY, + payload BLOB NOT NULL, + size_bytes INTEGER NOT NULL, + created_at REAL NOT NULL, + accessed_at REAL NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_accessed_at + ON entries(accessed_at); + """ + ) + self._conn.execute( + "INSERT OR IGNORE INTO schema_meta(key, value) VALUES (?, ?)", + ("schema_version", _SQLITE_SCHEMA_VERSION), + ) + if existing_version is not None and existing_version != _SQLITE_SCHEMA_VERSION: + # Reclaim the space freed by the DROPped tables. + self._compact() + + def _compact(self) -> None: + """Reclaim disk space after bulk deletes. + + WAL mode keeps deleted pages until a checkpoint, and the main DB file + holds freed pages until VACUUM. Without this, ``max_size_bytes`` only + bounds logical payload -- real on-disk usage grows unbounded. + """ + conn = self._require_open() + try: + conn.execute("PRAGMA wal_checkpoint(TRUNCATE)") + conn.execute("VACUUM") + except self._sqlite3.OperationalError: + # VACUUM/checkpoint can fail under active readers; next eviction + # will retry. Size-cap correctness is re-asserted above the WAL + # layer anyway -- this is opportunistic compaction. + pass + + def close(self) -> None: + with self._lock: + if self._conn is not None: + try: + self._conn.close() + finally: + self._conn = None + + def _require_open(self): + if self._conn is None: + raise RuntimeError("SQLiteProgramCache is closed") + return self._conn + + # -- mapping API --------------------------------------------------------- + + def __contains__(self, key: object) -> bool: + # Validate without routing through __getitem__: that would bump + # accessed_at and shift LRU order, so a bare membership probe could + # keep an otherwise cold entry alive. + return self._load(key, touch_lru=False) is not None + + def __getitem__(self, key: object) -> ObjectCode: + value = self._load(key, touch_lru=True) + if value is None: + raise KeyError(key) + return value + + def _load(self, key: object, *, touch_lru: bool) -> ObjectCode | None: + """Return the cached ObjectCode, pruning corrupt rows; None on miss. + + ``touch_lru=True`` updates ``accessed_at`` so real reads promote the + entry in LRU order; ``__contains__`` passes ``False`` so mere + existence checks do not change eviction priority. + """ + k = _as_key_bytes(key) + with self._lock: + conn = self._require_open() + row = conn.execute("SELECT payload FROM entries WHERE key = ?", (k,)).fetchone() + if row is None: + return None + try: + value = pickle.loads(row[0]) # noqa: S301 + except Exception: + conn.execute("DELETE FROM entries WHERE key = ?", (k,)) + return None + if not isinstance(value, ObjectCode): + conn.execute("DELETE FROM entries WHERE key = ?", (k,)) + return None + if touch_lru: + conn.execute( + "UPDATE entries SET accessed_at = ? WHERE key = ?", + (time.time(), k), + ) + return value + + def __setitem__(self, key: object, value: object) -> None: + obj = _require_object_code(value) + k = _as_key_bytes(key) + payload = pickle.dumps(obj, protocol=_PICKLE_PROTOCOL) + now = time.time() + with self._lock: + conn = self._require_open() + conn.execute( + """ + INSERT INTO entries(key, payload, size_bytes, created_at, accessed_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(key) DO UPDATE SET + payload = excluded.payload, + size_bytes = excluded.size_bytes, + accessed_at = excluded.accessed_at + """, + (k, payload, len(payload), now, now), + ) + self._enforce_size_cap() + + def __delitem__(self, key: object) -> None: + k = _as_key_bytes(key) + with self._lock: + conn = self._require_open() + cur = conn.execute("DELETE FROM entries WHERE key = ?", (k,)) + if cur.rowcount == 0: + raise KeyError(key) + + def __len__(self) -> int: + # Count only entries that would survive a real read: corrupt rows + # surface as misses on ``cache[key]``/``__contains__``, so reporting + # them here would make ``len(cache)`` disagree with both. Also + # prune what we find -- consistent with __contains__/__getitem__. + count = 0 + with self._lock: + conn = self._require_open() + rows = conn.execute("SELECT key, payload FROM entries").fetchall() + for k, payload in rows: + try: + value = pickle.loads(payload) # noqa: S301 + except Exception: + conn.execute("DELETE FROM entries WHERE key = ?", (k,)) + continue + if not isinstance(value, ObjectCode): + conn.execute("DELETE FROM entries WHERE key = ?", (k,)) + continue + count += 1 + return count + + def clear(self) -> None: + with self._lock: + self._require_open().execute("DELETE FROM entries") + self._compact() + + # -- eviction ------------------------------------------------------------ + + def _enforce_size_cap(self) -> None: + if self._max_size_bytes is None: + return + conn = self._require_open() + (total,) = conn.execute("SELECT COALESCE(SUM(size_bytes), 0) FROM entries").fetchone() + if total <= self._max_size_bytes: + return + # Delete oldest (least-recently-used) until at or under the cap. + rows: Iterable[tuple[bytes, int]] = conn.execute( + "SELECT key, size_bytes FROM entries ORDER BY accessed_at ASC" + ).fetchall() + evicted = False + for k, sz in rows: + if total <= self._max_size_bytes: + break + conn.execute("DELETE FROM entries WHERE key = ?", (k,)) + total -= sz + evicted = True + if evicted: + self._compact() + + +# --------------------------------------------------------------------------- +# FileStream backend +# --------------------------------------------------------------------------- + + +_FILESTREAM_SCHEMA_VERSION = 2 +_ENTRIES_SUBDIR = "entries" +_TMP_SUBDIR = "tmp" +_SCHEMA_FILE = "SCHEMA_VERSION" +# Temp files older than this are assumed to belong to a crashed writer and +# are eligible for cleanup. Picked large enough that no real ``os.replace`` +# write should still be in flight (writes are bounded by mkstemp + write + +# fsync + replace, all fast on healthy disks). +_TMP_STALE_AGE_SECONDS = 3600 + + +_SHARING_VIOLATION_WINERRORS = (5, 32, 33) # ERROR_ACCESS_DENIED, ERROR_SHARING_VIOLATION, ERROR_LOCK_VIOLATION +_REPLACE_RETRY_DELAYS = (0.0, 0.005, 0.010, 0.020, 0.050, 0.100) # ~185ms budget + + +def _replace_with_sharing_retry(tmp_path: Path, target: Path) -> bool: + """Atomic rename with Windows-specific retry on sharing/lock violations. + + Returns True on success. Returns False only after the retry budget is + exhausted on Windows with a genuine sharing violation -- the caller then + treats the cache write as dropped. Any other ``PermissionError`` (ACLs, + read-only dir, unexpected winerror, or any POSIX failure) propagates. + + ``ERROR_ACCESS_DENIED`` (winerror 5) is treated as a sharing violation + because Windows surfaces it when a file is held open without + ``FILE_SHARE_WRITE`` (Python's default for ``open(p, "wb")``) or while + a previous unlink is in ``PENDING_DELETE`` -- both are transient. + """ + for i, delay in enumerate(_REPLACE_RETRY_DELAYS): + if delay: + time.sleep(delay) + try: + os.replace(tmp_path, target) + return True + except PermissionError as exc: + if not _IS_WINDOWS or getattr(exc, "winerror", None) not in _SHARING_VIOLATION_WINERRORS: + raise + # Windows sharing violation; loop and try again unless this was the + # last attempt, in which case fall through and return False. + if i == len(_REPLACE_RETRY_DELAYS) - 1: + return False + return False + + +def _stat_and_read_with_sharing_retry(path: Path) -> tuple[os.stat_result, bytes]: + """Snapshot stat and read bytes, retrying briefly on Windows transient + sharing-violation ``PermissionError``. + + Reads race the rewriter's ``os.replace``: on Windows, the destination + can be momentarily inaccessible (winerror 5/32/33) while the rename + completes. Mirroring ``_replace_with_sharing_retry``'s budget keeps + transient contention from being mistaken for a real read failure. + + Raises ``FileNotFoundError`` on miss or after exhausting the Windows + sharing-retry budget. Non-Windows ``PermissionError`` propagates. + + On Windows, EACCES (errno 13) is treated as transient too: ``io.open`` + sometimes surfaces a pending-delete or share-mode mismatch as bare + EACCES with no ``winerror`` attribute, indistinguishable here from + a true sharing violation. Real ACL problems on a path the cache owns + would surface consistently; the bounded retry budget keeps the cost + of treating them as transient negligible. + """ + last_exc: BaseException | None = None + for delay in _REPLACE_RETRY_DELAYS: + if delay: + time.sleep(delay) + try: + return path.stat(), path.read_bytes() + except FileNotFoundError: + raise + except PermissionError as exc: + if not _IS_WINDOWS: + raise + winerror = getattr(exc, "winerror", None) + if winerror not in _SHARING_VIOLATION_WINERRORS and exc.errno != errno.EACCES: + raise + last_exc = exc + raise FileNotFoundError(path) from last_exc + + +def _prune_if_stat_unchanged(path: Path, st_before: os.stat_result) -> None: + """Unlink ``path`` iff its stat still matches ``st_before``. + + Guards against a cross-process race: a reader that sees a corrupt + record can have it atomically replaced (via ``os.replace``) by a + writer before the reader decides to prune. Comparing + ``(ino, size, mtime_ns)`` before and after rules out that case -- + any mismatch means someone else wrote a new file and we must not + delete their work. The residual TOCTOU window between stat and + unlink is narrow; worst case, a very-recently-written entry is + removed and the next read recompiles. + """ + try: + st_now = path.stat() + except FileNotFoundError: + return + key_before = (st_before.st_ino, st_before.st_size, st_before.st_mtime_ns) + key_now = (st_now.st_ino, st_now.st_size, st_now.st_mtime_ns) + if key_before != key_now: + return + with contextlib.suppress(FileNotFoundError): + path.unlink() + + +class FileStreamProgramCache(ProgramCacheResource): + """Persistent program cache backed by a directory of atomic files. + + Designed for multi-process use: writes stage a temporary file and then + :func:`os.replace` it into place, so concurrent readers never observe a + partially-written entry. There is no cross-process LRU tracking; size + enforcement is best-effort by file mtime. + + .. note:: **Best-effort writes.** + + On Windows, ``os.replace`` raises ``PermissionError`` (winerror + 32 / 33) when another process holds the target file open. This + backend retries with bounded backoff (~185 ms) and, if still + failing, drops the cache write silently and returns success-shaped + control flow. The next call will see no entry and recompile. POSIX + and other ``PermissionError`` codes propagate. + + .. note:: **Atomic for readers, not crash-durable.** + + Each entry's temp file is ``fsync``-ed before ``os.replace``, but + the containing directory is **not** ``fsync``-ed. A host crash + between write and the next directory commit may lose recently + added entries; surviving entries remain consistent. + + .. note:: **Cross-version sharing.** + + ``_FILESTREAM_SCHEMA_VERSION`` guards on-disk format changes: a + cache written by an incompatible version is wiped on open. Within + a single schema version, the cache is safe to share across + ``cuda.core`` patch releases because every entry's key encodes + the relevant backend/compiler/runtime fingerprints for its + compilation path (NVRTC entries pin the NVRTC version, NVVM + entries pin the libNVVM library and IR versions, PTX/linker + entries pin the chosen linker backend and its version -- and, + when the cuLink/driver backend is selected, the driver version + too; nvJitLink-backed PTX entries are deliberately driver-version + independent). + + Parameters + ---------- + path: + Directory that owns the cache. Created if missing. + max_size_bytes: + Optional soft cap on total on-disk size. Enforced opportunistically + on writes; concurrent writers may briefly exceed it. + """ + + def __init__( + self, + path: str | os.PathLike, + *, + max_size_bytes: int | None = None, + ) -> None: + if max_size_bytes is not None and max_size_bytes < 0: + raise ValueError("max_size_bytes must be non-negative or None") + self._root = Path(path) + self._entries = self._root / _ENTRIES_SUBDIR + self._tmp = self._root / _TMP_SUBDIR + self._schema_path = self._root / _SCHEMA_FILE + self._max_size_bytes = max_size_bytes + self._root.mkdir(parents=True, exist_ok=True) + self._entries.mkdir(exist_ok=True) + self._tmp.mkdir(exist_ok=True) + expected = str(_FILESTREAM_SCHEMA_VERSION) + if not self._schema_path.exists(): + self._schema_path.write_text(expected) + else: + existing = self._schema_path.read_text().strip() + if existing != expected: + # Schema mismatch: wipe incompatible entries. Losing cache + # contents is safe; returning stale/incompatible pickles is not. + for entry in list(self._iter_entry_paths()): + with contextlib.suppress(FileNotFoundError): + entry.unlink() + self._schema_path.write_text(expected) + # Opportunistic startup sweep of orphaned temp files left by any + # crashed writers. Age-based so concurrent in-flight writes from + # other processes are preserved. + self._sweep_stale_tmp_files() + + # -- key-to-path helpers ------------------------------------------------- + + def _path_for_key(self, key: object) -> Path: + k = _as_key_bytes(key) + # Hash the key to a fixed-length identifier so arbitrary-length user + # keys never exceed per-component filename limits (typically 255 on + # ext4 / NTFS). The original key is still stored inside the pickled + # record and verified on read, so two distinct keys cannot collide + # silently (hash collision would surface as KeyError via key mismatch). + digest = hashlib.blake2b(k, digest_size=32).hexdigest() if k else "empty" + if len(digest) < 3: + digest = digest.rjust(3, "0") + return self._entries / digest[:2] / digest[2:] + + # -- mapping API --------------------------------------------------------- + + def __contains__(self, key: object) -> bool: + # Route through __getitem__ so corrupt records / schema mismatches / + # stored-key mismatches are treated as absent (and pruned), matching + # the semantics of ``cache[key]``. + try: + self[key] + except KeyError: + return False + return True + + def __getitem__(self, key: object) -> ObjectCode: + path = self._path_for_key(key) + try: + # Snapshot stat *before* read so we can detect a concurrent + # os.replace during the read/parse window; a stale stat means + # another writer wrote a fresh file that must not be pruned. + # The helper retries on Windows transient sharing-violation + # PermissionErrors so a racing rewriter does not turn a hit + # into a spurious propagated error. + st_before, data = _stat_and_read_with_sharing_retry(path) + except FileNotFoundError: + raise KeyError(key) from None + k = _as_key_bytes(key) + try: + record = pickle.loads(data) # noqa: S301 + schema, stored_key, payload, _created_at = record + if schema != _FILESTREAM_SCHEMA_VERSION: + raise ValueError(f"unknown schema {schema}") + if stored_key != k: + raise ValueError("key mismatch") + value = pickle.loads(payload) # noqa: S301 + except Exception: + _prune_if_stat_unchanged(path, st_before) + raise KeyError(key) from None + if not isinstance(value, ObjectCode): + _prune_if_stat_unchanged(path, st_before) + raise KeyError(key) from None + return value + + def __setitem__(self, key: object, value: object) -> None: + obj = _require_object_code(value) + k = _as_key_bytes(key) + payload = pickle.dumps(obj, protocol=_PICKLE_PROTOCOL) + record = pickle.dumps( + (_FILESTREAM_SCHEMA_VERSION, k, payload, time.time()), + protocol=_PICKLE_PROTOCOL, + ) + + target = self._path_for_key(key) + target.parent.mkdir(parents=True, exist_ok=True) + + fd, tmp_name = tempfile.mkstemp(prefix="entry-", dir=self._tmp) + tmp_path = Path(tmp_name) + try: + with os.fdopen(fd, "wb") as fh: + fh.write(record) + fh.flush() + os.fsync(fh.fileno()) + # Retry os.replace under Windows sharing/lock violations; only + # give up (and drop the cache write) after a bounded backoff, so + # transient contention is not turned into a silent miss. + # Non-sharing PermissionErrors and all POSIX PermissionErrors + # propagate immediately (real config problem). + if not _replace_with_sharing_retry(tmp_path, target): + with contextlib.suppress(FileNotFoundError): + tmp_path.unlink() + return + except BaseException: + with contextlib.suppress(FileNotFoundError): + tmp_path.unlink() + raise + self._enforce_size_cap() + + def __delitem__(self, key: object) -> None: + path = self._path_for_key(key) + try: + path.unlink() + except FileNotFoundError: + raise KeyError(key) from None + + def __len__(self) -> int: + # Count only entries that would survive a real read; corrupt files + # surface as misses on ``cache[key]`` / ``__contains__``, so + # reporting them here would make ``len(cache)`` disagree with both. + # Pruning is stat-guarded to avoid racing a concurrent writer that + # has just os.replace-d a fresh entry into the same path. Validate + # the same way ``__getitem__`` does, including the stored_key -> + # path mapping check. + count = 0 + for path in list(self._iter_entry_paths()): + try: + st_before, data = _stat_and_read_with_sharing_retry(path) + except FileNotFoundError: + continue + try: + record = pickle.loads(data) # noqa: S301 + schema, stored_key, payload, _created_at = record + if schema != _FILESTREAM_SCHEMA_VERSION: + raise ValueError("schema mismatch") + if self._path_for_key(stored_key) != path: + raise ValueError("stored_key does not map to this path") + value = pickle.loads(payload) # noqa: S301 + except Exception: + _prune_if_stat_unchanged(path, st_before) + continue + if not isinstance(value, ObjectCode): + _prune_if_stat_unchanged(path, st_before) + continue + count += 1 + return count + + def clear(self) -> None: + # Snapshot stat alongside path so we can refuse to unlink an entry + # that was concurrently replaced by another process between the + # snapshot scan and the unlink. Same stat-guard contract as + # ``_prune_if_stat_unchanged`` and ``_enforce_size_cap``. + snapshot = [] + for path in self._iter_entry_paths(): + try: + snapshot.append((path, path.stat())) + except FileNotFoundError: + continue + for path, st_before in snapshot: + _prune_if_stat_unchanged(path, st_before) + # Sweep ONLY stale temp files. Deleting a young temp would race with + # another process between ``mkstemp`` and ``os.replace`` and turn its + # write into ``FileNotFoundError`` instead of a successful commit. + self._sweep_stale_tmp_files() + # Remove empty subdirs (best-effort; concurrent writers may re-create). + if self._entries.exists(): + for sub in sorted(self._entries.iterdir(), reverse=True): + if sub.is_dir(): + with contextlib.suppress(OSError): + sub.rmdir() + + # -- internals ----------------------------------------------------------- + + def _iter_entry_paths(self) -> Iterable[Path]: + if not self._entries.exists(): + return + for sub in self._entries.iterdir(): + if not sub.is_dir(): + continue + for entry in sub.iterdir(): + if entry.is_file(): + yield entry + + def _sweep_stale_tmp_files(self) -> None: + """Remove temp files left behind by crashed writers. + + Age threshold is conservative (``_TMP_STALE_AGE_SECONDS``) so an + in-flight write from another process is not interrupted. Best + effort: a missing file or a permission failure is ignored. + """ + if not self._tmp.exists(): + return + cutoff = time.time() - _TMP_STALE_AGE_SECONDS + for tmp in self._tmp.iterdir(): + if not tmp.is_file(): + continue + try: + if tmp.stat().st_mtime < cutoff: + tmp.unlink() + except (FileNotFoundError, PermissionError): + continue + + def _enforce_size_cap(self) -> None: + if self._max_size_bytes is None: + return + # Sweep stale temp files first so a long-dead writer's leftovers + # don't drag the apparent size up and force needless eviction. + self._sweep_stale_tmp_files() + entries = [] + total = 0 + # Count both committed entries AND surviving temp files: temp files + # occupy disk too, even if they're young. Without this the soft cap + # silently undercounts in-flight writes. + for path in self._iter_entry_paths(): + try: + st = path.stat() + except FileNotFoundError: + continue + # Carry the full stat so eviction can guard against a concurrent + # os.replace that swapped a fresh entry into this path between + # snapshot and unlink. + entries.append((st.st_mtime, st.st_size, path, st)) + total += st.st_size + if self._tmp.exists(): + for tmp in self._tmp.iterdir(): + if not tmp.is_file(): + continue + try: + total += tmp.stat().st_size + except FileNotFoundError: + continue + if total <= self._max_size_bytes: + return + entries.sort(key=lambda e: e[0]) # oldest mtime first + for _mtime, size, path, st_before in entries: + if total <= self._max_size_bytes: + return + # _prune_if_stat_unchanged refuses if a writer replaced the file + # between snapshot and now, so eviction can't silently delete a + # freshly-committed entry from another process. + try: + stat_now = path.stat() + except FileNotFoundError: + total -= size + continue + if (stat_now.st_ino, stat_now.st_size, stat_now.st_mtime_ns) != ( + st_before.st_ino, + st_before.st_size, + st_before.st_mtime_ns, + ): + # File was replaced -- don't unlink, but update ``total`` to + # reflect the replacement's actual size or the cap check + # below could declare us done while still over the limit. + total += stat_now.st_size - size + continue + with contextlib.suppress(FileNotFoundError): + path.unlink() + total -= size diff --git a/cuda_core/docs/source/api.rst b/cuda_core/docs/source/api.rst index 005866ddb2..269c8aab9b 100644 --- a/cuda_core/docs/source/api.rst +++ b/cuda_core/docs/source/api.rst @@ -271,7 +271,18 @@ Utility functions :toctree: generated/ args_viewable_as_strided_memory + make_program_cache_key :template: autosummary/cyclass.rst StridedMemoryView + +Program caches +-------------- + +.. autosummary:: + :toctree: generated/ + + ProgramCacheResource + SQLiteProgramCache + FileStreamProgramCache diff --git a/cuda_core/tests/test_program_cache.py b/cuda_core/tests/test_program_cache.py new file mode 100644 index 0000000000..8310e35acb --- /dev/null +++ b/cuda_core/tests/test_program_cache.py @@ -0,0 +1,2022 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import abc +import time + +import pytest + +try: + import sqlite3 # noqa: F401 + + _has_sqlite3 = True +except ImportError: + _has_sqlite3 = False + +needs_sqlite3 = pytest.mark.skipif(not _has_sqlite3, reason="libsqlite3 not available") + + +def test_program_cache_resource_is_abstract(): + from cuda.core.utils import ProgramCacheResource + + assert issubclass(ProgramCacheResource, abc.ABC) + with pytest.raises(TypeError, match="abstract"): + ProgramCacheResource() + + +def test_program_cache_resource_requires_core_methods(): + from cuda.core.utils import ProgramCacheResource + + required = { + "__getitem__", + "__setitem__", + "__contains__", + "__delitem__", + "__len__", + "clear", + } + assert required <= ProgramCacheResource.__abstractmethods__ + + +def _build_empty_subclass(): + from cuda.core.utils import ProgramCacheResource + + class _Empty(ProgramCacheResource): + def __getitem__(self, key): + raise KeyError(key) + + def __setitem__(self, key, value): + pass + + def __contains__(self, key): + return False + + def __delitem__(self, key): + raise KeyError(key) + + def __len__(self): + return 0 + + def clear(self): + pass + + return _Empty + + +def test_program_cache_resource_default_get_returns_default_on_miss(): + sentinel = object() + cache = _build_empty_subclass()() + assert cache.get(b"missing", default=sentinel) is sentinel + + +def test_program_cache_resource_default_get_returns_none_without_default(): + cache = _build_empty_subclass()() + assert cache.get(b"missing") is None + + +def test_program_cache_resource_close_is_noop_by_default(): + cache = _build_empty_subclass()() + cache.close() # does not raise + + +def test_program_cache_resource_context_manager_closes(): + from cuda.core.utils import ProgramCacheResource + + closed = [] + + class _Tracked(ProgramCacheResource): + def __getitem__(self, key): + raise KeyError(key) + + def __setitem__(self, key, value): + pass + + def __contains__(self, key): + return False + + def __delitem__(self, key): + raise KeyError(key) + + def __len__(self): + return 0 + + def clear(self): + pass + + def close(self): + closed.append(True) + + with _Tracked(): + pass + assert closed == [True] + + +def test_cuda_core_utils_memoryview_import_is_lightweight(tmp_path): + """``from cuda.core.utils import StridedMemoryView`` must NOT transitively + import the program-cache backends; the cache modules pull in extra + driver/NVRTC machinery that memoryview-only consumers have no reason + to load.""" + import subprocess + import sys + import textwrap + + prog = textwrap.dedent(""" + import sys + # Touch the memoryview-only API. + from cuda.core.utils import StridedMemoryView, args_viewable_as_strided_memory # noqa: F401 + assert "cuda.core.utils._program_cache" not in sys.modules, ( + "importing the memoryview shim eagerly imported the cache backend: " + + str([m for m in sys.modules if m.startswith("cuda.core.utils")]) + ) + # And that the lazy attr still works on first access. + import cuda.core.utils as u + _ = u.make_program_cache_key + assert "cuda.core.utils._program_cache" in sys.modules + """) + # Run from a neutral cwd so Python's implicit ``sys.path[0]=''`` does not + # resolve to the unbuilt cuda_core source tree (which lacks the + # setuptools-scm-generated ``_version.py``). The subprocess must import + # the installed cuda.core from site-packages. + subprocess.run([sys.executable, "-c", prog], check=True, cwd=str(tmp_path)) # noqa: S603 + + +# --------------------------------------------------------------------------- +# make_program_cache_key +# --------------------------------------------------------------------------- + + +def _opts(**kw): + from cuda.core import ProgramOptions + + kw.setdefault("arch", "sm_80") + return ProgramOptions(**kw) + + +def _make_key(**overrides): + """Call ``make_program_cache_key`` with a sensible default baseline. + + Tests only need to state the field(s) they care about; everything + unspecified defaults to a valid cubin-from-c++ compile over "a".""" + from cuda.core.utils import make_program_cache_key + + base = dict(code="a", code_type="c++", options=_opts(), target_type="cubin") + return make_program_cache_key(**{**base, **overrides}) + + +def test_make_program_cache_key_returns_bytes(): + key = _make_key() + assert isinstance(key, bytes) + assert len(key) == 32 + + +@pytest.mark.parametrize("code_type, code", [("c++", "void k(){}"), ("ptx", ".version 7.0")]) +def test_make_program_cache_key_is_deterministic(code_type, code): + assert _make_key(code=code, code_type=code_type) == _make_key(code=code, code_type=code_type) + + +def test_make_program_cache_key_accepts_bytes_code(): + # NVVM IR is bytes; accept both str and bytes equivalently (str is UTF-8). + k_str = _make_key(code="abc", code_type="nvvm", target_type="ptx") + k_bytes = _make_key(code=b"abc", code_type="nvvm", target_type="ptx") + assert k_str == k_bytes + + +@pytest.mark.parametrize( + "a, b", + [ + pytest.param({"code": "a"}, {"code": "b"}, id="code"), + pytest.param({"target_type": "ptx"}, {"target_type": "cubin"}, id="target_type"), + pytest.param({"options": _opts(arch="sm_80")}, {"options": _opts(arch="sm_90")}, id="arch"), + pytest.param( + {"options": _opts(use_fast_math=True)}, + {"options": _opts(use_fast_math=False)}, + id="option", + ), + pytest.param( + {"options": _opts(name="kernel-a")}, + {"options": _opts(name="kernel-b")}, + id="options.name", + ), + # no extra_digest vs some digest -- adding a digest must perturb the key. + pytest.param({}, {"extra_digest": b"\x01" * 32}, id="extra_digest_added"), + pytest.param( + {"extra_digest": b"\x01" * 32}, + {"extra_digest": b"\x02" * 32}, + id="extra_digest_value", + ), + ], +) +def test_make_program_cache_key_differs_on(a, b): + """Every invalidation axis: code, target, arch, option flag, options.name, + extra_digest presence and value.""" + assert _make_key(**a) != _make_key(**b) + + +@pytest.mark.parametrize( + "first, second", + [ + pytest.param(("driver", "13200"), ("nvJitLink", "12030"), id="backend_flip"), + pytest.param(("nvJitLink", "12030"), ("nvJitLink", "12040"), id="version_bump"), + ], +) +def test_make_program_cache_key_ptx_linker_probe_changes(first, second, monkeypatch): + """PTX keys must reflect both the linker backend choice (nvJitLink vs + driver) and its version.""" + from cuda.core.utils import _program_cache + + monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: first) + k1 = _make_key(code=".version 7.0", code_type="ptx") + monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: second) + k2 = _make_key(code=".version 7.0", code_type="ptx") + assert k1 != k2 + + +def test_make_program_cache_key_name_expressions_order_insensitive(): + assert _make_key(name_expressions=("f", "g")) == _make_key(name_expressions=("g", "f")) + + +@pytest.mark.parametrize("bad", [123, 1.5, object(), None]) +def test_make_program_cache_key_rejects_invalid_name_expressions_element(bad): + """For NVRTC, Program.compile only forwards str/bytes name_expressions; + persisting a key for an invalid input is just a foot-gun. Reject up front.""" + with pytest.raises(TypeError, match="name_expressions"): + _make_key(name_expressions=("ok", bad)) + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("ptx", ".version 7.0", "cubin", id="ptx"), + pytest.param("nvvm", "abc", "ptx", id="nvvm"), + ], +) +def test_make_program_cache_key_ignores_invalid_name_expressions_for_non_nvrtc(code_type, code, target_type): + """Program.compile silently ignores name_expressions on PTX/NVVM, so + the cache helper must not reject invalid elements there either -- + otherwise legitimate non-NVRTC compiles fail the cache layer.""" + # Should not raise even though 123 isn't a valid NVRTC name. + _make_key(code=code, code_type=code_type, target_type=target_type, name_expressions=(123, object())) + + +@pytest.mark.parametrize("code_type", ["PTX", "C++", "NVVM", "Ptx", "c++"]) +def test_make_program_cache_key_normalises_code_type_case(code_type): + """Program() normalises code_type to lower; the cache helper must do + the same so callers using ``Program(code, "PTX")`` get the same routing + and the same key as the lowercase form.""" + # Pick a target valid for any of the lowered code types. + if code_type.lower() == "nvvm": + target = "ptx" + code = "abc" + elif code_type.lower() == "ptx": + target = "cubin" + code = ".version 7.0" + else: + target = "cubin" + code = "void k(){}" + upper_key = _make_key(code=code, code_type=code_type, target_type=target) + lower_key = _make_key(code=code, code_type=code_type.lower(), target_type=target) + assert upper_key == lower_key + + +def test_make_program_cache_key_name_expressions_str_bytes_distinct(): + """``Program.compile`` records the *original* Python object as the key in + ``ObjectCode.symbol_mapping``. Returning a cached ObjectCode whose + mapping-key type differs from the caller's later ``get_kernel`` lookup + would silently miss, so ``"foo"`` and ``b"foo"`` must produce distinct + cache keys.""" + assert _make_key(name_expressions=("foo",)) != _make_key(name_expressions=(b"foo",)) + + +@pytest.mark.parametrize( + "code_type, target_type", + [ + pytest.param("c++", "cubin", id="nvrtc"), + pytest.param("ptx", "cubin", id="ptx"), + ], +) +def test_make_program_cache_key_rejects_bytes_code_outside_nvvm(code_type, target_type): + """``Program()`` only accepts bytes-like code for NVVM; c++ and PTX + require str. The cache helper must mirror that rejection.""" + with pytest.raises(TypeError, match="code must be str for code_type"): + _make_key(code=b"abc", code_type=code_type, target_type=target_type) + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("c++", "void k(){}", "cubin", id="nvrtc"), + pytest.param("ptx", ".version 7.0", "cubin", id="ptx"), + ], +) +def test_make_program_cache_key_rejects_extra_sources_outside_nvvm(code_type, code, target_type): + """``Program(code, code_type)`` rejects ``extra_sources`` for non-NVVM + backends. The cache key path should mirror that and not silently + fingerprint a configuration the real compile would refuse.""" + with pytest.raises(ValueError, match="extra_sources"): + _make_key( + code=code, + code_type=code_type, + target_type=target_type, + options=_opts(extra_sources=[("foo.cu", "int x = 0;")]), + ) + + +@pytest.mark.parametrize( + "kwargs, exc_type, match", + [ + pytest.param({"code_type": "fortran"}, ValueError, "code_type", id="unknown_code_type"), + pytest.param({"target_type": "exe"}, ValueError, "target_type", id="unknown_target_type"), + pytest.param({"code": 12345}, TypeError, "code", id="non_str_bytes_code"), + # Backend-specific target matrix -- Program.compile rejects these + # combinations, so caching a key for them would be a lie. + pytest.param( + {"code_type": "ptx", "target_type": "ltoir"}, + ValueError, + "not valid for code_type", + id="ptx_cannot_ltoir", + ), + pytest.param( + {"code_type": "nvvm", "target_type": "cubin"}, + ValueError, + "not valid for code_type", + id="nvvm_cannot_cubin", + ), + ], +) +def test_make_program_cache_key_rejects(kwargs, exc_type, match): + with pytest.raises(exc_type, match=match): + _make_key(**kwargs) + + +def test_make_program_cache_key_supported_targets_matches_program_compile(): + """``_SUPPORTED_TARGETS_BY_CODE_TYPE`` duplicates the backend target + matrix in ``_program.pyx``. Guard against drift: parse the pyx source + with :mod:`tokenize` (which skips string literals and comments) to + extract ``SUPPORTED_TARGETS`` and assert the two views agree.""" + import ast + import io + import tokenize + from pathlib import Path + + from cuda.core.utils._program_cache import _SUPPORTED_TARGETS_BY_CODE_TYPE + + backend_to_code_type = {"NVRTC": "c++", "NVVM": "nvvm"} + linker_backends = ("nvJitLink", "driver") + + pyx = Path(__file__).parent.parent / "cuda" / "core" / "_program.pyx" + text = pyx.read_text() + marker_idx = text.index("cdef dict SUPPORTED_TARGETS") + tokens = tokenize.generate_tokens(io.StringIO(text[marker_idx:]).readline) + + depth = 0 + start_offset = None + end_offset = None + lines = text[marker_idx:].splitlines(keepends=True) + line_starts = [0] + for line in lines[:-1]: + line_starts.append(line_starts[-1] + len(line)) + + def _offset(row, col): + return line_starts[row - 1] + col + + for tok in tokens: + if tok.type != tokenize.OP: + continue + if tok.string == "{": + if depth == 0: + start_offset = _offset(tok.start[0], tok.start[1]) + depth += 1 + elif tok.string == "}": + depth -= 1 + if depth == 0: + end_offset = _offset(tok.end[0], tok.end[1]) + break + assert start_offset is not None and end_offset is not None, "could not locate SUPPORTED_TARGETS literal" + pyx_targets = ast.literal_eval(text[marker_idx + start_offset : marker_idx + end_offset]) + + for backend, code_type in backend_to_code_type.items(): + assert frozenset(pyx_targets[backend]) == _SUPPORTED_TARGETS_BY_CODE_TYPE[code_type], ( + backend, + code_type, + ) + linker_sets = [frozenset(pyx_targets[b]) for b in linker_backends] + assert all(s == linker_sets[0] for s in linker_sets) + assert linker_sets[0] == _SUPPORTED_TARGETS_BY_CODE_TYPE["ptx"] + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("nvvm", "abc", "ptx", id="nvvm"), + pytest.param("ptx", ".version 7.0", "cubin", id="ptx"), + ], +) +def test_make_program_cache_key_ignores_name_expressions_for_non_nvrtc(code_type, code, target_type): + """Program.compile only forwards ``name_expressions`` on the NVRTC path + (_program.pyx). Folding them into the key for NVVM/PTX compiles would + cause identical compiles to miss the cache for no behavioural reason.""" + k_none = _make_key(code=code, code_type=code_type, target_type=target_type) + k_with = _make_key(code=code, code_type=code_type, target_type=target_type, name_expressions=("foo", "bar")) + assert k_none == k_with + + +@pytest.mark.parametrize( + "a, b", + [ + # ``debug`` / ``lineinfo`` / ``link_time_optimization`` are truthy-only + # gates in the linker; False and None produce identical output. + pytest.param({"debug": False}, {"debug": None}, id="debug_false_eq_none"), + pytest.param({"lineinfo": False}, {"lineinfo": None}, id="lineinfo_false_eq_none"), + pytest.param( + {"link_time_optimization": False}, + {"link_time_optimization": None}, + id="lto_false_eq_none", + ), + # ``time`` is a presence gate: the linker emits ``-time`` for any + # non-None value, so True / "path" produce the same flag. + pytest.param({"time": True}, {"time": "timing.csv"}, id="time_true_eq_path"), + # ``no_cache`` has an ``is True`` gate; False and None equivalent. + pytest.param({"no_cache": False}, {"no_cache": None}, id="no_cache_false_eq_none"), + ], +) +def test_make_program_cache_key_ptx_linker_equivalent_options_hash_same(a, b, monkeypatch): + """The linker folds several PTX-relevant fields through simple gates: + truthy-only (``debug``, ``lineinfo``, ``link_time_optimization``), + presence-only (``time``), ``is True`` (``no_cache``). Semantically + equivalent inputs under those gates must hash to the same key.""" + # Pin the linker probe so the only variable is the options gate. + from cuda.core.utils import _program_cache + + monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: ("nvJitLink", "12030")) + k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**a)) + k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**b)) + assert k_a == k_b + + +@pytest.mark.parametrize( + "field, a, b", + [ + pytest.param("ftz", True, False, id="ftz"), + pytest.param("prec_div", True, False, id="prec_div"), + pytest.param("prec_sqrt", True, False, id="prec_sqrt"), + pytest.param("fma", True, False, id="fma"), + ], +) +def test_make_program_cache_key_ptx_driver_ignored_fields_collapse(field, a, b, monkeypatch): + """The driver (cuLink) linker silently ignores ftz/prec_div/prec_sqrt/fma + (only emits a DeprecationWarning). Under the driver backend, those + fields must not perturb the PTX cache key -- two otherwise-equivalent + compiles differing only in these flags produce identical ObjectCode.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True) # driver + k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**{field: a})) + k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(**{field: b})) + assert k_a == k_b + + +@pytest.mark.parametrize( + "a, b", + [ + pytest.param("-v", ["-v"], id="str_vs_list"), + pytest.param("-v", ("-v",), id="str_vs_tuple"), + pytest.param(["-v"], ("-v",), id="list_vs_tuple"), + # Empty sequence emits no -Xptxas flags; must match None. + pytest.param(None, [], id="none_vs_empty_list"), + pytest.param(None, (), id="none_vs_empty_tuple"), + pytest.param([], (), id="empty_list_vs_empty_tuple"), + ], +) +def test_make_program_cache_key_ptx_ptxas_options_canonicalized(a, b, monkeypatch): + """_prepare_nvjitlink_options emits the same -Xptxas= flags for str, + list, and tuple shapes of ptxas_options. The cache key must treat them + as equivalent so equivalent compiles don't miss the cache.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink + k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ptxas_options=a)) + k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ptxas_options=b)) + assert k_a == k_b + + +def test_make_program_cache_key_ptx_driver_ignored_fields_still_matter_under_nvjitlink(monkeypatch): + """nvJitLink does honour those fields; they must still differentiate keys there.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink + k_a = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ftz=True)) + k_b = _make_key(code=".version 7.0", code_type="ptx", options=_opts(ftz=False)) + assert k_a != k_b + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("c++", "void k(){}", "cubin", id="nvrtc"), + pytest.param("ptx", ".version 7.0", "cubin", id="ptx"), + ], +) +def test_make_program_cache_key_use_libdevice_ignored_for_non_nvvm(code_type, code, target_type): + """``use_libdevice`` is only consumed on the NVVM path; NVRTC and PTX + ignore it, so toggling it must not perturb the cache key elsewhere.""" + k_off = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=False)) + k_on = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=True)) + k_none = _make_key(code=code, code_type=code_type, target_type=target_type, options=_opts(use_libdevice=None)) + assert k_off == k_on == k_none + + +def test_make_program_cache_key_nvvm_use_libdevice_requires_extra_digest(): + """NVVM with ``use_libdevice=True`` links an external libdevice bitcode + file whose contents the cache can't observe; require an extra_digest + or the cached ObjectCode can silently drift under a toolkit upgrade.""" + from cuda.core.utils import make_program_cache_key + + with pytest.raises(ValueError, match="libdevice"): + make_program_cache_key( + code="abc", + code_type="nvvm", + options=_opts(use_libdevice=True), + target_type="ptx", + ) + # With an extra_digest, it's accepted; different digests produce + # different keys so a caller can represent a libdevice change. + k_a = make_program_cache_key( + code="abc", + code_type="nvvm", + options=_opts(use_libdevice=True), + target_type="ptx", + extra_digest=b"libdev-a" * 4, + ) + k_b = make_program_cache_key( + code="abc", + code_type="nvvm", + options=_opts(use_libdevice=True), + target_type="ptx", + extra_digest=b"libdev-b" * 4, + ) + assert k_a != k_b + + +def test_make_program_cache_key_nvvm_use_libdevice_false_equals_none(): + """Program_init gates ``use_libdevice`` on truthiness, so False and None + compile identically and must hash the same way. (True without an + extra_digest is rejected; see test_...requires_extra_digest.)""" + k_none = _make_key(code="abc", code_type="nvvm", target_type="ptx", options=_opts(use_libdevice=None)) + k_false = _make_key(code="abc", code_type="nvvm", target_type="ptx", options=_opts(use_libdevice=False)) + assert k_none == k_false + # With an explicit extra_digest, True produces a different key. + k_true = _make_key( + code="abc", + code_type="nvvm", + target_type="ptx", + options=_opts(use_libdevice=True), + extra_digest=b"libdev" * 8, + ) + assert k_true != k_none + + +def test_make_program_cache_key_nvvm_library_version_changes_key(monkeypatch): + """Updating libNVVM (different ``module.version()``) must invalidate + NVVM cache entries even when the IR version stays constant; a patch + upgrade can change codegen without bumping the IR pair.""" + + class _FakeNVVM: + def __init__(self, lib_version): + self._lib_version = lib_version + + def version(self): + return self._lib_version + + def ir_version(self): + return (1, 8, 3, 0) # constant -- only the lib version varies + + fake_old = _FakeNVVM((12, 3)) + fake_new = _FakeNVVM((12, 4)) + from cuda.core import _program + + monkeypatch.setattr(_program, "_get_nvvm_module", lambda: fake_old) + k_old = _make_key(code="abc", code_type="nvvm", target_type="ptx") + monkeypatch.setattr(_program, "_get_nvvm_module", lambda: fake_new) + k_new = _make_key(code="abc", code_type="nvvm", target_type="ptx") + assert k_old != k_new + + +def test_make_program_cache_key_nvvm_fingerprint_uses_get_nvvm_module(monkeypatch): + """The fingerprint must call _get_nvvm_module() rather than importing + cuda.bindings.nvvm directly -- otherwise it bypasses the availability + /cuda-bindings-version gate and could disagree with the actual NVVM + compile path.""" + sentinel_called = {"n": 0} + + class _SentinelNVVM: + def version(self): + sentinel_called["n"] += 1 + return (12, 9) + + def ir_version(self): + return (1, 8, 3, 0) + + from cuda.core import _program + + monkeypatch.setattr(_program, "_get_nvvm_module", lambda: _SentinelNVVM()) + _make_key(code="abc", code_type="nvvm", target_type="ptx") + assert sentinel_called["n"] == 1 + + +def test_make_program_cache_key_nvvm_probe_changes_key(monkeypatch): + """NVVM keys must reflect the NVVM toolchain identity (IR version) + so an upgraded libNVVM does not silently reuse pre-upgrade entries.""" + from cuda.core.utils import _program_cache + + monkeypatch.setattr(_program_cache, "_nvvm_fingerprint", lambda: "ir=1.8.3.0") + k1 = _make_key(code="abc", code_type="nvvm", target_type="ptx") + monkeypatch.setattr(_program_cache, "_nvvm_fingerprint", lambda: "ir=2.0.3.0") + k2 = _make_key(code="abc", code_type="nvvm", target_type="ptx") + assert k1 != k2 + + +@pytest.mark.parametrize( + "option_kw", + [ + pytest.param({"time": True}, id="time_true"), + # ``_prepare_driver_options`` checks ``is not None``, so even the + # "falsy-but-set" cases must still be rejected at key time. + pytest.param({"time": False}, id="time_false"), + pytest.param({"ptxas_options": "-v"}, id="ptxas_options_str"), + pytest.param({"ptxas_options": ["-v", "-O2"]}, id="ptxas_options_list"), + pytest.param({"ptxas_options": []}, id="ptxas_options_empty_list"), + # ProgramOptions.ptxas_options also accepts tuples (and frozenset () + # literal is falsy). Lock in parity for all accepted shapes. + pytest.param({"ptxas_options": ("-v",)}, id="ptxas_options_tuple"), + pytest.param({"ptxas_options": ()}, id="ptxas_options_empty_tuple"), + pytest.param({"split_compile": 0}, id="split_compile_zero"), + pytest.param({"split_compile": 4}, id="split_compile_nonzero"), + # split_compile_extended is a LinkerOptions-only field; ProgramOptions + # does not expose it, so it cannot reach the driver linker via + # Program.compile and is not part of the cache-time guard. + ], +) +def test_make_program_cache_key_ptx_rejects_driver_linker_unsupported(option_kw, monkeypatch): + """When the driver (cuLink) linker backend is selected, options that + ``_prepare_driver_options`` rejects must also be rejected at key time + so we never cache a compilation that would fail. Uses ``is not None`` + to exactly mirror the driver-linker's own gate.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True) # driver + with pytest.raises(ValueError, match="driver linker"): + _make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw)) + + +def test_make_program_cache_key_ptx_accepts_driver_linker_unsupported_with_nvjitlink(monkeypatch): + """Under nvJitLink those same options are valid and must not be + rejected at key time.""" + from cuda.core import _linker + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: False) # nvJitLink + # Should not raise. + _make_key(code=".version 7.0", code_type="ptx", options=_opts(time=True)) + + +def test_filestream_cache_replace_retries_on_sharing_violation(tmp_path, monkeypatch): + """Under Windows sharing/lock violations, os.replace is retried with a + bounded backoff; a transient violation that clears within the budget + must still produce a successful cache write.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_IS_WINDOWS", True) + + real_replace = _os.replace + calls = {"n": 0} + + def _flaky_replace(src, dst): + calls["n"] += 1 + if calls["n"] < 3: + exc = PermissionError("sharing violation") + exc.winerror = 32 + raise exc + return real_replace(src, dst) + + with FileStreamProgramCache(tmp_path / "fc") as cache: + monkeypatch.setattr(_os, "replace", _flaky_replace) + cache[b"k"] = _fake_object_code(b"v") # succeeds on third attempt + assert calls["n"] == 3 + assert bytes(cache[b"k"].code) == b"v" + + +@pytest.mark.parametrize( + "option_kw", + [ + # Populated path-like options + pytest.param({"include_path": "/usr/local/include"}, id="include_path"), + pytest.param({"pre_include": "stdint.h"}, id="pre_include"), + pytest.param({"pch": True}, id="pch"), + pytest.param({"pch_dir": "pch-cache"}, id="pch_dir"), + # Non-list/tuple Sequence: the compiler iterates it via ``is_sequence`` + # (``isinstance(v, Sequence)``), so the guard must too. + pytest.param({"include_path": range(1)}, id="include_path_nonempty_range"), + # Empty-string path-like options -- NVRTC still emits a flag + # (``--use-pch=``, ``--pch-dir=``, ``--pre-include=``) so the guard + # must fire for them too. + pytest.param({"use_pch": ""}, id="use_pch_empty_string"), + pytest.param({"pch_dir": ""}, id="pch_dir_empty_string"), + pytest.param({"pre_include": ""}, id="pre_include_empty_string"), + # For path-shaped fields (``use_pch``, ``pch_dir``), NVRTC's gate is + # ``is not None``, so even False emits a real flag and must be caught. + pytest.param({"use_pch": False}, id="use_pch_false"), + pytest.param({"pch_dir": False}, id="pch_dir_false"), + # ``include_path`` / ``pre_include`` are NOT in that group: the + # compiler only emits them for str or non-empty sequences, so + # ``False`` is silently ignored at compile time -- test the accept + # path below, not the reject path. + ], +) +def test_make_program_cache_key_rejects_external_content_without_extra_digest(option_kw): + """Options that pull in external file content must force an extra_digest: + the cache cannot observe header/PCH bytes, so silently omitting them + would yield stale cache hits after header edits.""" + with pytest.raises(ValueError, match="extra_digest"): + _make_key(options=_opts(**option_kw)) + + +@pytest.mark.parametrize( + "option_kw", + [ + pytest.param({"include_path": []}, id="include_path_empty_list"), + pytest.param({"include_path": ()}, id="include_path_empty_tuple"), + pytest.param({"pre_include": []}, id="pre_include_empty_list"), + # ``_prepare_nvrtc_options_impl`` only emits include_path / pre_include + # for str or non-empty sequence, so False (or any non-str non-sequence) + # is silently ignored at compile time and must not trip the guard. + pytest.param({"include_path": False}, id="include_path_false"), + pytest.param({"pre_include": False}, id="pre_include_false"), + # Empty non-list/tuple Sequence: ``_prepare_nvrtc_options_impl`` uses + # ``is_sequence`` (i.e. ``isinstance(v, Sequence)``); a zero-length + # sequence produces no emission regardless of type. + pytest.param({"include_path": range(0)}, id="include_path_empty_range"), + ], +) +def test_make_program_cache_key_accepts_empty_external_content(option_kw): + """Truly empty sequences mean 'no external inputs' -- they must not + force an extra_digest. (Empty *strings* are rejected separately because + NVRTC still emits a flag for them.)""" + _make_key(options=_opts(**option_kw)) # Should not raise. + + +def test_make_program_cache_key_ptx_ignores_nvrtc_only_options(): + """PTX compiles go through ``_translate_program_options`` which drops + NVRTC-only fields (include_path, pch_*, frandom_seed, ...). Those + fields must not perturb the PTX cache key; otherwise a shared + ProgramOptions that happens to set them causes spurious misses.""" + base = _make_key(code=".version 7.0", code_type="ptx", options=_opts()) + # Each of these only affects NVRTC, never Linker. + for kw in ( + {"define_macro": "FOO"}, + {"frandom_seed": "1234"}, + {"ofast_compile": "min"}, + {"std": "c++17"}, + {"disable_warnings": True}, + ): + assert _make_key(code=".version 7.0", code_type="ptx", options=_opts(**kw)) == base, kw + + +@pytest.mark.parametrize( + "option_kw", + [ + pytest.param({"include_path": "/usr/local/include"}, id="include_path"), + pytest.param({"pre_include": "stdint.h"}, id="pre_include"), + pytest.param({"pch": True}, id="pch"), + pytest.param({"use_pch": "pch.file"}, id="use_pch"), + pytest.param({"pch_dir": "pch-cache"}, id="pch_dir"), + ], +) +def test_make_program_cache_key_accepts_external_content_options_for_ptx(option_kw): + """The external-content guard is NVRTC-only: ``Program.compile`` for PTX + inputs translates options via ``_translate_program_options``, which + drops include_path/pre_include/PCH fields entirely. A PTX compile must + not be blocked just because a reused ProgramOptions object carries + irrelevant header settings.""" + _make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw)) # no raise + + +def test_make_program_cache_key_accepts_external_content_with_extra_digest(): + """With an extra_digest, external-content options are accepted and + different digests produce different keys so callers can represent + header edits.""" + opts = _opts(include_path="/usr/local/include") + k_a = _make_key(options=opts, extra_digest=b"header-a" * 4) + k_b = _make_key(options=opts, extra_digest=b"header-b" * 4) + assert k_a != k_b + + +@pytest.mark.parametrize( + "option_kw, extra_digest", + [ + pytest.param({"create_pch": "out.pch"}, None, id="create_pch"), + # Even with extra_digest, create_pch is rejected: a cache hit skips + # compilation, so the side effect (writing the PCH) would not run. + pytest.param({"create_pch": "out.pch"}, b"x" * 32, id="create_pch_with_extra_digest"), + pytest.param({"create_pch": ""}, None, id="create_pch_empty_string"), + # NVRTC emits ``--create-pch=False`` for any non-None value, so False + # still triggers the side effect and must be rejected. + pytest.param({"create_pch": False}, None, id="create_pch_false"), + pytest.param({"time": "timing.csv"}, None, id="time"), + pytest.param({"time": False}, None, id="time_false"), + pytest.param({"fdevice_time_trace": "trace.json"}, None, id="fdevice_time_trace"), + pytest.param({"fdevice_time_trace": False}, None, id="fdevice_time_trace_false"), + ], +) +def test_make_program_cache_key_rejects_side_effect_options_nvrtc(option_kw, extra_digest): + """Options that write files as a compile-time side effect must refuse + key generation when the target backend is NVRTC; a cache hit would skip + compilation and the artifact would never be produced.""" + with pytest.raises(ValueError, match="side effect"): + _make_key(options=_opts(**option_kw), extra_digest=extra_digest) + + +@pytest.mark.parametrize( + "option_kw", + [ + # ``time`` goes through Linker's ``-time`` flag which only logs to the + # info log -- no filesystem side effect -- so PTX compiles with + # ``time=True`` must cache normally. + pytest.param({"time": True}, id="time_true"), + pytest.param({"time": "whatever.csv"}, id="time_path"), + ], +) +def test_make_program_cache_key_accepts_side_effect_options_for_ptx(option_kw): + """The side-effect guard is NVRTC-specific: PTX (linker) and NVVM must + not be blocked by options whose side effects only apply under NVRTC.""" + _make_key(code=".version 7.0", code_type="ptx", options=_opts(**option_kw)) # no raise + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("c++", "a", "cubin", id="nvrtc"), + pytest.param("ptx", ".version 7.0", "cubin", id="linker"), + pytest.param("nvvm", "abc", "ptx", id="nvvm"), + ], +) +def test_make_program_cache_key_survives_cuda_core_version_change(code_type, code, target_type, monkeypatch): + """The docstring promises cross-patch sharing within a schema version, so + cuda.core's own ``__version__`` must NOT be mixed into the digest.""" + import cuda.core._version as _version_mod + + monkeypatch.setattr(_version_mod, "__version__", "0.0.0") + k_a = _make_key(code=code, code_type=code_type, target_type=target_type) + monkeypatch.setattr(_version_mod, "__version__", "999.999.999") + k_b = _make_key(code=code, code_type=code_type, target_type=target_type) + assert k_a == k_b + + +def test_make_program_cache_key_driver_version_does_not_perturb_ptx_under_nvjitlink(monkeypatch): + """nvJitLink does NOT route PTX compilation through cuLink, so a + changing driver version must not invalidate PTX cache keys when + nvJitLink is the active linker backend.""" + from cuda.core.utils import _program_cache + + monkeypatch.setattr(_program_cache, "_linker_backend_and_version", lambda: ("nvJitLink", "12030")) + monkeypatch.setattr(_program_cache, "_driver_version", lambda: 13200) + k_a = _make_key(code=".version 7.0", code_type="ptx") + monkeypatch.setattr(_program_cache, "_driver_version", lambda: 13300) + k_b = _make_key(code=".version 7.0", code_type="ptx") + assert k_a == k_b + + +@pytest.mark.parametrize( + "code_type, code, target_type", + [ + pytest.param("c++", "a", "cubin", id="nvrtc"), + pytest.param("nvvm", "abc", "ptx", id="nvvm"), + ], +) +def test_make_program_cache_key_driver_probe_failure_does_not_perturb_non_linker( + code_type, code, target_type, monkeypatch +): + """The driver version is only consumed on the linker (PTX) path because + cuLink runs through the driver. NVRTC and NVVM produce identical bytes + regardless of the driver version, so a failed driver probe must NOT + perturb their cache keys -- otherwise driver upgrades would invalidate + perfectly good caches.""" + from cuda.core.utils import _program_cache + + def _broken(): + raise RuntimeError("driver probe failed") + + k_ok = _make_key(code=code, code_type=code_type, target_type=target_type) + monkeypatch.setattr(_program_cache, "_driver_version", _broken) + k_broken = _make_key(code=code, code_type=code_type, target_type=target_type) + assert k_ok == k_broken + + +@pytest.mark.parametrize( + "probe_name, code_type, code", + [ + pytest.param("_nvrtc_version", "c++", "a", id="nvrtc"), + pytest.param("_linker_backend_and_version", "ptx", ".ptx", id="linker"), + ], +) +def test_make_program_cache_key_fails_closed_on_probe_failure(probe_name, code_type, code, monkeypatch): + """A failed probe (a) must produce a key that differs from a working + probe (so environments never silently share cache entries), and (b) + must produce a *stable* key across calls -- otherwise the persistent + cache could not be reused in broken environments. ``_driver_version`` + is exercised separately because it's only invoked transitively from + ``_linker_backend_and_version`` on the cuLink driver path.""" + from cuda.core.utils import _program_cache + + def _broken(): + raise RuntimeError("probe failed") + + k_ok = _make_key(code=code, code_type=code_type) + monkeypatch.setattr(_program_cache, probe_name, _broken) + k_broken1 = _make_key(code=code, code_type=code_type) + k_broken2 = _make_key(code=code, code_type=code_type) + assert k_ok != k_broken1 + assert k_broken1 == k_broken2 # stable: same failure -> same key + + +def test_make_program_cache_key_driver_probe_failure_taints_ptx_under_cuLink(monkeypatch): + """When the driver linker is active, _linker_backend_and_version + invokes _driver_version internally; a failing driver probe must (a) + perturb the PTX key away from the success key, AND (b) be stable + across repeated calls so the persistent cache stays usable in the + failed environment.""" + from cuda.core import _linker + from cuda.core.utils import _program_cache + + def _broken(): + raise RuntimeError("driver probe failed") + + monkeypatch.setattr(_linker, "_decide_nvjitlink_or_driver", lambda: True) + k_ok = _make_key(code=".ptx", code_type="ptx") + monkeypatch.setattr(_program_cache, "_driver_version", _broken) + k_broken1 = _make_key(code=".ptx", code_type="ptx") + k_broken2 = _make_key(code=".ptx", code_type="ptx") + assert k_ok != k_broken1 + assert k_broken1 == k_broken2 # stable: same failure -> same key + + +# --------------------------------------------------------------------------- +# SQLiteProgramCache -- basic CRUD +# --------------------------------------------------------------------------- + + +def _fake_object_code(payload: bytes = b"fake-cubin", name: str = "unit"): + """Build an ObjectCode without touching the driver.""" + from cuda.core._module import ObjectCode + + return ObjectCode._init(payload, "cubin", name=name) + + +@needs_sqlite3 +def test_sqlite_cache_empty_on_create(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + assert len(cache) == 0 + assert b"nope" not in cache + with pytest.raises(KeyError): + cache[b"nope"] + assert cache.get(b"nope") is None + + +@needs_sqlite3 +def test_sqlite_cache_set_get_roundtrip(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + key = b"k1" + cache[key] = _fake_object_code(b"bytes-1", name="a") + + assert key in cache + assert len(cache) == 1 + got = cache[key] + assert bytes(got.code) == b"bytes-1" + assert got.name == "a" + assert got.code_type == "cubin" + + +@needs_sqlite3 +def test_sqlite_cache_overwrite_same_key(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"k"] = _fake_object_code(b"v1") + cache[b"k"] = _fake_object_code(b"v2") + assert len(cache) == 1 + assert bytes(cache[b"k"].code) == b"v2" + + +@needs_sqlite3 +def test_sqlite_cache_delete(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"k"] = _fake_object_code() + del cache[b"k"] + assert b"k" not in cache + assert len(cache) == 0 + with pytest.raises(KeyError): + del cache[b"k"] + + +@needs_sqlite3 +def test_sqlite_cache_clear(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"a"] = _fake_object_code(b"1") + cache[b"b"] = _fake_object_code(b"2") + cache.clear() + assert len(cache) == 0 + + +@needs_sqlite3 +def test_sqlite_cache_persists_across_open(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"k"] = _fake_object_code(b"persisted") + with SQLiteProgramCache(db) as cache: + assert bytes(cache[b"k"].code) == b"persisted" + + +@needs_sqlite3 +def test_sqlite_cache_len_excludes_and_prunes_corrupt_rows(tmp_path): + """``len(cache)`` must agree with ``key in cache`` -- corrupt rows + surface as misses on read, so they must not inflate the length AND + they must be pruned as a side effect.""" + import sqlite3 + + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"good"] = _fake_object_code(b"ok") + cache[b"bad"] = _fake_object_code(b"will-be-corrupted") + with sqlite3.connect(db) as conn: + conn.execute("UPDATE entries SET payload = ? WHERE key = ?", (b"\x00garbage", b"bad")) + conn.commit() + with SQLiteProgramCache(db) as cache: + assert len(cache) == 1 + assert b"good" in cache + assert b"bad" not in cache + # __len__ also pruned the corrupt row from the underlying table. + with sqlite3.connect(db) as conn: + rows = conn.execute("SELECT key FROM entries").fetchall() + assert [r[0] for r in rows] == [b"good"] + + +def test_filestream_cache_len_excludes_and_prunes_corrupt_files(tmp_path): + """Same contract for the file-stream backend, plus stored_key mismatch: + a file with a valid ObjectCode payload but the wrong stored_key (e.g. + salvaged from another path) must also be excluded and pruned, matching + __getitem__'s 'key mismatch' rejection.""" + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"good"] = _fake_object_code(b"ok") + cache[b"bad"] = _fake_object_code(b"will-be-corrupted") + cache[b"misplaced"] = _fake_object_code(b"valid-but-wrong-key") + bad_path = cache._path_for_key(b"bad") + misplaced_path = cache._path_for_key(b"misplaced") + # Stored_key mismatch: write the contents of "good" into misplaced's + # location. Schema/payload look fine but stored_key won't map to + # misplaced_path. + good_bytes = cache._path_for_key(b"good").read_bytes() + bad_path.write_bytes(b"\x00not-a-pickle") + misplaced_path.write_bytes(good_bytes) + with FileStreamProgramCache(root) as cache: + assert len(cache) == 1 + assert b"good" in cache + assert b"bad" not in cache + assert b"misplaced" not in cache + # Both bad files were pruned from disk. + assert not bad_path.exists() + assert not misplaced_path.exists() + + +@needs_sqlite3 +def test_sqlite_cache_corruption_is_reported_as_miss(tmp_path): + import sqlite3 + + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"k"] = _fake_object_code(b"ok") + # Overwrite the payload with garbage directly in the DB. + with sqlite3.connect(db) as conn: + conn.execute( + "UPDATE entries SET payload = ? WHERE key = ?", + (b"\x00\x01garbage", b"k"), + ) + conn.commit() + with SQLiteProgramCache(db) as cache: + # ``in`` must not report True for a corrupt row; callers rely on this + # to skip stale hits before attempting the load. + assert b"k" not in cache + with pytest.raises(KeyError): + cache[b"k"] + assert b"k" not in cache # still absent after the prune + + +@needs_sqlite3 +def test_sqlite_cache_rejects_non_object_code(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + with SQLiteProgramCache(tmp_path / "cache.db") as cache, pytest.raises(TypeError, match="ObjectCode"): + cache[b"k"] = b"not an ObjectCode" + + +@needs_sqlite3 +def test_sqlite_cache_rejects_path_backed_object_code(tmp_path): + from cuda.core._module import ObjectCode + from cuda.core.utils import SQLiteProgramCache + + path_backed = ObjectCode.from_cubin(str(tmp_path / "nonexistent.cubin"), name="x") + with SQLiteProgramCache(tmp_path / "cache.db") as cache, pytest.raises(TypeError, match="path-backed"): + cache[b"k"] = path_backed + + +@needs_sqlite3 +def test_sqlite_cache_accepts_str_keys(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache["str-key"] = _fake_object_code(b"v") + assert "str-key" in cache + # Same bytes representation so the corresponding bytes key also hits. + assert b"str-key" in cache + + +@needs_sqlite3 +def test_sqlite_cache_rejects_negative_size_cap(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + with pytest.raises(ValueError, match="non-negative"): + SQLiteProgramCache(tmp_path / "cache.db", max_size_bytes=-1) + + +# --------------------------------------------------------------------------- +# SQLiteProgramCache -- LRU eviction +# --------------------------------------------------------------------------- + + +@needs_sqlite3 +def test_sqlite_cache_evicts_under_size_cap(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + # Each payload pickles to > 2000 bytes; cap is 5000 so only ~2 fit. + cap = 5000 + db = tmp_path / "cache.db" + with SQLiteProgramCache(db, max_size_bytes=cap) as cache: + cache[b"a"] = _fake_object_code(b"A" * 2000, name="a") + cache[b"b"] = _fake_object_code(b"B" * 2000, name="b") + cache[b"c"] = _fake_object_code(b"C" * 2000, name="c") + # Adding c must have evicted a (oldest by accessed_at). + assert b"a" not in cache + assert b"b" in cache + assert b"c" in cache + + +@needs_sqlite3 +def test_sqlite_cache_contains_does_not_bump_lru(tmp_path): + """A bare ``key in cache`` must NOT promote the entry's LRU position. + Membership probes are observations, not accesses; bumping LRU from them + would keep otherwise-cold entries alive and starve hot ones.""" + from cuda.core.utils import SQLiteProgramCache + + cap = 5000 + db = tmp_path / "cache.db" + with SQLiteProgramCache(db, max_size_bytes=cap) as cache: + cache[b"a"] = _fake_object_code(b"A" * 2000, name="a") + time.sleep(0.01) + cache[b"b"] = _fake_object_code(b"B" * 2000, name="b") + time.sleep(0.01) + # Probe membership of 'a'. If this bumped LRU, 'a' would become MRU + # and 'b' would be evicted when 'c' arrives. It must not. + assert b"a" in cache + time.sleep(0.01) + cache[b"c"] = _fake_object_code(b"C" * 2000, name="c") + assert b"a" not in cache # still LRU and evicted + assert b"b" in cache + assert b"c" in cache + + +@needs_sqlite3 +def test_sqlite_cache_lru_order_respects_reads(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + cap = 5000 + db = tmp_path / "cache.db" + with SQLiteProgramCache(db, max_size_bytes=cap) as cache: + cache[b"a"] = _fake_object_code(b"A" * 2000, name="a") + time.sleep(0.01) + cache[b"b"] = _fake_object_code(b"B" * 2000, name="b") + time.sleep(0.01) + # Touch 'a' so it becomes MRU; 'b' must be evicted when 'c' is added. + _ = cache[b"a"] + time.sleep(0.01) + cache[b"c"] = _fake_object_code(b"C" * 2000, name="c") + assert b"a" in cache + assert b"b" not in cache + assert b"c" in cache + + +@needs_sqlite3 +def test_sqlite_cache_unbounded_by_default(tmp_path): + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + for i in range(25): + cache[f"k{i}".encode()] = _fake_object_code(b"X" * 1024, name=f"n{i}") + assert len(cache) == 25 + + +@needs_sqlite3 +def test_sqlite_cache_disk_usage_is_bounded(tmp_path): + """max_size_bytes must bound real on-disk usage, not just logical payload. + Without wal_checkpoint+VACUUM the WAL and free pages grow without limit.""" + from cuda.core.utils import SQLiteProgramCache + + cap = 8000 + db = tmp_path / "cache.db" + with SQLiteProgramCache(db, max_size_bytes=cap) as cache: + # Write many large-ish entries; eviction fires repeatedly. + for i in range(60): + cache[f"k{i}".encode()] = _fake_object_code(b"X" * 2000, name=f"n{i}") + + # Sum db + wal + shm file sizes -- this is what the user's disk holds. + total = 0 + for suffix in ("", "-wal", "-shm"): + p = db.with_name(db.name + suffix) + if p.exists(): + total += p.stat().st_size + # Without compaction, this balloons to >>200 KB. With compaction, it + # should stay within a small multiple of the cap. + assert total < cap * 5, f"on-disk size {total} exceeds 5x the cap ({cap})" + + +@needs_sqlite3 +def test_sqlite_cache_is_thread_safe(tmp_path): + """Concurrent get/set/delete from multiple threads on one cache object + must not raise sqlite3 OperationalError / ProgrammingError from + interleaved connection use.""" + import threading + + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + errors: list[BaseException] = [] + stop = threading.Event() + + def writer(thread_id: int): + try: + for i in range(200): + if stop.is_set(): + return + cache[f"t{thread_id}-{i}".encode()] = _fake_object_code(b"v" * 256, name=f"n{thread_id}-{i}") + except BaseException as exc: + errors.append(exc) + stop.set() + + def reader(thread_id: int): + try: + for i in range(200): + if stop.is_set(): + return + cache.get(f"t{thread_id}-{i}".encode()) + except BaseException as exc: + errors.append(exc) + stop.set() + + with SQLiteProgramCache(db, max_size_bytes=50_000) as cache: + threads = [threading.Thread(target=writer, args=(i,)) for i in range(4)] + threads += [threading.Thread(target=reader, args=(i,)) for i in range(4)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=30) + # On slow runners (Windows CI in particular) the loop above can return + # before every worker has finished its 200 iterations. Signal stop and + # join again so the ``with`` block does NOT tear down the cache while + # workers are still mid-operation -- that would surface as spurious + # "SQLiteProgramCache is closed" errors that mask the real assertion. + stop.set() + for t in threads: + t.join(timeout=30) + assert not any(t.is_alive() for t in threads) + assert not errors + + +@needs_sqlite3 +def test_sqlite_cache_does_not_nuke_on_operational_error(tmp_path, monkeypatch): + """A transient OperationalError (e.g. ``database is locked``) under + multi-process sharing must NOT cause _open to delete the file, AND + it must close any connection that ``connect()`` already returned + before re-raising (otherwise file descriptors / locks leak).""" + import sqlite3 + + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"k"] = _fake_object_code(b"keep-me") + db_size_before = db.stat().st_size + + # Realistic case: connect() succeeds, then a PRAGMA raises. We wrap the + # connection so we can prove _open closed it on the OperationalError path. + real_init = SQLiteProgramCache._connect_and_init + real_connect = sqlite3.connect + closed = [] + + class _TrackedConn: + def __init__(self, inner): + self._inner = inner + self.closed = False + + def close(self): + self.closed = True + closed.append(self) + return self._inner.close() + + def __getattr__(self, name): + return getattr(self._inner, name) + + def _connect_then_lock(self): + self._conn = _TrackedConn(real_connect(self._path, isolation_level=None)) + raise sqlite3.OperationalError("database is locked") + + monkeypatch.setattr(SQLiteProgramCache, "_connect_and_init", _connect_then_lock) + with pytest.raises(sqlite3.OperationalError, match="locked"): + SQLiteProgramCache(db) + # _open's OperationalError path must close the partial connection. + assert len(closed) == 1 and closed[0].closed, "OperationalError path leaked the partial connection" + # File untouched. + assert db.exists() + assert db.stat().st_size == db_size_before + monkeypatch.setattr(SQLiteProgramCache, "_connect_and_init", real_init) + with SQLiteProgramCache(db) as cache: + assert bytes(cache[b"k"].code) == b"keep-me" + + +@needs_sqlite3 +def test_sqlite_cache_recovers_from_non_sqlite_file(tmp_path): + """A cache directory could already contain a damaged / non-SQLite file + at the cache path (leftover from a crash, wrong filetype, filesystem + glitch). SQLiteProgramCache must degrade to an empty cache instead of + raising at construction time.""" + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + # Write garbage that is definitely not a valid SQLite database. + db.write_bytes(b"\x00not-a-sqlite-file-at-all\x00" * 100) + + with SQLiteProgramCache(db) as cache: + # Recovered to a fresh empty cache; must function normally. + assert len(cache) == 0 + cache[b"k"] = _fake_object_code(b"v") + assert bytes(cache[b"k"].code) == b"v" + + +@needs_sqlite3 +def test_sqlite_cache_wipes_on_schema_mismatch(tmp_path): + """A cache written with an older schema version must not silently mix + with a newer client; entries are wiped on open.""" + import sqlite3 + + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"k"] = _fake_object_code(b"old-payload") + # Simulate an older schema by rewriting the stored version. + with sqlite3.connect(db) as conn: + conn.execute( + "UPDATE schema_meta SET value = ? WHERE key = ?", + ("0", "schema_version"), + ) + conn.commit() + + with SQLiteProgramCache(db) as cache: + assert len(cache) == 0 + assert b"k" not in cache + + +@needs_sqlite3 +def test_sqlite_cache_drops_tables_on_schema_mismatch(tmp_path): + """On a schema mismatch the cache must DROP the old tables, not just + DELETE rows -- otherwise a future structural migration (added columns, + renamed indexes) would leave the old layout in place.""" + import sqlite3 + + from cuda.core.utils import SQLiteProgramCache + + db = tmp_path / "cache.db" + with SQLiteProgramCache(db) as cache: + cache[b"k"] = _fake_object_code(b"v") + # Simulate an older schema with a divergent layout: add an extra column + # and flip the version marker so the next open triggers the migration + # path. + with sqlite3.connect(db) as conn: + conn.execute("ALTER TABLE entries ADD COLUMN legacy_field TEXT") + conn.execute("UPDATE schema_meta SET value = ? WHERE key = ?", ("0", "schema_version")) + conn.commit() + + with SQLiteProgramCache(db) as cache: + assert len(cache) == 0 + # Verify the tables were actually dropped+recreated: no legacy_field. + cols = {row[1] for row in cache._conn.execute("PRAGMA table_info(entries)").fetchall()} + assert "legacy_field" not in cols + # And the cache still functions end-to-end. + cache[b"new"] = _fake_object_code(b"new-v") + assert bytes(cache[b"new"].code) == b"new-v" + + +# --------------------------------------------------------------------------- +# FileStreamProgramCache -- single-process CRUD +# --------------------------------------------------------------------------- + + +def test_filestream_cache_empty_on_create(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + assert len(cache) == 0 + assert b"nope" not in cache + with pytest.raises(KeyError): + cache[b"nope"] + + +def test_filestream_cache_roundtrip(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k1"] = _fake_object_code(b"v1", name="x") + assert b"k1" in cache + got = cache[b"k1"] + assert bytes(got.code) == b"v1" + assert got.name == "x" + assert got.code_type == "cubin" + + +def test_filestream_cache_delete(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k"] = _fake_object_code() + del cache[b"k"] + assert b"k" not in cache + with pytest.raises(KeyError): + del cache[b"k"] + + +def test_filestream_cache_len_counts_all(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"a"] = _fake_object_code(b"1") + cache[b"b"] = _fake_object_code(b"2") + cache[b"c"] = _fake_object_code(b"3") + assert len(cache) == 3 + + +def test_filestream_cache_clear(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"a"] = _fake_object_code() + cache.clear() + assert len(cache) == 0 + + +def test_filestream_cache_persists_across_reopen(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"k"] = _fake_object_code(b"persisted") + with FileStreamProgramCache(root) as cache: + assert bytes(cache[b"k"].code) == b"persisted" + + +def test_filestream_cache_permission_error_propagates_on_posix(tmp_path, monkeypatch): + """On non-Windows, PermissionError from os.replace is a real config error + and must not be silently swallowed.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_IS_WINDOWS", False) + + with FileStreamProgramCache(tmp_path / "fc") as cache: + + def _denied(src, dst): + raise PermissionError("denied") + + monkeypatch.setattr(_os, "replace", _denied) + with pytest.raises(PermissionError, match="denied"): + cache[b"k"] = _fake_object_code(b"v") + + +def test_filestream_cache_write_phase_permission_error_propagates_on_windows(tmp_path, monkeypatch): + """Even on Windows, a PermissionError from the write phase (mkstemp / + fdopen / fsync) is a real config problem -- the Windows carve-out is + only for the os.replace race. A write-phase error must propagate.""" + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_IS_WINDOWS", True) + + def _denied(*args, **kwargs): + raise PermissionError("mkstemp denied") + + monkeypatch.setattr(_program_cache.tempfile, "mkstemp", _denied) + + with FileStreamProgramCache(tmp_path / "fc") as cache, pytest.raises(PermissionError, match="mkstemp"): + cache[b"k"] = _fake_object_code(b"v") + + +@pytest.mark.parametrize( + "winerror, should_raise", + [ + pytest.param(5, False, id="access_denied_swallowed"), + pytest.param(32, False, id="sharing_violation_swallowed"), + pytest.param(33, False, id="lock_violation_swallowed"), + pytest.param(1, True, id="other_winerror_propagates"), + pytest.param(None, True, id="no_winerror_propagates"), + ], +) +def test_filestream_cache_permission_error_windows_is_narrowed(tmp_path, monkeypatch, winerror, should_raise): + """On Windows, ERROR_ACCESS_DENIED (5), ERROR_SHARING_VIOLATION (32) and + ERROR_LOCK_VIOLATION (33) are all transient "target held open by another + process / pending delete" cases worth swallowing after the bounded retry. + Any other PermissionError -- unrelated winerrors, missing winerror + attribute, etc. -- is a real problem and must propagate.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + monkeypatch.setattr(_program_cache, "_IS_WINDOWS", True) + + def _denied(src, dst): + exc = PermissionError("simulated") + exc.winerror = winerror + raise exc + + with FileStreamProgramCache(tmp_path / "fc") as cache: + monkeypatch.setattr(_os, "replace", _denied) + if should_raise: + with pytest.raises(PermissionError, match="simulated"): + cache[b"k"] = _fake_object_code(b"v") + else: + cache[b"k"] = _fake_object_code(b"v") # swallowed + assert b"k" not in cache + + +def test_filestream_cache_atomic_no_half_written_file(tmp_path, monkeypatch): + # Simulate a crash during write: patch os.replace to raise. + import os as _os + + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + + def _boom(src, dst): + raise RuntimeError("crash during replace") + + monkeypatch.setattr(_os, "replace", _boom) + with pytest.raises(RuntimeError, match="crash"): + cache[b"k"] = _fake_object_code(b"v") + monkeypatch.undo() + assert b"k" not in cache + + +def test_filestream_cache_prune_only_if_stat_unchanged(tmp_path): + """The reader-unlink-vs-writer-replace race: if a concurrent writer + atomically replaced a file between the reader's read and the reader's + prune, the pruner must NOT delete the replacement.""" + from cuda.core.utils import FileStreamProgramCache + from cuda.core.utils._program_cache import _prune_if_stat_unchanged + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[b"k"] = _fake_object_code(b"v1") + path = cache._path_for_key(b"k") + stale_stat = path.stat() + # Simulate a concurrent writer replacing the file. + time.sleep(0.02) + cache[b"k"] = _fake_object_code(b"v2") + + # Reader decides to prune using the stale stat; the guard refuses. + _prune_if_stat_unchanged(path, stale_stat) + assert path.exists() + + # With a fresh stat matching the current file, pruning proceeds. + _prune_if_stat_unchanged(path, path.stat()) + assert not path.exists() + + +def test_filestream_cache_corruption_is_reported_as_miss(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"k"] = _fake_object_code(b"ok") + path = cache._path_for_key(b"k") + + # Corrupt the file on disk. + path.write_bytes(b"\x00not-a-pickle") + with FileStreamProgramCache(root) as cache: + # ``in`` must not report True for a corrupt file; otherwise callers + # might skip recompilation based on a bogus membership check. + assert b"k" not in cache + with pytest.raises(KeyError): + cache[b"k"] + assert b"k" not in cache # still absent after the prune + + +def test_filestream_cache_rejects_non_object_code(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache, pytest.raises(TypeError, match="ObjectCode"): + cache[b"k"] = b"not an ObjectCode" + + +def test_filestream_cache_rejects_path_backed_object_code(tmp_path): + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache + + path_backed = ObjectCode.from_cubin(str(tmp_path / "nonexistent.cubin"), name="x") + with FileStreamProgramCache(tmp_path / "fc") as cache, pytest.raises(TypeError, match="path-backed"): + cache[b"k"] = path_backed + + +def test_filestream_cache_rejects_negative_size_cap(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with pytest.raises(ValueError, match="non-negative"): + FileStreamProgramCache(tmp_path / "fc", max_size_bytes=-1) + + +def test_filestream_cache_sweeps_stale_tmp_files_on_open(tmp_path): + """A crashed writer can leave files in ``tmp/``; the next ``open`` must + sweep ones older than the staleness threshold so disk usage doesn't + grow without bound.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + root = tmp_path / "fc" + # Create the cache directory layout, then plant two temp files: one + # young (must be preserved as it could be an in-flight write) and one + # ancient (must be swept). + with FileStreamProgramCache(root): + pass + young = root / "tmp" / "entry-young" + young.write_bytes(b"in-flight") + ancient = root / "tmp" / "entry-ancient" + ancient.write_bytes(b"crashed-writer-leftover") + ancient_mtime = time.time() - _program_cache._TMP_STALE_AGE_SECONDS - 60 + _os.utime(ancient, (ancient_mtime, ancient_mtime)) + + with FileStreamProgramCache(root): + # Reopen triggers _sweep_stale_tmp_files. + assert young.exists(), "young temp file must not be swept" + assert not ancient.exists(), "ancient temp file should have been swept" + + +def test_filestream_cache_clear_preserves_young_tmp_files(tmp_path): + """clear() must not delete young temp files: another process could be + mid-write between ``mkstemp`` and ``os.replace``, and unlinking under + it turns the writer's harmless rename into ``FileNotFoundError``. + Stale temps (older than the threshold) are still swept.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache, _program_cache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"k"] = _fake_object_code(b"v") + young_tmp = root / "tmp" / "entry-young" + young_tmp.write_bytes(b"in-flight") + ancient_tmp = root / "tmp" / "entry-ancient" + ancient_tmp.write_bytes(b"crashed") + ancient_mtime = time.time() - _program_cache._TMP_STALE_AGE_SECONDS - 60 + _os.utime(ancient_tmp, (ancient_mtime, ancient_mtime)) + + with FileStreamProgramCache(root) as cache: + cache.clear() + # Committed entry is gone, ancient orphan is gone, young temp survives. + # Filenames are hash-like (no extension), so use a file filter rather + # than a "*.*" glob. + remaining_entries = [p for p in (root / "entries").rglob("*") if p.is_file()] + assert not remaining_entries + assert young_tmp.exists() + assert not ancient_tmp.exists() + + +def test_filestream_cache_clear_does_not_unlink_replaced_file(tmp_path): + """``clear()``'s scan-then-unlink loop must use the stat-guard so a + concurrent writer's ``os.replace`` between snapshot and unlink doesn't + delete the fresh entry. Race injection: subclass the cache and have + ``_iter_entry_paths``'s post-yield cleanup os.replace path_a, then call + ``clear()`` and verify the fresh contents survive.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"a"] = _fake_object_code(b"A" * 200, name="a") + cache[b"b"] = _fake_object_code(b"B" * 200, name="b") + path_a = cache._path_for_key(b"a") + + class _RaceCache(FileStreamProgramCache): + race_armed = True + + def _iter_entry_paths(self): + yield from super()._iter_entry_paths() + # Generator cleanup runs at StopIteration, between clear()'s + # scan and its unlink loop. + if _RaceCache.race_armed and path_a.exists(): + _RaceCache.race_armed = False + tmp = path_a.parent / "_inflight" + tmp.write_bytes(b"\x80\x05fresh-by-other-writer-" * 32) + _os.replace(tmp, path_a) + + with _RaceCache(root) as cache: + cache.clear() + + # The fresh file must survive: clear() saw a stat mismatch and skipped. + assert path_a.exists(), "stat guard failed -- clear() unlinked a concurrently-replaced file" + assert path_a.read_bytes().startswith(b"\x80\x05fresh-by-other-writer-") + + +def test_filestream_cache_clear_does_not_break_concurrent_writer(tmp_path): + """Simulate a writer that has already produced a temp file but has not + yet executed ``os.replace``; a concurrent ``clear()`` from another + cache instance must NOT unlink that temp, so the writer's + ``os.replace`` still succeeds.""" + import os as _os + + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"seed"] = _fake_object_code(b"seed") + + # Stage a temp file that mimics an in-flight write. + inflight_tmp = root / "tmp" / "entry-inflight" + inflight_tmp.write_bytes(b"\x80\x05fake-pickle") # contents do not matter + + # Concurrent clear() from another cache handle. + with FileStreamProgramCache(root) as other: + other.clear() + + # The writer can now finish: rename the staged file into entries/. + target = root / "entries" / "ab" / "cdef" + target.parent.mkdir(parents=True, exist_ok=True) + _os.replace(inflight_tmp, target) + assert target.exists() + + +def test_filestream_cache_size_cap_does_not_unlink_replaced_file(tmp_path): + """The PRODUCTION ``_enforce_size_cap`` must compare the snapshot stat + to the current stat before unlinking; if the file was replaced under + us (a concurrent writer's ``os.replace``), the unlink is skipped. + + Race injection without reimplementing the method: subclass the cache + and override only ``_iter_entry_paths`` so that the cleanup code + *after* the generator's last yield runs an ``os.replace`` on path_a. + Python's for-loop calls ``next()`` until ``StopIteration``; the + generator code after its last yield runs at that ``StopIteration``, + which is exactly between ``_enforce_size_cap``'s scan loop and its + eviction loop. Eviction's per-entry re-stat then sees a different + stat for path_a and the production code's stat-guard must skip it. + """ + import os as _os + + from cuda.core.utils import FileStreamProgramCache + + # Cap fits two entries (each ~2123 bytes on disk for a 2000-byte + # payload, including pickle overhead) but not three. + cap = 5000 + root = tmp_path / "fc" + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + cache[b"a"] = _fake_object_code(b"A" * 2000, name="a") + time.sleep(0.02) + cache[b"b"] = _fake_object_code(b"B" * 2000, name="b") + path_a = cache._path_for_key(b"a") + assert path_a.exists(), "cap too small -- 'a' was evicted before the test ran" + + class _RaceCache(FileStreamProgramCache): + race_armed = True + + def _iter_entry_paths(self): + yield from super()._iter_entry_paths() + # Generator cleanup runs at StopIteration, between + # _enforce_size_cap's scan and its eviction loop. Fire the race + # here exactly once. + if _RaceCache.race_armed and path_a.exists(): + _RaceCache.race_armed = False + tmp = path_a.parent / "_inflight" + tmp.write_bytes(b"\x80\x05fresh-by-other-writer-" * 32) + _os.replace(tmp, path_a) + + with _RaceCache(root, max_size_bytes=cap) as cache: + # Trigger eviction by adding 'c'; eviction's scan exhausts our + # racing generator, the cleanup fires, then the eviction loop's + # re-stat sees the new stat and the production stat-guard MUST + # refuse to unlink path_a. + time.sleep(0.02) + cache[b"c"] = _fake_object_code(b"C" * 2000, name="c") + + # The race-injected fresh file must survive: production stat-guard worked. + assert path_a.exists(), "stat guard failed -- evicted a concurrently-replaced file" + assert path_a.read_bytes().startswith(b"\x80\x05fresh-by-other-writer-") + + +def test_filestream_cache_size_cap_counts_tmp_files(tmp_path): + """Surviving temp files occupy disk too; the soft cap must include + them, otherwise an attacker (or a flurry of crashed writers) could + inflate disk usage well past max_size_bytes.""" + from cuda.core.utils import FileStreamProgramCache + + cap = 4000 + root = tmp_path / "fc" + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + cache[b"a"] = _fake_object_code(b"A" * 1500, name="a") + time.sleep(0.02) + cache[b"b"] = _fake_object_code(b"B" * 1500, name="b") + # Plant a young temp file that pushes total over the cap. + young_tmp = root / "tmp" / "entry-leftover" + young_tmp.write_bytes(b"X" * 2500) + + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + # New write triggers _enforce_size_cap; 'a' must be evicted because + # the temp file's bytes count toward the cap now. + time.sleep(0.02) + cache[b"c"] = _fake_object_code(b"C" * 200, name="c") + assert b"a" not in cache + assert b"c" in cache + + +def test_filestream_cache_handles_long_keys(tmp_path): + """Arbitrary-length keys must not overflow per-component filename limits. + The filename is a fixed-length hash; the original key is verified from + the pickled record.""" + from cuda.core.utils import FileStreamProgramCache + + long_bytes_key = b"x" * 4096 + long_str_key = "y" * 4096 + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache[long_bytes_key] = _fake_object_code(b"b", name="nb") + cache[long_str_key] = _fake_object_code(b"s", name="ns") + assert long_bytes_key in cache + assert long_str_key in cache + assert bytes(cache[long_bytes_key].code) == b"b" + assert bytes(cache[long_str_key].code) == b"s" + + +def test_filestream_cache_accepts_str_keys(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + cache["my-key"] = _fake_object_code(b"v") + assert "my-key" in cache + assert b"my-key" in cache + + +def test_filestream_cache_size_cap_evicts_oldest(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + # Big payloads, small cap; after the third entry, the cap is exceeded and + # the oldest file (a) must be evicted. + cap = 3000 + root = tmp_path / "fc" + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + cache[b"a"] = _fake_object_code(b"A" * 2000, name="a") + time.sleep(0.02) + cache[b"b"] = _fake_object_code(b"B" * 2000, name="b") + time.sleep(0.02) + cache[b"c"] = _fake_object_code(b"C" * 2000, name="c") + + with FileStreamProgramCache(root, max_size_bytes=cap) as cache: + # 'a' was oldest by mtime; the opportunistic sweep on the third write + # must have removed it. 'c' is the newest and must remain. + assert b"a" not in cache + assert b"c" in cache + + +def test_filestream_cache_unbounded_by_default(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(tmp_path / "fc") as cache: + for i in range(20): + cache[f"k{i}".encode()] = _fake_object_code(b"X" * 1024, name=f"n{i}") + assert len(cache) == 20 + + +def test_filestream_cache_wipes_on_schema_mismatch(tmp_path): + """A cache written with an older schema must be wiped on open, not + silently mixed with a newer format.""" + from cuda.core.utils import FileStreamProgramCache + + root = tmp_path / "fc" + with FileStreamProgramCache(root) as cache: + cache[b"k"] = _fake_object_code(b"old-payload") + # Simulate an older schema by rewriting the version marker. + (root / "SCHEMA_VERSION").write_text("0") + + with FileStreamProgramCache(root) as cache: + assert len(cache) == 0 + assert b"k" not in cache + # Marker should be back at the current version. + assert (root / "SCHEMA_VERSION").read_text().strip() != "0" + + +# --------------------------------------------------------------------------- +# End-to-end: real NVRTC compilation through persistent cache +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("sqlite", marks=needs_sqlite3), + pytest.param("filestream"), + ], +) +def test_cache_roundtrip_with_real_compilation(backend, tmp_path, init_cuda): + """Compile a real kernel, persist it, reopen the cache, and reload the kernel. + + Exercises the full user workflow: NVRTC compile → persistent store → fresh + process (simulated by closing and reopening the cache handle) → driver-side + module load from the deserialised ``ObjectCode``. + """ + from cuda.core import Program, ProgramOptions + from cuda.core._module import Kernel + from cuda.core.utils import ( + FileStreamProgramCache, + SQLiteProgramCache, + make_program_cache_key, + ) + + code = 'extern "C" __global__ void my_kernel() {}' + code_type = "c++" + target_type = "cubin" + options = ProgramOptions(name="cached_kernel") + + program = Program(code, code_type, options=options) + try: + compiled = program.compile(target_type) + finally: + program.close() + # Do NOT call get_kernel() on ``compiled`` here: loading the module on the + # driver would mutate driver-side state on the original ObjectCode and + # weaken the roundtrip guarantee below. + + key = make_program_cache_key( + code=code, + code_type=code_type, + options=options, + target_type=target_type, + ) + + if backend == "sqlite": + path = tmp_path / "cache.db" + cache_cls = SQLiteProgramCache + else: + path = tmp_path / "fc" + cache_cls = FileStreamProgramCache + + # First "process": compile and store. + with cache_cls(path) as cache: + assert key not in cache + cache[key] = compiled + + # Second "process": reopen a fresh handle and retrieve. + with cache_cls(path) as cache: + assert key in cache + cached = cache[key] + + assert cached.code_type == target_type + assert cached.name == "cached_kernel" + assert bytes(cached.code) == bytes(compiled.code) + # The deserialised ObjectCode must still be usable against the driver. + assert isinstance(cached.get_kernel("my_kernel"), Kernel) diff --git a/cuda_core/tests/test_program_cache_multiprocess.py b/cuda_core/tests/test_program_cache_multiprocess.py new file mode 100644 index 0000000000..954440e79c --- /dev/null +++ b/cuda_core/tests/test_program_cache_multiprocess.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +"""Multiprocess stress tests for FileStreamProgramCache. + +These run without a GPU. They exercise the atomic-rename write path from +multiple processes launched via ``multiprocessing.get_context("spawn")``. +""" + +from __future__ import annotations + +import multiprocessing as _mp + + +def _worker_write(root: str, key: bytes, payload: bytes, name: str) -> None: + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(root) as cache: + cache[key] = ObjectCode._init(payload, "cubin", name=name) + + +def _worker_write_many(root: str, base: int, n: int) -> None: + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache + + with FileStreamProgramCache(root) as cache: + for i in range(n): + key = f"proc-{base}-key-{i}".encode() + cache[key] = ObjectCode._init(f"payload-{base}-{i}".encode(), "cubin", name=f"p{base}-{i}") + + +def _worker_reader(root: str, key: bytes, rounds: int, result_queue) -> None: + from cuda.core.utils import FileStreamProgramCache + + hits = 0 + for _ in range(rounds): + with FileStreamProgramCache(root) as cache: + got = cache.get(key) + if got is not None: + hits += 1 + result_queue.put(hits) + + +def test_concurrent_writers_same_key_no_corruption(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = str(tmp_path / "fc") + ctx = _mp.get_context("spawn") + procs = [ + ctx.Process( + target=_worker_write, + args=(root, b"shared", f"v{i}".encode() * 64, f"p{i}"), + ) + for i in range(6) + ] + for p in procs: + p.start() + for p in procs: + p.join(timeout=60) + assert p.exitcode == 0, f"worker exited with {p.exitcode}" + + with FileStreamProgramCache(root) as cache: + # At least one writer must have succeeded; on Windows some writes + # may silently fail due to PermissionError on os.replace. + got = cache.get(b"shared") + assert got is not None, "no writer succeeded" + assert bytes(got.code).startswith(b"v") + + +def test_concurrent_writers_distinct_keys_all_survive(tmp_path): + from cuda.core.utils import FileStreamProgramCache + + root = str(tmp_path / "fc") + n_procs = 4 + per_proc = 25 + ctx = _mp.get_context("spawn") + procs = [ctx.Process(target=_worker_write_many, args=(root, base, per_proc)) for base in range(n_procs)] + for p in procs: + p.start() + for p in procs: + p.join(timeout=60) + assert p.exitcode == 0 + + with FileStreamProgramCache(root) as cache: + for base in range(n_procs): + for i in range(per_proc): + key = f"proc-{base}-key-{i}".encode() + assert key in cache + + +def test_concurrent_reader_never_sees_torn_file(tmp_path): + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache + + root = str(tmp_path / "fc") + # Seed 'k' so the reader can hit; the writer writes unrelated keys so 'k' + # is never overwritten while the reader is active. + with FileStreamProgramCache(root) as cache: + cache[b"k"] = ObjectCode._init(b"seed" * 256, "cubin", name="seed") + + ctx = _mp.get_context("spawn") + queue = ctx.Queue() + writer = ctx.Process(target=_worker_write_many, args=(root, 99, 50)) + reader = ctx.Process(target=_worker_reader, args=(root, b"k", 200, queue)) + reader.start() + writer.start() + writer.join(timeout=60) + reader.join(timeout=60) + assert writer.exitcode == 0 + assert reader.exitcode == 0 + hits = queue.get(timeout=5) + # 'k' was never overwritten, so every read must hit. + assert hits == 200 + + +def _worker_race_rewriter(root: str, key: bytes, start_event, done_event) -> None: + """Repeatedly overwrite ``key`` with fresh valid records until signalled. + + Paired with :func:`_worker_race_corrupt_reader` below: the writer keeps + atomically replacing the file while the reader keeps loading a + deliberately-corrupt copy and deciding whether to prune it. The reader + must never delete the writer's valid record. + + After ``done_event`` fires (set by the caller only after the reader has + joined), the rewriter lands one final uncontested write so the test's + end-state assertion does not hinge on scheduler-dependent interleaving + of the last write vs. the reader's last prune. + """ + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache + + start_event.wait() + with FileStreamProgramCache(root) as cache: + i = 0 + while not done_event.is_set(): + cache[key] = ObjectCode._init(f"good-{i}".encode() * 64, "cubin", name=f"g{i}") + i += 1 + cache[key] = ObjectCode._init(b"good-final" * 64, "cubin", name="g-final") + + +def _worker_race_corrupt_reader(root: str, key: bytes, rounds: int, start_event, result_queue) -> None: + """Repeatedly corrupt then read ``key``, exercising the prune path. + + Each round writes garbage straight to the on-disk file (bypassing + ``os.replace``) to trigger the corrupt-read code path, then calls + ``cache[key]`` which will fail pickle and attempt to prune. The write + race with the rewriter process means some prunes will target a file the + rewriter has already replaced; if the prune guard is missing, a valid + entry gets deleted and ``cache.get(key)`` starts returning ``None``. + """ + import contextlib + + from cuda.core.utils import FileStreamProgramCache + + start_event.set() + missing = 0 + with FileStreamProgramCache(root) as cache: + path = cache._path_for_key(key) + for _ in range(rounds): + # Best-effort corrupt write; ignore errors from the writer + # having the file replaced mid-write. On Windows the rewriter's + # in-flight ``os.replace`` can also surface to the reader's open + # as ``PermissionError`` (sharing violation / pending-delete). + with contextlib.suppress(FileNotFoundError, PermissionError), open(path, "wb") as fh: + fh.write(b"\x00not-a-pickle") + with contextlib.suppress(KeyError): + cache[key] + if cache.get(key) is None: + missing += 1 + result_queue.put(missing) + + +def test_concurrent_prune_does_not_delete_replaced_file(tmp_path): + """A reader that decides a file is corrupt must not unlink a fresh + valid file the writer has already swapped in via ``os.replace``. + + The reader and rewriter fight over the same key: the reader keeps + corrupting the file and then prune-checking, while the rewriter keeps + atomically replacing it with a valid record. With the stat-guard, the + rewriter's entry wins: after both processes finish, the key must be + present. + """ + from cuda.core._module import ObjectCode + from cuda.core.utils import FileStreamProgramCache + + root = str(tmp_path / "fc") + with FileStreamProgramCache(root) as cache: + cache[b"k"] = ObjectCode._init(b"init" * 64, "cubin", name="init") + + ctx = _mp.get_context("spawn") + start_event = ctx.Event() + done_event = ctx.Event() + result_queue = ctx.Queue() + rewriter = ctx.Process(target=_worker_race_rewriter, args=(root, b"k", start_event, done_event)) + reader = ctx.Process(target=_worker_race_corrupt_reader, args=(root, b"k", 80, start_event, result_queue)) + rewriter.start() + reader.start() + reader.join(timeout=60) + done_event.set() + rewriter.join(timeout=60) + assert reader.exitcode == 0 + assert rewriter.exitcode == 0 + _ = result_queue.get(timeout=5) # best-effort counter, not asserted + + # The rewriter was writing right up until done_event; its last write must + # survive the reader's pruning attempts. + with FileStreamProgramCache(root) as cache: + got = cache.get(b"k") + assert got is not None, "rewriter's valid entry was pruned by racing reader" + assert bytes(got.code).startswith(b"good-")