cte by kzidane · Pull Request #113 · cs50/python-cs50 · GitHub
Skip to content
Merged

cte #113

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
82 changes: 43 additions & 39 deletions src/cs50/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ def execute(self, sql, *args, **kwargs):
if len(args) > 0 and len(kwargs) > 0:
raise RuntimeError("cannot pass both named and positional parameters")

# Infer command from (unflattened) statement
for token in statements[0]:
if token.ttype in [sqlparse.tokens.Keyword.DDL, sqlparse.tokens.Keyword.DML]:
command = token.value.upper()
break
else:
command = None

# Flatten statement
tokens = list(statements[0].flatten())

Expand Down Expand Up @@ -313,45 +321,41 @@ def shutdown_session(exception=None):

# Return value
ret = True
if tokens[0].ttype == sqlparse.tokens.Keyword.DML:

# Uppercase token's value
value = tokens[0].value.upper()

# If SELECT, return result set as list of dict objects
if value == "SELECT":

# Coerce types
rows = [dict(row) for row in result.fetchall()]
for row in rows:
for column in row:

# Coerce decimal.Decimal objects to float objects
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
if type(row[column]) is decimal.Decimal:
row[column] = float(row[column])

# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
elif type(row[column]) is memoryview:
row[column] = bytes(row[column])

# Rows to be returned
ret = rows

# If INSERT, return primary key value for a newly inserted row (or None if none)
elif value == "INSERT":
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
try:
result = connection.execute("SELECT LASTVAL()")
ret = result.first()[0]
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
ret = None
else:
ret = result.lastrowid if result.rowcount == 1 else None

# If DELETE or UPDATE, return number of rows matched
elif value in ["DELETE", "UPDATE"]:
ret = result.rowcount

# If SELECT, return result set as list of dict objects
if command == "SELECT":

# Coerce types
rows = [dict(row) for row in result.fetchall()]
for row in rows:
for column in row:

# Coerce decimal.Decimal objects to float objects
# https://groups.google.com/d/msg/sqlalchemy/0qXMYJvq8SA/oqtvMD9Uw-kJ
if type(row[column]) is decimal.Decimal:
row[column] = float(row[column])

# Coerce memoryview objects (as from PostgreSQL's bytea columns) to bytes
elif type(row[column]) is memoryview:
row[column] = bytes(row[column])

# Rows to be returned
ret = rows

# If INSERT, return primary key value for a newly inserted row (or None if none)
elif command == "INSERT":
if self._engine.url.get_backend_name() in ["postgres", "postgresql"]:
try:
result = connection.execute("SELECT LASTVAL()")
ret = result.first()[0]
except sqlalchemy.exc.OperationalError: # If lastval is not yet defined in this session
ret = None
else:
ret = result.lastrowid if result.rowcount == 1 else None

# If DELETE or UPDATE, return number of rows matched
elif command in ["DELETE", "UPDATE"]:
ret = result.rowcount

# If constraint violated, return None
except sqlalchemy.exc.IntegrityError as e:
Expand Down
6 changes: 6 additions & 0 deletions tests/sql.py