diff --git a/vulnerabilities/importer.py b/vulnerabilities/importer.py index 59a808c43..abd6f90c5 100644 --- a/vulnerabilities/importer.py +++ b/vulnerabilities/importer.py @@ -57,6 +57,16 @@ class VulnerabilitySeverity: published_at: Optional[datetime.datetime] = None url: Optional[str] = None + def __post_init__(self): + if not self.system: + raise ValueError("system is required for VulnerabilitySeverity") + + if not isinstance(self.system, ScoringSystem): + raise TypeError(f"system must be a ScoringSystem, got {type(self.system)!r}") + + if not isinstance(self.value, str): + self.value = str(self.value) + def to_dict(self): data = { "system": self.system.identifier, @@ -469,6 +479,42 @@ def __post_init__(self): "an affected version range, introduced commit patches, or fixed commit patches." ) + if self.affected_version_range is not None and not isinstance( + self.affected_version_range, VersionRange + ): + raise TypeError( + f"affected_version_range must be VersionRange or None, got {type(self.affected_version_range)!r}" + ) + + if self.fixed_version_range is not None and not isinstance( + self.fixed_version_range, VersionRange + ): + raise TypeError( + f"fixed_version_range must be VersionRange or None, got {type(self.fixed_version_range)!r}" + ) + + if not isinstance(self.introduced_by_commit_patches, list): + raise TypeError( + f"introduced_by_commit_patches must be a list, got {type(self.introduced_by_commit_patches)!r}" + ) + + if not isinstance(self.fixed_by_commit_patches, list): + raise TypeError( + f"fixed_by_commit_patches must be a list, got {type(self.fixed_by_commit_patches)!r}" + ) + + for item in self.introduced_by_commit_patches: + if not isinstance(item, PackageCommitPatchData): + raise TypeError( + f"introduced_by_commit_patches items must be PackageCommitPatchData, got {type(item)!r}" + ) + + for item in self.fixed_by_commit_patches: + if not isinstance(item, PackageCommitPatchData): + raise TypeError( + f"fixed_by_commit_patches items must be PackageCommitPatchData, got {type(item)!r}" + ) + def __lt__(self, other): if not isinstance(other, AffectedPackageV2): return NotImplemented @@ -648,6 +694,7 @@ def to_dict(self): def from_dict(cls, advisory_data): date_published = advisory_data["date_published"] transformed = { + "advisory_id": advisory_data["advisory_id"], "aliases": advisory_data["aliases"], "summary": advisory_data["summary"], "affected_packages": [ diff --git a/vulnerabilities/tests/pipelines/test_compute_advisory_todo_v2.py b/vulnerabilities/tests/pipelines/test_compute_advisory_todo_v2.py index 3c234db54..e55bf5f6a 100644 --- a/vulnerabilities/tests/pipelines/test_compute_advisory_todo_v2.py +++ b/vulnerabilities/tests/pipelines/test_compute_advisory_todo_v2.py @@ -11,6 +11,7 @@ from django.test import TestCase from packageurl import PackageURL +from univers.version_range import VersionRange from vulnerabilities.importer import AdvisoryDataV2 from vulnerabilities.importer import AffectedPackageV2 @@ -30,8 +31,8 @@ def setUp(self): affected_packages=[ AffectedPackageV2( package=PackageURL(type="npm", name="package1"), - affected_version_range="vers:npm/>=1.0.0|<2.0.0", - fixed_version_range="vers:npm/2.0.0", + affected_version_range=VersionRange.from_string("vers:npm/>=1.0.0|<2.0.0"), + fixed_version_range=VersionRange.from_string("vers:npm/2.0.0"), ) ], references=[ReferenceV2(url="https://example.com/vuln1")], @@ -44,7 +45,7 @@ def setUp(self): affected_packages=[ AffectedPackageV2( package=PackageURL(type="npm", name="package1"), - affected_version_range="vers:npm/>=1.0.0|<2.0.0", + affected_version_range=VersionRange.from_string("vers:npm/>=1.0.0|<2.0.0"), ) ], references=[ReferenceV2(url="https://example.com/vuln1")], @@ -57,7 +58,7 @@ def setUp(self): affected_packages=[ AffectedPackageV2( package=PackageURL(type="npm", name="package1"), - fixed_version_range="vers:npm/2.0.0", + fixed_version_range=VersionRange.from_string("vers:npm/2.0.0"), ) ], references=[ReferenceV2(url="https://example.com/vuln1")], @@ -70,8 +71,8 @@ def setUp(self): affected_packages=[ AffectedPackageV2( package=PackageURL(type="npm", name="package1"), - affected_version_range="vers:npm/>=1.0.0|<=2.0.0", - fixed_version_range="vers:npm/2.0.1", + affected_version_range=VersionRange.from_string("vers:npm/>=1.0.0|<=2.0.0"), + fixed_version_range=VersionRange.from_string("vers:npm/2.0.1"), ) ], references=[ReferenceV2(url="https://example.com/vuln1")], diff --git a/vulnerabilities/tests/pipelines/v2_importers/test_vulnrichment_importer_v2.py b/vulnerabilities/tests/pipelines/v2_importers/test_vulnrichment_importer_v2.py index f5c251e7f..45f45550f 100644 --- a/vulnerabilities/tests/pipelines/v2_importers/test_vulnrichment_importer_v2.py +++ b/vulnerabilities/tests/pipelines/v2_importers/test_vulnrichment_importer_v2.py @@ -196,7 +196,7 @@ def test_parse_cve_advisory(mock_pathlib, mock_vcs_response, mock_fetch_via_vcs) assert advisory.summary == "Sample PyPI vulnerability" assert advisory.url == advisory_url assert len(advisory.severities) == 1 - assert advisory.severities[0].value == 5.3 + assert advisory.severities[0].value == "5.3" def test_collect_advisories_with_invalid_json(mock_pathlib, mock_vcs_response, mock_fetch_via_vcs): diff --git a/vulnerabilities/tests/test_utils.py b/vulnerabilities/tests/test_utils.py index a100cbd74..40e92f239 100644 --- a/vulnerabilities/tests/test_utils.py +++ b/vulnerabilities/tests/test_utils.py @@ -254,3 +254,10 @@ def test_content_id_from_adv_data_and_adv_model_are_same(self): id_from_model = utils.compute_content_id_v2(advisory_model) self.assertEqual(id_from_data, id_from_model) + + def test_content_id_from_adv_data_roundtrip_are_same(self): + id_from_data = utils.compute_content_id_v2(self.advisory1) + adv_roundtrip = AdvisoryDataV2.from_dict(self.advisory1.to_dict()) + id_from_roundtrip = utils.compute_content_id_v2(adv_roundtrip) + + self.assertEqual(id_from_data, id_from_roundtrip)