feat: Implementation for batch dml in dbapi by ankiaga · Pull Request #1055 · googleapis/python-spanner · GitHub
Skip to content
This repository was archived by the owner on Jun 8, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions google/cloud/spanner_dbapi/batch_dml_executor.py
16 changes: 12 additions & 4 deletions google/cloud/spanner_dbapi/client_side_statement_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.cursor import Cursor
from google.cloud.spanner_dbapi import ProgrammingError

from google.cloud.spanner_dbapi.parsed_statement import (
Expand All @@ -38,17 +38,18 @@
)


def execute(connection: "Connection", parsed_statement: ParsedStatement):
def execute(cursor: "Cursor", parsed_statement: ParsedStatement):
"""Executes the client side statements by calling the relevant method.

It is an internal method that can make backwards-incompatible changes.

:type connection: Connection
:param connection: Connection object of the dbApi
:type cursor: Cursor
:param cursor: Cursor object of the dbApi

:type parsed_statement: ParsedStatement
:param parsed_statement: parsed_statement based on the sql query
"""
connection = cursor.connection
if connection.is_closed:
raise ProgrammingError(CONNECTION_CLOSED_ERROR)
statement_type = parsed_statement.client_side_statement_type
Expand Down Expand Up @@ -81,6 +82,13 @@ def execute(connection: "Connection", parsed_statement: ParsedStatement):
TypeCode.TIMESTAMP,
read_timestamp,
)
if statement_type == ClientSideStatementType.START_BATCH_DML:
connection.start_batch_dml(cursor)
return None
if statement_type == ClientSideStatementType.RUN_BATCH:
return connection.run_batch()
if statement_type == ClientSideStatementType.ABORT_BATCH:
return connection.abort_batch()


def _get_streamed_result_set(column_name, type_code, column_value):
Expand Down
12 changes: 11 additions & 1 deletion google/cloud/spanner_dbapi/client_side_statement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ParsedStatement,
StatementType,
ClientSideStatementType,
Statement,
)

RE_BEGIN = re.compile(r"^\s*(BEGIN|START)(TRANSACTION)?", re.IGNORECASE)
Expand All @@ -29,6 +30,9 @@
RE_SHOW_READ_TIMESTAMP = re.compile(
r"^\s*(SHOW)\s+(VARIABLE)\s+(READ_TIMESTAMP)", re.IGNORECASE
)
RE_START_BATCH_DML = re.compile(r"^\s*(START)\s+(BATCH)\s+(DML)", re.IGNORECASE)
RE_RUN_BATCH = re.compile(r"^\s*(RUN)\s+(BATCH)", re.IGNORECASE)
Comment thread
olavloite marked this conversation as resolved.
RE_ABORT_BATCH = re.compile(r"^\s*(ABORT)\s+(BATCH)", re.IGNORECASE)


def parse_stmt(query):
Expand All @@ -54,8 +58,14 @@ def parse_stmt(query):
client_side_statement_type = ClientSideStatementType.SHOW_COMMIT_TIMESTAMP
if RE_SHOW_READ_TIMESTAMP.match(query):
client_side_statement_type = ClientSideStatementType.SHOW_READ_TIMESTAMP
if RE_START_BATCH_DML.match(query):
client_side_statement_type = ClientSideStatementType.START_BATCH_DML
if RE_RUN_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.RUN_BATCH
if RE_ABORT_BATCH.match(query):
client_side_statement_type = ClientSideStatementType.ABORT_BATCH
if client_side_statement_type is not None:
return ParsedStatement(
StatementType.CLIENT_SIDE, query, client_side_statement_type
StatementType.CLIENT_SIDE, Statement(query), client_side_statement_type
)
return None
61 changes: 56 additions & 5 deletions google/cloud/spanner_dbapi/connection.py
Loading