|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import base64 |
| 16 | +import copy |
| 17 | +import datetime |
16 | 18 | import gzip |
17 | | -import pickle |
| 19 | +import json |
18 | 20 | from dataclasses import dataclass |
19 | 21 | from typing import Any |
20 | 22 |
|
| 23 | +from google.protobuf.json_format import MessageToDict, ParseDict |
| 24 | +from google.protobuf.message import Message |
| 25 | +from google.protobuf.struct_pb2 import Struct |
| 26 | + |
21 | 27 | from google.cloud.spanner_v1 import BatchTransactionId |
| 28 | +from google.cloud.spanner_v1._helpers import _make_value_pb |
| 29 | +from google.cloud.spanner_v1.types import DirectedReadOptions, ExecuteSqlRequest, Type |
| 30 | + |
| 31 | +_PROTO_CLASS_MAP = { |
| 32 | + "QueryOptions": ExecuteSqlRequest.QueryOptions, |
| 33 | + "DirectedReadOptions": DirectedReadOptions, |
| 34 | + "Struct": Struct, |
| 35 | + "Type": Type, |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | +def _serialize_value(val: Any) -> Any: |
| 40 | + if isinstance(val, bytes): |
| 41 | + return {"__type__": "bytes", "value": base64.b64encode(val).decode("utf-8")} |
| 42 | + elif isinstance(val, datetime.datetime): |
| 43 | + return {"__type__": "datetime", "value": val.isoformat()} |
| 44 | + elif hasattr(val, "_pb"): |
| 45 | + return { |
| 46 | + "__type__": "protobuf", |
| 47 | + "class": val.__class__.__name__, |
| 48 | + "value": MessageToDict(val._pb, preserving_proto_field_name=True), |
| 49 | + } |
| 50 | + elif isinstance(val, Message): |
| 51 | + return { |
| 52 | + "__type__": "protobuf", |
| 53 | + "class": val.__class__.__name__, |
| 54 | + "value": MessageToDict(val, preserving_proto_field_name=True), |
| 55 | + } |
| 56 | + elif isinstance(val, dict): |
| 57 | + return {k: _serialize_value(v) for k, v in val.items()} |
| 58 | + elif isinstance(val, list): |
| 59 | + return [_serialize_value(v) for v in val] |
| 60 | + elif isinstance(val, tuple): |
| 61 | + return {"__type__": "tuple", "value": [_serialize_value(v) for v in val]} |
| 62 | + return val |
| 63 | + |
| 64 | + |
| 65 | +def _deserialize_value(val: Any) -> Any: |
| 66 | + if isinstance(val, dict): |
| 67 | + if "__type__" in val: |
| 68 | + t = val["__type__"] |
| 69 | + if t == "bytes": |
| 70 | + return base64.b64decode(val["value"]) |
| 71 | + elif t == "datetime": |
| 72 | + dt_str = val["value"] |
| 73 | + if dt_str.endswith("Z"): |
| 74 | + dt_str = dt_str[:-1] + "+00:00" |
| 75 | + return datetime.datetime.fromisoformat(dt_str) |
| 76 | + elif t == "tuple": |
| 77 | + return tuple(_deserialize_value(x) for x in val["value"]) |
| 78 | + elif t == "protobuf": |
| 79 | + cls_name = val.get("class") |
| 80 | + dict_val = val["value"] |
| 81 | + if cls_name in _PROTO_CLASS_MAP: |
| 82 | + cls = _PROTO_CLASS_MAP[cls_name] |
| 83 | + msg = cls()._pb if hasattr(cls(), "_pb") else cls() |
| 84 | + ParseDict(dict_val, msg) |
| 85 | + return cls(msg) if hasattr(cls(), "_pb") else msg |
| 86 | + return _deserialize_value(dict_val) |
| 87 | + return {k: _deserialize_value(v) for k, v in val.items()} |
| 88 | + elif isinstance(val, list): |
| 89 | + return [_deserialize_value(v) for v in val] |
| 90 | + return val |
| 91 | + |
| 92 | + |
| 93 | +def _unpack_value_pb(value): |
| 94 | + which = value.WhichOneof("kind") |
| 95 | + if which == "null_value": |
| 96 | + return None |
| 97 | + elif which == "number_value": |
| 98 | + return value.number_value |
| 99 | + elif which == "string_value": |
| 100 | + return value.string_value |
| 101 | + elif which == "bool_value": |
| 102 | + return value.bool_value |
| 103 | + elif which == "struct_value": |
| 104 | + return {k: _unpack_value_pb(v) for k, v in value.struct_value.fields.items()} |
| 105 | + elif which == "list_value": |
| 106 | + return [_unpack_value_pb(v) for v in value.list_value.values] |
| 107 | + return None |
22 | 108 |
|
23 | 109 |
|
24 | 110 | def decode_from_string(encoded_partition_id): |
25 | 111 | gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8")) |
26 | 112 | partition_id_bytes = gzip.decompress(gzip_bytes) |
27 | | - return pickle.loads(partition_id_bytes) |
| 113 | + |
| 114 | + data = json.loads(partition_id_bytes.decode("utf-8")) |
| 115 | + btid_data = data["batch_transaction_id"] |
| 116 | + btid = BatchTransactionId( |
| 117 | + transaction_id=_deserialize_value(btid_data["transaction_id"]), |
| 118 | + session_id=btid_data["session_id"], |
| 119 | + read_timestamp=_deserialize_value(btid_data["read_timestamp"]), |
| 120 | + ) |
| 121 | + partition_result = _deserialize_value(data["partition_result"]) |
| 122 | + |
| 123 | + # Post-process query params back from Protobuf Struct to Python primitives |
| 124 | + if "query" in partition_result and "params" in partition_result["query"]: |
| 125 | + params_pb = partition_result["query"]["params"] |
| 126 | + if params_pb: |
| 127 | + partition_result["query"]["params"] = { |
| 128 | + k: _unpack_value_pb(v) for k, v in params_pb.fields.items() |
| 129 | + } |
| 130 | + |
| 131 | + return PartitionId(btid, partition_result) |
28 | 132 |
|
29 | 133 |
|
30 | 134 | def encode_to_string(batch_transaction_id, partition_result): |
31 | | - partition_id = PartitionId(batch_transaction_id, partition_result) |
32 | | - partition_id_bytes = pickle.dumps(partition_id) |
| 135 | + # Copy to avoid modifying the caller's dictionary in connection.py |
| 136 | + partition_result = copy.deepcopy(partition_result) |
| 137 | + |
| 138 | + # Pre-process query params into a Protobuf Struct |
| 139 | + if "query" in partition_result and "params" in partition_result["query"]: |
| 140 | + params = partition_result["query"]["params"] |
| 141 | + if params: |
| 142 | + params_pb = Struct(fields={k: _make_value_pb(v) for k, v in params.items()}) |
| 143 | + partition_result["query"]["params"] = params_pb |
| 144 | + |
| 145 | + data = { |
| 146 | + "batch_transaction_id": { |
| 147 | + "transaction_id": _serialize_value(batch_transaction_id.transaction_id), |
| 148 | + "session_id": batch_transaction_id.session_id, |
| 149 | + "read_timestamp": _serialize_value(batch_transaction_id.read_timestamp), |
| 150 | + }, |
| 151 | + "partition_result": _serialize_value(partition_result), |
| 152 | + } |
| 153 | + |
| 154 | + partition_id_bytes = json.dumps(data).encode("utf-8") |
33 | 155 | gzip_bytes = gzip.compress(partition_id_bytes) |
34 | 156 | return str(base64.b64encode(gzip_bytes), "utf-8") |
35 | 157 |
|
|
0 commit comments