@@ -59,34 +59,13 @@ def connect(dbapi_connection, connection_record):
5959 # Log statements to standard error
6060 logging .basicConfig (level = logging .DEBUG )
6161
62- def parse (self , e ):
63- """Parses an exception, returns its message."""
64-
65- # MySQL
66- matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
67- if matches :
68- return matches .group (1 )
69-
70- # PostgreSQL
71- matches = re .search (r"^\(psycopg2\.OperationalError\) (.+)$" , str (e ))
72- if matches :
73- return matches .group (1 )
74-
75- # SQLite
76- matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
77- if matches :
78- return matches .group (1 )
79-
80- # Default
81- return str (e )
82-
8362 # Test database
8463 try :
8564 disabled = self ._logger .disabled
8665 self ._logger .disabled = True
8766 self .execute ("SELECT 1" )
8867 except sqlalchemy .exc .OperationalError as e :
89- e = RuntimeError (parse (e ))
68+ e = RuntimeError (self . _parse_exception (e ))
9069 e .__cause__ = None
9170 raise e
9271 else :
@@ -126,19 +105,8 @@ def execute(self, sql, *args, **kwargs):
126105 # If token is a placeholder
127106 if token .ttype == sqlparse .tokens .Name .Placeholder :
128107
129- # Determine paramstyle
130- if token .value == "?" :
131- _paramstyle = "qmark"
132- elif re .search (r"^:[1-9]\d*$" , token .value ):
133- _paramstyle = "numeric"
134- elif re .search (r"^:[a-zA-Z]\w*$" , token .value ):
135- _paramstyle = "named"
136- elif re .search (r"^TODO$" , token .value ): # TODO
137- _paramstyle = "named"
138- elif re .search (r"%\([a-zA-Z]\w*\)s$" , token .value ): # TODO
139- _paramstyle = "pyformat"
140- else :
141- raise RuntimeError ("{}: invalid placeholder" .format (token .value ))
108+ # Determine paramstyle, name
109+ _paramstyle , name = self ._parse_placeholder (token )
142110
143111 # Ensure paramstyle is consistent
144112 if paramstyle is not None and _paramstyle != paramstyle :
@@ -148,10 +116,15 @@ def execute(self, sql, *args, **kwargs):
148116 if paramstyle is None :
149117 paramstyle = _paramstyle
150118
151- # Remember placeholder
152- placeholders [index ] = token . value
119+ # Remember placeholder's index, name
120+ placeholders [index ] = name
153121
154122 def escape (value ):
123+ """
124+ Escapes value using engine's conversion function.
125+
126+ https://docs.sqlalchemy.org/en/latest/core/type_api.html#sqlalchemy.types.TypeEngine.literal_processor
127+ """
155128
156129 # bool
157130 if type (value ) is bool :
@@ -221,18 +194,39 @@ def escape(value):
221194 elif paramstyle == "numeric" :
222195
223196 # Escape values
224- for index , value in placeholders .items ():
225- i = int (re . sub ( r"^:" , "" , value ) ) - 1
226- if i >= len (args ):
197+ for index , name in placeholders .items ():
198+ i = int (name ) - 1
199+ if i < 0 or i >= len (args ):
227200 raise RuntimeError ("placeholder out of range" )
228201 tokens [index ] = escape (args [i ])
229202
230203 # named
231204 elif paramstyle == "named" :
232205
233206 # Escape values
234- for index , value in placeholders .items ():
235- name = re .sub (r"^:" , "" , value )
207+ for index , name in placeholders .items ():
208+ if name not in kwargs :
209+ raise RuntimeError ("missing value for placeholder" )
210+ tokens [index ] = escape (kwargs [name ])
211+
212+ # format
213+ elif paramstyle == "format" :
214+
215+ # Validate number of placeholders
216+ if len (placeholders ) < len (args ):
217+ raise RuntimeError ("too few placeholders" )
218+ elif len (placeholders ) > len (args ):
219+ raise RuntimeError ("too many placeholders" )
220+
221+ # Escape values
222+ for i , index in enumerate (placeholders .keys ()):
223+ tokens [index ] = escape (args [i ])
224+
225+ # pyformat
226+ elif paramstyle == "pyformat" :
227+
228+ # Escape values
229+ for index , name in placeholders .items ():
236230 if name not in kwargs :
237231 raise RuntimeError ("missing value for placeholder" )
238232 tokens [index ] = escape (kwargs [name ])
@@ -285,11 +279,65 @@ def escape(value):
285279 # If user errror
286280 except sqlalchemy .exc .OperationalError as e :
287281 self ._logger .debug (termcolor .colored (statement , "red" ))
288- e = RuntimeError (self ._parse (e ))
282+ e = RuntimeError (self ._parse_exception (e ))
289283 e .__cause__ = None
290284 raise e
291285
292286 # Return value
293287 else :
294288 self ._logger .debug (termcolor .colored (statement , "green" ))
295289 return ret
290+
291+ def _parse_exception (self , e ):
292+ """Parses an exception, returns its message."""
293+
294+ # MySQL
295+ matches = re .search (r"^\(_mysql_exceptions\.OperationalError\) \(\d+, \"(.+)\"\)$" , str (e ))
296+ if matches :
297+ return matches .group (1 )
298+
299+ # PostgreSQL
300+ matches = re .search (r"^\(psycopg2\.OperationalError\) (.+)$" , str (e ))
301+ if matches :
302+ return matches .group (1 )
303+
304+ # SQLite
305+ matches = re .search (r"^\(sqlite3\.OperationalError\) (.+)$" , str (e ))
306+ if matches :
307+ return matches .group (1 )
308+
309+ # Default
310+ return str (e )
311+
312+ def _parse_placeholder (self , token ):
313+ """Infers paramstyle, name from sqlparse.tokens.Name.Placeholder."""
314+
315+ # Validate token
316+ if not isinstance (token , sqlparse .sql .Token ) or token .ttype != sqlparse .tokens .Name .Placeholder :
317+ raise TypeError ()
318+
319+ # qmark
320+ if token .value == "?" :
321+ return "qmark" , None
322+
323+ # numeric
324+ matches = re .search (r"^:(\d+)$" , token .value )
325+ if matches :
326+ return "numeric" , matches .group (1 )
327+
328+ # named
329+ matches = re .search (r"^:([a-zA-Z]\w*)$" , token .value )
330+ if matches :
331+ return "named" , matches .group (1 )
332+
333+ # format
334+ if token .value == "%s" :
335+ return "format" , None
336+
337+ # pyformat
338+ matches = re .search (r"%\((\w+)\)s$" , token .value )
339+ if matches :
340+ return "pyformat" , matches .group (1 )
341+
342+ # Invalid
343+ raise RuntimeError ("{}: invalid placeholder" .format (token .value ))
0 commit comments