From 9129306855513ddf0fb65b22403741f1f32c75f0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 2 May 2025 17:13:55 -0700 Subject: [PATCH 01/16] Revert "remvoved unready stages and expressions" This reverts commit c049e2107e87c6a8c5a0a693136efde3992c8832. --- google/cloud/firestore_v1/base_pipeline.py | 48 ++ .../firestore_v1/pipeline_expressions.py | 212 +++++++++ google/cloud/firestore_v1/pipeline_stages.py | 16 + tests/system/pipeline_e2e.yaml | 424 +++++++++++++++++- 4 files changed, 699 insertions(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 75f7c3e8d..28b9ea582 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -296,6 +296,54 @@ def sort(self, *orders: stages.Ordering) -> Self: """ return self._append(stages.Sort(*orders)) + def replace( + self, + field: Selectable, + mode: stages.Replace.Mode = stages.Replace.Mode.FULL_REPLACE, + ) -> Self: + """ + Replaces the entire document content with the value of a specified field, + typically a map. + + This stage allows you to emit a map value as the new document structure. + Each key of the map becomes a field in the output document, containing the + corresponding value. + + Example: + Input document: + ```json + { + "name": "John Doe Jr.", + "parents": { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + } + ``` + + >>> from google.cloud.firestore_v1.pipeline_expressions import Field + >>> pipeline = client.collection("people").pipeline() + >>> # Emit the 'parents' map as the document + >>> pipeline = pipeline.replace(Field.of("parents")) + + Output document: + ```json + { + "father": "John Doe Sr.", + "mother": "Jane Doe" + } + ``` + + Args: + field: The `Selectable` field containing the map whose content will + replace the document. + mode: The replacement mode + + Returns: + A new Pipeline object with this stage appended to the stage list + """ + return self._append(stages.Replace(field, mode)) + def sample(self, limit_or_options: int | SampleOptions) -> Self: """ Performs a pseudo-random sampling of the documents from the previous stage. diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 1eda32713..02602570a 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -430,6 +430,21 @@ def not_in_any(self, array: List[Expr | CONSTANT_TYPE]) -> "Not": """ return Not(self.in_any(array)) + def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat": + """Creates an expression that concatenates an array expression with another array. + + Example: + >>> # Combine the 'tags' array with a new array and an array field + >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) + + Args: + array: The list of constants or expressions to concat with. + + Returns: + A new `Expr` representing the concatenated array. + """ + return ArrayConcat(self, [self._cast_to_expr_or_convert_to_constant(o) for o in array]) + def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": """Creates an expression that checks if an array contains a specific element or value. @@ -744,6 +759,92 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] ) + def to_lower(self) -> "ToLower": + """Creates an expression that converts a string to lowercase. + + Example: + >>> # Convert the 'name' field to lowercase + >>> Field.of("name").to_lower() + + Returns: + A new `Expr` representing the lowercase string. + """ + return ToLower(self) + + def to_upper(self) -> "ToUpper": + """Creates an expression that converts a string to uppercase. + + Example: + >>> # Convert the 'title' field to uppercase + >>> Field.of("title").to_upper() + + Returns: + A new `Expr` representing the uppercase string. + """ + return ToUpper(self) + + def trim(self) -> "Trim": + """Creates an expression that removes leading and trailing whitespace from a string. + + Example: + >>> # Trim whitespace from the 'userInput' field + >>> Field.of("userInput").trim() + + Returns: + A new `Expr` representing the trimmed string. + """ + return Trim(self) + + def reverse(self) -> "Reverse": + """Creates an expression that reverses a string. + + Example: + >>> # Reverse the 'userInput' field + >>> Field.of("userInput").reverse() + + Returns: + A new `Expr` representing the reversed string. + """ + return Reverse(self) + + def replace_first(self, find: Expr | str, replace: Expr | str) -> "ReplaceFirst": + """Creates an expression that replaces the first occurrence of a substring within a string with + another substring. + + Example: + >>> # Replace the first occurrence of "hello" with "hi" in the 'message' field + >>> Field.of("message").replace_first("hello", "hi") + >>> # Replace the first occurrence of the value in 'findField' with the value in 'replaceField' in the 'message' field + >>> Field.of("message").replace_first(Field.of("findField"), Field.of("replaceField")) + + Args: + find: The substring (string or expression) to search for. + replace: The substring (string or expression) to replace the first occurrence of 'find' with. + + Returns: + A new `Expr` representing the string with the first occurrence replaced. + """ + return ReplaceFirst(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) + + def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll": + """Creates an expression that replaces all occurrences of a substring within a string with another + substring. + + Example: + >>> # Replace all occurrences of "hello" with "hi" in the 'message' field + >>> Field.of("message").replace_all("hello", "hi") + >>> # Replace all occurrences of the value in 'findField' with the value in 'replaceField' in the 'message' field + >>> Field.of("message").replace_all(Field.of("findField"), Field.of("replaceField")) + + Args: + find: The substring (string or expression) to search for. + replace: The substring (string or expression) to replace all occurrences of 'find' with. + + Returns: + A new `Expr` representing the string with all occurrences replaced. + """ + return ReplaceAll(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) + def map_get(self, key: str) -> "MapGet": """Accesses a value from a map (object) field using the provided key. @@ -760,6 +861,57 @@ def map_get(self, key: str) -> "MapGet": """ return MapGet(self, Constant.of(key)) + def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance": + """Calculates the cosine distance between two vectors. + + Example: + >>> # Calculate the cosine distance between the 'userVector' field and the 'itemVector' field + >>> Field.of("userVector").cosine_distance(Field.of("itemVector")) + >>> # Calculate the Cosine distance between the 'location' field and a target location + >>> Field.of("location").cosine_distance([37.7749, -122.4194]) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the cosine distance between the two vectors. + """ + return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) + + def euclidean_distance(self, other: Expr | list[float] | Vector) -> "EuclideanDistance": + """Calculates the Euclidean distance between two vectors. + + Example: + >>> # Calculate the Euclidean distance between the 'location' field and a target location + >>> Field.of("location").euclidean_distance([37.7749, -122.4194]) + >>> # Calculate the Euclidean distance between two vector fields: 'pointA' and 'pointB' + >>> Field.of("pointA").euclidean_distance(Field.of("pointB")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to compare against. + + Returns: + A new `Expr` representing the Euclidean distance between the two vectors. + """ + return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other)) + + def dot_product(self, other: Expr | list[float] | Vector) -> "DotProduct": + """Calculates the dot product between two vectors. + + Example: + >>> # Calculate the dot product between a feature vector and a target vector + >>> Field.of("features").dot_product([0.5, 0.8, 0.2]) + >>> # Calculate the dot product between two document vectors: 'docVector1' and 'docVector2' + >>> Field.of("docVector1").dot_product(Field.of("docVector2")) + + Args: + other: The other vector (represented as an Expr, list of floats, or Vector) to calculate dot product with. + + Returns: + A new `Expr` representing the dot product between the two vectors. + """ + return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other)) + def vector_length(self) -> "VectorLength": """Creates an expression that calculates the length (dimension) of a Firestore Vector. @@ -1005,6 +1157,18 @@ def __init__(self, left: Expr, right: Expr): super().__init__("divide", [left, right]) +class DotProduct(Function): + """Represents the vector dot product function.""" + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("dot_product", [vector1, vector2]) + + +class EuclideanDistance(Function): + """Represents the vector Euclidean distance function.""" + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("euclidean_distance", [vector1, vector2]) + + class LogicalMax(Function): """Represents the logical maximum function based on Firestore type ordering.""" @@ -1047,6 +1211,24 @@ def __init__(self, value: Expr): super().__init__("parent", [value]) +class ReplaceAll(Function): + """Represents replacing all occurrences of a substring.""" + def __init__(self, value: Expr, pattern: Expr, replacement: Expr): + super().__init__("replace_all", [value, pattern, replacement]) + + +class ReplaceFirst(Function): + """Represents replacing the first occurrence of a substring.""" + def __init__(self, value: Expr, pattern: Expr, replacement: Expr): + super().__init__("replace_first", [value, pattern, replacement]) + + +class Reverse(Function): + """Represents reversing a string.""" + def __init__(self, expr: Expr): + super().__init__("reverse", [expr]) + + class StrConcat(Function): """Represents concatenating multiple strings.""" @@ -1096,6 +1278,24 @@ def __init__(self, input: Expr): super().__init__("timestamp_to_unix_seconds", [input]) +class ToLower(Function): + """Represents converting a string to lowercase.""" + def __init__(self, value: Expr): + super().__init__("to_lower", [value]) + + +class ToUpper(Function): + """Represents converting a string to uppercase.""" + def __init__(self, value: Expr): + super().__init__("to_upper", [value]) + + +class Trim(Function): + """Represents trimming whitespace from a string.""" + def __init__(self, expr: Expr): + super().__init__("trim", [expr]) + + class UnixMicrosToTimestamp(Function): """Represents converting microseconds since epoch to a timestamp.""" @@ -1131,6 +1331,12 @@ def __init__(self, left: Expr, right: Expr): super().__init__("add", [left, right]) +class ArrayConcat(Function): + """Represents concatenating multiple arrays.""" + def __init__(self, array: Expr, rest: List[Expr]): + super().__init__("array_concat", [array] + rest) + + class ArrayElement(Function): """Represents accessing an element within an array""" @@ -1187,6 +1393,12 @@ def __init__(self, value: Expr): super().__init__("collection_id", [value]) +class CosineDistance(Function): + """Represents the vector cosine distance function.""" + def __init__(self, vector1: Expr, vector2: Expr): + super().__init__("cosine_distance", [vector1, vector2]) + + class Accumulator(Function): """A base class for aggregation functions that operate across multiple inputs.""" diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 686aaf2a0..9430474fe 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -308,6 +308,22 @@ def _pb_args(self) -> list[Value]: return [f._to_pb() for f in self.fields] +class Replace(Stage): + """Replaces the document content with the value of a specified field.""" + class Mode(Enum): + FULL_REPLACE = "full_replace" + MERGE_PREFER_NEXT = "merge_prefer_nest" + MERGE_PREFER_PARENT = "merge_prefer_parent" + + def __init__(self, field: Selectable | str, mode: Mode | str = Mode.FULL_REPLACE): + super().__init__() + self.field = Field(field) if isinstance(field, str) else field + self.mode = self.Mode[mode] if isinstance(mode, str) else mode + + def _pb_args(self): + return [self.field._to_pb(), Value(string_value=self.mode.value)] + + class Sample(Stage): """Performs pseudo-random sampling of documents.""" diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index dc262f4a9..d92397347 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -214,6 +214,43 @@ tests: accumulators: [] groups: [genre] assert_error: ".* requires at least one accumulator" + - description: testDistinct + pipeline: + - Collection: books + - Where: + - Lt: + - Field: published + - Constant: 1900 + - Distinct: + - ExprWithAlias: + - ToLower: + - Field: genre + - "lower_genre" + assert_results: + - lower_genre: romance + - lower_genre: psychological thriller + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: published + - integerValue: '1900' + name: lt + name: where + - args: + - mapValue: + fields: + lower_genre: + functionValue: + args: + - fieldReferenceValue: genre + name: to_lower + name: distinct - description: testGroupBysAndAggregate pipeline: - Collection: books @@ -776,6 +813,44 @@ tests: - integerValue: '3' name: eq name: where + - description: testArrayConcat + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ArrayConcat: + - Field: tags + - - Constant: newTag1 + - Constant: newTag2 + - "modifiedTags" + - Limit: 1 + assert_results: + - modifiedTags: + - comedy + - space + - adventure + - newTag1 + - newTag2 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + modifiedTags: + functionValue: + args: + - fieldReferenceValue: tags + - stringValue: newTag1 + - stringValue: newTag2 + name: array_concat + name: select + - args: + - integerValue: '1' + name: limit - description: testStrConcat pipeline: - Collection: books @@ -967,6 +1042,122 @@ tests: expression: fieldReferenceValue: title name: sort + - description: testStringFunctions - Reverse + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - Reverse: + - Field: title + - "reversed_title" + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + assert_results: + - reversed_title: yxalaG ot ediug s'reknhiHcH ehT + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + reversed_title: + functionValue: + args: + - fieldReferenceValue: title + name: reverse + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - description: testStringFunctions - ReplaceFirst + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ReplaceFirst: + - Field: title + - Constant: The + - Constant: A + - "replaced_title" + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + assert_results: + - replaced_title: A Hitchhiker's Guide to the Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + replaced_title: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: The + - stringValue: A + name: replace_first + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where + - description: testStringFunctions - ReplaceAll + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ReplaceAll: + - Field: title + - Constant: " " + - Constant: "_" + - "replaced_title" + - Where: + - Eq: + - Field: author + - Constant: Douglas Adams + assert_results: + - replaced_title: The_Hitchhiker's_Guide_to_the_Galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + replaced_title: + functionValue: + args: + - fieldReferenceValue: title + - stringValue: ' ' + - stringValue: _ + name: replace_all + name: select + - args: + - functionValue: + args: + - fieldReferenceValue: author + - stringValue: Douglas Adams + name: eq + name: where - description: testStringFunctions - CharLength pipeline: - Collection: books @@ -1045,6 +1236,115 @@ tests: name: str_concat name: byte_length name: select + - description: testToLowercase + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ToLower: + - Field: title + - "lowercaseTitle" + - Limit: 1 + assert_results: + - lowercaseTitle: the hitchhiker's guide to the galaxy + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + lowercaseTitle: + functionValue: + args: + - fieldReferenceValue: title + name: to_lower + name: select + - args: + - integerValue: '1' + name: limit + - description: testToUppercase + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - ToUpper: + - Field: author + - "uppercaseAuthor" + - Limit: 1 + assert_results: + - uppercaseAuthor: DOUGLAS ADAMS + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + uppercaseAuthor: + functionValue: + args: + - fieldReferenceValue: author + name: to_upper + name: select + - args: + - integerValue: '1' + name: limit + - description: testTrim + pipeline: + - Collection: books + - AddFields: + - ExprWithAlias: + - StrConcat: + - Constant: " " + - Field: title + - Constant: " " + - "spacedTitle" + - Select: + - ExprWithAlias: + - Trim: + - Field: spacedTitle + - "trimmedTitle" + - spacedTitle + - Limit: 1 + assert_results: + - trimmedTitle: The Hitchhiker's Guide to the Galaxy + spacedTitle: " The Hitchhiker's Guide to the Galaxy " + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + spacedTitle: + functionValue: + args: + - stringValue: ' ' + - fieldReferenceValue: title + - stringValue: ' ' + name: str_concat + name: add_fields + - args: + - mapValue: + fields: + spacedTitle: + fieldReferenceValue: spacedTitle + trimmedTitle: + functionValue: + args: + - fieldReferenceValue: spacedTitle + name: trim + name: select + - args: + - integerValue: '1' + name: limit - description: testLike pipeline: - Collection: books @@ -1356,6 +1656,11 @@ tests: - IsNaN: - Field: rating - Select: + - ExprWithAlias: + - Eq: + - Field: rating + - Constant: null + - "ratingIsNull" - ExprWithAlias: - Not: - IsNaN: @@ -1363,7 +1668,8 @@ tests: - "ratingIsNotNaN" - Limit: 1 assert_results: - - ratingIsNotNaN: true + - ratingIsNull: false + ratingIsNotNaN: true assert_proto: pipeline: stages: @@ -1390,6 +1696,12 @@ tests: - fieldReferenceValue: rating name: is_nan name: not + ratingIsNull: + functionValue: + args: + - fieldReferenceValue: rating + - nullValue: null + name: eq name: select - args: - integerValue: '1' @@ -1500,6 +1812,79 @@ tests: - booleanValue: true name: eq name: where + - description: testDistanceFunctions + pipeline: + - Collection: books + - Select: + - ExprWithAlias: + - CosineDistance: + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] + - "cosineDistance" + - ExprWithAlias: + - DotProduct: + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] + - "dotProductDistance" + - ExprWithAlias: + - EuclideanDistance: + - Constant: [[0.1, 0.1]] + - Constant: [[0.5, 0.8]] + - "euclideanDistance" + - Limit: 1 + assert_results: + - cosineDistance: 0.02560880430538015 + dotProductDistance: 0.13 + euclideanDistance: 0.806225774829855 + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - mapValue: + fields: + cosineDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: cosine_distance + dotProductDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: dot_product + euclideanDistance: + functionValue: + args: + - arrayValue: + values: + - doubleValue: 0.1 + - doubleValue: 0.1 + - arrayValue: + values: + - doubleValue: 0.5 + - doubleValue: 0.8 + name: euclidean_distance + name: select + - args: + - integerValue: '1' + name: limit - description: testNestedFields pipeline: - Collection: books @@ -1548,6 +1933,43 @@ tests: title: fieldReferenceValue: title name: select + - description: testReplace + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Replace: awards + assert_results: + - title: The Hitchhiker's Guide to the Galaxy + author: Douglas Adams + genre: Science Fiction + published: 1979 + rating: 4.2 + tags: + - comedy + - space + - adventure + hugo: true + nebula: false + assert_proto: + pipeline: + stages: + - args: + - referenceValue: /books + name: collection + - args: + - functionValue: + args: + - fieldReferenceValue: title + - stringValue: The Hitchhiker's Guide to the Galaxy + name: eq + name: where + - args: + - fieldReferenceValue: awards + - stringValue: full_replace + name: replace - description: testSampleLimit pipeline: - Collection: books From ca381603dfa159a3f029a53a4cdd3312fa3bcedb Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 6 May 2025 15:08:32 -0700 Subject: [PATCH 02/16] fixed docstrings --- google/cloud/firestore_v1/base_pipeline.py | 2 +- google/cloud/firestore_v1/pipeline_expressions.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 72bb39f1f..b7b9a0f6f 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -322,7 +322,7 @@ def replace( ``` >>> from google.cloud.firestore_v1.pipeline_expressions import Field - >>> pipeline = client.collection("people").pipeline() + >>> pipeline = client.pipeline().collection("people") >>> # Emit the 'parents' map as the document >>> pipeline = pipeline.replace(Field.of("parents")) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 02602570a..9aac78cdb 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -1059,7 +1059,7 @@ def ascending(self) -> Ordering: Example: >>> # Sort documents by the 'name' field in ascending order - >>> firestore.pipeline().collection("users").sort(Field.of("name").ascending()) + >>> client.pipeline().collection("users").sort(Field.of("name").ascending()) Returns: A new `Ordering` for ascending sorting. @@ -1071,7 +1071,7 @@ def descending(self) -> Ordering: Example: >>> # Sort documents by the 'createdAt' field in descending order - >>> firestore.pipeline().collection("users").sort(Field.of("createdAt").descending()) + >>> client.pipeline().collection("users").sort(Field.of("createdAt").descending()) Returns: A new `Ordering` for descending sorting. @@ -1086,7 +1086,7 @@ def as_(self, alias: str) -> "ExprWithAlias": Example: >>> # Calculate the total price and assign it the alias "totalPrice" and add it to the output. - >>> firestore.pipeline().collection("items").add_fields( + >>> client.pipeline().collection("items").add_fields( ... Field.of("price").multiply(Field.of("quantity")).as_("totalPrice") ... ) From 60a770c1718345e1951d545d0e6d317611db7daa Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 13 May 2025 11:08:31 -0700 Subject: [PATCH 03/16] fixed lint --- .../firestore_v1/pipeline_expressions.py | 30 ++++++++++++++++--- google/cloud/firestore_v1/pipeline_stages.py | 2 ++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 6624e12f8..8b33a2af2 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -436,7 +436,9 @@ def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat": Returns: A new `Expr` representing the concatenated array. """ - return ArrayConcat(self, [self._cast_to_expr_or_convert_to_constant(o) for o in array]) + return ArrayConcat( + self, [self._cast_to_expr_or_convert_to_constant(o) for o in array] + ) def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": """Creates an expression that checks if an array contains a specific element or value. @@ -817,7 +819,11 @@ def replace_first(self, find: Expr | str, replace: Expr | str) -> "ReplaceFirst" Returns: A new `Expr` representing the string with the first occurrence replaced. """ - return ReplaceFirst(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) + return ReplaceFirst( + self, + self._cast_to_expr_or_convert_to_constant(find), + self._cast_to_expr_or_convert_to_constant(replace), + ) def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll": """Creates an expression that replaces all occurrences of a substring within a string with another @@ -836,7 +842,11 @@ def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll": Returns: A new `Expr` representing the string with all occurrences replaced. """ - return ReplaceAll(self, self._cast_to_expr_or_convert_to_constant(find), self._cast_to_expr_or_convert_to_constant(replace)) + return ReplaceAll( + self, + self._cast_to_expr_or_convert_to_constant(find), + self._cast_to_expr_or_convert_to_constant(replace), + ) def map_get(self, key: str) -> "MapGet": """Accesses a value from a map (object) field using the provided key. @@ -871,7 +881,9 @@ def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance """ return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) - def euclidean_distance(self, other: Expr | list[float] | Vector) -> "EuclideanDistance": + def euclidean_distance( + self, other: Expr | list[float] | Vector + ) -> "EuclideanDistance": """Calculates the Euclidean distance between two vectors. Example: @@ -1152,12 +1164,14 @@ def __init__(self, left: Expr, right: Expr): class DotProduct(Function): """Represents the vector dot product function.""" + def __init__(self, vector1: Expr, vector2: Expr): super().__init__("dot_product", [vector1, vector2]) class EuclideanDistance(Function): """Represents the vector Euclidean distance function.""" + def __init__(self, vector1: Expr, vector2: Expr): super().__init__("euclidean_distance", [vector1, vector2]) @@ -1206,18 +1220,21 @@ def __init__(self, value: Expr): class ReplaceAll(Function): """Represents replacing all occurrences of a substring.""" + def __init__(self, value: Expr, pattern: Expr, replacement: Expr): super().__init__("replace_all", [value, pattern, replacement]) class ReplaceFirst(Function): """Represents replacing the first occurrence of a substring.""" + def __init__(self, value: Expr, pattern: Expr, replacement: Expr): super().__init__("replace_first", [value, pattern, replacement]) class Reverse(Function): """Represents reversing a string.""" + def __init__(self, expr: Expr): super().__init__("reverse", [expr]) @@ -1273,18 +1290,21 @@ def __init__(self, input: Expr): class ToLower(Function): """Represents converting a string to lowercase.""" + def __init__(self, value: Expr): super().__init__("to_lower", [value]) class ToUpper(Function): """Represents converting a string to uppercase.""" + def __init__(self, value: Expr): super().__init__("to_upper", [value]) class Trim(Function): """Represents trimming whitespace from a string.""" + def __init__(self, expr: Expr): super().__init__("trim", [expr]) @@ -1326,6 +1346,7 @@ def __init__(self, left: Expr, right: Expr): class ArrayConcat(Function): """Represents concatenating multiple arrays.""" + def __init__(self, array: Expr, rest: List[Expr]): super().__init__("array_concat", [array] + rest) @@ -1388,6 +1409,7 @@ def __init__(self, value: Expr): class CosineDistance(Function): """Represents the vector cosine distance function.""" + def __init__(self, vector1: Expr, vector2: Expr): super().__init__("cosine_distance", [vector1, vector2]) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 29f85be1d..d2f057902 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -16,6 +16,7 @@ from typing import Optional, Sequence, TYPE_CHECKING from abc import ABC from abc import abstractmethod +from enum import Enum from google.cloud.firestore_v1.types.document import Pipeline as Pipeline_pb from google.cloud.firestore_v1.types.document import Value @@ -314,6 +315,7 @@ def _pb_args(self) -> list[Value]: class Replace(Stage): """Replaces the document content with the value of a specified field.""" + class Mode(Enum): FULL_REPLACE = "full_replace" MERGE_PREFER_NEXT = "merge_prefer_nest" From 699fc05366470d83df03f0b9c1777b2914cb5e8c Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Jun 2025 16:47:03 -0700 Subject: [PATCH 04/16] added unit tests --- google/cloud/firestore_v1/pipeline_stages.py | 13 +- tests/unit/v1/test_pipeline.py | 2 + tests/unit/v1/test_pipeline_expressions.py | 227 +++++++++++++------ tests/unit/v1/test_pipeline_stages.py | 41 ++++ 4 files changed, 214 insertions(+), 69 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 694050476..361295256 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -376,17 +376,20 @@ class Replace(Stage): """Replaces the document content with the value of a specified field.""" class Mode(Enum): - FULL_REPLACE = "full_replace" - MERGE_PREFER_NEXT = "merge_prefer_nest" - MERGE_PREFER_PARENT = "merge_prefer_parent" + FULL_REPLACE = 0 + MERGE_PREFER_NEXT = 1 + MERGE_PREFER_PARENT = 2 + + def __repr__(self): + return f'Replace.Mode.{self.name.upper()}' def __init__(self, field: Selectable | str, mode: Mode | str = Mode.FULL_REPLACE): super().__init__() self.field = Field(field) if isinstance(field, str) else field - self.mode = self.Mode[mode] if isinstance(mode, str) else mode + self.mode = self.Mode[mode.upper()] if isinstance(mode, str) else mode def _pb_args(self): - return [self.field._to_pb(), Value(string_value=self.mode.value)] + return [self.field._to_pb(), Value(string_value=self.mode.name.lower())] class Sample(Stage): diff --git a/tests/unit/v1/test_pipeline.py b/tests/unit/v1/test_pipeline.py index 765c5b1e1..e576f8b44 100644 --- a/tests/unit/v1/test_pipeline.py +++ b/tests/unit/v1/test_pipeline.py @@ -310,6 +310,8 @@ def test_pipeline_execute_with_transaction(): ), ("sort", (Field.of("n").descending(),), stages.Sort), ("sort", (Field.of("n").descending(), Field.of("m").ascending()), stages.Sort), + ("replace", (Field.of("n"),), stages.Replace), + ("replace", (Field.of("n"), stages.Replace.Mode.FULL_REPLACE), stages.Replace), ("sample", (10,), stages.Sample), ("sample", (stages.SampleOptions.doc_limit(10),), stages.Sample), ("union", (_make_pipeline(),), stages.Union), diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index d14a33cf9..304063b0d 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -100,6 +100,7 @@ def test_ctor(self): ("gte", (2,), expr.Gte), ("in_any", ([None],), expr.In), ("not_in_any", ([None],), expr.Not), + ("array_concat", ([None],), expr.ArrayConcat), ("array_contains", (None,), expr.ArrayContains), ("array_contains_all", ([None],), expr.ArrayContainsAll), ("array_contains_any", ([None],), expr.ArrayContainsAny), @@ -121,7 +122,16 @@ def test_ctor(self): ("starts_with", ("prefix",), expr.StartsWith), ("ends_with", ("postfix",), expr.EndsWith), ("str_concat", ("elem1", expr.Constant("elem2")), expr.StrConcat), + ("to_lower", (), expr.ToLower), + ("to_upper", (), expr.ToUpper), + ("trim", (), expr.Trim), + ("reverse", (), expr.Reverse), + ("replace_first", ("1", "2"), expr.ReplaceFirst), + ("replace_all", ("1", "2"), expr.ReplaceAll), ("map_get", ("key",), expr.MapGet), + ("cosine_distance", [1], expr.CosineDistance), + ("euclidean_distance", [1], expr.EuclideanDistance), + ("dot_product", [1], expr.DotProduct), ("vector_length", (), expr.VectorLength), ("timestamp_to_unix_micros", (), expr.TimestampToUnixMicros), ("unix_micros_to_timestamp", (), expr.UnixMicrosToTimestamp), @@ -825,6 +835,96 @@ def test_not(self): assert instance.params == [arg1] assert repr(instance) == "Not(Condition)" + def test_array_contains_all(self): + arg1 = self._make_arg("ArrayField") + arg2 = self._make_arg("Element1") + arg3 = self._make_arg("Element2") + instance = expr.ArrayContainsAll(arg1, [arg2, arg3]) + assert instance.name == "array_contains_all" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert ( + repr(instance) + == "ArrayField.array_contains_all(ListOfExprs([Element1, Element2]))" + ) + + def test_ends_with(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Postfix") + instance = expr.EndsWith(arg1, arg2) + assert instance.name == "ends_with" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.ends_with(Postfix)" + + def test_if(self): + arg1 = self._make_arg("Condition") + arg2 = self._make_arg("TrueExpr") + arg3 = self._make_arg("FalseExpr") + instance = expr.If(arg1, arg2, arg3) + assert instance.name == "if" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "If(Condition, TrueExpr, FalseExpr)" + + def test_like(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Pattern") + instance = expr.Like(arg1, arg2) + assert instance.name == "like" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.like(Pattern)" + + def test_regex_contains(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Regex") + instance = expr.RegexContains(arg1, arg2) + assert instance.name == "regex_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.regex_contains(Regex)" + + def test_regex_match(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Regex") + instance = expr.RegexMatch(arg1, arg2) + assert instance.name == "regex_match" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.regex_match(Regex)" + + def test_starts_with(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Prefix") + instance = expr.StartsWith(arg1, arg2) + assert instance.name == "starts_with" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.starts_with(Prefix)" + + def test_str_contains(self): + arg1 = self._make_arg("Expr") + arg2 = self._make_arg("Substring") + instance = expr.StrContains(arg1, arg2) + assert instance.name == "str_contains" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Expr.str_contains(Substring)" + + def test_xor(self): + arg1 = self._make_arg("Condition1") + arg2 = self._make_arg("Condition2") + instance = expr.Xor([arg1, arg2]) + assert instance.name == "xor" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Xor(Condition1, Condition2)" + + +class TestFunctionClasses: + """ + contains test methods for each Expr class that derives from Function + """ + + def _make_arg(self, name="Mock"): + arg = mock.Mock() + arg.__repr__ = lambda x: name + return arg + def test_divide(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") @@ -1035,81 +1135,80 @@ def test_max(self): assert instance.params == [arg1] assert repr(instance) == "Max(Value)" - def test_array_contains_all(self): - arg1 = self._make_arg("ArrayField") - arg2 = self._make_arg("Element1") - arg3 = self._make_arg("Element2") - instance = expr.ArrayContainsAll(arg1, [arg2, arg3]) - assert instance.name == "array_contains_all" - assert isinstance(instance.params[1], ListOfExprs) - assert instance.params[0] == arg1 - assert instance.params[1].exprs == [arg2, arg3] - assert ( - repr(instance) - == "ArrayField.array_contains_all(ListOfExprs([Element1, Element2]))" - ) + def test_dot_product(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.DotProduct(arg1, arg2) + assert instance.name == "dot_product" + assert instance.params == [arg1, arg2] + assert repr(instance) == "DotProduct(Left, Right)" - def test_ends_with(self): - arg1 = self._make_arg("Expr") - arg2 = self._make_arg("Postfix") - instance = expr.EndsWith(arg1, arg2) - assert instance.name == "ends_with" + def test_euclidean_distance(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.EuclideanDistance(arg1, arg2) + assert instance.name == "euclidean_distance" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.ends_with(Postfix)" + assert repr(instance) == "EuclideanDistance(Left, Right)" - def test_if(self): - arg1 = self._make_arg("Condition") - arg2 = self._make_arg("TrueExpr") - arg3 = self._make_arg("FalseExpr") - instance = expr.If(arg1, arg2, arg3) - assert instance.name == "if" - assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "If(Condition, TrueExpr, FalseExpr)" + def test_cosine_distance(self): + arg1 = self._make_arg("Left") + arg2 = self._make_arg("Right") + instance = expr.CosineDistance(arg1, arg2) + assert instance.name == "cosine_distance" + assert instance.params == [arg1, arg2] + assert repr(instance) == "CosineDistance(Left, Right)" - def test_like(self): + def test_replace_all(self): arg1 = self._make_arg("Expr") - arg2 = self._make_arg("Pattern") - instance = expr.Like(arg1, arg2) - assert instance.name == "like" - assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.like(Pattern)" + arg2 = self._make_arg("OldValue") + arg3 = self._make_arg("NewValue") + instance = expr.ReplaceAll(arg1, arg2, arg3) + assert instance.name == "replace_all" + assert instance.params == [arg1, arg2, arg3] - def test_regex_contains(self): + def test_replace_first(self): arg1 = self._make_arg("Expr") - arg2 = self._make_arg("Regex") - instance = expr.RegexContains(arg1, arg2) - assert instance.name == "regex_contains" - assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.regex_contains(Regex)" + arg2 = self._make_arg("OldValue") + arg3 = self._make_arg("NewValue") + instance = expr.ReplaceFirst(arg1, arg2, arg3) + assert instance.name == "replace_first" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "ReplaceFirst(Expr, OldValue, NewValue)" - def test_regex_match(self): + def test_reverse(self): arg1 = self._make_arg("Expr") - arg2 = self._make_arg("Regex") - instance = expr.RegexMatch(arg1, arg2) - assert instance.name == "regex_match" - assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.regex_match(Regex)" + instance = expr.Reverse(arg1) + assert instance.name == "reverse" + assert instance.params == [arg1] + assert repr(instance) == "Reverse(Expr)" - def test_starts_with(self): + def test_to_lower(self): arg1 = self._make_arg("Expr") - arg2 = self._make_arg("Prefix") - instance = expr.StartsWith(arg1, arg2) - assert instance.name == "starts_with" - assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.starts_with(Prefix)" + instance = expr.ToLower(arg1) + assert instance.name == "to_lower" + assert instance.params == [arg1] + assert repr(instance) == "ToLower(Expr)" - def test_str_contains(self): + def test_to_upper(self): arg1 = self._make_arg("Expr") - arg2 = self._make_arg("Substring") - instance = expr.StrContains(arg1, arg2) - assert instance.name == "str_contains" - assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.str_contains(Substring)" + instance = expr.ToUpper(arg1) + assert instance.name == "to_upper" + assert instance.params == [arg1] + assert repr(instance) == "ToUpper(Expr)" - def test_xor(self): - arg1 = self._make_arg("Condition1") - arg2 = self._make_arg("Condition2") - instance = expr.Xor([arg1, arg2]) - assert instance.name == "xor" - assert instance.params == [arg1, arg2] - assert repr(instance) == "Xor(Condition1, Condition2)" + def test_trim(self): + arg1 = self._make_arg("Expr") + instance = expr.Trim(arg1) + assert instance.name == "trim" + assert instance.params == [arg1] + assert repr(instance) == "Trim(Expr)" + + def test_array_concat(self): + arg1 = self._make_arg("1") + arg2 = self._make_arg("2") + arg3 = self._make_arg("3") + instance = expr.ArrayConcat(arg1, [arg2, arg3]) + assert instance.name == "array_concat" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "ArrayConcat(1, 2, 3)" \ No newline at end of file diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index e050514b7..158d94989 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -532,6 +532,47 @@ def test_to_pb(self): assert len(result.options) == 0 +class TestReplace: + def _make_one(self, *args, **kwargs): + return stages.Replace(*args, **kwargs) + + def test_ctor_default(self): + instance = self._make_one("field") + assert isinstance(instance.field, Field) + assert instance.field.path == "field" + # default mode is FULL_REPLACE + assert instance.mode == stages.Replace.Mode.FULL_REPLACE + + @pytest.mark.parametrize("mode_str,expected_mode", [ + ("full_replace", stages.Replace.Mode.FULL_REPLACE), + ("merge_prefer_next", stages.Replace.Mode.MERGE_PREFER_NEXT), + ("merge_prefer_parent", stages.Replace.Mode.MERGE_PREFER_PARENT), + ]) + def test_ctor_str_mode(self, mode_str, expected_mode): + instance = self._make_one("field", mode_str) + assert instance.mode == expected_mode + assert repr(instance) == f"Replace(field=Field.of('field'), mode=Replace.Mode.{mode_str.upper()})" + + def test_ctor_w_field(self): + field = Field.of("field") + instance = self._make_one(field) + assert isinstance(instance.field, Field) + assert instance.field == field + + def test_repr(self): + instance = self._make_one("field", stages.Replace.Mode.MERGE_PREFER_NEXT) + repr_str = repr(instance) + assert repr_str == "Replace(field=Field.of('field'), mode=Replace.Mode.MERGE_PREFER_NEXT)" + + def test_to_pb(self): + instance = self._make_one("field", stages.Replace.Mode.MERGE_PREFER_NEXT) + result = instance._to_pb() + assert result.name == "replace" + assert len(result.args) == 2 + assert result.args[0].field_reference_value == "field" + assert result.args[1].string_value == "merge_prefer_next" + + class TestSample: class TestSampleOptions: def test_ctor_percent(self): From a1c35376a36076de8ca57035f4d0c1fb828b7284 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 12 Jun 2025 16:49:44 -0700 Subject: [PATCH 05/16] fixed lint --- google/cloud/firestore_v1/pipeline_stages.py | 2 +- tests/unit/v1/test_pipeline_expressions.py | 2 +- tests/unit/v1/test_pipeline_stages.py | 17 +++++++++++++---- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_stages.py b/google/cloud/firestore_v1/pipeline_stages.py index 361295256..5b036cf43 100644 --- a/google/cloud/firestore_v1/pipeline_stages.py +++ b/google/cloud/firestore_v1/pipeline_stages.py @@ -381,7 +381,7 @@ class Mode(Enum): MERGE_PREFER_PARENT = 2 def __repr__(self): - return f'Replace.Mode.{self.name.upper()}' + return f"Replace.Mode.{self.name.upper()}" def __init__(self, field: Selectable | str, mode: Mode | str = Mode.FULL_REPLACE): super().__init__() diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 304063b0d..91f604eda 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -1211,4 +1211,4 @@ def test_array_concat(self): instance = expr.ArrayConcat(arg1, [arg2, arg3]) assert instance.name == "array_concat" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "ArrayConcat(1, 2, 3)" \ No newline at end of file + assert repr(instance) == "ArrayConcat(1, 2, 3)" diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 158d94989..ae35647b0 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -543,15 +543,21 @@ def test_ctor_default(self): # default mode is FULL_REPLACE assert instance.mode == stages.Replace.Mode.FULL_REPLACE - @pytest.mark.parametrize("mode_str,expected_mode", [ + @pytest.mark.parametrize( + "mode_str,expected_mode", + [ ("full_replace", stages.Replace.Mode.FULL_REPLACE), ("merge_prefer_next", stages.Replace.Mode.MERGE_PREFER_NEXT), ("merge_prefer_parent", stages.Replace.Mode.MERGE_PREFER_PARENT), - ]) + ], + ) def test_ctor_str_mode(self, mode_str, expected_mode): instance = self._make_one("field", mode_str) assert instance.mode == expected_mode - assert repr(instance) == f"Replace(field=Field.of('field'), mode=Replace.Mode.{mode_str.upper()})" + assert ( + repr(instance) + == f"Replace(field=Field.of('field'), mode=Replace.Mode.{mode_str.upper()})" + ) def test_ctor_w_field(self): field = Field.of("field") @@ -562,7 +568,10 @@ def test_ctor_w_field(self): def test_repr(self): instance = self._make_one("field", stages.Replace.Mode.MERGE_PREFER_NEXT) repr_str = repr(instance) - assert repr_str == "Replace(field=Field.of('field'), mode=Replace.Mode.MERGE_PREFER_NEXT)" + assert ( + repr_str + == "Replace(field=Field.of('field'), mode=Replace.Mode.MERGE_PREFER_NEXT)" + ) def test_to_pb(self): instance = self._make_one("field", stages.Replace.Mode.MERGE_PREFER_NEXT) From 407aa85a8fdf2e7e9e44483d95e18ba0f1728a9e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 11 Jul 2025 17:02:26 -0700 Subject: [PATCH 06/16] renamed classes --- google/cloud/firestore_v1/_pipeline_stages.py | 13 +- google/cloud/firestore_v1/base_pipeline.py | 18 +- .../firestore_v1/pipeline_expressions.py | 340 ++++++++++-------- tests/unit/v1/test_pipeline_expressions.py | 36 +- 4 files changed, 219 insertions(+), 188 deletions(-) diff --git a/google/cloud/firestore_v1/_pipeline_stages.py b/google/cloud/firestore_v1/_pipeline_stages.py index caba9926a..7cd0efed4 100644 --- a/google/cloud/firestore_v1/_pipeline_stages.py +++ b/google/cloud/firestore_v1/_pipeline_stages.py @@ -23,11 +23,12 @@ from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1.base_vector_query import DistanceMeasure from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, + AggregateFunction, Expr, - ExprWithAlias, + AliasedAggregate, + AliasedExpr, Field, - FilterCondition, + BooleanExpr, Selectable, Ordering, ) @@ -164,8 +165,8 @@ class Aggregate(Stage): def __init__( self, - *args: ExprWithAlias[Accumulator], - accumulators: Sequence[ExprWithAlias[Accumulator]] = (), + *args: AliasedExpr[AggregateFunction], + accumulators: Sequence[AliasedAggregate] = (), groups: Sequence[str | Selectable] = (), ): super().__init__() @@ -459,7 +460,7 @@ def _pb_options(self): class Where(Stage): """Filters documents based on a specified condition.""" - def __init__(self, condition: FilterCondition): + def __init__(self, condition: BooleanExpr): super().__init__() self.condition = condition diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 26bdda2e8..5cf341e58 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -23,11 +23,9 @@ from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.pipeline_expressions import ( - Accumulator, Expr, - ExprWithAlias, Field, - FilterCondition, + BooleanExpr, Selectable, ) from google.cloud.firestore_v1 import _helpers @@ -220,14 +218,14 @@ def select(self, *selections: str | Selectable) -> "_BasePipeline": """ return self._append(stages.Select(*selections)) - def where(self, condition: FilterCondition) -> "_BasePipeline": + def where(self, condition: BooleanExpr) -> "_BasePipeline": """ Filters the documents from previous stages to only include those matching - the specified `FilterCondition`. + the specified `BooleanExpr`. This stage allows you to apply conditions to the data, similar to a "WHERE" clause in SQL. You can filter documents based on their field values, using - implementations of `FilterCondition`, typically including but not limited to: + implementations of `BooleanExpr`, typically including but not limited to: - field comparators: `eq`, `lt` (less than), `gt` (greater than), etc. - logical operators: `And`, `Or`, `Not`, etc. - advanced functions: `regex_matches`, `array_contains`, etc. @@ -252,7 +250,7 @@ def where(self, condition: FilterCondition) -> "_BasePipeline": Args: - condition: The `FilterCondition` to apply. + condition: The `BooleanExpr` to apply. Returns: A new Pipeline object with this stage appended to the stage list @@ -579,7 +577,7 @@ def limit(self, limit: int) -> "_BasePipeline": def aggregate( self, - *accumulators: ExprWithAlias[Accumulator], + *accumulators: AliasedAggregate, groups: Sequence[str | Selectable] = (), ) -> "_BasePipeline": """ @@ -589,7 +587,7 @@ def aggregate( This stage allows you to calculate aggregate values (like sum, average, count, min, max) over a set of documents. - - **Accumulators:** Define the aggregation calculations using `Accumulator` + - **AggregateFunctions:** Define the aggregation calculations using `AggregateFunction` expressions (e.g., `sum()`, `avg()`, `count()`, `min()`, `max()`) combined with `as_()` to name the result field. - **Groups:** Optionally specify fields (by name or `Selectable`) to group @@ -617,7 +615,7 @@ def aggregate( Args: - *accumulators: One or more `ExprWithAlias[Accumulator]` expressions defining + *accumulators: One or more `AliasedAggregate` expressions defining the aggregations to perform and their output names. groups: An optional sequence of field names (str) or `Selectable` expressions to group by before aggregating. diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 1129ea999..9564d0d23 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -116,7 +116,7 @@ def _to_pb(self) -> Value: def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) - def add(self, other: Expr | float) -> "Add": + def add(self, other: Expr | float) -> "Expr": """Creates an expression that adds this expression to another expression or constant. Example: @@ -133,7 +133,7 @@ def add(self, other: Expr | float) -> "Add": """ return Add(self, self._cast_to_expr_or_convert_to_constant(other)) - def subtract(self, other: Expr | float) -> "Subtract": + def subtract(self, other: Expr | float) -> "Expr": """Creates an expression that subtracts another expression or constant from this expression. Example: @@ -150,7 +150,7 @@ def subtract(self, other: Expr | float) -> "Subtract": """ return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) - def multiply(self, other: Expr | float) -> "Multiply": + def multiply(self, other: Expr | float) -> "Expr": """Creates an expression that multiplies this expression by another expression or constant. Example: @@ -167,7 +167,7 @@ def multiply(self, other: Expr | float) -> "Multiply": """ return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) - def divide(self, other: Expr | float) -> "Divide": + def divide(self, other: Expr | float) -> "Expr": """Creates an expression that divides this expression by another expression or constant. Example: @@ -184,7 +184,7 @@ def divide(self, other: Expr | float) -> "Divide": """ return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) - def mod(self, other: Expr | float) -> "Mod": + def mod(self, other: Expr | float) -> "Expr": """Creates an expression that calculates the modulo (remainder) to another expression or constant. Example: @@ -201,7 +201,7 @@ def mod(self, other: Expr | float) -> "Mod": """ return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) - def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": + def logical_max(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -222,7 +222,7 @@ def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": """ return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) - def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": + def logical_min(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -243,7 +243,7 @@ def logical_min(self, other: Expr | CONSTANT_TYPE) -> "LogicalMin": """ return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) - def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": + def eq(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to another expression or constant value. @@ -261,7 +261,7 @@ def eq(self, other: Expr | CONSTANT_TYPE) -> "Eq": """ return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) - def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": + def neq(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to another expression or constant value. @@ -279,7 +279,7 @@ def neq(self, other: Expr | CONSTANT_TYPE) -> "Neq": """ return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) - def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": + def gt(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than another expression or constant value. @@ -297,7 +297,7 @@ def gt(self, other: Expr | CONSTANT_TYPE) -> "Gt": """ return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) - def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": + def gte(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. @@ -315,7 +315,7 @@ def gte(self, other: Expr | CONSTANT_TYPE) -> "Gte": """ return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) - def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": + def lt(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another expression or constant value. @@ -333,7 +333,7 @@ def lt(self, other: Expr | CONSTANT_TYPE) -> "Lt": """ return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) - def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": + def lte(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. @@ -351,7 +351,7 @@ def lte(self, other: Expr | CONSTANT_TYPE) -> "Lte": """ return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) - def in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "In": + def in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -367,7 +367,7 @@ def in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "In": """ return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) - def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Not": + def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -383,7 +383,7 @@ def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Not": """ return Not(self.in_any(array)) - def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat": + def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "Expr": """Creates an expression that concatenates an array expression with another array. Example: @@ -400,7 +400,7 @@ def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "ArrayConcat": self, [self._cast_to_expr_or_convert_to_constant(o) for o in array] ) - def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": + def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if an array contains a specific element or value. Example: @@ -419,7 +419,7 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "ArrayContains": def array_contains_all( self, elements: Sequence[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": + ) -> "BooleanExpr": """Creates an expression that checks if an array contains all the specified elements. Example: @@ -440,7 +440,7 @@ def array_contains_all( def array_contains_any( self, elements: Sequence[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": + ) -> "BooleanExpr": """Creates an expression that checks if an array contains any of the specified elements. Example: @@ -460,7 +460,7 @@ def array_contains_any( self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] ) - def array_length(self) -> "ArrayLength": + def array_length(self) -> "Expr": """Creates an expression that calculates the length of an array. Example: @@ -472,7 +472,7 @@ def array_length(self) -> "ArrayLength": """ return ArrayLength(self) - def array_reverse(self) -> "ArrayReverse": + def array_reverse(self) -> "Expr": """Creates an expression that returns the reversed content of an array. Example: @@ -484,7 +484,7 @@ def array_reverse(self) -> "ArrayReverse": """ return ArrayReverse(self) - def is_nan(self) -> "IsNaN": + def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). Example: @@ -496,7 +496,7 @@ def is_nan(self) -> "IsNaN": """ return IsNaN(self) - def exists(self) -> "Exists": + def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. Example: @@ -508,7 +508,7 @@ def exists(self) -> "Exists": """ return Exists(self) - def sum(self) -> "Sum": + def sum(self) -> "Expr": """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. Example: @@ -516,11 +516,11 @@ def sum(self) -> "Sum": >>> Field.of("orderAmount").sum().as_("totalRevenue") Returns: - A new `Accumulator` representing the 'sum' aggregation. + A new `AggregateFunction` representing the 'sum' aggregation. """ return Sum(self) - def avg(self) -> "Avg": + def avg(self) -> "Expr": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. @@ -529,11 +529,11 @@ def avg(self) -> "Avg": >>> Field.of("age").avg().as_("averageAge") Returns: - A new `Accumulator` representing the 'avg' aggregation. + A new `AggregateFunction` representing the 'avg' aggregation. """ return Avg(self) - def count(self) -> "Count": + def count(self) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. @@ -542,11 +542,11 @@ def count(self) -> "Count": >>> Field.of("productId").count().as_("totalProducts") Returns: - A new `Accumulator` representing the 'count' aggregation. + A new `AggregateFunction` representing the 'count' aggregation. """ return Count(self) - def min(self) -> "Min": + def min(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. Example: @@ -554,11 +554,11 @@ def min(self) -> "Min": >>> Field.of("price").min().as_("lowestPrice") Returns: - A new `Accumulator` representing the 'min' aggregation. + A new `AggregateFunction` representing the 'min' aggregation. """ return Min(self) - def max(self) -> "Max": + def max(self) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: @@ -566,11 +566,11 @@ def max(self) -> "Max": >>> Field.of("score").max().as_("highestScore") Returns: - A new `Accumulator` representing the 'max' aggregation. + A new `AggregateFunction` representing the 'max' aggregation. """ return Max(self) - def char_length(self) -> "CharLength": + def char_length(self) -> "Expr": """Creates an expression that calculates the character length of a string. Example: @@ -582,7 +582,7 @@ def char_length(self) -> "CharLength": """ return CharLength(self) - def byte_length(self) -> "ByteLength": + def byte_length(self) -> "Expr": """Creates an expression that calculates the byte length of a string in its UTF-8 form. Example: @@ -594,7 +594,7 @@ def byte_length(self) -> "ByteLength": """ return ByteLength(self) - def like(self, pattern: Expr | str) -> "Like": + def like(self, pattern: Expr | str) -> "BooleanExpr": """Creates an expression that performs a case-sensitive string comparison. Example: @@ -611,7 +611,7 @@ def like(self, pattern: Expr | str) -> "Like": """ return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) - def regex_contains(self, regex: Expr | str) -> "RegexContains": + def regex_contains(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string contains a specified regular expression as a substring. @@ -629,7 +629,7 @@ def regex_contains(self, regex: Expr | str) -> "RegexContains": """ return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) - def regex_matches(self, regex: Expr | str) -> "RegexMatch": + def regex_matches(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string matches a specified regular expression. Example: @@ -646,7 +646,7 @@ def regex_matches(self, regex: Expr | str) -> "RegexMatch": """ return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) - def str_contains(self, substring: Expr | str) -> "StrContains": + def str_contains(self, substring: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this string expression contains a specified substring. Example: @@ -663,7 +663,7 @@ def str_contains(self, substring: Expr | str) -> "StrContains": """ return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) - def starts_with(self, prefix: Expr | str) -> "StartsWith": + def starts_with(self, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. Example: @@ -680,7 +680,7 @@ def starts_with(self, prefix: Expr | str) -> "StartsWith": """ return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) - def ends_with(self, postfix: Expr | str) -> "EndsWith": + def ends_with(self, postfix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string ends with a given postfix. Example: @@ -697,7 +697,7 @@ def ends_with(self, postfix: Expr | str) -> "EndsWith": """ return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) - def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that concatenates string expressions, fields or constants together. Example: @@ -714,7 +714,7 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] ) - def to_lower(self) -> "ToLower": + def to_lower(self) -> "Expr": """Creates an expression that converts a string to lowercase. Example: @@ -726,7 +726,7 @@ def to_lower(self) -> "ToLower": """ return ToLower(self) - def to_upper(self) -> "ToUpper": + def to_upper(self) -> "Expr": """Creates an expression that converts a string to uppercase. Example: @@ -738,7 +738,7 @@ def to_upper(self) -> "ToUpper": """ return ToUpper(self) - def trim(self) -> "Trim": + def trim(self) -> "Expr": """Creates an expression that removes leading and trailing whitespace from a string. Example: @@ -750,7 +750,7 @@ def trim(self) -> "Trim": """ return Trim(self) - def reverse(self) -> "Reverse": + def reverse(self) -> "Expr": """Creates an expression that reverses a string. Example: @@ -762,7 +762,7 @@ def reverse(self) -> "Reverse": """ return Reverse(self) - def replace_first(self, find: Expr | str, replace: Expr | str) -> "ReplaceFirst": + def replace_first(self, find: Expr | str, replace: Expr | str) -> "Expr": """Creates an expression that replaces the first occurrence of a substring within a string with another substring. @@ -785,7 +785,7 @@ def replace_first(self, find: Expr | str, replace: Expr | str) -> "ReplaceFirst" self._cast_to_expr_or_convert_to_constant(replace), ) - def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll": + def replace_all(self, find: Expr | str, replace: Expr | str) -> "Expr": """Creates an expression that replaces all occurrences of a substring within a string with another substring. @@ -808,7 +808,7 @@ def replace_all(self, find: Expr | str, replace: Expr | str) -> "ReplaceAll": self._cast_to_expr_or_convert_to_constant(replace), ) - def map_get(self, key: str) -> "MapGet": + def map_get(self, key: str) -> "Expr": """Accesses a value from a map (object) field using the provided key. Example: @@ -824,7 +824,7 @@ def map_get(self, key: str) -> "MapGet": """ return MapGet(self, Constant.of(key)) - def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance": + def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": """Calculates the cosine distance between two vectors. Example: @@ -843,7 +843,7 @@ def cosine_distance(self, other: Expr | list[float] | Vector) -> "CosineDistance def euclidean_distance( self, other: Expr | list[float] | Vector - ) -> "EuclideanDistance": + ) -> "Expr": """Calculates the Euclidean distance between two vectors. Example: @@ -860,7 +860,7 @@ def euclidean_distance( """ return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other)) - def dot_product(self, other: Expr | list[float] | Vector) -> "DotProduct": + def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": """Calculates the dot product between two vectors. Example: @@ -877,7 +877,7 @@ def dot_product(self, other: Expr | list[float] | Vector) -> "DotProduct": """ return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other)) - def vector_length(self) -> "VectorLength": + def vector_length(self) -> "Expr": """Creates an expression that calculates the length (dimension) of a Firestore Vector. Example: @@ -889,7 +889,7 @@ def vector_length(self) -> "VectorLength": """ return VectorLength(self) - def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": + def timestamp_to_unix_micros(self) -> "Expr": """Creates an expression that converts a timestamp to the number of microseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -904,7 +904,7 @@ def timestamp_to_unix_micros(self) -> "TimestampToUnixMicros": """ return TimestampToUnixMicros(self) - def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": + def unix_micros_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -917,7 +917,7 @@ def unix_micros_to_timestamp(self) -> "UnixMicrosToTimestamp": """ return UnixMicrosToTimestamp(self) - def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": + def timestamp_to_unix_millis(self) -> "Expr": """Creates an expression that converts a timestamp to the number of milliseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -932,7 +932,7 @@ def timestamp_to_unix_millis(self) -> "TimestampToUnixMillis": """ return TimestampToUnixMillis(self) - def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": + def unix_millis_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -945,7 +945,7 @@ def unix_millis_to_timestamp(self) -> "UnixMillisToTimestamp": """ return UnixMillisToTimestamp(self) - def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": + def timestamp_to_unix_seconds(self) -> "Expr": """Creates an expression that converts a timestamp to the number of seconds since the epoch (1970-01-01 00:00:00 UTC). @@ -960,7 +960,7 @@ def timestamp_to_unix_seconds(self) -> "TimestampToUnixSeconds": """ return TimestampToUnixSeconds(self) - def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": + def unix_seconds_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -973,7 +973,7 @@ def unix_seconds_to_timestamp(self) -> "UnixSecondsToTimestamp": """ return UnixSecondsToTimestamp(self) - def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd": + def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that adds a specified amount of time to this timestamp expression. Example: @@ -996,7 +996,7 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "TimestampAdd self._cast_to_expr_or_convert_to_constant(amount), ) - def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "TimestampSub": + def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that subtracts a specified amount of time from this timestamp expression. Example: @@ -1043,7 +1043,7 @@ def descending(self) -> Ordering: """ return Ordering(self, Ordering.Direction.DESCENDING) - def as_(self, alias: str) -> "ExprWithAlias": + def as_(self, alias: str) -> "AliasedExpr": """Assigns an alias to this expression. Aliases are useful for renaming fields in the output of a stage or for giving meaningful @@ -1059,10 +1059,10 @@ def as_(self, alias: str) -> "ExprWithAlias": alias: The alias to assign to this expression. Returns: - A new `Selectable` (typically an `ExprWithAlias`) that wraps this + A new `Selectable` (typically an `AliasedExpr`) that wraps this expression and associates it with the provided alias. """ - return ExprWithAlias(self, alias) + return AliasedExpr(self, alias) class Constant(Expr, Generic[CONSTANT_TYPE]): @@ -1132,7 +1132,7 @@ def _to_pb(self): } ) - def add(left: Expr | str, right: Expr | float) -> "Add": + def add(left: Expr | str, right: Expr | float) -> "Expr": """Creates an expression that adds two expressions together. Example: @@ -1149,7 +1149,7 @@ def add(left: Expr | str, right: Expr | float) -> "Add": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.add(left_expr, right) - def subtract(left: Expr | str, right: Expr | float) -> "Subtract": + def subtract(left: Expr | str, right: Expr | float) -> "Expr": """Creates an expression that subtracts another expression or constant from this expression. Example: @@ -1166,7 +1166,7 @@ def subtract(left: Expr | str, right: Expr | float) -> "Subtract": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.subtract(left_expr, right) - def multiply(left: Expr | str, right: Expr | float) -> "Multiply": + def multiply(left: Expr | str, right: Expr | float) -> "Expr": """Creates an expression that multiplies this expression by another expression or constant. Example: @@ -1183,7 +1183,7 @@ def multiply(left: Expr | str, right: Expr | float) -> "Multiply": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.multiply(left_expr, right) - def divide(left: Expr | str, right: Expr | float) -> "Divide": + def divide(left: Expr | str, right: Expr | float) -> "Expr": """Creates an expression that divides this expression by another expression or constant. Example: @@ -1200,7 +1200,7 @@ def divide(left: Expr | str, right: Expr | float) -> "Divide": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.divide(left_expr, right) - def mod(left: Expr | str, right: Expr | float) -> "Mod": + def mod(left: Expr | str, right: Expr | float) -> "Expr": """Creates an expression that calculates the modulo (remainder) to another expression or constant. Example: @@ -1217,7 +1217,7 @@ def mod(left: Expr | str, right: Expr | float) -> "Mod": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.mod(left_expr, right) - def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMax": + def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -1238,7 +1238,7 @@ def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMax": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.logical_max(left_expr, right) - def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMin": + def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -1259,7 +1259,7 @@ def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMin": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.logical_min(left_expr, right) - def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Eq": + def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to another expression or constant value. @@ -1277,7 +1277,7 @@ def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Eq": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.eq(left_expr, right) - def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Neq": + def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to another expression or constant value. @@ -1295,7 +1295,7 @@ def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Neq": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.neq(left_expr, right) - def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gt": + def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than another expression or constant value. @@ -1313,7 +1313,7 @@ def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gt": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.gt(left_expr, right) - def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gte": + def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. @@ -1331,7 +1331,7 @@ def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Gte": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.gte(left_expr, right) - def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lt": + def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another expression or constant value. @@ -1349,7 +1349,7 @@ def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lt": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.lt(left_expr, right) - def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lte": + def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. @@ -1367,7 +1367,7 @@ def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Lte": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.lte(left_expr, right) - def in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "In": + def in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -1385,7 +1385,7 @@ def in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "In": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.in_any(left_expr, array) - def not_in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "Not": + def not_in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -1404,7 +1404,7 @@ def not_in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "Not": def array_contains( array: Expr | str, element: Expr | CONSTANT_TYPE - ) -> "ArrayContains": + ) -> "BooleanExpr": """Creates an expression that checks if an array contains a specific element or value. Example: @@ -1423,7 +1423,7 @@ def array_contains( def array_contains_all( array: Expr | str, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAll": + ) -> "BooleanExpr": """Creates an expression that checks if an array contains all the specified elements. Example: @@ -1442,7 +1442,7 @@ def array_contains_all( def array_contains_any( array: Expr | str, elements: List[Expr | CONSTANT_TYPE] - ) -> "ArrayContainsAny": + ) -> "BooleanExpr": """Creates an expression that checks if an array contains any of the specified elements. Example: @@ -1459,7 +1459,7 @@ def array_contains_any( array_expr = Field.of(array) if isinstance(array, str) else array return Expr.array_contains_any(array_expr, elements) - def array_length(array: Expr | str) -> "ArrayLength": + def array_length(array: Expr | str) -> "Expr": """Creates an expression that calculates the length of an array. Example: @@ -1471,7 +1471,7 @@ def array_length(array: Expr | str) -> "ArrayLength": array_expr = Field.of(array) if isinstance(array, str) else array return Expr.array_length(array_expr) - def array_reverse(array: Expr | str) -> "ArrayReverse": + def array_reverse(array: Expr | str) -> "Expr": """Creates an expression that returns the reversed content of an array. Example: @@ -1483,7 +1483,7 @@ def array_reverse(array: Expr | str) -> "ArrayReverse": array_expr = Field.of(array) if isinstance(array, str) else array return Expr.array_reverse(array_expr) - def is_nan(expr: Expr | str) -> "IsNaN": + def is_nan(expr: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). Example: @@ -1495,7 +1495,7 @@ def is_nan(expr: Expr | str) -> "IsNaN": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.is_nan(expr_val) - def exists(expr: Expr | str) -> "Exists": + def exists(expr: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. Example: @@ -1507,19 +1507,19 @@ def exists(expr: Expr | str) -> "Exists": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.exists(expr_val) - def sum(expr: Expr | str) -> "Sum": + def sum(expr: Expr | str) -> "Expr": """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. Example: >>> Function.sum("orderAmount") Returns: - A new `Accumulator` representing the 'sum' aggregation. + A new `AggregateFunction` representing the 'sum' aggregation. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.sum(expr_val) - def avg(expr: Expr | str) -> "Avg": + def avg(expr: Expr | str) -> "Expr": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. @@ -1527,12 +1527,12 @@ def avg(expr: Expr | str) -> "Avg": >>> Function.avg("age") Returns: - A new `Accumulator` representing the 'avg' aggregation. + A new `AggregateFunction` representing the 'avg' aggregation. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.avg(expr_val) - def count(expr: Expr | str | None = None) -> "Count": + def count(expr: Expr | str | None = None) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. If no expression is provided, it counts all inputs. @@ -1541,38 +1541,38 @@ def count(expr: Expr | str | None = None) -> "Count": >>> Function.count() Returns: - A new `Accumulator` representing the 'count' aggregation. + A new `AggregateFunction` representing the 'count' aggregation. """ if expr is None: return Count() expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.count(expr_val) - def min(expr: Expr | str) -> "Min": + def min(expr: Expr | str) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. Example: >>> Function.min("price") Returns: - A new `Accumulator` representing the 'min' aggregation. + A new `AggregateFunction` representing the 'min' aggregation. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.min(expr_val) - def max(expr: Expr | str) -> "Max": + def max(expr: Expr | str) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: >>> Function.max("score") Returns: - A new `Accumulator` representing the 'max' aggregation. + A new `AggregateFunction` representing the 'max' aggregation. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.max(expr_val) - def char_length(expr: Expr | str) -> "CharLength": + def char_length(expr: Expr | str) -> "Expr": """Creates an expression that calculates the character length of a string. Example: @@ -1584,7 +1584,7 @@ def char_length(expr: Expr | str) -> "CharLength": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.char_length(expr_val) - def byte_length(expr: Expr | str) -> "ByteLength": + def byte_length(expr: Expr | str) -> "Expr": """Creates an expression that calculates the byte length of a string in its UTF-8 form. Example: @@ -1596,7 +1596,7 @@ def byte_length(expr: Expr | str) -> "ByteLength": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.byte_length(expr_val) - def like(expr: Expr | str, pattern: Expr | str) -> "Like": + def like(expr: Expr | str, pattern: Expr | str) -> "BooleanExpr": """Creates an expression that performs a case-sensitive string comparison. Example: @@ -1613,7 +1613,7 @@ def like(expr: Expr | str, pattern: Expr | str) -> "Like": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.like(expr_val, pattern) - def regex_contains(expr: Expr | str, regex: Expr | str) -> "RegexContains": + def regex_contains(expr: Expr | str, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string contains a specified regular expression as a substring. @@ -1631,7 +1631,7 @@ def regex_contains(expr: Expr | str, regex: Expr | str) -> "RegexContains": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.regex_contains(expr_val, regex) - def regex_matches(expr: Expr | str, regex: Expr | str) -> "RegexMatch": + def regex_matches(expr: Expr | str, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string matches a specified regular expression. Example: @@ -1649,7 +1649,7 @@ def regex_matches(expr: Expr | str, regex: Expr | str) -> "RegexMatch": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.regex_matches(expr_val, regex) - def str_contains(expr: Expr | str, substring: Expr | str) -> "StrContains": + def str_contains(expr: Expr | str, substring: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this string expression contains a specified substring. Example: @@ -1666,7 +1666,7 @@ def str_contains(expr: Expr | str, substring: Expr | str) -> "StrContains": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.str_contains(expr_val, substring) - def starts_with(expr: Expr | str, prefix: Expr | str) -> "StartsWith": + def starts_with(expr: Expr | str, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. Example: @@ -1683,7 +1683,7 @@ def starts_with(expr: Expr | str, prefix: Expr | str) -> "StartsWith": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.starts_with(expr_val, prefix) - def ends_with(expr: Expr | str, postfix: Expr | str) -> "EndsWith": + def ends_with(expr: Expr | str, postfix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string ends with a given postfix. Example: @@ -1700,7 +1700,7 @@ def ends_with(expr: Expr | str, postfix: Expr | str) -> "EndsWith": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.ends_with(expr_val, postfix) - def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "StrConcat": + def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that concatenates string expressions, fields or constants together. Example: @@ -1716,7 +1716,7 @@ def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "StrConcat first_expr = Field.of(first) if isinstance(first, str) else first return Expr.str_concat(first_expr, *elements) - def map_get(map_expr: Expr | str, key: str) -> "MapGet": + def map_get(map_expr: Expr | str, key: str) -> "Expr": """Accesses a value from a map (object) field using the provided key. Example: @@ -1732,7 +1732,7 @@ def map_get(map_expr: Expr | str, key: str) -> "MapGet": map_val = Field.of(map_expr) if isinstance(map_expr, str) else map_expr return Expr.map_get(map_val, key) - def vector_length(vector_expr: Expr | str) -> "VectorLength": + def vector_length(vector_expr: Expr | str) -> "Expr": """Creates an expression that calculates the length (dimension) of a Firestore Vector. Example: @@ -1746,7 +1746,7 @@ def vector_length(vector_expr: Expr | str) -> "VectorLength": ) return Expr.vector_length(vector_val) - def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "TimestampToUnixMicros": + def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "Expr": """Creates an expression that converts a timestamp to the number of microseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1765,7 +1765,7 @@ def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "TimestampToUnixMicr ) return Expr.timestamp_to_unix_micros(timestamp_val) - def unix_micros_to_timestamp(micros_expr: Expr | str) -> "UnixMicrosToTimestamp": + def unix_micros_to_timestamp(micros_expr: Expr | str) -> "Expr": """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1780,7 +1780,7 @@ def unix_micros_to_timestamp(micros_expr: Expr | str) -> "UnixMicrosToTimestamp" ) return Expr.unix_micros_to_timestamp(micros_val) - def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "TimestampToUnixMillis": + def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "Expr": """Creates an expression that converts a timestamp to the number of milliseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1799,7 +1799,7 @@ def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "TimestampToUnixMill ) return Expr.timestamp_to_unix_millis(timestamp_val) - def unix_millis_to_timestamp(millis_expr: Expr | str) -> "UnixMillisToTimestamp": + def unix_millis_to_timestamp(millis_expr: Expr | str) -> "Expr": """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1816,7 +1816,7 @@ def unix_millis_to_timestamp(millis_expr: Expr | str) -> "UnixMillisToTimestamp" def timestamp_to_unix_seconds( timestamp_expr: Expr | str, - ) -> "TimestampToUnixSeconds": + ) -> "Expr": """Creates an expression that converts a timestamp to the number of seconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1835,7 +1835,7 @@ def timestamp_to_unix_seconds( ) return Expr.timestamp_to_unix_seconds(timestamp_val) - def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "UnixSecondsToTimestamp": + def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "Expr": """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1852,7 +1852,7 @@ def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "UnixSecondsToTimesta def timestamp_add( timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "TimestampAdd": + ) -> "Expr": """Creates an expression that adds a specified amount of time to this timestamp expression. Example: @@ -1875,7 +1875,7 @@ def timestamp_add( def timestamp_sub( timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "TimestampSub": + ) -> "Expr": """Creates an expression that subtracts a specified amount of time from this timestamp expression. Example: @@ -2103,7 +2103,7 @@ def __init__(self): class ArrayFilter(Function): """Represents filtering elements from an array based on a condition.""" - def __init__(self, array: Expr, filter: "FilterCondition"): + def __init__(self, array: Expr, filter: "BooleanExpr"): super().__init__("array_filter", [array, filter]) @@ -2156,39 +2156,54 @@ def __init__(self, vector1: Expr, vector2: Expr): super().__init__("cosine_distance", [vector1, vector2]) -class Accumulator(Function): +class AggregateFunction(Function): """A base class for aggregation functions that operate across multiple inputs.""" + def as_(self, alias: str) -> "AliasedAggregate": + """Assigns an alias to this expression. + + Aliases are useful for renaming fields in the output of a stage or for giving meaningful + names to calculated values. + + Args: + alias: The alias to assign to this expression. -class Max(Accumulator): + Returns: A new AliasedAggregate that wraps this expression and associates it with the + provided alias. + """ + return AliasedAggregate(self, alias) + + + +class Max(AggregateFunction): """Represents the maximum aggregation function.""" def __init__(self, value: Expr): super().__init__("maximum", [value]) -class Min(Accumulator): +class Min(AggregateFunction): """Represents the minimum aggregation function.""" def __init__(self, value: Expr): super().__init__("minimum", [value]) -class Sum(Accumulator): +class Sum(AggregateFunction): """Represents the sum aggregation function.""" def __init__(self, value: Expr): super().__init__("sum", [value]) -class Avg(Accumulator): +class Avg(AggregateFunction): """Represents the average aggregation function.""" def __init__(self, value: Expr): super().__init__("avg", [value]) -class Count(Accumulator): +class Count(AggregateFunction): """Represents an aggregation that counts the total number of inputs.""" def __init__(self, value: Expr | None = None): @@ -2234,7 +2249,7 @@ def _to_value(field_list: Sequence[Selectable]) -> Value: T = TypeVar("T", bound=Expr) -class ExprWithAlias(Selectable, Generic[T]): +class AliasedExpr(Selectable, Generic[T]): """Wraps an expression with an alias.""" def __init__(self, expr: T, alias: str): @@ -2251,6 +2266,23 @@ def _to_pb(self): return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) +class AliasedAggregate: + """Wraps an aggregate with an alias""" + + def __init__(self, expr: AggregateFunction, alias: str): + self.expr = expr + self.alias = alias + + def _to_map(self): + return self.alias, self.expr._to_pb() + + def __repr__(self): + return f"{self.expr}.as_('{self.alias}')" + + def _to_pb(self): + return Value(map_value={"fields": {self.alias: self.expr._to_pb()}}) + + class Field(Selectable): """Represents a reference to a field within a document.""" @@ -2288,7 +2320,7 @@ def _to_pb(self): return Value(field_reference_value=self.path) -class FilterCondition(Function): +class BooleanExpr(Function): """Filters the given data in some way.""" def __init__( @@ -2304,7 +2336,7 @@ def __init__( def __repr__(self): """ - Most FilterConditions can be triggered infix. Eg: Field.of('age').gte(18). + Most BooleanExprs can be triggered infix. Eg: Field.of('age').gte(18). Display them this way in the repr string where possible """ @@ -2320,7 +2352,7 @@ def __repr__(self): def _from_query_filter_pb(filter_pb, client): if isinstance(filter_pb, Query_pb.CompositeFilter): sub_filters = [ - FilterCondition._from_query_filter_pb(f, client) + BooleanExpr._from_query_filter_pb(f, client) for f in filter_pb.filters ] if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: @@ -2375,82 +2407,82 @@ def _from_query_filter_pb(filter_pb, client): or filter_pb.field_filter or filter_pb.unary_filter ) - return FilterCondition._from_query_filter_pb(f, client) + return BooleanExpr._from_query_filter_pb(f, client) else: raise TypeError(f"Unexpected filter type: {type(filter_pb)}") -class And(FilterCondition): - def __init__(self, *conditions: "FilterCondition"): +class And(BooleanExpr): + def __init__(self, *conditions: "BooleanExpr"): super().__init__("and", conditions, use_infix_repr=False) -class ArrayContains(FilterCondition): +class ArrayContains(BooleanExpr): def __init__(self, array: Expr, element: Expr): super().__init__( "array_contains", [array, element] ) -class ArrayContainsAll(FilterCondition): +class ArrayContainsAll(BooleanExpr): """Represents checking if an array contains all specified elements.""" def __init__(self, array: Expr, elements: Sequence[Expr]): super().__init__("array_contains_all", [array, ListOfExprs(elements)]) -class ArrayContainsAny(FilterCondition): +class ArrayContainsAny(BooleanExpr): """Represents checking if an array contains any of the specified elements.""" def __init__(self, array: Expr, elements: Sequence[Expr]): super().__init__("array_contains_any", [array, ListOfExprs(elements)]) -class EndsWith(FilterCondition): +class EndsWith(BooleanExpr): """Represents checking if a string ends with a specific postfix.""" def __init__(self, expr: Expr, postfix: Expr): super().__init__("ends_with", [expr, postfix]) -class Eq(FilterCondition): +class Eq(BooleanExpr): """Represents the equality comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("eq", [left, right]) -class Exists(FilterCondition): +class Exists(BooleanExpr): """Represents checking if a field exists.""" def __init__(self, expr: Expr): super().__init__("exists", [expr]) -class Gt(FilterCondition): +class Gt(BooleanExpr): """Represents the greater than comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("gt", [left, right]) -class Gte(FilterCondition): +class Gte(BooleanExpr): """Represents the greater than or equal to comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("gte", [left, right]) -class If(FilterCondition): +class If(BooleanExpr): """Represents a conditional expression (if-then-else).""" - def __init__(self, condition: "FilterCondition", true_expr: Expr, false_expr: Expr): + def __init__(self, condition: "BooleanExpr", true_expr: Expr, false_expr: Expr): super().__init__( "if", [condition, true_expr, false_expr] ) -class In(FilterCondition): +class In(BooleanExpr): """Represents checking if an expression's value is within a list of values.""" def __init__(self, left: Expr, others: Sequence[Expr]): @@ -2459,85 +2491,85 @@ def __init__(self, left: Expr, others: Sequence[Expr]): ) -class IsNaN(FilterCondition): +class IsNaN(BooleanExpr): """Represents checking if a numeric value is NaN.""" def __init__(self, value: Expr): super().__init__("is_nan", [value]) -class Like(FilterCondition): +class Like(BooleanExpr): """Represents a case-sensitive wildcard string comparison.""" def __init__(self, expr: Expr, pattern: Expr): super().__init__("like", [expr, pattern]) -class Lt(FilterCondition): +class Lt(BooleanExpr): """Represents the less than comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("lt", [left, right]) -class Lte(FilterCondition): +class Lte(BooleanExpr): """Represents the less than or equal to comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("lte", [left, right]) -class Neq(FilterCondition): +class Neq(BooleanExpr): """Represents the inequality comparison.""" def __init__(self, left: Expr, right: Expr): super().__init__("neq", [left, right]) -class Not(FilterCondition): +class Not(BooleanExpr): """Represents the logical NOT of a filter condition.""" def __init__(self, condition: Expr): super().__init__("not", [condition], use_infix_repr=False) -class Or(FilterCondition): +class Or(BooleanExpr): """Represents the logical OR of multiple filter conditions.""" - def __init__(self, *conditions: "FilterCondition"): + def __init__(self, *conditions: "BooleanExpr"): super().__init__("or", conditions) -class RegexContains(FilterCondition): +class RegexContains(BooleanExpr): """Represents checking if a string contains a substring matching a regex.""" def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_contains", [expr, regex]) -class RegexMatch(FilterCondition): +class RegexMatch(BooleanExpr): """Represents checking if a string fully matches a regex.""" def __init__(self, expr: Expr, regex: Expr): super().__init__("regex_match", [expr, regex]) -class StartsWith(FilterCondition): +class StartsWith(BooleanExpr): """Represents checking if a string starts with a specific prefix.""" def __init__(self, expr: Expr, prefix: Expr): super().__init__("starts_with", [expr, prefix]) -class StrContains(FilterCondition): +class StrContains(BooleanExpr): """Represents checking if a string contains a specific substring.""" def __init__(self, expr: Expr, substring: Expr): super().__init__("str_contains", [expr, substring]) -class Xor(FilterCondition): +class Xor(BooleanExpr): """Represents the logical XOR of multiple filter conditions.""" - def __init__(self, conditions: Sequence["FilterCondition"]): + def __init__(self, conditions: Sequence["BooleanExpr"]): super().__init__("xor", conditions, use_infix_repr=False) diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 66d05f3db..9f7e4a48f 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -22,7 +22,7 @@ from google.cloud.firestore_v1.types.document import Value from google.cloud.firestore_v1.vector import Vector from google.cloud.firestore_v1._helpers import GeoPoint -from google.cloud.firestore_v1.pipeline_expressions import FilterCondition, ListOfExprs +from google.cloud.firestore_v1.pipeline_expressions import BooleanExpr, ListOfExprs import google.cloud.firestore_v1.pipeline_expressions as expr @@ -143,13 +143,13 @@ def test_ctor(self): ("timestamp_sub", ("hour", 2.5), expr.TimestampSub), ("ascending", (), expr.Ordering), ("descending", (), expr.Ordering), - ("as_", ("alias",), expr.ExprWithAlias), + ("as_", ("alias",), expr.AliasedExpr), ], ) @pytest.mark.parametrize("base_instance", [expr.Constant(1), expr.Function.add("1", 1), expr.Field.of("test"), expr.Constant(1).as_("one")]) def test_infix_call(self, method, args, result_cls, base_instance): """ - many FilterCondition expressions support infix execution, and are exposed as methods on Expr. Test calling them + many BooleanExpr expressions support infix execution, and are exposed as methods on Expr. Test calling them """ method_ptr = getattr(base_instance, method) @@ -363,7 +363,7 @@ def test_to_map(self): assert result[0] == "field1" assert result[1] == Value(field_reference_value="field1") - class TestExprWithAlias: + class TestAliasedExpr: def test_repr(self): instance = expr.Field.of("field1").as_("alias1") assert repr(instance) == "Field.of('field1').as_('alias1')" @@ -371,14 +371,14 @@ def test_repr(self): def test_ctor(self): arg = expr.Field.of("field1") alias = "alias1" - instance = expr.ExprWithAlias(arg, alias) + instance = expr.AliasedExpr(arg, alias) assert instance.expr == arg assert instance.alias == alias def test_to_pb(self): arg = expr.Field.of("field1") alias = "alias1" - instance = expr.ExprWithAlias(arg, alias) + instance = expr.AliasedExpr(arg, alias) result = instance._to_pb() assert result.map_value.fields.get("alias1") == arg._to_pb() @@ -389,7 +389,7 @@ def test_to_map(self): assert result[1] == Value(field_reference_value="field1") -class TestFilterCondition: +class TestBooleanExpr: def test__from_query_filter_pb_composite_filter_or(self, mock_client): """ test composite OR filters @@ -417,7 +417,7 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): composite_filter=composite_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks expected_cond1 = expr.And( @@ -460,7 +460,7 @@ def test__from_query_filter_pb_composite_filter_and(self, mock_client): composite_filter=composite_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) # should include existance checks expected_cond1 = expr.And( @@ -511,7 +511,7 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): composite_filter=outer_or_pb ) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) expected_cond1 = expr.And( expr.Exists(expr.Field.of("field1")), @@ -548,7 +548,7 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): ) with pytest.raises(TypeError, match="Unexpected CompositeFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, expected_expr_func", @@ -581,7 +581,7 @@ def test__from_query_filter_pb_unary_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) field_expr_inst = expr.Field.of(field_path) expected_condition = expected_expr_func(field_expr_inst) @@ -602,7 +602,7 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(unary_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected UnaryFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) @pytest.mark.parametrize( "op_enum, value, expected_expr_func", @@ -654,7 +654,7 @@ def test__from_query_filter_pb_field_filter( ) wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) - result = FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + result = BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) field_expr = expr.Field.of(field_path) # convert values into constants @@ -683,7 +683,7 @@ def test__from_query_filter_pb_field_filter_unknown_op(self, mock_client): wrapped_filter_pb = query_pb.StructuredQuery.Filter(field_filter=filter_pb) with pytest.raises(TypeError, match="Unexpected FieldFilter operator type"): - FilterCondition._from_query_filter_pb(wrapped_filter_pb, mock_client) + BooleanExpr._from_query_filter_pb(wrapped_filter_pb, mock_client) def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ @@ -691,12 +691,12 @@ def test__from_query_filter_pb_unknown_filter_type(self, mock_client): """ # Test with an unexpected protobuf type with pytest.raises(TypeError, match="Unexpected filter type"): - FilterCondition._from_query_filter_pb(document_pb.Value(), mock_client) + BooleanExpr._from_query_filter_pb(document_pb.Value(), mock_client) -class TestFilterConditionClasses: +class TestBooleanExprClasses: """ - contains test methods for each Expr class that derives from FilterCondition + contains test methods for each Expr class that derives from BooleanExpr """ def _make_arg(self, name="Mock"): From 8233f47ae269cc374cab4aab84da28123b0e6072 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 16 Oct 2025 16:01:52 -0700 Subject: [PATCH 07/16] renamed expressions --- .../firestore_v1/pipeline_expressions.py | 250 +++++++++--------- 1 file changed, 125 insertions(+), 125 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 9564d0d23..5cef1c225 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -201,7 +201,7 @@ def mod(self, other: Expr | float) -> "Expr": """ return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) - def logical_max(self, other: Expr | CONSTANT_TYPE) -> "Expr": + def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -210,19 +210,19 @@ def logical_max(self, other: Expr | CONSTANT_TYPE) -> "Expr": Example: >>> # Returns the larger value between the 'discount' field and the 'cap' field. - >>> Field.of("discount").logical_max(Field.of("cap")) + >>> Field.of("discount").logical_maximum(Field.of("cap")) >>> # Returns the larger value between the 'value' field and 10. - >>> Field.of("value").logical_max(10) + >>> Field.of("value").logical_maximum(10) Args: other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical max operation. + A new `Expr` representing the logical maximum operation. """ - return LogicalMax(self, self._cast_to_expr_or_convert_to_constant(other)) + return LogicalMaximum(self, self._cast_to_expr_or_convert_to_constant(other)) - def logical_min(self, other: Expr | CONSTANT_TYPE) -> "Expr": + def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -231,27 +231,27 @@ def logical_min(self, other: Expr | CONSTANT_TYPE) -> "Expr": Example: >>> # Returns the smaller value between the 'discount' field and the 'floor' field. - >>> Field.of("discount").logical_min(Field.of("floor")) + >>> Field.of("discount").logical_minimum(Field.of("floor")) >>> # Returns the smaller value between the 'value' field and 10. - >>> Field.of("value").logical_min(10) + >>> Field.of("value").logical_minimum(10) Args: other: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical min operation. + A new `Expr` representing the logical minimum operation. """ - return LogicalMin(self, self._cast_to_expr_or_convert_to_constant(other)) + return LogicalMinimum(self, self._cast_to_expr_or_convert_to_constant(other)) - def eq(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to another expression or constant value. Example: >>> # Check if the 'age' field is equal to 21 - >>> Field.of("age").eq(21) + >>> Field.of("age").equal(21) >>> # Check if the 'city' field is equal to "London" - >>> Field.of("city").eq("London") + >>> Field.of("city").equal("London") Args: other: The expression or constant value to compare for equality. @@ -259,17 +259,17 @@ def eq(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": Returns: A new `Expr` representing the equality comparison. """ - return Eq(self, self._cast_to_expr_or_convert_to_constant(other)) + return Equal(self, self._cast_to_expr_or_convert_to_constant(other)) - def neq(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def not_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to another expression or constant value. Example: >>> # Check if the 'status' field is not equal to "completed" - >>> Field.of("status").neq("completed") + >>> Field.of("status").not_equal("completed") >>> # Check if the 'country' field is not equal to "USA" - >>> Field.of("country").neq("USA") + >>> Field.of("country").not_equal("USA") Args: other: The expression or constant value to compare for inequality. @@ -277,17 +277,17 @@ def neq(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": Returns: A new `Expr` representing the inequality comparison. """ - return Neq(self, self._cast_to_expr_or_convert_to_constant(other)) + return NotEqual(self, self._cast_to_expr_or_convert_to_constant(other)) - def gt(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def greater_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than another expression or constant value. Example: >>> # Check if the 'age' field is greater than the 'limit' field - >>> Field.of("age").gt(Field.of("limit")) + >>> Field.of("age").greater_than(Field.of("limit")) >>> # Check if the 'price' field is greater than 100 - >>> Field.of("price").gt(100) + >>> Field.of("price").greater_than(100) Args: other: The expression or constant value to compare for greater than. @@ -295,17 +295,17 @@ def gt(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": Returns: A new `Expr` representing the greater than comparison. """ - return Gt(self, self._cast_to_expr_or_convert_to_constant(other)) + return GreaterThan(self, self._cast_to_expr_or_convert_to_constant(other)) - def gte(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. Example: >>> # Check if the 'quantity' field is greater than or equal to field 'requirement' plus 1 - >>> Field.of("quantity").gte(Field.of('requirement').add(1)) + >>> Field.of("quantity").greater_than_or_equal(Field.of('requirement').add(1)) >>> # Check if the 'score' field is greater than or equal to 80 - >>> Field.of("score").gte(80) + >>> Field.of("score").greater_than_or_equal(80) Args: other: The expression or constant value to compare for greater than or equal to. @@ -313,17 +313,17 @@ def gte(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": Returns: A new `Expr` representing the greater than or equal to comparison. """ - return Gte(self, self._cast_to_expr_or_convert_to_constant(other)) + return GreaterThanOrEqual(self, self._cast_to_expr_or_convert_to_constant(other)) - def lt(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another expression or constant value. Example: >>> # Check if the 'age' field is less than 'limit' - >>> Field.of("age").lt(Field.of('limit')) + >>> Field.of("age").less_than(Field.of('limit')) >>> # Check if the 'price' field is less than 50 - >>> Field.of("price").lt(50) + >>> Field.of("price").less_than(50) Args: other: The expression or constant value to compare for less than. @@ -331,17 +331,17 @@ def lt(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": Returns: A new `Expr` representing the less than comparison. """ - return Lt(self, self._cast_to_expr_or_convert_to_constant(other)) + return LessThan(self, self._cast_to_expr_or_convert_to_constant(other)) - def lte(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. Example: >>> # Check if the 'quantity' field is less than or equal to 20 - >>> Field.of("quantity").lte(Constant.of(20)) + >>> Field.of("quantity").less_than_or_equal(Constant.of(20)) >>> # Check if the 'score' field is less than or equal to 70 - >>> Field.of("score").lte(70) + >>> Field.of("score").less_than_or_equal(70) Args: other: The expression or constant value to compare for less than or equal to. @@ -349,7 +349,7 @@ def lte(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": Returns: A new `Expr` representing the less than or equal to comparison. """ - return Lte(self, self._cast_to_expr_or_convert_to_constant(other)) + return LessThanOrEqual(self, self._cast_to_expr_or_convert_to_constant(other)) def in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the @@ -520,18 +520,18 @@ def sum(self) -> "Expr": """ return Sum(self) - def avg(self) -> "Expr": + def average(self) -> "Expr": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. Example: >>> # Calculate the average age of users - >>> Field.of("age").avg().as_("averageAge") + >>> Field.of("age").average().as_("averageAge") Returns: A new `AggregateFunction` representing the 'avg' aggregation. """ - return Avg(self) + return Average(self) def count(self) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the @@ -546,24 +546,24 @@ def count(self) -> "Expr": """ return Count(self) - def min(self) -> "Expr": + def minimum(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. Example: >>> # Find the lowest price of all products - >>> Field.of("price").min().as_("lowestPrice") + >>> Field.of("price").minimum().as_("lowestPrice") Returns: - A new `AggregateFunction` representing the 'min' aggregation. + A new `AggregateFunction` representing the 'minimum' aggregation. """ return Min(self) - def max(self) -> "Expr": + def maxiumum(self) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: >>> # Find the highest score in a leaderboard - >>> Field.of("score").max().as_("highestScore") + >>> Field.of("score").maxiumum().as_("highestScore") Returns: A new `AggregateFunction` representing the 'max' aggregation. @@ -646,14 +646,14 @@ def regex_matches(self, regex: Expr | str) -> "BooleanExpr": """ return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) - def str_contains(self, substring: Expr | str) -> "BooleanExpr": + def string_contains(self, substring: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this string expression contains a specified substring. Example: >>> # Check if the 'description' field contains "example". - >>> Field.of("description").str_contains("example") + >>> Field.of("description").string_contains("example") >>> # Check if the 'description' field contains the value of the 'keyword' field. - >>> Field.of("description").str_contains(Field.of("keyword")) + >>> Field.of("description").string_contains(Field.of("keyword")) Args: substring: The substring (string or expression) to use for the search. @@ -661,7 +661,7 @@ def str_contains(self, substring: Expr | str) -> "BooleanExpr": Returns: A new `Expr` representing the 'contains' comparison. """ - return StrContains(self, self._cast_to_expr_or_convert_to_constant(substring)) + return StringContains(self, self._cast_to_expr_or_convert_to_constant(substring)) def starts_with(self, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. @@ -697,12 +697,12 @@ def ends_with(self, postfix: Expr | str) -> "BooleanExpr": """ return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) - def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": + def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that concatenates string expressions, fields or constants together. Example: >>> # Combine the 'firstName', " ", and 'lastName' fields into a single string - >>> Field.of("firstName").str_concat(" ", Field.of("lastName")) + >>> Field.of("firstName").string_concat(" ", Field.of("lastName")) Args: *elements: The expressions or constants (typically strings) to concatenate. @@ -710,7 +710,7 @@ def str_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": Returns: A new `Expr` representing the concatenated string. """ - return StrConcat( + return StringConcat( self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] ) @@ -996,14 +996,14 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": self._cast_to_expr_or_convert_to_constant(amount), ) - def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "Expr": + def timestamp_subtract(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that subtracts a specified amount of time from this timestamp expression. Example: >>> # Subtract a duration specified by the 'unit' and 'amount' fields from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub(Field.of("unit"), Field.of("amount")) + >>> Field.of("timestamp").timestamp_subtract(Field.of("unit"), Field.of("amount")) >>> # Subtract 2.5 hours from the 'timestamp' field. - >>> Field.of("timestamp").timestamp_sub("hour", 2.5) + >>> Field.of("timestamp").timestamp_subtract("hour", 2.5) Args: unit: The expression or string evaluating to the unit of time to subtract, must be one of @@ -1013,7 +1013,7 @@ def timestamp_sub(self, unit: Expr | str, amount: Expr | float) -> "Expr": Returns: A new `Expr` representing the resulting timestamp. """ - return TimestampSub( + return TimestampSubtract( self, self._cast_to_expr_or_convert_to_constant(unit), self._cast_to_expr_or_convert_to_constant(amount), @@ -1217,7 +1217,7 @@ def mod(left: Expr | str, right: Expr | float) -> "Expr": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.mod(left_expr, right) - def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": + def logical_maximum(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -1225,8 +1225,8 @@ def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering Example: - >>> Function.logical_max("value", 10) - >>> Function.logical_max(Field.of("discount"), Field.of("cap")) + >>> Function.logical_maximum("value", 10) + >>> Function.logical_maximum(Field.of("discount"), Field.of("cap")) Args: left: The expression or field path to compare. @@ -1236,9 +1236,9 @@ def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": A new `Expr` representing the logical max operation. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_max(left_expr, right) + return Expr.logical_maximum(left_expr, right) - def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": + def logical_minimum(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -1246,26 +1246,26 @@ def logical_min(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering Example: - >>> Function.logical_min("value", 10) - >>> Function.logical_min(Field.of("discount"), Field.of("floor")) + >>> Function.logical_minimum("value", 10) + >>> Function.logical_minimum(Field.of("discount"), Field.of("floor")) Args: left: The expression or field path to compare. right: The other expression or constant value to compare with. Returns: - A new `Expr` representing the logical min operation. + A new `Expr` representing the logical minimum operation. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_min(left_expr, right) + return Expr.logical_minimum(left_expr, right) - def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to another expression or constant value. Example: - >>> Function.eq("city", "London") - >>> Function.eq(Field.of("age"), 21) + >>> Function.equal("city", "London") + >>> Function.equal(Field.of("age"), 21) Args: left: The expression or field path to compare. @@ -1275,15 +1275,15 @@ def eq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": A new `Expr` representing the equality comparison. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.eq(left_expr, right) + return Expr.equal(left_expr, right) - def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def not_equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to another expression or constant value. Example: - >>> Function.neq("country", "USA") - >>> Function.neq(Field.of("status"), "completed") + >>> Function.not_equal("country", "USA") + >>> Function.not_equal(Field.of("status"), "completed") Args: left: The expression or field path to compare. @@ -1293,15 +1293,15 @@ def neq(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": A new `Expr` representing the inequality comparison. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.neq(left_expr, right) + return Expr.not_equal(left_expr, right) - def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def greater_than(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than another expression or constant value. Example: - >>> Function.gt("price", 100) - >>> Function.gt(Field.of("age"), Field.of("limit")) + >>> Function.greater_than("price", 100) + >>> Function.greater_than(Field.of("age"), Field.of("limit")) Args: left: The expression or field path to compare. @@ -1311,15 +1311,15 @@ def gt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": A new `Expr` representing the greater than comparison. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.gt(left_expr, right) + return Expr.greater_than(left_expr, right) - def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def greater_than_or_equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. Example: - >>> Function.gte("score", 80) - >>> Function.gte(Field.of("quantity"), Field.of('requirement').add(1)) + >>> Function.greater_than_or_equal("score", 80) + >>> Function.greater_than_or_equal(Field.of("quantity"), Field.of('requirement').add(1)) Args: left: The expression or field path to compare. @@ -1329,15 +1329,15 @@ def gte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": A new `Expr` representing the greater than or equal to comparison. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.gte(left_expr, right) + return Expr.greater_than_or_equal(left_expr, right) - def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def less_than(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another expression or constant value. Example: - >>> Function.lt("price", 50) - >>> Function.lt(Field.of("age"), Field.of('limit')) + >>> Function.less_than("price", 50) + >>> Function.less_than(Field.of("age"), Field.of('limit')) Args: left: The expression or field path to compare. @@ -1347,15 +1347,15 @@ def lt(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": A new `Expr` representing the less than comparison. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.lt(left_expr, right) + return Expr.less_than(left_expr, right) - def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def less_than_or_equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. Example: - >>> Function.lte("score", 70) - >>> Function.lte(Field.of("quantity"), Constant.of(20)) + >>> Function.less_than_or_equal("score", 70) + >>> Function.less_than_or_equal(Field.of("quantity"), Constant.of(20)) Args: left: The expression or field path to compare. @@ -1365,7 +1365,7 @@ def lte(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": A new `Expr` representing the less than or equal to comparison. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.lte(left_expr, right) + return Expr.less_than_or_equal(left_expr, right) def in_any(left: Expr | str, array: List[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the @@ -1519,18 +1519,18 @@ def sum(expr: Expr | str) -> "Expr": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.sum(expr_val) - def avg(expr: Expr | str) -> "Expr": + def average(expr: Expr | str) -> "Expr": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. Example: - >>> Function.avg("age") + >>> Function.average("age") Returns: - A new `AggregateFunction` representing the 'avg' aggregation. + A new `AggregateFunction` representing the 'average' aggregation. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.avg(expr_val) + return Expr.average(expr_val) def count(expr: Expr | str | None = None) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the @@ -1548,29 +1548,29 @@ def count(expr: Expr | str | None = None) -> "Expr": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.count(expr_val) - def min(expr: Expr | str) -> "Expr": + def minimum(expr: Expr | str) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. Example: - >>> Function.min("price") + >>> Function.minimum("price") Returns: - A new `AggregateFunction` representing the 'min' aggregation. + A new `AggregateFunction` representing the 'minimum' aggregation. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.min(expr_val) + return Expr.minimum(expr_val) - def max(expr: Expr | str) -> "Expr": + def maxiumum(expr: Expr | str) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: - >>> Function.max("score") + >>> Function.maxiumum("score") Returns: - A new `AggregateFunction` representing the 'max' aggregation. + A new `AggregateFunction` representing the 'maximum' aggregation. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.max(expr_val) + return Expr.maxiumum(expr_val) def char_length(expr: Expr | str) -> "Expr": """Creates an expression that calculates the character length of a string. @@ -1649,12 +1649,12 @@ def regex_matches(expr: Expr | str, regex: Expr | str) -> "BooleanExpr": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.regex_matches(expr_val, regex) - def str_contains(expr: Expr | str, substring: Expr | str) -> "BooleanExpr": + def string_contains(expr: Expr | str, substring: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this string expression contains a specified substring. Example: - >>> Function.str_contains("description", "example") - >>> Function.str_contains(Field.of("description"), Field.of("keyword")) + >>> Function.string_contains("description", "example") + >>> Function.string_contains(Field.of("description"), Field.of("keyword")) Args: expr: The expression or field path to perform the comparison on. @@ -1664,7 +1664,7 @@ def str_contains(expr: Expr | str, substring: Expr | str) -> "BooleanExpr": A new `Expr` representing the 'contains' comparison. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.str_contains(expr_val, substring) + return Expr.string_contains(expr_val, substring) def starts_with(expr: Expr | str, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. @@ -1700,11 +1700,11 @@ def ends_with(expr: Expr | str, postfix: Expr | str) -> "BooleanExpr": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.ends_with(expr_val, postfix) - def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "Expr": + def string_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that concatenates string expressions, fields or constants together. Example: - >>> Function.str_concat("firstName", " ", Field.of("lastName")) + >>> Function.string_concat("firstName", " ", Field.of("lastName")) Args: first: The first expression or field path to concatenate. @@ -1714,7 +1714,7 @@ def str_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "Expr": A new `Expr` representing the concatenated string. """ first_expr = Field.of(first) if isinstance(first, str) else first - return Expr.str_concat(first_expr, *elements) + return Expr.string_concat(first_expr, *elements) def map_get(map_expr: Expr | str, key: str) -> "Expr": """Accesses a value from a map (object) field using the provided key. @@ -1918,14 +1918,14 @@ def __init__(self, vector1: Expr, vector2: Expr): super().__init__("euclidean_distance", [vector1, vector2]) -class LogicalMax(Function): +class LogicalMaximum(Function): """Represents the logical maximum function based on Firestore type ordering.""" def __init__(self, left: Expr, right: Expr): super().__init__("logical_maximum", [left, right]) -class LogicalMin(Function): +class LogicalMinimum(Function): """Represents the logical minimum function based on Firestore type ordering.""" def __init__(self, left: Expr, right: Expr): @@ -1981,11 +1981,11 @@ def __init__(self, expr: Expr): super().__init__("reverse", [expr]) -class StrConcat(Function): +class StringConcat(Function): """Represents concatenating multiple strings.""" def __init__(self, *exprs: Expr): - super().__init__("str_concat", exprs) + super().__init__("string_concat", exprs) class Subtract(Function): @@ -2002,11 +2002,11 @@ def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_add", [timestamp, unit, amount]) -class TimestampSub(Function): +class TimestampSubtract(Function): """Represents subtracting a duration from a timestamp.""" def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): - super().__init__("timestamp_sub", [timestamp, unit, amount]) + super().__init__("timestamp_subtract", [timestamp, unit, amount]) class TimestampToUnixMicros(Function): @@ -2196,7 +2196,7 @@ def __init__(self, value: Expr): super().__init__("sum", [value]) -class Avg(AggregateFunction): +class Average(AggregateFunction): """Represents the average aggregation function.""" def __init__(self, value: Expr): @@ -2445,7 +2445,7 @@ def __init__(self, expr: Expr, postfix: Expr): super().__init__("ends_with", [expr, postfix]) -class Eq(BooleanExpr): +class Equal(BooleanExpr): """Represents the equality comparison.""" def __init__(self, left: Expr, right: Expr): @@ -2459,18 +2459,18 @@ def __init__(self, expr: Expr): super().__init__("exists", [expr]) -class Gt(BooleanExpr): +class GreaterThan(BooleanExpr): """Represents the greater than comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("gt", [left, right]) + super().__init__("greater_than", [left, right]) -class Gte(BooleanExpr): +class GreaterThanOrEqual(BooleanExpr): """Represents the greater than or equal to comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("gte", [left, right]) + super().__init__("greater_than_or_equal", [left, right]) class If(BooleanExpr): @@ -2505,25 +2505,25 @@ def __init__(self, expr: Expr, pattern: Expr): super().__init__("like", [expr, pattern]) -class Lt(BooleanExpr): +class LessThan(BooleanExpr): """Represents the less than comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("lt", [left, right]) + super().__init__("less_than", [left, right]) -class Lte(BooleanExpr): +class LessThanOrEqual(BooleanExpr): """Represents the less than or equal to comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("lte", [left, right]) + super().__init__("less_than_or_equal", [left, right]) -class Neq(BooleanExpr): +class NotEqual(BooleanExpr): """Represents the inequality comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("neq", [left, right]) + super().__init__("not_equal", [left, right]) class Not(BooleanExpr): @@ -2561,11 +2561,11 @@ def __init__(self, expr: Expr, prefix: Expr): super().__init__("starts_with", [expr, prefix]) -class StrContains(BooleanExpr): +class StringContains(BooleanExpr): """Represents checking if a string contains a specific substring.""" def __init__(self, expr: Expr, substring: Expr): - super().__init__("str_contains", [expr, substring]) + super().__init__("string_contains", [expr, substring]) class Xor(BooleanExpr): From 60924fe1cd7ce4763736d23a1890a3d59beb7a68 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 16 Oct 2025 16:33:37 -0700 Subject: [PATCH 08/16] added new math expressions --- .../firestore_v1/pipeline_expressions.py | 317 ++++++++++++++++++ tests/system/pipeline_e2e.yaml | 139 ++++++++ tests/unit/v1/test_pipeline_expressions.py | 84 +++++ 3 files changed, 540 insertions(+) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 1129ea999..7d6db1fa0 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -201,6 +201,124 @@ def mod(self, other: Expr | float) -> "Mod": """ return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) + def abs(self) -> "Abs": + """Creates an expression that calculates the absolute value of this expression. + + Example: + >>> # Get the absolute value of the 'change' field. + >>> Field.of("change").abs() + + Returns: + A new `Expr` representing the absolute value. + """ + return Abs(self) + + def ceil(self) -> "Ceil": + """Creates an expression that calculates the ceiling of this expression. + + Example: + >>> # Get the ceiling of the 'value' field. + >>> Field.of("value").ceil() + + Returns: + A new `Expr` representing the ceiling value. + """ + return Ceil(self) + + def exp(self) -> "Exp": + """Creates an expression that computes e to the power of this expression. + + Example: + >>> # Compute e to the power of the 'value' field + >>> Field.of("value").exp() + + Returns: + A new `Expr` representing the exponential value. + """ + return Exp(self) + + def floor(self) -> "Floor": + """Creates an expression that calculates the floor of this expression. + + Example: + >>> # Get the floor of the 'value' field. + >>> Field.of("value").floor() + + Returns: + A new `Expr` representing the floor value. + """ + return Floor(self) + + def ln(self) -> "Ln": + """Creates an expression that calculates the natural logarithm of this expression. + + Example: + >>> # Get the natural logarithm of the 'value' field. + >>> Field.of("value").ln() + + Returns: + A new `Expr` representing the natural logarithm. + """ + return Ln(self) + + def log(self, base: Expr | float) -> "Log": + """Creates an expression that calculates the logarithm of this expression with a given base. + + Example: + >>> # Get the logarithm of 'value' with base 2. + >>> Field.of("value").log(2) + >>> # Get the logarithm of 'value' with base from 'base_field'. + >>> Field.of("value").log(Field.of("base_field")) + + Args: + base: The base of the logarithm. + + Returns: + A new `Expr` representing the logarithm. + """ + return Log(self, self._cast_to_expr_or_convert_to_constant(base)) + + def pow(self, exponent: Expr | float) -> "Pow": + """Creates an expression that calculates this expression raised to the power of the exponent. + + Example: + >>> # Raise 'base_val' to the power of 2. + >>> Field.of("base_val").pow(2) + >>> # Raise 'base_val' to the power of 'exponent_val'. + >>> Field.of("base_val").pow(Field.of("exponent_val")) + + Args: + exponent: The exponent. + + Returns: + A new `Expr` representing the power operation. + """ + return Pow(self, self._cast_to_expr_or_convert_to_constant(exponent)) + + def round(self) -> "Round": + """Creates an expression that rounds this expression to the nearest integer. + + Example: + >>> # Round the 'value' field. + >>> Field.of("value").round() + + Returns: + A new `Expr` representing the rounded value. + """ + return Round(self) + + def sqrt(self) -> "Sqrt": + """Creates an expression that calculates the square root of this expression. + + Example: + >>> # Get the square root of the 'area' field. + >>> Field.of("area").sqrt() + + Returns: + A new `Expr` representing the square root. + """ + return Sqrt(self) + def logical_max(self, other: Expr | CONSTANT_TYPE) -> "LogicalMax": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -1217,6 +1335,143 @@ def mod(left: Expr | str, right: Expr | float) -> "Mod": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.mod(left_expr, right) + def abs(expr: Expr | str) -> "Abs": + """Creates an expression that calculates the absolute value of an expression. + + Example: + >>> Function.abs("change") + + Args: + expr: The expression or field path. + + Returns: + A new `Expr` representing the absolute value. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.abs(expr_val) + + def ceil(expr: Expr | str) -> "Ceil": + """Creates an expression that calculates the ceiling of an expression. + + Example: + >>> Function.ceil("value") + + Args: + expr: The expression or field path. + + Returns: + A new `Expr` representing the ceiling value. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.ceil(expr_val) + + def exp(expr: Expr | str) -> "Exp": + """Creates an expression that calculates the exponential of an expression. + + Example: + >>> Function.exp("value") + + Args: + expr: The expression or field path. + + Returns: + A new `Expr` representing the exponential value. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.exp(expr_val) + + def floor(expr: Expr | str) -> "Floor": + """Creates an expression that calculates the floor of an expression. + + Example: + >>> Function.floor("value") + + Args: + expr: The expression or field path. + + Returns: + A new `Expr` representing the floor value. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.floor(expr_val) + + def ln(expr: Expr | str) -> "Ln": + """Creates an expression that calculates the natural logarithm of an expression. + + Example: + >>> Function.ln("value") + + Args: + expr: The expression or field path. + + Returns: + A new `Expr` representing the natural logarithm. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.ln(expr_val) + + def log(expr: Expr | str, base: Expr | float) -> "Log": + """Creates an expression that calculates the logarithm of an expression with a given base. + + Example: + >>> Function.log("value", 2) + + Args: + expr: The expression or field path. + base: The base of the logarithm. + + Returns: + A new `Expr` representing the logarithm. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.log(expr_val, base) + + def pow(base: Expr | str, exponent: Expr | float) -> "Pow": + """Creates an expression that calculates the base raised to the power of the exponent. + + Example: + >>> Function.pow("base_val", 2) + + Args: + base: The base expression or field path. + exponent: The exponent. + + Returns: + A new `Expr` representing the power operation. + """ + base_val = Field.of(base) if isinstance(base, str) else base + return Expr.pow(base_val, exponent) + + def round(expr: Expr | str) -> "Round": + """Creates an expression that rounds an expression to the nearest integer. + + Example: + >>> Function.round("value") + + Args: + expr: The expression or field path. + + Returns: + A new `Expr` representing the rounded value. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.round(expr_val) + + def sqrt(expr: Expr | str) -> "Sqrt": + """Creates an expression that calculates the square root of an expression. + + Example: + >>> Function.sqrt("area") + + Args: + expr: The expression or field path. + + Returns: + A new `Expr` representing the square root. + """ + expr_val = Field.of(expr) if isinstance(expr, str) else expr + return Expr.sqrt(expr_val) + def logical_max(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "LogicalMax": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -2001,6 +2256,68 @@ class TimestampAdd(Function): def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_add", [timestamp, unit, amount]) +class Abs(Function): + """Represents the absolute value function.""" + + def __init__(self, value: Expr): + super().__init__("abs", [value]) + + +class Ceil(Function): + """Represents the ceiling function.""" + + def __init__(self, value: Expr): + super().__init__("ceil", [value]) + + +class Exp(Function): + """Represents the exponential function.""" + + def __init__(self, value: Expr): + super().__init__("exp", [value]) + + +class Floor(Function): + """Represents the floor function.""" + + def __init__(self, value: Expr): + super().__init__("floor", [value]) + + +class Ln(Function): + """Represents the natural logarithm function.""" + + def __init__(self, value: Expr): + super().__init__("ln", [value]) + + +class Log(Function): + """Represents the logarithm function.""" + + def __init__(self, value: Expr, base: Expr): + super().__init__("log", [value, base]) + + +class Pow(Function): + """Represents the power function.""" + + def __init__(self, base: Expr, exponent: Expr): + super().__init__("pow", [base, exponent]) + + +class Round(Function): + """Represents the round function.""" + + def __init__(self, value: Expr): + super().__init__("round", [value]) + + +class Sqrt(Function): + """Represents the square root function.""" + + def __init__(self, value: Expr): + super().__init__("sqrt", [value]) + class TimestampSub(Function): """Represents subtracting a duration from a timestamp.""" diff --git a/tests/system/pipeline_e2e.yaml b/tests/system/pipeline_e2e.yaml index d92397347..59206d90c 100644 --- a/tests/system/pipeline_e2e.yaml +++ b/tests/system/pipeline_e2e.yaml @@ -2002,6 +2002,145 @@ tests: - doubleValue: 0.6 - stringValue: percent name: sample + - description: testMathFunctions - Abs + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Abs: + - Subtract: + - Field: rating + - Constant: 5 + - "absRating" + assert_results: + - absRating: 0.8 + - description: testMathFunctions - Ceil + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Ceil: + - Field: rating + - "ceilRating" + assert_results: + - ceilRating: 5 + - description: testMathFunctions - Floor + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Floor: + - Field: rating + - "floorRating" + assert_results: + - floorRating: 4 + - description: testMathFunctions - Exp + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Lord of the Rings" + - Limit: 1 + - Select: + - ExprWithAlias: + - Exp: + - Field: rating + - "expRating" + assert_results: + - expRating: 109.94717245212352 + - description: testMathFunctions - Pow + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Pow: + - Field: rating + - Constant: 2 + - "powerRating" + assert_results: + - powerRating: 17.64 + - description: testMathFunctions - Round + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Round: + - Field: rating + - "roundedRating" + assert_results: + - roundedRating: 4 + - description: testMathFunctions - Ln + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Ln: + - Field: rating + - "lnRating" + assert_results: + - lnRating: 1.4350845252893227 + - description: testMathFunctions - Log + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Log: + - Field: rating + - Constant: 10 + - "logRating" + assert_results: + - logRating: 0.6232492903979004 + - description: testMathFunctions - Sqrt + pipeline: + - Collection: books + - Where: + - Eq: + - Field: title + - Constant: "The Hitchhiker's Guide to the Galaxy" + - Limit: 1 + - Select: + - ExprWithAlias: + - Sqrt: + - Field: rating + - "sqrtRating" + assert_results: + - sqrtRating: 2.04939015319192 - description: testUnion pipeline: - Collection: books diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 66d05f3db..3ab51953f 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -90,6 +90,15 @@ def test_ctor(self): ("multiply", (2,), expr.Multiply), ("divide", (2,), expr.Divide), ("mod", (2,), expr.Mod), + ("abs", (), expr.Abs), + ("ceil", (), expr.Ceil), + ("exp", (), expr.Exp), + ("floor", (), expr.Floor), + ("ln", (), expr.Ln), + ("log", (10,), expr.Log), + ("pow", (2,), expr.Pow), + ("round", (), expr.Round), + ("sqrt", (), expr.Sqrt), ("logical_max", (2,), expr.LogicalMax), ("logical_min", (2,), expr.LogicalMin), ("eq", (2,), expr.Eq), @@ -915,6 +924,15 @@ class TestFunctionClasses: ("multiply", ("field", 2), expr.Multiply), ("divide", ("field", 2), expr.Divide), ("mod", ("field", 2), expr.Mod), + ("abs", ("field",), expr.Abs), + ("ceil", ("field",), expr.Ceil), + ("exp", ("field",), expr.Exp), + ("floor", ("field",), expr.Floor), + ("ln", ("field",), expr.Ln), + ("log", ("field", 10), expr.Log), + ("pow", ("field", 2), expr.Pow), + ("round", ("field",), expr.Round), + ("sqrt", ("field",), expr.Sqrt), ("logical_max", ("field", 2), expr.LogicalMax), ("logical_min", ("field", 2), expr.LogicalMin), ("eq", ("field", 2), expr.Eq), @@ -1309,3 +1327,69 @@ def test_array_concat(self): assert instance.name == "array_concat" assert instance.params == [arg1, arg2, arg3] assert repr(instance) == "ArrayConcat(1, 2, 3)" + + def test_abs(self): + arg1 = self._make_arg("Value") + instance = expr.Abs(arg1) + assert instance.name == "abs" + assert instance.params == [arg1] + assert repr(instance) == "Abs(Value)" + + def test_ceil(self): + arg1 = self._make_arg("Value") + instance = expr.Ceil(arg1) + assert instance.name == "ceil" + assert instance.params == [arg1] + assert repr(instance) == "Ceil(Value)" + + def test_exp(self): + arg1 = self._make_arg("Value") + instance = expr.Exp(arg1) + assert instance.name == "exp" + assert instance.params == [arg1] + assert repr(instance) == "Exp(Value)" + + def test_floor(self): + arg1 = self._make_arg("Value") + instance = expr.Floor(arg1) + assert instance.name == "floor" + assert instance.params == [arg1] + assert repr(instance) == "Floor(Value)" + + def test_ln(self): + arg1 = self._make_arg("Value") + instance = expr.Ln(arg1) + assert instance.name == "ln" + assert instance.params == [arg1] + assert repr(instance) == "Ln(Value)" + + def test_log(self): + arg1 = self._make_arg("Value") + arg2 = self._make_arg("Base") + instance = expr.Log(arg1, arg2) + assert instance.name == "log" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Log(Value, Base)" + + def test_pow(self): + arg1 = self._make_arg("Base") + arg2 = self._make_arg("Exponent") + instance = expr.Pow(arg1, arg2) + assert instance.name == "pow" + assert instance.params == [arg1, arg2] + assert repr(instance) == "Pow(Base, Exponent)" + + + def test_round(self): + arg1 = self._make_arg("Value") + instance = expr.Round(arg1) + assert instance.name == "round" + assert instance.params == [arg1] + assert repr(instance) == "Round(Value)" + + def test_sqrt(self): + arg1 = self._make_arg("Value") + instance = expr.Sqrt(arg1) + assert instance.name == "sqrt" + assert instance.params == [arg1] + assert repr(instance) == "Sqrt(Value)" From 3217b0055ed9a1735d2b77f067fff2ce2d349f9b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 17 Oct 2025 14:01:03 -0700 Subject: [PATCH 09/16] fixed tests --- google/cloud/firestore_v1/base_aggregation.py | 8 +- google/cloud/firestore_v1/base_query.py | 2 +- .../firestore_v1/pipeline_expressions.py | 66 +++--- tests/unit/v1/test_aggregation.py | 24 +-- tests/unit/v1/test_async_aggregation.py | 8 +- tests/unit/v1/test_base_query.py | 2 +- tests/unit/v1/test_pipeline_expressions.py | 198 +++++++++--------- tests/unit/v1/test_pipeline_stages.py | 14 +- 8 files changed, 164 insertions(+), 158 deletions(-) diff --git a/google/cloud/firestore_v1/base_aggregation.py b/google/cloud/firestore_v1/base_aggregation.py index 89e4edd0e..ba60e1314 100644 --- a/google/cloud/firestore_v1/base_aggregation.py +++ b/google/cloud/firestore_v1/base_aggregation.py @@ -34,9 +34,9 @@ from google.cloud.firestore_v1.types import ( StructuredAggregationQuery, ) -from google.cloud.firestore_v1.pipeline_expressions import Accumulator +from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction from google.cloud.firestore_v1.pipeline_expressions import Count -from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias +from google.cloud.firestore_v1.pipeline_expressions import AliasedExpr from google.cloud.firestore_v1.pipeline_expressions import Field # Types needed only for Type Hints @@ -86,7 +86,7 @@ def _to_protobuf(self): @abc.abstractmethod def _to_pipeline_expr( self, autoindexer: Iterable[int] - ) -> ExprWithAlias[Accumulator]: + ) -> AliasedExpr[AggregateFunction]: """ Convert this instance to a pipeline expression for use with pipeline.aggregate() @@ -162,7 +162,7 @@ def _to_protobuf(self): return aggregation_pb def _to_pipeline_expr(self, autoindexer: Iterable[int]): - return Field.of(self.field_ref).avg().as_(self._pipeline_alias(autoindexer)) + return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer)) def _query_response_to_result( diff --git a/google/cloud/firestore_v1/base_query.py b/google/cloud/firestore_v1/base_query.py index 67c4f27fd..797572b1b 100644 --- a/google/cloud/firestore_v1/base_query.py +++ b/google/cloud/firestore_v1/base_query.py @@ -1151,7 +1151,7 @@ def pipeline(self): # Filters for filter_ in self._field_filters: ppl = ppl.where( - pipeline_expressions.FilterCondition._from_query_filter_pb( + pipeline_expressions.BooleanExpr._from_query_filter_pb( filter_, self._client ) ) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 59f4a5e0a..de9d89fe8 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -674,19 +674,19 @@ def minimum(self) -> "Expr": Returns: A new `AggregateFunction` representing the 'minimum' aggregation. """ - return Min(self) + return Minimum(self) - def maxiumum(self) -> "Expr": + def maximum(self) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: >>> # Find the highest score in a leaderboard - >>> Field.of("score").maxiumum().as_("highestScore") + >>> Field.of("score").maximum().as_("highestScore") Returns: - A new `AggregateFunction` representing the 'max' aggregation. + A new `AggregateFunction` representing the 'maximum' aggregation. """ - return Max(self) + return Maximum(self) def char_length(self) -> "Expr": """Creates an expression that calculates the character length of a string. @@ -1815,17 +1815,17 @@ def minimum(expr: Expr | str) -> "Expr": expr_val = Field.of(expr) if isinstance(expr, str) else expr return Expr.minimum(expr_val) - def maxiumum(expr: Expr | str) -> "Expr": + def maximum(expr: Expr | str) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. Example: - >>> Function.maxiumum("score") + >>> Function.maximum("score") Returns: A new `AggregateFunction` representing the 'maximum' aggregation. """ expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.maxiumum(expr_val) + return Expr.maximum(expr_val) def char_length(expr: Expr | str) -> "Expr": """Creates an expression that calculates the character length of a string. @@ -2128,7 +2128,7 @@ def timestamp_add( ) return Expr.timestamp_add(timestamp_expr, unit, amount) - def timestamp_sub( + def timestamp_subtract( timestamp: Expr | str, unit: Expr | str, amount: Expr | float ) -> "Expr": """Creates an expression that subtracts a specified amount of time from this timestamp expression. @@ -2149,7 +2149,7 @@ def timestamp_sub( timestamp_expr = ( Field.of(timestamp) if isinstance(timestamp, str) else timestamp ) - return Expr.timestamp_sub(timestamp_expr, unit, amount) + return Expr.timestamp_subtract(timestamp_expr, unit, amount) class Divide(Function): @@ -2174,17 +2174,23 @@ def __init__(self, vector1: Expr, vector2: Expr): class LogicalMaximum(Function): - """Represents the logical maximum function based on Firestore type ordering.""" + """ + Returns the larger value between this expression and another expression or constant, + based on Firestore's value type ordering. + """ def __init__(self, left: Expr, right: Expr): - super().__init__("logical_maximum", [left, right]) + super().__init__("max", [left, right]) class LogicalMinimum(Function): - """Represents the logical minimum function based on Firestore type ordering.""" + """ + Returns the smaller value between this expression and another expression or constant, + based on Firestore's value type ordering. + """ def __init__(self, left: Expr, right: Expr): - super().__init__("logical_minimum", [left, right]) + super().__init__("min", [left, right]) class MapGet(Function): @@ -2492,18 +2498,18 @@ def as_(self, alias: str) -> "AliasedAggregate": -class Max(AggregateFunction): - """Represents the maximum aggregation function.""" +class Maximum(AggregateFunction): + """Finds the maximum value of a field, aggregated across multiple stage inputs.""" def __init__(self, value: Expr): - super().__init__("maximum", [value]) + super().__init__("max", [value]) -class Min(AggregateFunction): - """Represents the minimum aggregation function.""" +class Minimum(AggregateFunction): + """Finds the maximum value of a field, aggregated across multiple stage inputs.""" def __init__(self, value: Expr): - super().__init__("minimum", [value]) + super().__init__("min", [value]) class Sum(AggregateFunction): @@ -2517,7 +2523,7 @@ class Average(AggregateFunction): """Represents the average aggregation function.""" def __init__(self, value: Expr): - super().__init__("avg", [value]) + super().__init__("average", [value]) class Count(AggregateFunction): @@ -2687,26 +2693,26 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NAN: return And(field.exists(), Not(field.is_nan())) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NULL: - return And(field.exists(), field.eq(None)) + return And(field.exists(), field.equal(None)) elif filter_pb.op == Query_pb.UnaryFilter.Operator.IS_NOT_NULL: - return And(field.exists(), Not(field.eq(None))) + return And(field.exists(), Not(field.equal(None))) else: raise TypeError(f"Unexpected UnaryFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.FieldFilter): field = Field.of(filter_pb.field.field_path) value = decode_value(filter_pb.value, client) if filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN: - return And(field.exists(), field.lt(value)) + return And(field.exists(), field.less_than(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.LESS_THAN_OR_EQUAL: - return And(field.exists(), field.lte(value)) + return And(field.exists(), field.less_than_or_equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN: - return And(field.exists(), field.gt(value)) + return And(field.exists(), field.greater_than(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.GREATER_THAN_OR_EQUAL: - return And(field.exists(), field.gte(value)) + return And(field.exists(), field.greater_than_or_equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.EQUAL: - return And(field.exists(), field.eq(value)) + return And(field.exists(), field.equal(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_EQUAL: - return And(field.exists(), field.neq(value)) + return And(field.exists(), field.not_equal(value)) if filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS: return And(field.exists(), field.array_contains(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: @@ -2764,7 +2770,7 @@ class Equal(BooleanExpr): """Represents the equality comparison.""" def __init__(self, left: Expr, right: Expr): - super().__init__("eq", [left, right]) + super().__init__("equal", [left, right]) class Exists(BooleanExpr): diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 5064e87ae..7e3784f63 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -127,12 +127,12 @@ def test_avg_aggregation_no_alias_to_pb(): "in_alias,expected_alias", [("total", "total"), (None, "field_1")] ) def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate from google.cloud.firestore_v1.pipeline_expressions import Count count_aggregation = CountAggregation(alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias assert isinstance(got.expr, Count) assert len(got.expr.params) == 0 @@ -143,12 +143,12 @@ def test_count_aggregation_to_pipeline_expr(in_alias, expected_alias): [("total", "path", "total"), (None, "some_ref", "field_1")], ) def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate from google.cloud.firestore_v1.pipeline_expressions import Sum count_aggregation = SumAggregation(expected_path, alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias assert isinstance(got.expr, Sum) assert got.expr.params[0].path == expected_path @@ -159,14 +159,14 @@ def test_sum_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alia [("total", "path", "total"), (None, "some_ref", "field_1")], ) def test_avg_aggregation_to_pipeline_expr(in_alias, expected_path, expected_alias): - from google.cloud.firestore_v1.pipeline_expressions import ExprWithAlias - from google.cloud.firestore_v1.pipeline_expressions import Avg + from google.cloud.firestore_v1.pipeline_expressions import AliasedAggregate + from google.cloud.firestore_v1.pipeline_expressions import Average count_aggregation = AvgAggregation(expected_path, alias=in_alias) got = count_aggregation._to_pipeline_expr(iter([1])) - assert isinstance(got, ExprWithAlias) + assert isinstance(got, AliasedAggregate) assert got.alias == expected_alias - assert isinstance(got.expr, Avg) + assert isinstance(got.expr, Average) assert got.expr.params[0].path == expected_path @@ -1068,7 +1068,7 @@ def test_aggreation_to_pipeline_sum(field, in_alias, out_alias): def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate - from google.cloud.firestore_v1.pipeline_expressions import Avg + from google.cloud.firestore_v1.pipeline_expressions import Average client = make_client() parent = client.collection("dee") @@ -1083,7 +1083,7 @@ def test_aggreation_to_pipeline_avg(field, in_alias, out_alias): aggregate_stage = pipeline.stages[1] assert isinstance(aggregate_stage, Aggregate) assert len(aggregate_stage.accumulators) == 1 - assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + assert isinstance(aggregate_stage.accumulators[0].expr, Average) expected_field = field if isinstance(field, str) else field.to_api_repr() assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field assert aggregate_stage.accumulators[0].alias == out_alias @@ -1142,7 +1142,7 @@ def test_aggreation_to_pipeline_count_increment(): def test_aggreation_to_pipeline_complex(): from google.cloud.firestore_v1.pipeline import Pipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select - from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + from google.cloud.firestore_v1.pipeline_expressions import Sum, Average, Count client = make_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) @@ -1163,7 +1163,7 @@ def test_aggreation_to_pipeline_complex(): assert aggregate_stage.accumulators[0].alias == "alias" assert isinstance(aggregate_stage.accumulators[1].expr, Count) assert aggregate_stage.accumulators[1].alias == "field_1" - assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert isinstance(aggregate_stage.accumulators[2].expr, Average) assert aggregate_stage.accumulators[2].alias == "field_2" assert isinstance(aggregate_stage.accumulators[3].expr, Sum) assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index fdd4a1450..f3ca8aa93 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -742,7 +742,7 @@ def test_async_aggreation_to_pipeline_sum(field, in_alias, out_alias): def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate - from google.cloud.firestore_v1.pipeline_expressions import Avg + from google.cloud.firestore_v1.pipeline_expressions import Average client = make_async_client() parent = client.collection("dee") @@ -757,7 +757,7 @@ def test_async_aggreation_to_pipeline_avg(field, in_alias, out_alias): aggregate_stage = pipeline.stages[1] assert isinstance(aggregate_stage, Aggregate) assert len(aggregate_stage.accumulators) == 1 - assert isinstance(aggregate_stage.accumulators[0].expr, Avg) + assert isinstance(aggregate_stage.accumulators[0].expr, Average) expected_field = field if isinstance(field, str) else field.to_api_repr() assert aggregate_stage.accumulators[0].expr.params[0].path == expected_field assert aggregate_stage.accumulators[0].alias == out_alias @@ -816,7 +816,7 @@ def test_aggreation_to_pipeline_count_increment(): def test_async_aggreation_to_pipeline_complex(): from google.cloud.firestore_v1.async_pipeline import AsyncPipeline from google.cloud.firestore_v1._pipeline_stages import Collection, Aggregate, Select - from google.cloud.firestore_v1.pipeline_expressions import Sum, Avg, Count + from google.cloud.firestore_v1.pipeline_expressions import Sum, Average, Count client = make_async_client() query = client.collection("my_col").select(["field_a", "field_b.c"]) @@ -837,7 +837,7 @@ def test_async_aggreation_to_pipeline_complex(): assert aggregate_stage.accumulators[0].alias == "alias" assert isinstance(aggregate_stage.accumulators[1].expr, Count) assert aggregate_stage.accumulators[1].alias == "field_1" - assert isinstance(aggregate_stage.accumulators[2].expr, Avg) + assert isinstance(aggregate_stage.accumulators[2].expr, Average) assert aggregate_stage.accumulators[2].alias == "field_2" assert isinstance(aggregate_stage.accumulators[3].expr, Sum) assert aggregate_stage.accumulators[3].alias == "field_3" diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 9bb3e61f8..8bf060a60 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2041,7 +2041,7 @@ def test__query_pipeline_composite_filter(): in_filter = FieldFilter("field_a", "==", "value_a") query = client.collection("my_col").where(filter=in_filter) with mock.patch.object( - expr.FilterCondition, "_from_query_filter_pb" + expr.BooleanExpr, "_from_query_filter_pb" ) as convert_mock: pipeline = query.pipeline() convert_mock.assert_called_once_with(in_filter._to_pb(), client) diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index a57eafd25..376656929 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -99,14 +99,14 @@ def test_ctor(self): ("pow", (2,), expr.Pow), ("round", (), expr.Round), ("sqrt", (), expr.Sqrt), - ("logical_max", (2,), expr.LogicalMax), - ("logical_min", (2,), expr.LogicalMin), - ("eq", (2,), expr.Eq), - ("neq", (2,), expr.Neq), - ("lt", (2,), expr.Lt), - ("lte", (2,), expr.Lte), - ("gt", (2,), expr.Gt), - ("gte", (2,), expr.Gte), + ("logical_maximum", (2,), expr.LogicalMaximum), + ("logical_minimum", (2,), expr.LogicalMinimum), + ("equal", (2,), expr.Equal), + ("not_equal", (2,), expr.NotEqual), + ("less_than", (2,), expr.LessThan), + ("less_than_or_equal", (2,), expr.LessThanOrEqual), + ("greater_than", (2,), expr.GreaterThan), + ("greater_than_or_equal", (2,), expr.GreaterThanOrEqual), ("in_any", ([None],), expr.In), ("not_in_any", ([None],), expr.Not), ("array_concat", ([None],), expr.ArrayConcat), @@ -118,19 +118,19 @@ def test_ctor(self): ("is_nan", (), expr.IsNaN), ("exists", (), expr.Exists), ("sum", (), expr.Sum), - ("avg", (), expr.Avg), + ("average", (), expr.Average), ("count", (), expr.Count), - ("min", (), expr.Min), - ("max", (), expr.Max), + ("minimum", (), expr.Minimum), + ("maximum", (), expr.Maximum), ("char_length", (), expr.CharLength), ("byte_length", (), expr.ByteLength), ("like", ("pattern",), expr.Like), ("regex_contains", ("regex",), expr.RegexContains), ("regex_matches", ("regex",), expr.RegexMatch), - ("str_contains", ("substring",), expr.StrContains), + ("string_contains", ("substring",), expr.StringContains), ("starts_with", ("prefix",), expr.StartsWith), ("ends_with", ("postfix",), expr.EndsWith), - ("str_concat", ("elem1", expr.Constant("elem2")), expr.StrConcat), + ("string_concat", ("elem1", expr.Constant("elem2")), expr.StringConcat), ("to_lower", (), expr.ToLower), ("to_upper", (), expr.ToUpper), ("trim", (), expr.Trim), @@ -149,7 +149,7 @@ def test_ctor(self): ("timestamp_to_unix_seconds", (), expr.TimestampToUnixSeconds), ("unix_seconds_to_timestamp", (), expr.UnixSecondsToTimestamp), ("timestamp_add", ("day", 1), expr.TimestampAdd), - ("timestamp_sub", ("hour", 2.5), expr.TimestampSub), + ("timestamp_subtract", ("hour", 2.5), expr.TimestampSubtract), ("ascending", (), expr.Ordering), ("descending", (), expr.Ordering), ("as_", ("alias",), expr.AliasedExpr), @@ -439,11 +439,11 @@ def test__from_query_filter_pb_composite_filter_or(self, mock_client): # should include existance checks expected_cond1 = expr.And( expr.Exists(expr.Field.of("field1")), - expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), + expr.Equal(expr.Field.of("field1"), expr.Constant("val1")), ) expected_cond2 = expr.And( expr.Exists(expr.Field.of("field2")), - expr.Eq(expr.Field.of("field2"), expr.Constant(None)), + expr.Equal(expr.Field.of("field2"), expr.Constant(None)), ) expected = expr.Or(expected_cond1, expected_cond2) @@ -482,11 +482,11 @@ def test__from_query_filter_pb_composite_filter_and(self, mock_client): # should include existance checks expected_cond1 = expr.And( expr.Exists(expr.Field.of("field1")), - expr.Gt(expr.Field.of("field1"), expr.Constant(100)), + expr.GreaterThan(expr.Field.of("field1"), expr.Constant(100)), ) expected_cond2 = expr.And( expr.Exists(expr.Field.of("field2")), - expr.Lt(expr.Field.of("field2"), expr.Constant(200)), + expr.LessThan(expr.Field.of("field2"), expr.Constant(200)), ) expected = expr.And(expected_cond1, expected_cond2) assert repr(result) == repr(expected) @@ -532,15 +532,15 @@ def test__from_query_filter_pb_composite_filter_nested(self, mock_client): expected_cond1 = expr.And( expr.Exists(expr.Field.of("field1")), - expr.Eq(expr.Field.of("field1"), expr.Constant("val1")), + expr.Equal(expr.Field.of("field1"), expr.Constant("val1")), ) expected_cond2 = expr.And( expr.Exists(expr.Field.of("field2")), - expr.Gt(expr.Field.of("field2"), expr.Constant(10)), + expr.GreaterThan(expr.Field.of("field2"), expr.Constant(10)), ) expected_cond3 = expr.And( expr.Exists(expr.Field.of("field3")), - expr.Not(expr.Eq(expr.Field.of("field3"), expr.Constant(None))), + expr.Not(expr.Equal(expr.Field.of("field3"), expr.Constant(None))), ) expected_inner_and = expr.And(expected_cond2, expected_cond3) expected_outer_or = expr.Or(expected_cond1, expected_inner_and) @@ -577,11 +577,11 @@ def test__from_query_filter_pb_composite_filter_unknown_op(self, mock_client): ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NULL, - lambda f: f.eq(None), + lambda f: f.equal(None), ), ( query_pb.StructuredQuery.UnaryFilter.Operator.IS_NOT_NULL, - lambda f: expr.Not(f.eq(None)), + lambda f: expr.Not(f.equal(None)), ), ], ) @@ -624,20 +624,20 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): @pytest.mark.parametrize( "op_enum, value, expected_expr_func", [ - (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.Lt), + (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.LessThan), ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, - expr.Lte, + expr.LessThanOrEqual, ), - (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.Gt), + (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.GreaterThan), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, - expr.Gte, + expr.GreaterThanOrEqual, ), - (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Eq), - (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.Neq), + (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Equal), + (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.NotEqual), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, @@ -766,53 +766,53 @@ def test_exists(self): assert instance.params == [arg1] assert repr(instance) == "Field.exists()" - def test_eq(self): + def test_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Eq(arg1, arg2) - assert instance.name == "eq" + instance = expr.Equal(arg1, arg2) + assert instance.name == "equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.eq(Right)" + assert repr(instance) == "Left.equal(Right)" - def test_gte(self): + def test_greater_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Gte(arg1, arg2) - assert instance.name == "gte" + instance = expr.GreaterThanOrEqual(arg1, arg2) + assert instance.name == "greater_than_or_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.gte(Right)" + assert repr(instance) == "Left.greater_than_or_equal(Right)" - def test_gt(self): + def test_greater_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Gt(arg1, arg2) - assert instance.name == "gt" + instance = expr.GreaterThan(arg1, arg2) + assert instance.name == "greater_than" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.gt(Right)" + assert repr(instance) == "Left.greater_than(Right)" - def test_lte(self): + def test_less_than_or_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Lte(arg1, arg2) - assert instance.name == "lte" + instance = expr.LessThanOrEqual(arg1, arg2) + assert instance.name == "less_than_or_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.lte(Right)" + assert repr(instance) == "Left.less_than_or_equal(Right)" - def test_lt(self): + def test_less_than(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Lt(arg1, arg2) - assert instance.name == "lt" + instance = expr.LessThan(arg1, arg2) + assert instance.name == "less_than" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.lt(Right)" + assert repr(instance) == "Left.less_than(Right)" - def test_neq(self): + def test_not_equal(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.Neq(arg1, arg2) - assert instance.name == "neq" + instance = expr.NotEqual(arg1, arg2) + assert instance.name == "not_equal" assert instance.params == [arg1, arg2] - assert repr(instance) == "Left.neq(Right)" + assert repr(instance) == "Left.not_equal(Right)" def test_in(self): arg1 = self._make_arg("Field") @@ -902,13 +902,13 @@ def test_starts_with(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.starts_with(Prefix)" - def test_str_contains(self): + def test_string_contains(self): arg1 = self._make_arg("Expr") arg2 = self._make_arg("Substring") - instance = expr.StrContains(arg1, arg2) - assert instance.name == "str_contains" + instance = expr.StringContains(arg1, arg2) + assert instance.name == "string_contains" assert instance.params == [arg1, arg2] - assert repr(instance) == "Expr.str_contains(Substring)" + assert repr(instance) == "Expr.string_contains(Substring)" def test_xor(self): arg1 = self._make_arg("Condition1") @@ -941,14 +941,14 @@ class TestFunctionClasses: ("pow", ("field", 2), expr.Pow), ("round", ("field",), expr.Round), ("sqrt", ("field",), expr.Sqrt), - ("logical_max", ("field", 2), expr.LogicalMax), - ("logical_min", ("field", 2), expr.LogicalMin), - ("eq", ("field", 2), expr.Eq), - ("neq", ("field", 2), expr.Neq), - ("lt", ("field", 2), expr.Lt), - ("lte", ("field", 2), expr.Lte), - ("gt", ("field", 2), expr.Gt), - ("gte", ("field", 2), expr.Gte), + ("logical_maximum", ("field", 2), expr.LogicalMaximum), + ("logical_minimum", ("field", 2), expr.LogicalMinimum), + ("equal", ("field", 2), expr.Equal), + ("not_equal", ("field", 2), expr.NotEqual), + ("less_than", ("field", 2), expr.LessThan), + ("less_than_or_equal", ("field", 2), expr.LessThanOrEqual), + ("greater_than", ("field", 2), expr.GreaterThan), + ("greater_than_or_equal", ("field", 2), expr.GreaterThanOrEqual), ("in_any", ("field", [None]), expr.In), ("not_in_any", ("field", [None]), expr.Not), ("array_contains", ("field", None), expr.ArrayContains), @@ -959,20 +959,20 @@ class TestFunctionClasses: ("is_nan", ("field",), expr.IsNaN), ("exists", ("field",), expr.Exists), ("sum", ("field",), expr.Sum), - ("avg", ("field",), expr.Avg), + ("average", ("field",), expr.Average), ("count", ("field",), expr.Count), ("count", (), expr.Count), - ("min", ("field",), expr.Min), - ("max", ("field",), expr.Max), + ("minimum", ("field",), expr.Minimum), + ("maximum", ("field",), expr.Maximum), ("char_length", ("field",), expr.CharLength), ("byte_length", ("field",), expr.ByteLength), ("like", ("field", "pattern"), expr.Like), ("regex_contains", ("field", "regex"), expr.RegexContains), ("regex_matches", ("field", "regex"), expr.RegexMatch), - ("str_contains", ("field", "substring"), expr.StrContains), + ("string_contains", ("field", "substring"), expr.StringContains), ("starts_with", ("field", "prefix"), expr.StartsWith), ("ends_with", ("field", "postfix"), expr.EndsWith), - ("str_concat", ("field", "elem1", "elem2"), expr.StrConcat), + ("string_concat", ("field", "elem1", "elem2"), expr.StringConcat), ("map_get", ("field", "key"), expr.MapGet), ("vector_length", ("field",), expr.VectorLength), ("timestamp_to_unix_micros", ("field",), expr.TimestampToUnixMicros), @@ -982,7 +982,7 @@ class TestFunctionClasses: ("timestamp_to_unix_seconds", ("field",), expr.TimestampToUnixSeconds), ("unix_seconds_to_timestamp", ("field",), expr.UnixSecondsToTimestamp), ("timestamp_add", ("field", "day", 1), expr.TimestampAdd), - ("timestamp_sub", ("field", "hour", 2.5), expr.TimestampSub), + ("timestamp_subtract", ("field", "hour", 2.5), expr.TimestampSubtract), ], ) def test_function_builder(self, method, args, result_cls): @@ -1023,21 +1023,21 @@ def test_divide(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Divide(Left, Right)" - def test_logical_max(self): + def test_logical_maximum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.LogicalMax(arg1, arg2) - assert instance.name == "logical_maximum" + instance = expr.LogicalMaximum(arg1, arg2) + assert instance.name == "max" assert instance.params == [arg1, arg2] - assert repr(instance) == "LogicalMax(Left, Right)" + assert repr(instance) == "LogicalMaximum(Left, Right)" - def test_logical_min(self): + def test_logical_minimum(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") - instance = expr.LogicalMin(arg1, arg2) - assert instance.name == "logical_minimum" + instance = expr.LogicalMinimum(arg1, arg2) + assert instance.name == "min" assert instance.params == [arg1, arg2] - assert repr(instance) == "LogicalMin(Left, Right)" + assert repr(instance) == "LogicalMinimum(Left, Right)" def test_map_get(self): arg1 = self._make_arg("Map") @@ -1070,13 +1070,13 @@ def test_parent(self): assert instance.params == [arg1] assert repr(instance) == "Parent(Value)" - def test_str_concat(self): + def test_string_concat(self): arg1 = self._make_arg("Str1") arg2 = self._make_arg("Str2") - instance = expr.StrConcat(arg1, arg2) - assert instance.name == "str_concat" + instance = expr.StringConcat(arg1, arg2) + assert instance.name == "string_concat" assert instance.params == [arg1, arg2] - assert repr(instance) == "StrConcat(Str1, Str2)" + assert repr(instance) == "StringConcat(Str1, Str2)" def test_subtract(self): arg1 = self._make_arg("Left") @@ -1095,14 +1095,14 @@ def test_timestamp_add(self): assert instance.params == [arg1, arg2, arg3] assert repr(instance) == "TimestampAdd(Timestamp, Unit, Amount)" - def test_timestamp_sub(self): + def test_timestamp_subtract(self): arg1 = self._make_arg("Timestamp") arg2 = self._make_arg("Unit") arg3 = self._make_arg("Amount") - instance = expr.TimestampSub(arg1, arg2, arg3) - assert instance.name == "timestamp_sub" + instance = expr.TimestampSubtract(arg1, arg2, arg3) + assert instance.name == "timestamp_subtract" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "TimestampSub(Timestamp, Unit, Amount)" + assert repr(instance) == "TimestampSubtract(Timestamp, Unit, Amount)" def test_timestamp_to_unix_micros(self): arg1 = self._make_arg("Input") @@ -1225,12 +1225,12 @@ def test_sum(self): assert instance.params == [arg1] assert repr(instance) == "Sum(Value)" - def test_avg(self): + def test_average(self): arg1 = self._make_arg("Value") - instance = expr.Avg(arg1) - assert instance.name == "avg" + instance = expr.Average(arg1) + assert instance.name == "average" assert instance.params == [arg1] - assert repr(instance) == "Avg(Value)" + assert repr(instance) == "Average(Value)" def test_count(self): arg1 = self._make_arg("Value") @@ -1244,19 +1244,19 @@ def test_count_empty(self): assert instance.params == [] assert repr(instance) == "Count()" - def test_min(self): + def test_minimum(self): arg1 = self._make_arg("Value") - instance = expr.Min(arg1) - assert instance.name == "minimum" + instance = expr.Minimum(arg1) + assert instance.name == "min" assert instance.params == [arg1] - assert repr(instance) == "Min(Value)" + assert repr(instance) == "Minimum(Value)" - def test_max(self): + def test_maximum(self): arg1 = self._make_arg("Value") - instance = expr.Max(arg1) - assert instance.name == "maximum" + instance = expr.Maximum(arg1) + assert instance.name == "max" assert instance.params == [arg1] - assert repr(instance) == "Max(Value)" + assert repr(instance) == "Maximum(Value)" def test_dot_product(self): arg1 = self._make_arg("Left") diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index ef9170d6a..049b04cb4 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -80,7 +80,7 @@ def _make_one(self, *args, **kwargs): def test_ctor_positional(self): """test with only positional arguments""" sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + avg_price = Field.of("price").average().as_("avg_price") instance = self._make_one(sum_total, avg_price) assert list(instance.accumulators) == [sum_total, avg_price] assert len(instance.groups) == 0 @@ -89,7 +89,7 @@ def test_ctor_positional(self): def test_ctor_keyword(self): """test with only keyword arguments""" sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + avg_price = Field.of("price").average().as_("avg_price") group_category = Field.of("category") instance = self._make_one( accumulators=[avg_price, sum_total], groups=[group_category, "city"] @@ -104,7 +104,7 @@ def test_ctor_keyword(self): def test_ctor_combined(self): """test with a mix of arguments""" sum_total = Sum(Field.of("total")).as_("sum_total") - avg_price = Field.of("price").avg().as_("avg_price") + avg_price = Field.of("price").average().as_("avg_price") count = Count(Field.of("total")).as_("count") with pytest.raises(ValueError): self._make_one(sum_total, accumulators=[avg_price, count]) @@ -840,19 +840,19 @@ def _make_one(self, *args, **kwargs): return stages.Where(*args, **kwargs) def test_repr(self): - condition = Field.of("age").gt(30) + condition = Field.of("age").greater_than(30) instance = self._make_one(condition) repr_str = repr(instance) - assert repr_str == "Where(condition=Field.of('age').gt(Constant.of(30)))" + assert repr_str == "Where(condition=Field.of('age').greater_than(Constant.of(30)))" def test_to_pb(self): - condition = Field.of("city").eq("SF") + condition = Field.of("city").equal("SF") instance = self._make_one(condition) result = instance._to_pb() assert result.name == "where" assert len(result.args) == 1 got_fn = result.args[0].function_value - assert got_fn.name == "eq" + assert got_fn.name == "equal" assert len(got_fn.args) == 2 assert got_fn.args[0].field_reference_value == "city" assert got_fn.args[1].string_value == "SF" From e1a2f15409b6a78766da86c47c2e35fc4f6a9ef2 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 17 Oct 2025 14:32:14 -0700 Subject: [PATCH 10/16] fixed lint --- google/cloud/firestore_v1/base_pipeline.py | 1 + .../firestore_v1/pipeline_expressions.py | 41 +++++++++++-------- tests/unit/v1/test_aggregation.py | 2 +- tests/unit/v1/test_async_aggregation.py | 2 +- tests/unit/v1/test_base_query.py | 4 +- tests/unit/v1/test_pipeline_expressions.py | 19 +++++++-- tests/unit/v1/test_pipeline_stages.py | 4 +- 7 files changed, 46 insertions(+), 27 deletions(-) diff --git a/google/cloud/firestore_v1/base_pipeline.py b/google/cloud/firestore_v1/base_pipeline.py index 5cf341e58..ac917c5d2 100644 --- a/google/cloud/firestore_v1/base_pipeline.py +++ b/google/cloud/firestore_v1/base_pipeline.py @@ -23,6 +23,7 @@ from google.cloud.firestore_v1.types.firestore import ExecutePipelineRequest from google.cloud.firestore_v1.pipeline_result import PipelineResult from google.cloud.firestore_v1.pipeline_expressions import ( + AliasedAggregate, Expr, Field, BooleanExpr, diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index de9d89fe8..d4ff6a929 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -431,7 +431,9 @@ def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": Returns: A new `Expr` representing the greater than or equal to comparison. """ - return GreaterThanOrEqual(self, self._cast_to_expr_or_convert_to_constant(other)) + return GreaterThanOrEqual( + self, self._cast_to_expr_or_convert_to_constant(other) + ) def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another @@ -501,7 +503,7 @@ def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """ return Not(self.in_any(array)) - def array_concat(self, array: List[Expr | CONSTANT_TYPE]) -> "Expr": + def array_concat(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Expr": """Creates an expression that concatenates an array expression with another array. Example: @@ -779,7 +781,9 @@ def string_contains(self, substring: Expr | str) -> "BooleanExpr": Returns: A new `Expr` representing the 'contains' comparison. """ - return StringContains(self, self._cast_to_expr_or_convert_to_constant(substring)) + return StringContains( + self, self._cast_to_expr_or_convert_to_constant(substring) + ) def starts_with(self, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. @@ -959,9 +963,7 @@ def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": """ return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) - def euclidean_distance( - self, other: Expr | list[float] | Vector - ) -> "Expr": + def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": """Calculates the Euclidean distance between two vectors. Example: @@ -1568,7 +1570,9 @@ def greater_than(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr" left_expr = Field.of(left) if isinstance(left, str) else left return Expr.greater_than(left_expr, right) - def greater_than_or_equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def greater_than_or_equal( + left: Expr | str, right: Expr | CONSTANT_TYPE + ) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. @@ -1604,7 +1608,9 @@ def less_than(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": left_expr = Field.of(left) if isinstance(left, str) else left return Expr.less_than(left_expr, right) - def less_than_or_equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": + def less_than_or_equal( + left: Expr | str, right: Expr | CONSTANT_TYPE + ) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. @@ -1622,7 +1628,9 @@ def less_than_or_equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Boolea left_expr = Field.of(left) if isinstance(left, str) else left return Expr.less_than_or_equal(left_expr, right) - def in_any(left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def in_any( + left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE] + ) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -1640,7 +1648,9 @@ def in_any(left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanE left_expr = Field.of(left) if isinstance(left, str) else left return Expr.in_any(left_expr, array) - def not_in_any(left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def not_in_any( + left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE] + ) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -2262,6 +2272,7 @@ class TimestampAdd(Function): def __init__(self, timestamp: Expr, unit: Expr, amount: Expr): super().__init__("timestamp_add", [timestamp, unit, amount]) + class Abs(Function): """Represents the absolute value function.""" @@ -2412,7 +2423,7 @@ def __init__(self, left: Expr, right: Expr): class ArrayConcat(Function): """Represents concatenating multiple arrays.""" - def __init__(self, array: Expr, rest: List[Expr]): + def __init__(self, array: Expr, rest: Sequence[Expr]): super().__init__("array_concat", [array] + rest) @@ -2497,7 +2508,6 @@ def as_(self, alias: str) -> "AliasedAggregate": return AliasedAggregate(self, alias) - class Maximum(AggregateFunction): """Finds the maximum value of a field, aggregated across multiple stage inputs.""" @@ -2675,8 +2685,7 @@ def __repr__(self): def _from_query_filter_pb(filter_pb, client): if isinstance(filter_pb, Query_pb.CompositeFilter): sub_filters = [ - BooleanExpr._from_query_filter_pb(f, client) - for f in filter_pb.filters + BooleanExpr._from_query_filter_pb(f, client) for f in filter_pb.filters ] if filter_pb.op == Query_pb.CompositeFilter.Operator.OR: return Or(*sub_filters) @@ -2798,9 +2807,7 @@ class If(BooleanExpr): """Represents a conditional expression (if-then-else).""" def __init__(self, condition: "BooleanExpr", true_expr: Expr, false_expr: Expr): - super().__init__( - "if", [condition, true_expr, false_expr] - ) + super().__init__("if", [condition, true_expr, false_expr]) class In(BooleanExpr): diff --git a/tests/unit/v1/test_aggregation.py b/tests/unit/v1/test_aggregation.py index 7e3784f63..46c2dd4f0 100644 --- a/tests/unit/v1/test_aggregation.py +++ b/tests/unit/v1/test_aggregation.py @@ -1136,7 +1136,7 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_aggreation_to_pipeline_complex(): diff --git a/tests/unit/v1/test_async_aggregation.py b/tests/unit/v1/test_async_aggregation.py index f3ca8aa93..c69d44dd8 100644 --- a/tests/unit/v1/test_async_aggregation.py +++ b/tests/unit/v1/test_async_aggregation.py @@ -810,7 +810,7 @@ def test_aggreation_to_pipeline_count_increment(): assert len(aggregate_stage.accumulators) == n for i in range(n): assert isinstance(aggregate_stage.accumulators[i].expr, Count) - assert aggregate_stage.accumulators[i].alias == f"field_{i+1}" + assert aggregate_stage.accumulators[i].alias == f"field_{i + 1}" def test_async_aggreation_to_pipeline_complex(): diff --git a/tests/unit/v1/test_base_query.py b/tests/unit/v1/test_base_query.py index 8bf060a60..c13efbfa8 100644 --- a/tests/unit/v1/test_base_query.py +++ b/tests/unit/v1/test_base_query.py @@ -2040,9 +2040,7 @@ def test__query_pipeline_composite_filter(): client = make_client() in_filter = FieldFilter("field_a", "==", "value_a") query = client.collection("my_col").where(filter=in_filter) - with mock.patch.object( - expr.BooleanExpr, "_from_query_filter_pb" - ) as convert_mock: + with mock.patch.object(expr.BooleanExpr, "_from_query_filter_pb") as convert_mock: pipeline = query.pipeline() convert_mock.assert_called_once_with(in_filter._to_pb(), client) assert len(pipeline.stages) == 2 diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 376656929..257d5594e 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -624,20 +624,32 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): @pytest.mark.parametrize( "op_enum, value, expected_expr_func", [ - (query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, 10, expr.LessThan), + ( + query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN, + 10, + expr.LessThan, + ), ( query_pb.StructuredQuery.FieldFilter.Operator.LESS_THAN_OR_EQUAL, 10, expr.LessThanOrEqual, ), - (query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, 10, expr.GreaterThan), + ( + query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN, + 10, + expr.GreaterThan, + ), ( query_pb.StructuredQuery.FieldFilter.Operator.GREATER_THAN_OR_EQUAL, 10, expr.GreaterThanOrEqual, ), (query_pb.StructuredQuery.FieldFilter.Operator.EQUAL, 10, expr.Equal), - (query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, 10, expr.NotEqual), + ( + query_pb.StructuredQuery.FieldFilter.Operator.NOT_EQUAL, + 10, + expr.NotEqual, + ), ( query_pb.StructuredQuery.FieldFilter.Operator.ARRAY_CONTAINS, 10, @@ -1387,7 +1399,6 @@ def test_pow(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Pow(Base, Exponent)" - def test_round(self): arg1 = self._make_arg("Value") instance = expr.Round(arg1) diff --git a/tests/unit/v1/test_pipeline_stages.py b/tests/unit/v1/test_pipeline_stages.py index 049b04cb4..941c4668a 100644 --- a/tests/unit/v1/test_pipeline_stages.py +++ b/tests/unit/v1/test_pipeline_stages.py @@ -843,7 +843,9 @@ def test_repr(self): condition = Field.of("age").greater_than(30) instance = self._make_one(condition) repr_str = repr(instance) - assert repr_str == "Where(condition=Field.of('age').greater_than(Constant.of(30)))" + assert ( + repr_str == "Where(condition=Field.of('age').greater_than(Constant.of(30)))" + ) def test_to_pb(self): condition = Field.of("city").equal("SF") From 4de3908d0bfbe8080e8a3fb407143df329b068d9 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 17 Oct 2025 14:49:20 -0700 Subject: [PATCH 11/16] added test for AliasedAggregate --- .../firestore_v1/pipeline_expressions.py | 2 +- tests/unit/v1/test_pipeline_expressions.py | 28 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index d4ff6a929..1042964da 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -2669,7 +2669,7 @@ def __init__( def __repr__(self): """ - Most BooleanExprs can be triggered infix. Eg: Field.of('age').gte(18). + Most BooleanExprs can be triggered infix. Eg: Field.of('age').greater_than(18). Display them this way in the repr string where possible """ diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 257d5594e..b1b468677 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -405,6 +405,34 @@ def test_to_map(self): assert result[0] == "alias1" assert result[1] == Value(field_reference_value="field1") + class TestAliasedAggregate: + + def test_repr(self): + instance = expr.Field.of("field1").maximum().as_("alias1") + assert repr(instance) == "Maximum(Field.of('field1')).as_('alias1')" + + def test_ctor(self): + arg = expr.Field.of("field1").minimum() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + assert instance.expr == arg + assert instance.alias == alias + + def test_to_pb(self): + arg = expr.Field.of("field1").average() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + result = instance._to_pb() + assert result.map_value.fields.get("alias1") == arg._to_pb() + + def test_to_map(self): + arg = expr.Field.of("field1").count() + alias = "alias1" + instance = expr.AliasedAggregate(arg, alias) + result = instance._to_map() + assert result[0] == "alias1" + assert result[1] == arg._to_pb() + class TestBooleanExpr: def test__from_query_filter_pb_composite_filter_or(self, mock_client): From 5fd0ba4db7bf1610a1e661d216b73807b3b6e701 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 17 Oct 2025 16:16:39 -0700 Subject: [PATCH 12/16] removed unused expressions --- .../firestore_v1/pipeline_expressions.py | 81 ------------------- tests/unit/v1/test_pipeline_expressions.py | 33 +++----- 2 files changed, 9 insertions(+), 105 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 1042964da..c8b47b16d 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -884,52 +884,6 @@ def reverse(self) -> "Expr": """ return Reverse(self) - def replace_first(self, find: Expr | str, replace: Expr | str) -> "Expr": - """Creates an expression that replaces the first occurrence of a substring within a string with - another substring. - - Example: - >>> # Replace the first occurrence of "hello" with "hi" in the 'message' field - >>> Field.of("message").replace_first("hello", "hi") - >>> # Replace the first occurrence of the value in 'findField' with the value in 'replaceField' in the 'message' field - >>> Field.of("message").replace_first(Field.of("findField"), Field.of("replaceField")) - - Args: - find: The substring (string or expression) to search for. - replace: The substring (string or expression) to replace the first occurrence of 'find' with. - - Returns: - A new `Expr` representing the string with the first occurrence replaced. - """ - return ReplaceFirst( - self, - self._cast_to_expr_or_convert_to_constant(find), - self._cast_to_expr_or_convert_to_constant(replace), - ) - - def replace_all(self, find: Expr | str, replace: Expr | str) -> "Expr": - """Creates an expression that replaces all occurrences of a substring within a string with another - substring. - - Example: - >>> # Replace all occurrences of "hello" with "hi" in the 'message' field - >>> Field.of("message").replace_all("hello", "hi") - >>> # Replace all occurrences of the value in 'findField' with the value in 'replaceField' in the 'message' field - >>> Field.of("message").replace_all(Field.of("findField"), Field.of("replaceField")) - - Args: - find: The substring (string or expression) to search for. - replace: The substring (string or expression) to replace all occurrences of 'find' with. - - Returns: - A new `Expr` representing the string with all occurrences replaced. - """ - return ReplaceAll( - self, - self._cast_to_expr_or_convert_to_constant(find), - self._cast_to_expr_or_convert_to_constant(replace), - ) - def map_get(self, key: str) -> "Expr": """Accesses a value from a map (object) field using the provided key. @@ -2231,20 +2185,6 @@ def __init__(self, value: Expr): super().__init__("parent", [value]) -class ReplaceAll(Function): - """Represents replacing all occurrences of a substring.""" - - def __init__(self, value: Expr, pattern: Expr, replacement: Expr): - super().__init__("replace_all", [value, pattern, replacement]) - - -class ReplaceFirst(Function): - """Represents replacing the first occurrence of a substring.""" - - def __init__(self, value: Expr, pattern: Expr, replacement: Expr): - super().__init__("replace_first", [value, pattern, replacement]) - - class Reverse(Function): """Represents reversing a string.""" @@ -2427,20 +2367,6 @@ def __init__(self, array: Expr, rest: Sequence[Expr]): super().__init__("array_concat", [array] + rest) -class ArrayElement(Function): - """Represents accessing an element within an array""" - - def __init__(self): - super().__init__("array_element", []) - - -class ArrayFilter(Function): - """Represents filtering elements from an array based on a condition.""" - - def __init__(self, array: Expr, filter: "BooleanExpr"): - super().__init__("array_filter", [array, filter]) - - class ArrayLength(Function): """Represents getting the length of an array.""" @@ -2455,13 +2381,6 @@ def __init__(self, array: Expr): super().__init__("array_reverse", [array]) -class ArrayTransform(Function): - """Represents applying a transformation function to each element of an array.""" - - def __init__(self, array: Expr, transform: Function): - super().__init__("array_transform", [array, transform]) - - class ByteLength(Function): """Represents getting the byte length of a string (UTF-8).""" diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index b1b468677..9ac1ce130 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -135,8 +135,6 @@ def test_ctor(self): ("to_upper", (), expr.ToUpper), ("trim", (), expr.Trim), ("reverse", (), expr.Reverse), - ("replace_first", ("1", "2"), expr.ReplaceFirst), - ("replace_all", ("1", "2"), expr.ReplaceAll), ("map_get", ("key",), expr.MapGet), ("cosine_distance", [1], expr.CosineDistance), ("euclidean_distance", [1], expr.EuclideanDistance), @@ -1037,14 +1035,18 @@ def test_function_builder(self, method, args, result_cls): @pytest.mark.parametrize( "first,second,expected", [ - (expr.ArrayElement(), expr.ArrayElement(), True), - (expr.ArrayElement(), expr.CharLength(1), False), - (expr.ArrayElement(), object(), False), - (expr.ArrayElement(), None, False), - (expr.CharLength(1), expr.ArrayElement(), False), + (expr.Array([]), expr.Array([]), True), + (expr.Array([]), expr.CharLength(1), False), + (expr.Array([]), object(), False), + (expr.Array([]), None, False), + (expr.CharLength(1), expr.Array([]), False), (expr.CharLength(1), expr.CharLength(2), False), (expr.CharLength(1), expr.CharLength(1), True), (expr.CharLength(1), expr.ByteLength(1), False), + (expr.Array([1]), expr.Array([1]), True), + (expr.Array([1]), expr.Array([2]), False), + (expr.Array([1]), expr.Array([]), False), + (expr.Array([1, 2]), expr.Array([1]), False), ], ) def test_equality(self, first, second, expected): @@ -1322,23 +1324,6 @@ def test_cosine_distance(self): assert instance.params == [arg1, arg2] assert repr(instance) == "CosineDistance(Left, Right)" - def test_replace_all(self): - arg1 = self._make_arg("Expr") - arg2 = self._make_arg("OldValue") - arg3 = self._make_arg("NewValue") - instance = expr.ReplaceAll(arg1, arg2, arg3) - assert instance.name == "replace_all" - assert instance.params == [arg1, arg2, arg3] - - def test_replace_first(self): - arg1 = self._make_arg("Expr") - arg2 = self._make_arg("OldValue") - arg3 = self._make_arg("NewValue") - instance = expr.ReplaceFirst(arg1, arg2, arg3) - assert instance.name == "replace_first" - assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "ReplaceFirst(Expr, OldValue, NewValue)" - def test_reverse(self): arg1 = self._make_arg("Expr") instance = expr.Reverse(arg1) From 5a3eb149601429a65daeb39637e4459f7b4ca171 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 17 Oct 2025 16:42:16 -0700 Subject: [PATCH 13/16] improved array functions --- .../firestore_v1/pipeline_expressions.py | 126 +++++++++++++++--- tests/unit/v1/test_pipeline_expressions.py | 79 ++++++----- 2 files changed, 149 insertions(+), 56 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index c8b47b16d..1a7551cab 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -471,13 +471,13 @@ def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """ return LessThanOrEqual(self, self._cast_to_expr_or_convert_to_constant(other)) - def in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. Example: >>> # Check if the 'category' field is either "Electronics" or value of field 'primaryType' - >>> Field.of("category").in_any(["Electronics", Field.of("primaryType")]) + >>> Field.of("category").equal_any(["Electronics", Field.of("primaryType")]) Args: array: The values or expressions to check against. @@ -485,15 +485,15 @@ def in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": Returns: A new `Expr` representing the 'IN' comparison. """ - return In(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + return EqualAny(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) - def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": + def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. Example: >>> # Check if the 'status' field is neither "pending" nor "cancelled" - >>> Field.of("status").not_in_any(["pending", "cancelled"]) + >>> Field.of("status").not_equal_any(["pending", "cancelled"]) Args: array: The values or expressions to check against. @@ -501,7 +501,38 @@ def not_in_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": Returns: A new `Expr` representing the 'NOT IN' comparison. """ - return Not(self.in_any(array)) + return NotEqualAny(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + + def array(self, elements: list[Expr | CONSTANT_TYPE]) -> "Expr": + """Creates an expression that creates a Firestore array value from an input list. + + Example: + >>> Expr.array(["bar", Field.of("baz").as("foo"))]) + + Args: + elements: THe input list to evaluate in the expression + + Returns: + A new `Expr` representing the array function. + """ + return Array(elements) + + def array_get(self, index: Expr | int) -> "Expr": + """Creates an expression that indexes into an array from the beginning or end + and returns the element. If the index exceeds the array length, an error is + returned. A negative index, starts from the end. + + Example: + >>> # Return the value in the tags field array at index specified by field 'favoriteTag'. + >>> Field.of("tags").array_get(Field.of("favoriteTag")) + + Args: + index: The index of the element to return. + + Returns: + A new `Expr` representing the operation. + """ + return ArrayGet(self, self._cast_to_expr_or_convert_to_constant(index)) def array_concat(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Expr": """Creates an expression that concatenates an array expression with another array. @@ -1582,15 +1613,15 @@ def less_than_or_equal( left_expr = Field.of(left) if isinstance(left, str) else left return Expr.less_than_or_equal(left_expr, right) - def in_any( + def equal_any( left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE] ) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. Example: - >>> Function.in_any("category", ["Electronics", "Apparel"]) - >>> Function.in_any(Field.of("category"), ["Electronics", Field.of("primaryType")]) + >>> Function.equal_any("category", ["Electronics", "Apparel"]) + >>> Function.equal_any(Field.of("category"), ["Electronics", Field.of("primaryType")]) Args: left: The expression or field path to compare. @@ -1600,16 +1631,16 @@ def in_any( A new `Expr` representing the 'IN' comparison. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.in_any(left_expr, array) + return Expr.equal_any(left_expr, array) - def not_in_any( + def not_equal_any( left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE] ) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. Example: - >>> Function.not_in_any("status", ["pending", "cancelled"]) + >>> Function.not_equal_any("status", ["pending", "cancelled"]) Args: left: The expression or field path to compare. @@ -1619,7 +1650,42 @@ def not_in_any( A new `Expr` representing the 'NOT IN' comparison. """ left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.not_in_any(left_expr, array) + return Expr.not_equal_any(left_expr, array) + + def array_get(array: Expr | str, index: Expr | int) -> "Expr": + """Creates an expression that indexes into an array from the beginning or end + and returns the element. If the index exceeds the array length, an error is + returned. A negative index, starts from the end. + + Example: + >>> # Return the value in the tags field array at index specified by field 'favoriteTag'. + >>> Field.of("tags").array_get(Field.of("favoriteTag")) + + Args: + index: The index of the element to return. + + Returns: + A new `Expr` representing the operation. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_get(array_expr, index) + + def array_concat(array: Expr | str, other: Sequence[Expr | CONSTANT_TYPE]) -> "Expr": + """Creates an expression that concatenates an array expression with another array. + + Example: + >>> # Combine the 'tags' array with a new array and an array field + >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) + + Args: + array: The list of constants or expressions to concat with. + + Returns: + A new `Expr` representing the concatenated array. + """ + array_expr = Field.of(array) if isinstance(array, str) else array + return Expr.array_concat(array_expr, other) + def array_contains( array: Expr | str, element: Expr | CONSTANT_TYPE @@ -2360,6 +2426,23 @@ def __init__(self, left: Expr, right: Expr): super().__init__("add", [left, right]) +class Array(Function): + """Creates an expression that creates a Firestore array value from an input list.""" + + def __init__(self, elements: list[Expr]): + super().__init__("array", elements) + + def __repr__(self): + return f"Array({self.params})" + + +class ArrayGet(Function): + """Creates an expression that indexes into an array from the beginning or end and returns an element.""" + + def __init__(self, array: Expr, index: Expr): + super().__init__("array_get", [array, index]) + + class ArrayConcat(Function): """Represents concatenating multiple arrays.""" @@ -2646,9 +2729,9 @@ def _from_query_filter_pb(filter_pb, client): elif filter_pb.op == Query_pb.FieldFilter.Operator.ARRAY_CONTAINS_ANY: return And(field.exists(), field.array_contains_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.IN: - return And(field.exists(), field.in_any(value)) + return And(field.exists(), field.equal_any(value)) elif filter_pb.op == Query_pb.FieldFilter.Operator.NOT_IN: - return And(field.exists(), field.not_in_any(value)) + return And(field.exists(), field.not_equal_any(value)) else: raise TypeError(f"Unexpected FieldFilter operator type: {filter_pb.op}") elif isinstance(filter_pb, Query_pb.Filter): @@ -2729,13 +2812,18 @@ def __init__(self, condition: "BooleanExpr", true_expr: Expr, false_expr: Expr): super().__init__("if", [condition, true_expr, false_expr]) -class In(BooleanExpr): +class EqualAny(BooleanExpr): """Represents checking if an expression's value is within a list of values.""" def __init__(self, left: Expr, others: Sequence[Expr]): - super().__init__( - "in", [left, ListOfExprs(others)], infix_name_override="in_any" - ) + super().__init__("equal_any", [left, ListOfExprs(others)]) + + +class NotEqualAny(BooleanExpr): + """Represents checking if an expression's value is not within a list of values.""" + + def __init__(self, left: Expr, others: Sequence[Expr]): + super().__init__("not_equal_any", [left, ListOfExprs(others)]) class IsNaN(BooleanExpr): diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 9ac1ce130..f765491e9 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -107,8 +107,9 @@ def test_ctor(self): ("less_than_or_equal", (2,), expr.LessThanOrEqual), ("greater_than", (2,), expr.GreaterThan), ("greater_than_or_equal", (2,), expr.GreaterThanOrEqual), - ("in_any", ([None],), expr.In), - ("not_in_any", ([None],), expr.Not), + ("equal_any", ([None],), expr.EqualAny), + ("not_equal_any", ([None],), expr.NotEqualAny), + ("array_get", (1,), expr.ArrayGet), ("array_concat", ([None],), expr.ArrayConcat), ("array_contains", (None,), expr.ArrayContains), ("array_contains_all", ([None],), expr.ArrayContainsAll), @@ -170,7 +171,9 @@ def test_infix_call(self, method, args, result_cls, base_instance): result = method_ptr(*args) assert isinstance(result, result_cls) - if isinstance(result, expr.Function) and not method == "not_in_any": + if isinstance(result, (expr.Ordering, expr.AliasedExpr)): + assert result.expr == base_instance + else: assert result.params[0] == base_instance @@ -686,12 +689,8 @@ def test__from_query_filter_pb_unary_filter_unknown_op(self, mock_client): [10, 20], expr.ArrayContainsAny, ), - (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.In), - ( - query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, - [10, 20], - lambda f, v: expr.Not(f.in_any(v)), - ), + (query_pb.StructuredQuery.FieldFilter.Operator.IN, [10, 20], expr.EqualAny), + (query_pb.StructuredQuery.FieldFilter.Operator.NOT_IN, [10, 20], expr.NotEqualAny), ], ) def test__from_query_filter_pb_field_filter( @@ -852,16 +851,27 @@ def test_not_equal(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Left.not_equal(Right)" - def test_in(self): + def test_equal_any(self): arg1 = self._make_arg("Field") arg2 = self._make_arg("Value1") arg3 = self._make_arg("Value2") - instance = expr.In(arg1, [arg2, arg3]) - assert instance.name == "in" + instance = expr.EqualAny(arg1, [arg2, arg3]) + assert instance.name == "equal_any" assert isinstance(instance.params[1], ListOfExprs) assert instance.params[0] == arg1 assert instance.params[1].exprs == [arg2, arg3] - assert repr(instance) == "Field.in_any(ListOfExprs([Value1, Value2]))" + assert repr(instance) == "Field.equal_any(ListOfExprs([Value1, Value2]))" + + def test_not_equal_any(self): + arg1 = self._make_arg("Field") + arg2 = self._make_arg("Value1") + arg3 = self._make_arg("Value2") + instance = expr.NotEqualAny(arg1, [arg2, arg3]) + assert instance.name == "not_equal_any" + assert isinstance(instance.params[1], ListOfExprs) + assert instance.params[0] == arg1 + assert instance.params[1].exprs == [arg2, arg3] + assert repr(instance) == "Field.not_equal_any(ListOfExprs([Value1, Value2]))" def test_is_nan(self): arg1 = self._make_arg("Value") @@ -987,8 +997,10 @@ class TestFunctionClasses: ("less_than_or_equal", ("field", 2), expr.LessThanOrEqual), ("greater_than", ("field", 2), expr.GreaterThan), ("greater_than_or_equal", ("field", 2), expr.GreaterThanOrEqual), - ("in_any", ("field", [None]), expr.In), - ("not_in_any", ("field", [None]), expr.Not), + ("equal_any", ("field", [None]), expr.EqualAny), + ("not_equal_any", ("field", [None]), expr.NotEqualAny), + ("array", ("field", [1, 2, 3]), expr.Array), + ("array_get", ("field", 2), expr.ArrayGet), ("array_contains", ("field", None), expr.ArrayContains), ("array_contains_all", ("field", [None]), expr.ArrayContainsAll), ("array_contains_any", ("field", [None]), expr.ArrayContainsAny), @@ -1203,20 +1215,6 @@ def test_add(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Add(Left, Right)" - def test_array_element(self): - instance = expr.ArrayElement() - assert instance.name == "array_element" - assert instance.params == [] - assert repr(instance) == "ArrayElement()" - - def test_array_filter(self): - arg1 = self._make_arg("Array") - arg2 = self._make_arg("FilterCond") - instance = expr.ArrayFilter(arg1, arg2) - assert instance.name == "array_filter" - assert instance.params == [arg1, arg2] - assert repr(instance) == "ArrayFilter(Array, FilterCond)" - def test_array_length(self): arg1 = self._make_arg("Array") instance = expr.ArrayLength(arg1) @@ -1231,14 +1229,6 @@ def test_array_reverse(self): assert instance.params == [arg1] assert repr(instance) == "ArrayReverse(Array)" - def test_array_transform(self): - arg1 = self._make_arg("Array") - arg2 = self._make_arg("TransformFunc") - instance = expr.ArrayTransform(arg1, arg2) - assert instance.name == "array_transform" - assert instance.params == [arg1, arg2] - assert repr(instance) == "ArrayTransform(Array, TransformFunc)" - def test_byte_length(self): arg1 = self._make_arg("Expr") instance = expr.ByteLength(arg1) @@ -1352,6 +1342,21 @@ def test_trim(self): assert instance.params == [arg1] assert repr(instance) == "Trim(Expr)" + def test_array(self): + arg = self._make_arg("Value") + instance = expr.Array([1, 2, arg]) + assert instance.name == "array" + assert instance.params == [1, 2, arg] + assert repr(instance) == "Array([1, 2, Value])" + + def test_array_get(self): + arg1 = self._make_arg("Array") + arg2 = self._make_arg("Index") + instance = expr.ArrayGet(arg1, arg2) + assert instance.name == "array_get" + assert instance.params == [arg1, arg2] + assert repr(instance) == "ArrayGet(Array, Index)" + def test_array_concat(self): arg1 = self._make_arg("1") arg2 = self._make_arg("2") From 3a35035b786b90947140847209a6da453192c3d3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 17 Oct 2025 20:07:54 -0700 Subject: [PATCH 14/16] removed duplicate code --- .../firestore_v1/pipeline_expressions.py | 1040 ++--------------- tests/unit/v1/test_pipeline_expressions.py | 3 +- 2 files changed, 94 insertions(+), 949 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 1a7551cab..c9c03529a 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -116,6 +116,32 @@ def _to_pb(self) -> Value: def _cast_to_expr_or_convert_to_constant(o: Any) -> "Expr": return o if isinstance(o, Expr) else Constant(o) + class expose_as_static: + """ + Decorator to mark instance methods to be exposed as static methods as well as instance + methods. + + When called statically, the first argument is converted to a Field expression if needed. + + Example: + >>> Field.of("test").add(5) + >>> Function.add("test", 5) + """ + + def __init__(self, instance_func): + self.instance_func = instance_func + + def static_func(self, first_arg, *other_args, **kwargs): + first_expr = Field.of(first_arg) if not isinstance(first_arg, Expr) else first_arg + return self.instance_func(first_expr, *other_args, **kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self.static_func.__get__(instance, owner) + else: + return self.instance_func.__get__(instance, owner) + + @expose_as_static def add(self, other: Expr | float) -> "Expr": """Creates an expression that adds this expression to another expression or constant. @@ -133,6 +159,7 @@ def add(self, other: Expr | float) -> "Expr": """ return Add(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def subtract(self, other: Expr | float) -> "Expr": """Creates an expression that subtracts another expression or constant from this expression. @@ -150,6 +177,7 @@ def subtract(self, other: Expr | float) -> "Expr": """ return Subtract(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def multiply(self, other: Expr | float) -> "Expr": """Creates an expression that multiplies this expression by another expression or constant. @@ -167,6 +195,7 @@ def multiply(self, other: Expr | float) -> "Expr": """ return Multiply(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def divide(self, other: Expr | float) -> "Expr": """Creates an expression that divides this expression by another expression or constant. @@ -184,6 +213,7 @@ def divide(self, other: Expr | float) -> "Expr": """ return Divide(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def mod(self, other: Expr | float) -> "Expr": """Creates an expression that calculates the modulo (remainder) to another expression or constant. @@ -201,6 +231,7 @@ def mod(self, other: Expr | float) -> "Expr": """ return Mod(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def abs(self) -> "Abs": """Creates an expression that calculates the absolute value of this expression. @@ -213,6 +244,7 @@ def abs(self) -> "Abs": """ return Abs(self) + @expose_as_static def ceil(self) -> "Ceil": """Creates an expression that calculates the ceiling of this expression. @@ -225,6 +257,7 @@ def ceil(self) -> "Ceil": """ return Ceil(self) + @expose_as_static def exp(self) -> "Exp": """Creates an expression that computes e to the power of this expression. @@ -237,6 +270,7 @@ def exp(self) -> "Exp": """ return Exp(self) + @expose_as_static def floor(self) -> "Floor": """Creates an expression that calculates the floor of this expression. @@ -249,6 +283,7 @@ def floor(self) -> "Floor": """ return Floor(self) + @expose_as_static def ln(self) -> "Ln": """Creates an expression that calculates the natural logarithm of this expression. @@ -261,6 +296,7 @@ def ln(self) -> "Ln": """ return Ln(self) + @expose_as_static def log(self, base: Expr | float) -> "Log": """Creates an expression that calculates the logarithm of this expression with a given base. @@ -278,6 +314,7 @@ def log(self, base: Expr | float) -> "Log": """ return Log(self, self._cast_to_expr_or_convert_to_constant(base)) + @expose_as_static def pow(self, exponent: Expr | float) -> "Pow": """Creates an expression that calculates this expression raised to the power of the exponent. @@ -295,6 +332,7 @@ def pow(self, exponent: Expr | float) -> "Pow": """ return Pow(self, self._cast_to_expr_or_convert_to_constant(exponent)) + @expose_as_static def round(self) -> "Round": """Creates an expression that rounds this expression to the nearest integer. @@ -307,6 +345,7 @@ def round(self) -> "Round": """ return Round(self) + @expose_as_static def sqrt(self) -> "Sqrt": """Creates an expression that calculates the square root of this expression. @@ -319,6 +358,7 @@ def sqrt(self) -> "Sqrt": """ return Sqrt(self) + @expose_as_static def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the larger value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -340,6 +380,7 @@ def logical_maximum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """ return LogicalMaximum(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that returns the smaller value between this expression and another expression or constant, based on Firestore's value type ordering. @@ -361,6 +402,7 @@ def logical_minimum(self, other: Expr | CONSTANT_TYPE) -> "Expr": """ return LogicalMinimum(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to another expression or constant value. @@ -379,6 +421,7 @@ def equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """ return Equal(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def not_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to another expression or constant value. @@ -397,6 +440,7 @@ def not_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """ return NotEqual(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def greater_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than another expression or constant value. @@ -415,6 +459,7 @@ def greater_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """ return GreaterThan(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is greater than or equal to another expression or constant value. @@ -435,6 +480,7 @@ def greater_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": self, self._cast_to_expr_or_convert_to_constant(other) ) + @expose_as_static def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than another expression or constant value. @@ -453,6 +499,7 @@ def less_than(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """ return LessThan(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if this expression is less than or equal to another expression or constant value. @@ -471,6 +518,7 @@ def less_than_or_equal(self, other: Expr | CONSTANT_TYPE) -> "BooleanExpr": """ return LessThanOrEqual(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is equal to any of the provided values or expressions. @@ -487,6 +535,7 @@ def equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """ return EqualAny(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) + @expose_as_static def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """Creates an expression that checks if this expression is not equal to any of the provided values or expressions. @@ -503,11 +552,12 @@ def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """ return NotEqualAny(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) - def array(self, elements: list[Expr | CONSTANT_TYPE]) -> "Expr": + @staticmethod + def array(elements: list[Expr | CONSTANT_TYPE]) -> "Expr": """Creates an expression that creates a Firestore array value from an input list. Example: - >>> Expr.array(["bar", Field.of("baz").as("foo"))]) + >>> Expr.array(["bar", Field.of("baz")]) Args: elements: THe input list to evaluate in the expression @@ -515,8 +565,9 @@ def array(self, elements: list[Expr | CONSTANT_TYPE]) -> "Expr": Returns: A new `Expr` representing the array function. """ - return Array(elements) + return Array([Expr._cast_to_expr_or_convert_to_constant(e) for e in elements]) + @expose_as_static def array_get(self, index: Expr | int) -> "Expr": """Creates an expression that indexes into an array from the beginning or end and returns the element. If the index exceeds the array length, an error is @@ -534,6 +585,7 @@ def array_get(self, index: Expr | int) -> "Expr": """ return ArrayGet(self, self._cast_to_expr_or_convert_to_constant(index)) + @expose_as_static def array_concat(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Expr": """Creates an expression that concatenates an array expression with another array. @@ -551,6 +603,7 @@ def array_concat(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "Expr": self, [self._cast_to_expr_or_convert_to_constant(o) for o in array] ) + @expose_as_static def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": """Creates an expression that checks if an array contains a specific element or value. @@ -568,6 +621,7 @@ def array_contains(self, element: Expr | CONSTANT_TYPE) -> "BooleanExpr": """ return ArrayContains(self, self._cast_to_expr_or_convert_to_constant(element)) + @expose_as_static def array_contains_all( self, elements: Sequence[Expr | CONSTANT_TYPE] ) -> "BooleanExpr": @@ -589,6 +643,7 @@ def array_contains_all( self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] ) + @expose_as_static def array_contains_any( self, elements: Sequence[Expr | CONSTANT_TYPE] ) -> "BooleanExpr": @@ -611,6 +666,7 @@ def array_contains_any( self, [self._cast_to_expr_or_convert_to_constant(e) for e in elements] ) + @expose_as_static def array_length(self) -> "Expr": """Creates an expression that calculates the length of an array. @@ -623,6 +679,7 @@ def array_length(self) -> "Expr": """ return ArrayLength(self) + @expose_as_static def array_reverse(self) -> "Expr": """Creates an expression that returns the reversed content of an array. @@ -635,6 +692,7 @@ def array_reverse(self) -> "Expr": """ return ArrayReverse(self) + @expose_as_static def is_nan(self) -> "BooleanExpr": """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). @@ -647,6 +705,7 @@ def is_nan(self) -> "BooleanExpr": """ return IsNaN(self) + @expose_as_static def exists(self) -> "BooleanExpr": """Creates an expression that checks if a field exists in the document. @@ -659,6 +718,7 @@ def exists(self) -> "BooleanExpr": """ return Exists(self) + @expose_as_static def sum(self) -> "Expr": """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. @@ -671,6 +731,7 @@ def sum(self) -> "Expr": """ return Sum(self) + @expose_as_static def average(self) -> "Expr": """Creates an aggregation that calculates the average (mean) of a numeric field across multiple stage inputs. @@ -684,6 +745,7 @@ def average(self) -> "Expr": """ return Average(self) + def count(self) -> "Expr": """Creates an aggregation that counts the number of stage inputs with valid evaluations of the expression or field. @@ -697,6 +759,7 @@ def count(self) -> "Expr": """ return Count(self) + @expose_as_static def minimum(self) -> "Expr": """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. @@ -709,6 +772,7 @@ def minimum(self) -> "Expr": """ return Minimum(self) + @expose_as_static def maximum(self) -> "Expr": """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. @@ -721,6 +785,7 @@ def maximum(self) -> "Expr": """ return Maximum(self) + @expose_as_static def char_length(self) -> "Expr": """Creates an expression that calculates the character length of a string. @@ -733,6 +798,7 @@ def char_length(self) -> "Expr": """ return CharLength(self) + @expose_as_static def byte_length(self) -> "Expr": """Creates an expression that calculates the byte length of a string in its UTF-8 form. @@ -745,6 +811,7 @@ def byte_length(self) -> "Expr": """ return ByteLength(self) + @expose_as_static def like(self, pattern: Expr | str) -> "BooleanExpr": """Creates an expression that performs a case-sensitive string comparison. @@ -762,6 +829,7 @@ def like(self, pattern: Expr | str) -> "BooleanExpr": """ return Like(self, self._cast_to_expr_or_convert_to_constant(pattern)) + @expose_as_static def regex_contains(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string contains a specified regular expression as a substring. @@ -780,6 +848,7 @@ def regex_contains(self, regex: Expr | str) -> "BooleanExpr": """ return RegexContains(self, self._cast_to_expr_or_convert_to_constant(regex)) + @expose_as_static def regex_matches(self, regex: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string matches a specified regular expression. @@ -797,6 +866,7 @@ def regex_matches(self, regex: Expr | str) -> "BooleanExpr": """ return RegexMatch(self, self._cast_to_expr_or_convert_to_constant(regex)) + @expose_as_static def string_contains(self, substring: Expr | str) -> "BooleanExpr": """Creates an expression that checks if this string expression contains a specified substring. @@ -816,6 +886,7 @@ def string_contains(self, substring: Expr | str) -> "BooleanExpr": self, self._cast_to_expr_or_convert_to_constant(substring) ) + @expose_as_static def starts_with(self, prefix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string starts with a given prefix. @@ -833,6 +904,7 @@ def starts_with(self, prefix: Expr | str) -> "BooleanExpr": """ return StartsWith(self, self._cast_to_expr_or_convert_to_constant(prefix)) + @expose_as_static def ends_with(self, postfix: Expr | str) -> "BooleanExpr": """Creates an expression that checks if a string ends with a given postfix. @@ -850,6 +922,7 @@ def ends_with(self, postfix: Expr | str) -> "BooleanExpr": """ return EndsWith(self, self._cast_to_expr_or_convert_to_constant(postfix)) + @expose_as_static def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": """Creates an expression that concatenates string expressions, fields or constants together. @@ -867,6 +940,7 @@ def string_concat(self, *elements: Expr | CONSTANT_TYPE) -> "Expr": self, *[self._cast_to_expr_or_convert_to_constant(el) for el in elements] ) + @expose_as_static def to_lower(self) -> "Expr": """Creates an expression that converts a string to lowercase. @@ -879,6 +953,7 @@ def to_lower(self) -> "Expr": """ return ToLower(self) + @expose_as_static def to_upper(self) -> "Expr": """Creates an expression that converts a string to uppercase. @@ -891,6 +966,7 @@ def to_upper(self) -> "Expr": """ return ToUpper(self) + @expose_as_static def trim(self) -> "Expr": """Creates an expression that removes leading and trailing whitespace from a string. @@ -903,6 +979,7 @@ def trim(self) -> "Expr": """ return Trim(self) + @expose_as_static def reverse(self) -> "Expr": """Creates an expression that reverses a string. @@ -915,6 +992,7 @@ def reverse(self) -> "Expr": """ return Reverse(self) + @expose_as_static def map_get(self, key: str) -> "Expr": """Accesses a value from a map (object) field using the provided key. @@ -931,6 +1009,7 @@ def map_get(self, key: str) -> "Expr": """ return MapGet(self, Constant.of(key)) + @expose_as_static def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": """Calculates the cosine distance between two vectors. @@ -948,6 +1027,7 @@ def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": """ return CosineDistance(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": """Calculates the Euclidean distance between two vectors. @@ -965,6 +1045,7 @@ def euclidean_distance(self, other: Expr | list[float] | Vector) -> "Expr": """ return EuclideanDistance(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": """Calculates the dot product between two vectors. @@ -982,6 +1063,7 @@ def dot_product(self, other: Expr | list[float] | Vector) -> "Expr": """ return DotProduct(self, self._cast_to_expr_or_convert_to_constant(other)) + @expose_as_static def vector_length(self) -> "Expr": """Creates an expression that calculates the length (dimension) of a Firestore Vector. @@ -994,6 +1076,7 @@ def vector_length(self) -> "Expr": """ return VectorLength(self) + @expose_as_static def timestamp_to_unix_micros(self) -> "Expr": """Creates an expression that converts a timestamp to the number of microseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1009,6 +1092,7 @@ def timestamp_to_unix_micros(self) -> "Expr": """ return TimestampToUnixMicros(self) + @expose_as_static def unix_micros_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1022,6 +1106,7 @@ def unix_micros_to_timestamp(self) -> "Expr": """ return UnixMicrosToTimestamp(self) + @expose_as_static def timestamp_to_unix_millis(self) -> "Expr": """Creates an expression that converts a timestamp to the number of milliseconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1037,6 +1122,7 @@ def timestamp_to_unix_millis(self) -> "Expr": """ return TimestampToUnixMillis(self) + @expose_as_static def unix_millis_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1050,6 +1136,7 @@ def unix_millis_to_timestamp(self) -> "Expr": """ return UnixMillisToTimestamp(self) + @expose_as_static def timestamp_to_unix_seconds(self) -> "Expr": """Creates an expression that converts a timestamp to the number of seconds since the epoch (1970-01-01 00:00:00 UTC). @@ -1065,6 +1152,7 @@ def timestamp_to_unix_seconds(self) -> "Expr": """ return TimestampToUnixSeconds(self) + @expose_as_static def unix_seconds_to_timestamp(self) -> "Expr": """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 UTC) to a timestamp. @@ -1078,6 +1166,7 @@ def unix_seconds_to_timestamp(self) -> "Expr": """ return UnixSecondsToTimestamp(self) + @expose_as_static def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that adds a specified amount of time to this timestamp expression. @@ -1101,6 +1190,7 @@ def timestamp_add(self, unit: Expr | str, amount: Expr | float) -> "Expr": self._cast_to_expr_or_convert_to_constant(amount), ) + @expose_as_static def timestamp_subtract(self, unit: Expr | str, amount: Expr | float) -> "Expr": """Creates an expression that subtracts a specified amount of time from this timestamp expression. @@ -1237,950 +1327,6 @@ def _to_pb(self): } ) - def add(left: Expr | str, right: Expr | float) -> "Expr": - """Creates an expression that adds two expressions together. - - Example: - >>> Function.add("rating", 5) - >>> Function.add(Field.of("quantity"), Field.of("reserve")) - - Args: - left: The first expression or field path to add. - right: The second expression or constant value to add. - - Returns: - A new `Expr` representing the addition operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.add(left_expr, right) - - def subtract(left: Expr | str, right: Expr | float) -> "Expr": - """Creates an expression that subtracts another expression or constant from this expression. - - Example: - >>> Function.subtract("total", 20) - >>> Function.subtract(Field.of("price"), Field.of("discount")) - - Args: - left: The expression or field path to subtract from. - right: The expression or constant value to subtract. - - Returns: - A new `Expr` representing the subtraction operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.subtract(left_expr, right) - - def multiply(left: Expr | str, right: Expr | float) -> "Expr": - """Creates an expression that multiplies this expression by another expression or constant. - - Example: - >>> Function.multiply("value", 2) - >>> Function.multiply(Field.of("quantity"), Field.of("price")) - - Args: - left: The expression or field path to multiply. - right: The expression or constant value to multiply by. - - Returns: - A new `Expr` representing the multiplication operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.multiply(left_expr, right) - - def divide(left: Expr | str, right: Expr | float) -> "Expr": - """Creates an expression that divides this expression by another expression or constant. - - Example: - >>> Function.divide("value", 10) - >>> Function.divide(Field.of("total"), Field.of("count")) - - Args: - left: The expression or field path to be divided. - right: The expression or constant value to divide by. - - Returns: - A new `Expr` representing the division operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.divide(left_expr, right) - - def mod(left: Expr | str, right: Expr | float) -> "Expr": - """Creates an expression that calculates the modulo (remainder) to another expression or constant. - - Example: - >>> Function.mod("value", 5) - >>> Function.mod(Field.of("value"), Field.of("divisor")) - - Args: - left: The dividend expression or field path. - right: The divisor expression or constant. - - Returns: - A new `Expr` representing the modulo operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.mod(left_expr, right) - - def abs(expr: Expr | str) -> "Abs": - """Creates an expression that calculates the absolute value of an expression. - - Example: - >>> Function.abs("change") - - Args: - expr: The expression or field path. - - Returns: - A new `Expr` representing the absolute value. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.abs(expr_val) - - def ceil(expr: Expr | str) -> "Ceil": - """Creates an expression that calculates the ceiling of an expression. - - Example: - >>> Function.ceil("value") - - Args: - expr: The expression or field path. - - Returns: - A new `Expr` representing the ceiling value. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.ceil(expr_val) - - def exp(expr: Expr | str) -> "Exp": - """Creates an expression that calculates the exponential of an expression. - - Example: - >>> Function.exp("value") - - Args: - expr: The expression or field path. - - Returns: - A new `Expr` representing the exponential value. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.exp(expr_val) - - def floor(expr: Expr | str) -> "Floor": - """Creates an expression that calculates the floor of an expression. - - Example: - >>> Function.floor("value") - - Args: - expr: The expression or field path. - - Returns: - A new `Expr` representing the floor value. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.floor(expr_val) - - def ln(expr: Expr | str) -> "Ln": - """Creates an expression that calculates the natural logarithm of an expression. - - Example: - >>> Function.ln("value") - - Args: - expr: The expression or field path. - - Returns: - A new `Expr` representing the natural logarithm. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.ln(expr_val) - - def log(expr: Expr | str, base: Expr | float) -> "Log": - """Creates an expression that calculates the logarithm of an expression with a given base. - - Example: - >>> Function.log("value", 2) - - Args: - expr: The expression or field path. - base: The base of the logarithm. - - Returns: - A new `Expr` representing the logarithm. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.log(expr_val, base) - - def pow(base: Expr | str, exponent: Expr | float) -> "Pow": - """Creates an expression that calculates the base raised to the power of the exponent. - - Example: - >>> Function.pow("base_val", 2) - - Args: - base: The base expression or field path. - exponent: The exponent. - - Returns: - A new `Expr` representing the power operation. - """ - base_val = Field.of(base) if isinstance(base, str) else base - return Expr.pow(base_val, exponent) - - def round(expr: Expr | str) -> "Round": - """Creates an expression that rounds an expression to the nearest integer. - - Example: - >>> Function.round("value") - - Args: - expr: The expression or field path. - - Returns: - A new `Expr` representing the rounded value. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.round(expr_val) - - def sqrt(expr: Expr | str) -> "Sqrt": - """Creates an expression that calculates the square root of an expression. - - Example: - >>> Function.sqrt("area") - - Args: - expr: The expression or field path. - - Returns: - A new `Expr` representing the square root. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.sqrt(expr_val) - - def logical_maximum(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": - """Creates an expression that returns the larger value between this expression - and another expression or constant, based on Firestore's value type ordering. - - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering - - Example: - >>> Function.logical_maximum("value", 10) - >>> Function.logical_maximum(Field.of("discount"), Field.of("cap")) - - Args: - left: The expression or field path to compare. - right: The other expression or constant value to compare with. - - Returns: - A new `Expr` representing the logical max operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_maximum(left_expr, right) - - def logical_minimum(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "Expr": - """Creates an expression that returns the smaller value between this expression - and another expression or constant, based on Firestore's value type ordering. - - Firestore's value type ordering is described here: - https://cloud.google.com/firestore/docs/concepts/data-types#value_type_ordering - - Example: - >>> Function.logical_minimum("value", 10) - >>> Function.logical_minimum(Field.of("discount"), Field.of("floor")) - - Args: - left: The expression or field path to compare. - right: The other expression or constant value to compare with. - - Returns: - A new `Expr` representing the logical minimum operation. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.logical_minimum(left_expr, right) - - def equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": - """Creates an expression that checks if this expression is equal to another - expression or constant value. - - Example: - >>> Function.equal("city", "London") - >>> Function.equal(Field.of("age"), 21) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for equality. - - Returns: - A new `Expr` representing the equality comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.equal(left_expr, right) - - def not_equal(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": - """Creates an expression that checks if this expression is not equal to another - expression or constant value. - - Example: - >>> Function.not_equal("country", "USA") - >>> Function.not_equal(Field.of("status"), "completed") - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for inequality. - - Returns: - A new `Expr` representing the inequality comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.not_equal(left_expr, right) - - def greater_than(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": - """Creates an expression that checks if this expression is greater than another - expression or constant value. - - Example: - >>> Function.greater_than("price", 100) - >>> Function.greater_than(Field.of("age"), Field.of("limit")) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for greater than. - - Returns: - A new `Expr` representing the greater than comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.greater_than(left_expr, right) - - def greater_than_or_equal( - left: Expr | str, right: Expr | CONSTANT_TYPE - ) -> "BooleanExpr": - """Creates an expression that checks if this expression is greater than or equal - to another expression or constant value. - - Example: - >>> Function.greater_than_or_equal("score", 80) - >>> Function.greater_than_or_equal(Field.of("quantity"), Field.of('requirement').add(1)) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for greater than or equal to. - - Returns: - A new `Expr` representing the greater than or equal to comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.greater_than_or_equal(left_expr, right) - - def less_than(left: Expr | str, right: Expr | CONSTANT_TYPE) -> "BooleanExpr": - """Creates an expression that checks if this expression is less than another - expression or constant value. - - Example: - >>> Function.less_than("price", 50) - >>> Function.less_than(Field.of("age"), Field.of('limit')) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for less than. - - Returns: - A new `Expr` representing the less than comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.less_than(left_expr, right) - - def less_than_or_equal( - left: Expr | str, right: Expr | CONSTANT_TYPE - ) -> "BooleanExpr": - """Creates an expression that checks if this expression is less than or equal to - another expression or constant value. - - Example: - >>> Function.less_than_or_equal("score", 70) - >>> Function.less_than_or_equal(Field.of("quantity"), Constant.of(20)) - - Args: - left: The expression or field path to compare. - right: The expression or constant value to compare for less than or equal to. - - Returns: - A new `Expr` representing the less than or equal to comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.less_than_or_equal(left_expr, right) - - def equal_any( - left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE] - ) -> "BooleanExpr": - """Creates an expression that checks if this expression is equal to any of the - provided values or expressions. - - Example: - >>> Function.equal_any("category", ["Electronics", "Apparel"]) - >>> Function.equal_any(Field.of("category"), ["Electronics", Field.of("primaryType")]) - - Args: - left: The expression or field path to compare. - array: The values or expressions to check against. - - Returns: - A new `Expr` representing the 'IN' comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.equal_any(left_expr, array) - - def not_equal_any( - left: Expr | str, array: Sequence[Expr | CONSTANT_TYPE] - ) -> "BooleanExpr": - """Creates an expression that checks if this expression is not equal to any of the - provided values or expressions. - - Example: - >>> Function.not_equal_any("status", ["pending", "cancelled"]) - - Args: - left: The expression or field path to compare. - array: The values or expressions to check against. - - Returns: - A new `Expr` representing the 'NOT IN' comparison. - """ - left_expr = Field.of(left) if isinstance(left, str) else left - return Expr.not_equal_any(left_expr, array) - - def array_get(array: Expr | str, index: Expr | int) -> "Expr": - """Creates an expression that indexes into an array from the beginning or end - and returns the element. If the index exceeds the array length, an error is - returned. A negative index, starts from the end. - - Example: - >>> # Return the value in the tags field array at index specified by field 'favoriteTag'. - >>> Field.of("tags").array_get(Field.of("favoriteTag")) - - Args: - index: The index of the element to return. - - Returns: - A new `Expr` representing the operation. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_get(array_expr, index) - - def array_concat(array: Expr | str, other: Sequence[Expr | CONSTANT_TYPE]) -> "Expr": - """Creates an expression that concatenates an array expression with another array. - - Example: - >>> # Combine the 'tags' array with a new array and an array field - >>> Field.of("tags").array_concat(["newTag1", "newTag2", Field.of("otherTag")]) - - Args: - array: The list of constants or expressions to concat with. - - Returns: - A new `Expr` representing the concatenated array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_concat(array_expr, other) - - - def array_contains( - array: Expr | str, element: Expr | CONSTANT_TYPE - ) -> "BooleanExpr": - """Creates an expression that checks if an array contains a specific element or value. - - Example: - >>> Function.array_contains("colors", "red") - >>> Function.array_contains(Field.of("sizes"), Field.of("selectedSize")) - - Args: - array: The array expression or field path to check. - element: The element (expression or constant) to search for in the array. - - Returns: - A new `Expr` representing the 'array_contains' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains(array_expr, element) - - def array_contains_all( - array: Expr | str, elements: Sequence[Expr | CONSTANT_TYPE] - ) -> "BooleanExpr": - """Creates an expression that checks if an array contains all the specified elements. - - Example: - >>> Function.array_contains_all("tags", ["news", "sports"]) - >>> Function.array_contains_all(Field.of("tags"), [Field.of("tag1"), "tag2"]) - - Args: - array: The array expression or field path to check. - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_all' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains_all(array_expr, elements) - - def array_contains_any( - array: Expr | str, elements: Sequence[Expr | CONSTANT_TYPE] - ) -> "BooleanExpr": - """Creates an expression that checks if an array contains any of the specified elements. - - Example: - >>> Function.array_contains_any("groups", ["admin", "editor"]) - >>> Function.array_contains_any(Field.of("categories"), [Field.of("cate1"), Field.of("cate2")]) - - Args: - array: The array expression or field path to check. - elements: The list of elements (expressions or constants) to check for in the array. - - Returns: - A new `Expr` representing the 'array_contains_any' comparison. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_contains_any(array_expr, elements) - - def array_length(array: Expr | str) -> "Expr": - """Creates an expression that calculates the length of an array. - - Example: - >>> Function.array_length("cart") - - Returns: - A new `Expr` representing the length of the array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_length(array_expr) - - def array_reverse(array: Expr | str) -> "Expr": - """Creates an expression that returns the reversed content of an array. - - Example: - >>> Function.array_reverse("preferences") - - Returns: - A new `Expr` representing the reversed array. - """ - array_expr = Field.of(array) if isinstance(array, str) else array - return Expr.array_reverse(array_expr) - - def is_nan(expr: Expr | str) -> "BooleanExpr": - """Creates an expression that checks if this expression evaluates to 'NaN' (Not a Number). - - Example: - >>> Function.is_nan("measurement") - - Returns: - A new `Expr` representing the 'isNaN' check. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.is_nan(expr_val) - - def exists(expr: Expr | str) -> "BooleanExpr": - """Creates an expression that checks if a field exists in the document. - - Example: - >>> Function.exists("phoneNumber") - - Returns: - A new `Expr` representing the 'exists' check. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.exists(expr_val) - - def sum(expr: Expr | str) -> "Expr": - """Creates an aggregation that calculates the sum of a numeric field across multiple stage inputs. - - Example: - >>> Function.sum("orderAmount") - - Returns: - A new `AggregateFunction` representing the 'sum' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.sum(expr_val) - - def average(expr: Expr | str) -> "Expr": - """Creates an aggregation that calculates the average (mean) of a numeric field across multiple - stage inputs. - - Example: - >>> Function.average("age") - - Returns: - A new `AggregateFunction` representing the 'average' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.average(expr_val) - - def count(expr: Expr | str | None = None) -> "Expr": - """Creates an aggregation that counts the number of stage inputs with valid evaluations of the - expression or field. If no expression is provided, it counts all inputs. - - Example: - >>> Function.count("productId") - >>> Function.count() - - Returns: - A new `AggregateFunction` representing the 'count' aggregation. - """ - if expr is None: - return Count() - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.count(expr_val) - - def minimum(expr: Expr | str) -> "Expr": - """Creates an aggregation that finds the minimum value of a field across multiple stage inputs. - - Example: - >>> Function.minimum("price") - - Returns: - A new `AggregateFunction` representing the 'minimum' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.minimum(expr_val) - - def maximum(expr: Expr | str) -> "Expr": - """Creates an aggregation that finds the maximum value of a field across multiple stage inputs. - - Example: - >>> Function.maximum("score") - - Returns: - A new `AggregateFunction` representing the 'maximum' aggregation. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.maximum(expr_val) - - def char_length(expr: Expr | str) -> "Expr": - """Creates an expression that calculates the character length of a string. - - Example: - >>> Function.char_length("name") - - Returns: - A new `Expr` representing the length of the string. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.char_length(expr_val) - - def byte_length(expr: Expr | str) -> "Expr": - """Creates an expression that calculates the byte length of a string in its UTF-8 form. - - Example: - >>> Function.byte_length("name") - - Returns: - A new `Expr` representing the byte length of the string. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.byte_length(expr_val) - - def like(expr: Expr | str, pattern: Expr | str) -> "BooleanExpr": - """Creates an expression that performs a case-sensitive string comparison. - - Example: - >>> Function.like("title", "%guide%") - >>> Function.like(Field.of("title"), Field.of("pattern")) - - Args: - expr: The expression or field path to perform the comparison on. - pattern: The pattern (string or expression) to search for. You can use "%" as a wildcard character. - - Returns: - A new `Expr` representing the 'like' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.like(expr_val, pattern) - - def regex_contains(expr: Expr | str, regex: Expr | str) -> "BooleanExpr": - """Creates an expression that checks if a string contains a specified regular expression as a - substring. - - Example: - >>> Function.regex_contains("description", "(?i)example") - >>> Function.regex_contains(Field.of("description"), Field.of("regex")) - - Args: - expr: The expression or field path to perform the comparison on. - regex: The regular expression (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.regex_contains(expr_val, regex) - - def regex_matches(expr: Expr | str, regex: Expr | str) -> "BooleanExpr": - """Creates an expression that checks if a string matches a specified regular expression. - - Example: - >>> # Check if the 'email' field matches a valid email pattern - >>> Function.regex_matches("email", "[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}") - >>> Function.regex_matches(Field.of("email"), Field.of("regex")) - - Args: - expr: The expression or field path to match against. - regex: The regular expression (string or expression) to use for the match. - - Returns: - A new `Expr` representing the regular expression match. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.regex_matches(expr_val, regex) - - def string_contains(expr: Expr | str, substring: Expr | str) -> "BooleanExpr": - """Creates an expression that checks if this string expression contains a specified substring. - - Example: - >>> Function.string_contains("description", "example") - >>> Function.string_contains(Field.of("description"), Field.of("keyword")) - - Args: - expr: The expression or field path to perform the comparison on. - substring: The substring (string or expression) to use for the search. - - Returns: - A new `Expr` representing the 'contains' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.string_contains(expr_val, substring) - - def starts_with(expr: Expr | str, prefix: Expr | str) -> "BooleanExpr": - """Creates an expression that checks if a string starts with a given prefix. - - Example: - >>> Function.starts_with("name", "Mr.") - >>> Function.starts_with(Field.of("fullName"), Field.of("firstName")) - - Args: - expr: The expression or field path to check. - prefix: The prefix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'starts with' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.starts_with(expr_val, prefix) - - def ends_with(expr: Expr | str, postfix: Expr | str) -> "BooleanExpr": - """Creates an expression that checks if a string ends with a given postfix. - - Example: - >>> Function.ends_with("filename", ".txt") - >>> Function.ends_with(Field.of("url"), Field.of("extension")) - - Args: - expr: The expression or field path to check. - postfix: The postfix (string or expression) to check for. - - Returns: - A new `Expr` representing the 'ends with' comparison. - """ - expr_val = Field.of(expr) if isinstance(expr, str) else expr - return Expr.ends_with(expr_val, postfix) - - def string_concat(first: Expr | str, *elements: Expr | CONSTANT_TYPE) -> "Expr": - """Creates an expression that concatenates string expressions, fields or constants together. - - Example: - >>> Function.string_concat("firstName", " ", Field.of("lastName")) - - Args: - first: The first expression or field path to concatenate. - *elements: The expressions or constants (typically strings) to concatenate. - - Returns: - A new `Expr` representing the concatenated string. - """ - first_expr = Field.of(first) if isinstance(first, str) else first - return Expr.string_concat(first_expr, *elements) - - def map_get(map_expr: Expr | str, key: str) -> "Expr": - """Accesses a value from a map (object) field using the provided key. - - Example: - >>> Function.map_get("address", "city") - - Args: - map_expr: The expression or field path of the map. - key: The key to access in the map. - - Returns: - A new `Expr` representing the value associated with the given key in the map. - """ - map_val = Field.of(map_expr) if isinstance(map_expr, str) else map_expr - return Expr.map_get(map_val, key) - - def vector_length(vector_expr: Expr | str) -> "Expr": - """Creates an expression that calculates the length (dimension) of a Firestore Vector. - - Example: - >>> Function.vector_length("embedding") - - Returns: - A new `Expr` representing the length of the vector. - """ - vector_val = ( - Field.of(vector_expr) if isinstance(vector_expr, str) else vector_expr - ) - return Expr.vector_length(vector_val) - - def timestamp_to_unix_micros(timestamp_expr: Expr | str) -> "Expr": - """Creates an expression that converts a timestamp to the number of microseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the microsecond. - - Example: - >>> Function.timestamp_to_unix_micros("timestamp") - - Returns: - A new `Expr` representing the number of microseconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_micros(timestamp_val) - - def unix_micros_to_timestamp(micros_expr: Expr | str) -> "Expr": - """Creates an expression that converts a number of microseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> Function.unix_micros_to_timestamp("microseconds") - - Returns: - A new `Expr` representing the timestamp. - """ - micros_val = ( - Field.of(micros_expr) if isinstance(micros_expr, str) else micros_expr - ) - return Expr.unix_micros_to_timestamp(micros_val) - - def timestamp_to_unix_millis(timestamp_expr: Expr | str) -> "Expr": - """Creates an expression that converts a timestamp to the number of milliseconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the millisecond. - - Example: - >>> Function.timestamp_to_unix_millis("timestamp") - - Returns: - A new `Expr` representing the number of milliseconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_millis(timestamp_val) - - def unix_millis_to_timestamp(millis_expr: Expr | str) -> "Expr": - """Creates an expression that converts a number of milliseconds since the epoch (1970-01-01 - 00:00:00 UTC) to a timestamp. - - Example: - >>> Function.unix_millis_to_timestamp("milliseconds") - - Returns: - A new `Expr` representing the timestamp. - """ - millis_val = ( - Field.of(millis_expr) if isinstance(millis_expr, str) else millis_expr - ) - return Expr.unix_millis_to_timestamp(millis_val) - - def timestamp_to_unix_seconds( - timestamp_expr: Expr | str, - ) -> "Expr": - """Creates an expression that converts a timestamp to the number of seconds since the epoch - (1970-01-01 00:00:00 UTC). - - Truncates higher levels of precision by rounding down to the beginning of the second. - - Example: - >>> Function.timestamp_to_unix_seconds("timestamp") - - Returns: - A new `Expr` representing the number of seconds since the epoch. - """ - timestamp_val = ( - Field.of(timestamp_expr) - if isinstance(timestamp_expr, str) - else timestamp_expr - ) - return Expr.timestamp_to_unix_seconds(timestamp_val) - - def unix_seconds_to_timestamp(seconds_expr: Expr | str) -> "Expr": - """Creates an expression that converts a number of seconds since the epoch (1970-01-01 00:00:00 - UTC) to a timestamp. - - Example: - >>> Function.unix_seconds_to_timestamp("seconds") - - Returns: - A new `Expr` representing the timestamp. - """ - seconds_val = ( - Field.of(seconds_expr) if isinstance(seconds_expr, str) else seconds_expr - ) - return Expr.unix_seconds_to_timestamp(seconds_val) - - def timestamp_add( - timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "Expr": - """Creates an expression that adds a specified amount of time to this timestamp expression. - - Example: - >>> Function.timestamp_add("timestamp", "day", 1.5) - >>> Function.timestamp_add(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) - - Args: - timestamp: The expression or field path of the timestamp. - unit: The expression or string evaluating to the unit of time to add, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to add. - - Returns: - A new `Expr` representing the resulting timestamp. - """ - timestamp_expr = ( - Field.of(timestamp) if isinstance(timestamp, str) else timestamp - ) - return Expr.timestamp_add(timestamp_expr, unit, amount) - - def timestamp_subtract( - timestamp: Expr | str, unit: Expr | str, amount: Expr | float - ) -> "Expr": - """Creates an expression that subtracts a specified amount of time from this timestamp expression. - - Example: - >>> Function.timestamp_sub("timestamp", "hour", 2.5) - >>> Function.timestamp_sub(Field.of("timestamp"), Field.of("unit"), Field.of("amount")) - - Args: - timestamp: The expression or field path of the timestamp. - unit: The expression or string evaluating to the unit of time to subtract, must be one of - 'microsecond', 'millisecond', 'second', 'minute', 'hour', 'day'. - amount: The expression or float representing the amount of time to subtract. - - Returns: - A new `Expr` representing the resulting timestamp. - """ - timestamp_expr = ( - Field.of(timestamp) if isinstance(timestamp, str) else timestamp - ) - return Expr.timestamp_subtract(timestamp_expr, unit, amount) - class Divide(Function): """Represents the division function.""" diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index f765491e9..4feee4ba4 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -999,7 +999,7 @@ class TestFunctionClasses: ("greater_than_or_equal", ("field", 2), expr.GreaterThanOrEqual), ("equal_any", ("field", [None]), expr.EqualAny), ("not_equal_any", ("field", [None]), expr.NotEqualAny), - ("array", ("field", [1, 2, 3]), expr.Array), + ("array", ([1, 2, 3],), expr.Array), ("array_get", ("field", 2), expr.ArrayGet), ("array_contains", ("field", None), expr.ArrayContains), ("array_contains_all", ("field", [None]), expr.ArrayContainsAll), @@ -1011,7 +1011,6 @@ class TestFunctionClasses: ("sum", ("field",), expr.Sum), ("average", ("field",), expr.Average), ("count", ("field",), expr.Count), - ("count", (), expr.Count), ("minimum", ("field",), expr.Minimum), ("maximum", ("field",), expr.Maximum), ("char_length", ("field",), expr.CharLength), From 32c4e4b5341bf1d15ad3c2321dd41c2c9be53b5f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 17 Oct 2025 20:56:45 -0700 Subject: [PATCH 15/16] added map expressions --- .../firestore_v1/pipeline_expressions.py | 121 +++++++++++++++--- tests/unit/v1/test_pipeline_expressions.py | 30 +++++ 2 files changed, 132 insertions(+), 19 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index c9c03529a..6c4c12da0 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -141,6 +141,36 @@ def __get__(self, instance, owner): else: return self.instance_func.__get__(instance, owner) + @staticmethod + def array(elements: list[Expr | CONSTANT_TYPE]) -> "Expr": + """Creates an expression that creates a Firestore array value from an input list. + + Example: + >>> Expr.array(["bar", Field.of("baz")]) + + Args: + elements: THe input list to evaluate in the expression + + Returns: + A new `Expr` representing the array function. + """ + return Array([Expr._cast_to_expr_or_convert_to_constant(e) for e in elements]) + + @staticmethod + def map(elements: dict[str, Expr | CONSTANT_TYPE]) -> "Expr": + """Creates an expression that creates a Firestore map value from an input dict. + + Example: + >>> Expr.map({"foo": "bar", "baz": Field.of("baz")}) + + Args: + elements: THe input dict to evaluate in the expression + + Returns: + A new `Expr` representing the map function. + """ + return Map({Constant.of(k): Expr._cast_to_expr_or_convert_to_constant(v) for k, v in elements.items()}) + @expose_as_static def add(self, other: Expr | float) -> "Expr": """Creates an expression that adds this expression to another expression or constant. @@ -552,21 +582,6 @@ def not_equal_any(self, array: Sequence[Expr | CONSTANT_TYPE]) -> "BooleanExpr": """ return NotEqualAny(self, [self._cast_to_expr_or_convert_to_constant(v) for v in array]) - @staticmethod - def array(elements: list[Expr | CONSTANT_TYPE]) -> "Expr": - """Creates an expression that creates a Firestore array value from an input list. - - Example: - >>> Expr.array(["bar", Field.of("baz")]) - - Args: - elements: THe input list to evaluate in the expression - - Returns: - A new `Expr` representing the array function. - """ - return Array([Expr._cast_to_expr_or_convert_to_constant(e) for e in elements]) - @expose_as_static def array_get(self, index: Expr | int) -> "Expr": """Creates an expression that indexes into an array from the beginning or end @@ -994,11 +1009,10 @@ def reverse(self) -> "Expr": @expose_as_static def map_get(self, key: str) -> "Expr": - """Accesses a value from a map (object) field using the provided key. + """Accesses a value from the map produced by evaluating this expression. Example: - >>> # Get the 'city' value from - >>> # the 'address' map field + >>> Expr.map({"city": "London"}).map_get("city") >>> Field.of("address").map_get("city") Args: @@ -1009,6 +1023,42 @@ def map_get(self, key: str) -> "Expr": """ return MapGet(self, Constant.of(key)) + @expose_as_static + def map_remove(self, key: str) -> "Expr": + """Remove a key from a the map produced by evaluating this expression. + + Example: + >>> Expr.map({"city": "London"}).map_remove("city") + >>> Field.of("address").map_remove("city") + + Args: + key: The key to ewmove in the map. + + Returns: + A new `Expr` representing the map_remove operation. + """ + return MapRemove(self, Constant.of(key)) + + @expose_as_static + def map_merge(self, *other_maps: Expr | dict[str, Expr | CONSTANT_TYPE])-> "Expr": + """Creates an expression that merges one or more dicts into a single map. + + Example: + >>> Field.of("settings").map_merge({"enabled":True}, Function.cond(Field.of('isAdmin'), {"admin":True}, {}}) + >>> Expr.map({"city": "London"}).map_merge({"country": "UK"}, {"isCapital": True}) + + Args: + *other_maps: Sequence of maps to merge into the resulting map. + + Returns: + A new `Expr` representing the value associated with the given key in the map. + """ + map_list = [] + for map in other_maps: + map_list.append(map if isinstance(map, Expr) else Expr.map(map)) + return MapMerge(self, *map_list) + + @expose_as_static def cosine_distance(self, other: Expr | list[float] | Vector) -> "Expr": """Calculates the cosine distance between two vectors. @@ -1280,6 +1330,9 @@ def of(value: CONSTANT_TYPE) -> Constant[CONSTANT_TYPE]: def __repr__(self): return f"Constant.of({self.value!r})" + def __hash__(self): + return hash(self.value) + def _to_pb(self) -> Value: return encode_value(self.value) @@ -1369,13 +1422,43 @@ def __init__(self, left: Expr, right: Expr): super().__init__("min", [left, right]) +class Map(Function): + """Creates an expression that creates a Firestore map value from an input dict.""" + + def __init__(self, elements: dict[Constant[str], Expr]): + element_list = [] + for k,v in elements.items(): + element_list.append(k) + element_list.append(v) + super().__init__("map", element_list) + + def __repr__(self): + d = {a:b for a, b in zip(self.params[::2], self.params[1::2])} + return f"Map({d})" + + class MapGet(Function): - """Represents accessing a value within a map by key.""" + """Creates an expression that accesses a map value by key.""" def __init__(self, map_: Expr, key: Constant[str]): super().__init__("map_get", [map_, key]) +class MapMerge(Function): + """Creates an expression that merges multiple map values.""" + + def __init__(self, *maps: Expr): + super().__init__("map_merge", [*maps]) + + +class MapRemove(Function): + """Creates an expression that removes a key from a map.""" + + def __init__(self, map_: Expr, key: Constant[str]): + super().__init__("map_remove", [map_, key]) + + + class Mod(Function): """Represents the modulo function.""" diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 4feee4ba4..4074c7e26 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -137,6 +137,8 @@ def test_ctor(self): ("trim", (), expr.Trim), ("reverse", (), expr.Reverse), ("map_get", ("key",), expr.MapGet), + ("map_remove", ("key",), expr.MapRemove), + ("map_merge", ({"key": "value"}, ), expr.MapMerge), ("cosine_distance", [1], expr.CosineDistance), ("euclidean_distance", [1], expr.EuclideanDistance), ("dot_product", [1], expr.DotProduct), @@ -1000,6 +1002,7 @@ class TestFunctionClasses: ("equal_any", ("field", [None]), expr.EqualAny), ("not_equal_any", ("field", [None]), expr.NotEqualAny), ("array", ([1, 2, 3],), expr.Array), + ("map", ({"hello": "world"},), expr.Map), ("array_get", ("field", 2), expr.ArrayGet), ("array_contains", ("field", None), expr.ArrayContains), ("array_contains_all", ("field", [None]), expr.ArrayContainsAll), @@ -1023,6 +1026,8 @@ class TestFunctionClasses: ("ends_with", ("field", "postfix"), expr.EndsWith), ("string_concat", ("field", "elem1", "elem2"), expr.StringConcat), ("map_get", ("field", "key"), expr.MapGet), + ("map_remove", ("field", "key"), expr.MapRemove), + ("map_merge", ("field", {"key": "value"}), expr.MapMerge), ("vector_length", ("field",), expr.VectorLength), ("timestamp_to_unix_micros", ("field",), expr.TimestampToUnixMicros), ("unix_micros_to_timestamp", ("field",), expr.UnixMicrosToTimestamp), @@ -1092,6 +1097,14 @@ def test_logical_minimum(self): assert instance.params == [arg1, arg2] assert repr(instance) == "LogicalMinimum(Left, Right)" + def test_map(self): + key = expr.Constant.of("key") + value = self._make_arg("value") + instance = expr.Map({"key": value}) + assert instance.name == "map" + assert instance.params == [key, value] + assert repr(instance) == "Map({'key': value})" + def test_map_get(self): arg1 = self._make_arg("Map") arg2 = expr.Constant("Key") @@ -1100,6 +1113,23 @@ def test_map_get(self): assert instance.params == [arg1, arg2] assert repr(instance) == "MapGet(Map, Constant.of('Key'))" + def test_map_remove(self): + arg1 = self._make_arg("Map") + arg2 = expr.Constant("Key") + instance = expr.MapRemove(arg1, arg2) + assert instance.name == "map_remove" + assert instance.params == [arg1, arg2] + assert repr(instance) == "MapRemove(Map, Constant.of('Key'))" + + def test_map_merge(self): + arg1 = self._make_arg("Map1") + arg2 = self._make_arg("Map2") + arg3 = self._make_arg("Map3") + instance = expr.MapMerge(arg1, arg2, arg3) + assert instance.name == "map_merge" + assert instance.params == [arg1, arg2, arg3] + assert repr(instance) == "MapMerge(Map1, Map2, Map3)" + def test_mod(self): arg1 = self._make_arg("Left") arg2 = self._make_arg("Right") From d5e854c992250003ddd54ca10e14507dbc3f5cca Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 17 Oct 2025 21:05:00 -0700 Subject: [PATCH 16/16] renamed if to conditional --- .../firestore_v1/pipeline_expressions.py | 44 ++++++++++++++----- tests/unit/v1/test_pipeline_expressions.py | 13 +++--- 2 files changed, 39 insertions(+), 18 deletions(-) diff --git a/google/cloud/firestore_v1/pipeline_expressions.py b/google/cloud/firestore_v1/pipeline_expressions.py index 6c4c12da0..30953357c 100644 --- a/google/cloud/firestore_v1/pipeline_expressions.py +++ b/google/cloud/firestore_v1/pipeline_expressions.py @@ -148,11 +148,11 @@ def array(elements: list[Expr | CONSTANT_TYPE]) -> "Expr": Example: >>> Expr.array(["bar", Field.of("baz")]) - Args: - elements: THe input list to evaluate in the expression + Args: + elements: THe input list to evaluate in the expression - Returns: - A new `Expr` representing the array function. + Returns: + A new `Expr` representing the array function. """ return Array([Expr._cast_to_expr_or_convert_to_constant(e) for e in elements]) @@ -163,14 +163,34 @@ def map(elements: dict[str, Expr | CONSTANT_TYPE]) -> "Expr": Example: >>> Expr.map({"foo": "bar", "baz": Field.of("baz")}) - Args: - elements: THe input dict to evaluate in the expression + Args: + elements: THe input dict to evaluate in the expression - Returns: - A new `Expr` representing the map function. + Returns: + A new `Expr` representing the map function. """ return Map({Constant.of(k): Expr._cast_to_expr_or_convert_to_constant(v) for k, v in elements.items()}) + @staticmethod + def conditional(conditional: BooleanExpr, then_expr: Expr, else_expr: Expr) -> "Expr": + """ + Creates a conditional expression that evaluates to a 'then' expression if a condition is true + and an 'else' expression if the condition is false. + + Example: + >>> # If 'age' is greater than 18, return "Adult"; otherwise, return "Minor". + >>> Expr.conditional(Field.of("age").greater_than(18), Constant.of("Adult"), Constant.of("Minor")); + + Args: + conditional: The condition to evaluate. + then_expr: The expression to return if the condition is true. + else_expr: The expression to return if the condition is false + + Returns: + A new `Expr` representing the conditional expression. + """ + return Conditional(conditional, then_expr, else_expr) + @expose_as_static def add(self, other: Expr | float) -> "Expr": """Creates an expression that adds this expression to another expression or constant. @@ -1044,7 +1064,7 @@ def map_merge(self, *other_maps: Expr | dict[str, Expr | CONSTANT_TYPE])-> "Expr """Creates an expression that merges one or more dicts into a single map. Example: - >>> Field.of("settings").map_merge({"enabled":True}, Function.cond(Field.of('isAdmin'), {"admin":True}, {}}) + >>> Field.of("settings").map_merge({"enabled":True}, Function.conditional(Field.of('isAdmin'), {"admin":True}, {}}) >>> Expr.map({"city": "London"}).map_merge({"country": "UK"}, {"isCapital": True}) Args: @@ -2034,11 +2054,11 @@ def __init__(self, left: Expr, right: Expr): super().__init__("greater_than_or_equal", [left, right]) -class If(BooleanExpr): +class Conditional(BooleanExpr): """Represents a conditional expression (if-then-else).""" - def __init__(self, condition: "BooleanExpr", true_expr: Expr, false_expr: Expr): - super().__init__("if", [condition, true_expr, false_expr]) + def __init__(self, condition: "BooleanExpr", then_expr: Expr, else_expr: Expr): + super().__init__("conditional", [condition, then_expr, else_expr]) class EqualAny(BooleanExpr): diff --git a/tests/unit/v1/test_pipeline_expressions.py b/tests/unit/v1/test_pipeline_expressions.py index 4074c7e26..95bee4941 100644 --- a/tests/unit/v1/test_pipeline_expressions.py +++ b/tests/unit/v1/test_pipeline_expressions.py @@ -911,14 +911,14 @@ def test_ends_with(self): assert instance.params == [arg1, arg2] assert repr(instance) == "Expr.ends_with(Postfix)" - def test_if(self): + def test_conditional(self): arg1 = self._make_arg("Condition") - arg2 = self._make_arg("TrueExpr") - arg3 = self._make_arg("FalseExpr") - instance = expr.If(arg1, arg2, arg3) - assert instance.name == "if" + arg2 = self._make_arg("ThenExpr") + arg3 = self._make_arg("ElseExpr") + instance = expr.Conditional(arg1, arg2, arg3) + assert instance.name == "conditional" assert instance.params == [arg1, arg2, arg3] - assert repr(instance) == "If(Condition, TrueExpr, FalseExpr)" + assert repr(instance) == "Conditional(Condition, ThenExpr, ElseExpr)" def test_like(self): arg1 = self._make_arg("Expr") @@ -977,6 +977,7 @@ class TestFunctionClasses: @pytest.mark.parametrize( "method,args,result_cls", [ + ("conditional", ("field", "then", "else"), expr.Conditional), ("add", ("field", 2), expr.Add), ("subtract", ("field", 2), expr.Subtract), ("multiply", ("field", 2), expr.Multiply),