|
20 | 20 | from typing import TYPE_CHECKING |
21 | 21 |
|
22 | 22 | import pyarrow as pa |
23 | | -from datafusion import SessionContext, udtf |
| 23 | +from datafusion import Expr, SessionContext, udtf |
24 | 24 | from datafusion_ffi_example import MyTableFunction, MyTableProvider |
25 | 25 |
|
26 | 26 | if TYPE_CHECKING: |
@@ -77,19 +77,23 @@ class PythonTableFunction: |
77 | 77 | provider, and this function takes no arguments |
78 | 78 | """ |
79 | 79 |
|
80 | | - def __init__(self) -> None: |
81 | | - self.table_provider = MyTableProvider(3, 2, 4) |
82 | | - |
83 | | - def __call__(self) -> TableProviderExportable: |
84 | | - return self.table_provider |
| 80 | + def __call__( |
| 81 | + self, num_cols: Expr, num_rows: Expr, num_batches: Expr |
| 82 | + ) -> TableProviderExportable: |
| 83 | + args = [ |
| 84 | + num_cols.to_variant().value_i64(), |
| 85 | + num_rows.to_variant().value_i64(), |
| 86 | + num_batches.to_variant().value_i64(), |
| 87 | + ] |
| 88 | + return MyTableProvider(*args) |
85 | 89 |
|
86 | 90 |
|
87 | 91 | def test_python_table_function(): |
88 | 92 | ctx = SessionContext() |
89 | 93 | table_func = PythonTableFunction() |
90 | 94 | table_udtf = udtf(table_func, "my_table_func") |
91 | 95 | ctx.register_udtf(table_udtf) |
92 | | - result = ctx.sql("select * from my_table_func()").collect() |
| 96 | + result = ctx.sql("select * from my_table_func(3,2,4)").collect() |
93 | 97 |
|
94 | 98 | assert len(result) == 4 |
95 | 99 | assert result[0].num_columns == 3 |
|
0 commit comments