kuzu

Kuzu Python API bindings.

This package provides a Python API for Kuzu graph database management system.

To install the package, run:

python3 -m pip install kuzu

Example usage:

import kuzu

db = kuzu.Database("./test")
conn = kuzu.Connection(db)

# Define the schema
conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))")
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)")

# Load some data
conn.execute('COPY User FROM "user.csv"')
conn.execute('COPY City FROM "city.csv"')
conn.execute('COPY Follows FROM "follows.csv"')
conn.execute('COPY LivesIn FROM "lives-in.csv"')

# Query the data
results = conn.execute("MATCH (u:User) RETURN u.name, u.age;")
while results.has_next():
    print(results.get_next())

The dataset used in this example can be found here.

 1"""
 2# Kuzu Python API bindings.
 3
 4This package provides a Python API for Kuzu graph database management system.
 5
 6To install the package, run:
 7```
 8python3 -m pip install kuzu
 9```
10
11Example usage:
12```python
13import kuzu
14
15db = kuzu.Database("./test")
16conn = kuzu.Connection(db)
17
18# Define the schema
19conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
20conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))")
21conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
22conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)")
23
24# Load some data
25conn.execute('COPY User FROM "user.csv"')
26conn.execute('COPY City FROM "city.csv"')
27conn.execute('COPY Follows FROM "follows.csv"')
28conn.execute('COPY LivesIn FROM "lives-in.csv"')
29
30# Query the data
31results = conn.execute("MATCH (u:User) RETURN u.name, u.age;")
32while results.has_next():
33    print(results.get_next())
34```
35
36The dataset used in this example can be found [here](https://github.com/kuzudb/kuzu/tree/master/dataset/demo-db/csv).
37
38"""
39
40from __future__ import annotations
41
42import os
43import sys
44
45# Set RTLD_GLOBAL and RTLD_LAZY flags on Linux to fix the issue with loading
46# extensions
47if sys.platform == "linux":
48    original_dlopen_flags = sys.getdlopenflags()
49    sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
50
51from .async_connection import AsyncConnection
52from .connection import Connection
53from .database import Database
54from .prepared_statement import PreparedStatement
55from .query_result import QueryResult
56from .types import Type
57
58
59def __getattr__(name: str) -> str | int:
60    if name in ("version", "__version__"):
61        return Database.get_version()
62    elif name == "storage_version":
63        return Database.get_storage_version()
64    else:
65        msg = f"module {__name__!r} has no attribute {name!r}"
66        raise AttributeError(msg)
67
68
69# Restore the original dlopen flags
70if sys.platform == "linux":
71    sys.setdlopenflags(original_dlopen_flags)
72
73__all__ = [
74    "AsyncConnection",
75    "Connection",
76    "Database",
77    "PreparedStatement",
78    "QueryResult",
79    "Type",
80    "__version__",  # noqa: F822
81    "storage_version",  # noqa: F822
82    "version",  # noqa: F822
83]
class AsyncConnection:
 26class AsyncConnection:
 27    """AsyncConnection enables asynchronous execution of queries with a pool of connections and threads."""
 28
 29    def __init__(
 30        self,
 31        database: Database,
 32        max_concurrent_queries: int = 4,
 33        max_threads_per_query: int = 0,
 34    ) -> None:
 35        """
 36        Initialise the async connection.
 37
 38        Parameters
 39        ----------
 40        database : Database
 41            Database to connect to.
 42
 43        max_concurrent_queries : int
 44            Maximum number of concurrent queries to execute. This corresponds to the
 45            number of connections and thread pool size. Default is 4.
 46
 47        max_threads_per_query : int
 48            Controls the maximum number of threads per connection that can be used
 49            to execute one query. Default is 0, which means no limit.
 50        """
 51        self.database = database
 52        self.connections = [Connection(database) for _ in range(max_concurrent_queries)]
 53        self.connections_counter = [0 for _ in range(max_concurrent_queries)]
 54        self.lock = threading.Lock()
 55
 56        for conn in self.connections:
 57            conn.init_connection()
 58            conn.set_max_threads_for_exec(max_threads_per_query)
 59
 60        self.executor = ThreadPoolExecutor(max_workers=max_concurrent_queries)
 61
 62    def __enter__(self) -> Self:
 63        return self
 64
 65    def __exit__(
 66        self,
 67        exc_type: type[BaseException] | None,
 68        exc_value: BaseException | None,
 69        exc_traceback: TracebackType | None,
 70    ) -> None:
 71        self.close()
 72
 73    def __del__(self) -> None:
 74        self.close()
 75
 76    def __get_connection_with_least_queries(self) -> tuple[Connection, int]:
 77        with self.lock:
 78            conn_index = self.connections_counter.index(min(self.connections_counter))
 79            self.connections_counter[conn_index] += 1
 80        return self.connections[conn_index], conn_index
 81
 82    def __decrement_connection_counter(self, conn_index: int) -> None:
 83        """Decrement the query counter for a connection."""
 84        with self.lock:
 85            self.connections_counter[conn_index] -= 1
 86            if self.connections_counter[conn_index] < 0:
 87                self.connections_counter[conn_index] = 0
 88
 89    def acquire_connection(self) -> Connection:
 90        """
 91        Acquire a connection from the connection pool for temporary synchronous
 92        calls. If the connection pool is oversubscribed, the method will return
 93        the connection with the least number of queued queries. It is required
 94        to release the connection by calling `release_connection` after the
 95        connection is no longer needed.
 96
 97        Returns
 98        -------
 99        Connection
100            A connection object.
101        """
102        conn, _ = self.__get_connection_with_least_queries()
103        return conn
104
105    def release_connection(self, conn: Connection) -> None:
106        """
107        Release a connection acquired by `acquire_connection` back to the
108        connection pool. Calling this method is required when the connection is
109        no longer needed.
110
111        Parameters
112        ----------
113        conn : Connection
114            Connection object to release.
115
116
117        """
118        for i, existing_conn in enumerate(self.connections):
119            if existing_conn == conn:
120                self.__decrement_connection_counter(i)
121                break
122
123    def set_query_timeout(self, timeout_in_ms: int) -> None:
124        """
125        Set the query timeout value in ms for executing queries.
126
127        Parameters
128        ----------
129        timeout_in_ms : int
130            query timeout value in ms for executing queries.
131
132        """
133        for conn in self.connections:
134            conn.set_query_timeout(timeout_in_ms)
135
136    async def execute(
137        self, query: str | PreparedStatement, parameters: dict[str, Any] | None = None
138    ) -> QueryResult | list[QueryResult]:
139        """
140        Execute a query asynchronously.
141
142        Parameters
143        ----------
144        query : str | PreparedStatement
145            A prepared statement or a query string.
146            If a query string is given, a prepared statement will be created
147            automatically.
148
149        parameters : dict[str, Any]
150            Parameters for the query.
151
152        Returns
153        -------
154        QueryResult
155            Query result.
156
157        """
158        loop = asyncio.get_running_loop()
159        # If the query is a prepared statement, use the connection associated with it
160        if isinstance(query, PreparedStatement):
161            conn = query._connection
162            for i, existing_conn in enumerate(self.connections):
163                if existing_conn == conn:
164                    conn_index = i
165                    with self.lock:
166                        self.connections_counter[conn_index] += 1
167                    break
168        else:
169            conn, conn_index = self.__get_connection_with_least_queries()
170
171        try:
172            return await loop.run_in_executor(self.executor, conn.execute, query, parameters)
173        except asyncio.CancelledError:
174            conn.interrupt()
175        finally:
176            self.__decrement_connection_counter(conn_index)
177
178    async def _prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement:
179        """
180        The only parameters supported during prepare are dataframes.
181        Any remaining parameters will be ignored and should be passed to execute().
182        """  # noqa: D401
183        loop = asyncio.get_running_loop()
184        conn, conn_index = self.__get_connection_with_least_queries()
185
186        try:
187            preparedStatement = await loop.run_in_executor(self.executor, conn.prepare, query, parameters)
188            return preparedStatement
189        finally:
190            self.__decrement_connection_counter(conn_index)
191
192    async def prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement:
193        """
194        Create a prepared statement for a query asynchronously.
195
196        Parameters
197        ----------
198        query : str
199            Query to prepare.
200        parameters : dict[str, Any]
201            Parameters for the query.
202
203        Returns
204        -------
205        PreparedStatement
206            Prepared statement.
207
208        """
209        warnings.warn(
210            "The use of separate prepare + execute of queries is deprecated. "
211            "Please using a single call to the execute() API instead.",
212            DeprecationWarning,
213            stacklevel=2,
214        )
215        return await self._prepare(query, parameters)
216
217    def close(self) -> None:
218        """
219        Close all connections and shutdown the thread pool.
220
221        Note: Call to this method is optional. The connections and thread pool
222        will be closed automatically when the instance is garbage collected.
223        """
224        for conn in self.connections:
225            conn.close()
226
227        self.executor.shutdown(wait=True)

AsyncConnection enables asynchronous execution of queries with a pool of connections and threads.

AsyncConnection( database: Database, max_concurrent_queries: int = 4, max_threads_per_query: int = 0)
29    def __init__(
30        self,
31        database: Database,
32        max_concurrent_queries: int = 4,
33        max_threads_per_query: int = 0,
34    ) -> None:
35        """
36        Initialise the async connection.
37
38        Parameters
39        ----------
40        database : Database
41            Database to connect to.
42
43        max_concurrent_queries : int
44            Maximum number of concurrent queries to execute. This corresponds to the
45            number of connections and thread pool size. Default is 4.
46
47        max_threads_per_query : int
48            Controls the maximum number of threads per connection that can be used
49            to execute one query. Default is 0, which means no limit.
50        """
51        self.database = database
52        self.connections = [Connection(database) for _ in range(max_concurrent_queries)]
53        self.connections_counter = [0 for _ in range(max_concurrent_queries)]
54        self.lock = threading.Lock()
55
56        for conn in self.connections:
57            conn.init_connection()
58            conn.set_max_threads_for_exec(max_threads_per_query)
59
60        self.executor = ThreadPoolExecutor(max_workers=max_concurrent_queries)

Initialise the async connection.

Parameters
  • database (Database): Database to connect to.
  • max_concurrent_queries (int): Maximum number of concurrent queries to execute. This corresponds to the number of connections and thread pool size. Default is 4.
  • max_threads_per_query (int): Controls the maximum number of threads per connection that can be used to execute one query. Default is 0, which means no limit.
database
connections
connections_counter
lock
executor
def acquire_connection(self) -> Connection:
 89    def acquire_connection(self) -> Connection:
 90        """
 91        Acquire a connection from the connection pool for temporary synchronous
 92        calls. If the connection pool is oversubscribed, the method will return
 93        the connection with the least number of queued queries. It is required
 94        to release the connection by calling `release_connection` after the
 95        connection is no longer needed.
 96
 97        Returns
 98        -------
 99        Connection
100            A connection object.
101        """
102        conn, _ = self.__get_connection_with_least_queries()
103        return conn

Acquire a connection from the connection pool for temporary synchronous calls. If the connection pool is oversubscribed, the method will return the connection with the least number of queued queries. It is required to release the connection by calling release_connection after the connection is no longer needed.

Returns
  • Connection: A connection object.
def release_connection(self, conn: Connection) -> None:
105    def release_connection(self, conn: Connection) -> None:
106        """
107        Release a connection acquired by `acquire_connection` back to the
108        connection pool. Calling this method is required when the connection is
109        no longer needed.
110
111        Parameters
112        ----------
113        conn : Connection
114            Connection object to release.
115
116
117        """
118        for i, existing_conn in enumerate(self.connections):
119            if existing_conn == conn:
120                self.__decrement_connection_counter(i)
121                break

Release a connection acquired by acquire_connection back to the connection pool. Calling this method is required when the connection is no longer needed.

Parameters
  • conn (Connection): Connection object to release.
def set_query_timeout(self, timeout_in_ms: int) -> None:
123    def set_query_timeout(self, timeout_in_ms: int) -> None:
124        """
125        Set the query timeout value in ms for executing queries.
126
127        Parameters
128        ----------
129        timeout_in_ms : int
130            query timeout value in ms for executing queries.
131
132        """
133        for conn in self.connections:
134            conn.set_query_timeout(timeout_in_ms)

Set the query timeout value in ms for executing queries.

Parameters
  • timeout_in_ms (int): query timeout value in ms for executing queries.
async def execute( self, query: str | PreparedStatement, parameters: dict[str, typing.Any] | None = None) -> QueryResult | list[QueryResult]:
136    async def execute(
137        self, query: str | PreparedStatement, parameters: dict[str, Any] | None = None
138    ) -> QueryResult | list[QueryResult]:
139        """
140        Execute a query asynchronously.
141
142        Parameters
143        ----------
144        query : str | PreparedStatement
145            A prepared statement or a query string.
146            If a query string is given, a prepared statement will be created
147            automatically.
148
149        parameters : dict[str, Any]
150            Parameters for the query.
151
152        Returns
153        -------
154        QueryResult
155            Query result.
156
157        """
158        loop = asyncio.get_running_loop()
159        # If the query is a prepared statement, use the connection associated with it
160        if isinstance(query, PreparedStatement):
161            conn = query._connection
162            for i, existing_conn in enumerate(self.connections):
163                if existing_conn == conn:
164                    conn_index = i
165                    with self.lock:
166                        self.connections_counter[conn_index] += 1
167                    break
168        else:
169            conn, conn_index = self.__get_connection_with_least_queries()
170
171        try:
172            return await loop.run_in_executor(self.executor, conn.execute, query, parameters)
173        except asyncio.CancelledError:
174            conn.interrupt()
175        finally:
176            self.__decrement_connection_counter(conn_index)

Execute a query asynchronously.

Parameters
  • query (str | PreparedStatement): A prepared statement or a query string. If a query string is given, a prepared statement will be created automatically.
  • parameters (dict[str, Any]): Parameters for the query.
Returns
  • QueryResult: Query result.
async def prepare( self, query: str, parameters: dict[str, typing.Any] | None = None) -> PreparedStatement:
192    async def prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement:
193        """
194        Create a prepared statement for a query asynchronously.
195
196        Parameters
197        ----------
198        query : str
199            Query to prepare.
200        parameters : dict[str, Any]
201            Parameters for the query.
202
203        Returns
204        -------
205        PreparedStatement
206            Prepared statement.
207
208        """
209        warnings.warn(
210            "The use of separate prepare + execute of queries is deprecated. "
211            "Please using a single call to the execute() API instead.",
212            DeprecationWarning,
213            stacklevel=2,
214        )
215        return await self._prepare(query, parameters)

Create a prepared statement for a query asynchronously.

Parameters
  • query (str): Query to prepare.
  • parameters (dict[str, Any]): Parameters for the query.
Returns
  • PreparedStatement: Prepared statement.
def close(self) -> None:
217    def close(self) -> None:
218        """
219        Close all connections and shutdown the thread pool.
220
221        Note: Call to this method is optional. The connections and thread pool
222        will be closed automatically when the instance is garbage collected.
223        """
224        for conn in self.connections:
225            conn.close()
226
227        self.executor.shutdown(wait=True)

Close all connections and shutdown the thread pool.

Note: Call to this method is optional. The connections and thread pool will be closed automatically when the instance is garbage collected.

class Connection:
 24class Connection:
 25    """Connection to a database."""
 26
 27    def __init__(self, database: Database, num_threads: int = 0):
 28        """
 29        Initialise kuzu database connection.
 30
 31        Parameters
 32        ----------
 33        database : Database
 34            Database to connect to.
 35
 36        num_threads : int
 37            Maximum number of threads to use for executing queries.
 38
 39        """
 40        self._connection: Any = None  # (type: _kuzu.Connection from pybind11)
 41        self.database = database
 42        self.num_threads = num_threads
 43        self.is_closed = False
 44        self.init_connection()
 45
 46    def __getstate__(self) -> dict[str, Any]:
 47        state = {
 48            "database": self.database,
 49            "num_threads": self.num_threads,
 50            "_connection": None,
 51        }
 52        return state
 53
 54    def init_connection(self) -> None:
 55        """Establish a connection to the database, if not already initalised."""
 56        if self.is_closed:
 57            error_msg = "Connection is closed."
 58            raise RuntimeError(error_msg)
 59        self.database.init_database()
 60        if self._connection is None:
 61            self._connection = _kuzu.Connection(self.database._database, self.num_threads)  # type: ignore[union-attr]
 62
 63    def set_max_threads_for_exec(self, num_threads: int) -> None:
 64        """
 65        Set the maximum number of threads for executing queries.
 66
 67        Parameters
 68        ----------
 69        num_threads : int
 70            Maximum number of threads to use for executing queries.
 71
 72        """
 73        self.init_connection()
 74        self._connection.set_max_threads_for_exec(num_threads)
 75
 76    def close(self) -> None:
 77        """
 78        Close the connection.
 79
 80        Note: Call to this method is optional. The connection will be closed
 81        automatically when the object goes out of scope.
 82        """
 83        if self._connection is not None:
 84            self._connection.close()
 85        self._connection = None
 86        self.is_closed = True
 87
 88    def __enter__(self) -> Self:
 89        return self
 90
 91    def __exit__(
 92        self,
 93        exc_type: type[BaseException] | None,
 94        exc_value: BaseException | None,
 95        exc_traceback: TracebackType | None,
 96    ) -> None:
 97        self.close()
 98
 99    def execute(
100        self,
101        query: str | PreparedStatement,
102        parameters: dict[str, Any] | None = None,
103    ) -> QueryResult | list[QueryResult]:
104        """
105        Execute a query.
106
107        Parameters
108        ----------
109        query : str | PreparedStatement
110            A prepared statement or a query string.
111            If a query string is given, a prepared statement will be created
112            automatically.
113
114        parameters : dict[str, Any]
115            Parameters for the query.
116
117        Returns
118        -------
119        QueryResult
120            Query result.
121
122        """
123        if parameters is None:
124            parameters = {}
125
126        self.init_connection()
127        if not isinstance(parameters, dict):
128            msg = f"Parameters must be a dict; found {type(parameters)}."
129            raise RuntimeError(msg)  # noqa: TRY004
130
131        if len(parameters) == 0 and isinstance(query, str):
132            query_result_internal = self._connection.query(query)
133        else:
134            prepared_statement = self._prepare(query, parameters) if isinstance(query, str) else query
135            query_result_internal = self._connection.execute(prepared_statement._prepared_statement, parameters)
136        if not query_result_internal.isSuccess():
137            raise RuntimeError(query_result_internal.getErrorMessage())
138        current_query_result = QueryResult(self, query_result_internal)
139        if not query_result_internal.hasNextQueryResult():
140            return current_query_result
141        all_query_results = [current_query_result]
142        while query_result_internal.hasNextQueryResult():
143            query_result_internal = query_result_internal.getNextQueryResult()
144            if not query_result_internal.isSuccess():
145                raise RuntimeError(query_result_internal.getErrorMessage())
146            all_query_results.append(QueryResult(self, query_result_internal))
147        return all_query_results
148
149    def _prepare(
150        self,
151        query: str,
152        parameters: dict[str, Any] | None = None,
153    ) -> PreparedStatement:
154        """
155        The only parameters supported during prepare are dataframes.
156        Any remaining parameters will be ignored and should be passed to execute().
157        """  # noqa: D401
158        return PreparedStatement(self, query, parameters)
159
160    def prepare(
161        self,
162        query: str,
163        parameters: dict[str, Any] | None = None,
164    ) -> PreparedStatement:
165        """
166        Create a prepared statement for a query.
167
168        Parameters
169        ----------
170        query : str
171            Query to prepare.
172
173        parameters : dict[str, Any]
174            Parameters for the query.
175
176        Returns
177        -------
178        PreparedStatement
179            Prepared statement.
180
181        """
182        warnings.warn(
183            "The use of separate prepare + execute of queries is deprecated. "
184            "Please using a single call to the execute() API instead.",
185            DeprecationWarning,
186            stacklevel=2,
187        )
188        return self._prepare(query, parameters)
189
190    def _get_node_property_names(self, table_name: str) -> dict[str, Any]:
191        LIST_START_SYMBOL = "["
192        LIST_END_SYMBOL = "]"
193        self.init_connection()
194        query_result = self.execute(f"CALL table_info('{table_name}') RETURN *;")
195        results = {}
196        while query_result.has_next():
197            row = query_result.get_next()
198            prop_name = row[1]
199            prop_type = row[2]
200            is_primary_key = row[4] is True
201            dimension = prop_type.count(LIST_START_SYMBOL)
202            splitted = prop_type.split(LIST_START_SYMBOL)
203            shape = []
204            for s in splitted:
205                if LIST_END_SYMBOL not in s:
206                    continue
207                s = s.split(LIST_END_SYMBOL)[0]
208                if s != "":
209                    shape.append(int(s))
210            prop_type = splitted[0]
211            results[prop_name] = {
212                "type": prop_type,
213                "dimension": dimension,
214                "is_primary_key": is_primary_key,
215            }
216            if len(shape) > 0:
217                results[prop_name]["shape"] = tuple(shape)
218        return results
219
220    def _get_node_table_names(self) -> list[Any]:
221        results = []
222        self.init_connection()
223        query_result = self.execute("CALL show_tables() RETURN *;")
224        while query_result.has_next():
225            row = query_result.get_next()
226            if row[2] == "NODE":
227                results.append(row[1])
228        return results
229
230    def _get_rel_table_names(self) -> list[dict[str, Any]]:
231        results = []
232        self.init_connection()
233        tables_result = self.execute("CALL show_tables() RETURN *;")
234        while tables_result.has_next():
235            row = tables_result.get_next()
236            if row[2] == "REL":
237                name = row[1]
238                connections_result = self.execute(f"CALL show_connection({name!r}) RETURN *;")
239                src_dst_row = connections_result.get_next()
240                src_node = src_dst_row[0]
241                dst_node = src_dst_row[1]
242                results.append({"name": name, "src": src_node, "dst": dst_node})
243        return results
244
245    def set_query_timeout(self, timeout_in_ms: int) -> None:
246        """
247        Set the query timeout value in ms for executing queries.
248
249        Parameters
250        ----------
251        timeout_in_ms : int
252            query timeout value in ms for executing queries.
253
254        """
255        self.init_connection()
256        self._connection.set_query_timeout(timeout_in_ms)
257
258    def interrupt(self) -> None:
259        """
260        Interrupts execution of the current query.
261
262        If there is no currently executing query, this function does nothing.
263        """
264        self._connection.interrupt()
265
266    def create_function(
267        self,
268        name: str,
269        udf: Callable[[...], Any],
270        params_type: list[Type | str] | None = None,
271        return_type: Type | str = "",
272        *,
273        default_null_handling: bool = True,
274        catch_exceptions: bool = False,
275    ) -> None:
276        """
277        Set a User Defined Function (UDF) for use in cypher queries.
278
279        Parameters
280        ----------
281        name: str
282            name of function
283
284        udf: Callable[[...], Any]
285            function to be executed
286
287        params_type: Optional[list[Type]]
288            list of Type enums to describe the input parameters
289
290        return_type: Optional[Type]
291            a Type enum to describe the returned value
292
293        default_null_handling: Optional[bool]
294            if true, when any parameter is null, the resulting value will be null
295
296        catch_exceptions: Optional[bool]
297            if true, when an exception is thrown from python, the function output will be null
298            Otherwise, the exception will be rethrown
299        """
300        if params_type is None:
301            params_type = []
302        parsed_params_type = [x if type(x) is str else x.value for x in params_type]
303        if type(return_type) is not str:
304            return_type = return_type.value
305
306        self._connection.create_function(
307            name=name,
308            udf=udf,
309            params_type=parsed_params_type,
310            return_value=return_type,
311            default_null=default_null_handling,
312            catch_exceptions=catch_exceptions,
313        )
314
315    def remove_function(self, name: str) -> None:
316        """
317        Remove a User Defined Function (UDF).
318
319        Parameters
320        ----------
321        name: str
322            name of function to be removed.
323        """
324        self._connection.remove_function(name)

Connection to a database.

Connection(database: Database, num_threads: int = 0)
27    def __init__(self, database: Database, num_threads: int = 0):
28        """
29        Initialise kuzu database connection.
30
31        Parameters
32        ----------
33        database : Database
34            Database to connect to.
35
36        num_threads : int
37            Maximum number of threads to use for executing queries.
38
39        """
40        self._connection: Any = None  # (type: _kuzu.Connection from pybind11)
41        self.database = database
42        self.num_threads = num_threads
43        self.is_closed = False
44        self.init_connection()

Initialise kuzu database connection.

Parameters
  • database (Database): Database to connect to.
  • num_threads (int): Maximum number of threads to use for executing queries.
database
num_threads
is_closed
def init_connection(self) -> None:
54    def init_connection(self) -> None:
55        """Establish a connection to the database, if not already initalised."""
56        if self.is_closed:
57            error_msg = "Connection is closed."
58            raise RuntimeError(error_msg)
59        self.database.init_database()
60        if self._connection is None:
61            self._connection = _kuzu.Connection(self.database._database, self.num_threads)  # type: ignore[union-attr]

Establish a connection to the database, if not already initalised.

def set_max_threads_for_exec(self, num_threads: int) -> None:
63    def set_max_threads_for_exec(self, num_threads: int) -> None:
64        """
65        Set the maximum number of threads for executing queries.
66
67        Parameters
68        ----------
69        num_threads : int
70            Maximum number of threads to use for executing queries.
71
72        """
73        self.init_connection()
74        self._connection.set_max_threads_for_exec(num_threads)

Set the maximum number of threads for executing queries.

Parameters
  • num_threads (int): Maximum number of threads to use for executing queries.
def close(self) -> None:
76    def close(self) -> None:
77        """
78        Close the connection.
79
80        Note: Call to this method is optional. The connection will be closed
81        automatically when the object goes out of scope.
82        """
83        if self._connection is not None:
84            self._connection.close()
85        self._connection = None
86        self.is_closed = True

Close the connection.

Note: Call to this method is optional. The connection will be closed automatically when the object goes out of scope.

def execute( self, query: str | PreparedStatement, parameters: dict[str, typing.Any] | None = None) -> QueryResult | list[QueryResult]:
 99    def execute(
100        self,
101        query: str | PreparedStatement,
102        parameters: dict[str, Any] | None = None,
103    ) -> QueryResult | list[QueryResult]:
104        """
105        Execute a query.
106
107        Parameters
108        ----------
109        query : str | PreparedStatement
110            A prepared statement or a query string.
111            If a query string is given, a prepared statement will be created
112            automatically.
113
114        parameters : dict[str, Any]
115            Parameters for the query.
116
117        Returns
118        -------
119        QueryResult
120            Query result.
121
122        """
123        if parameters is None:
124            parameters = {}
125
126        self.init_connection()
127        if not isinstance(parameters, dict):
128            msg = f"Parameters must be a dict; found {type(parameters)}."
129            raise RuntimeError(msg)  # noqa: TRY004
130
131        if len(parameters) == 0 and isinstance(query, str):
132            query_result_internal = self._connection.query(query)
133        else:
134            prepared_statement = self._prepare(query, parameters) if isinstance(query, str) else query
135            query_result_internal = self._connection.execute(prepared_statement._prepared_statement, parameters)
136        if not query_result_internal.isSuccess():
137            raise RuntimeError(query_result_internal.getErrorMessage())
138        current_query_result = QueryResult(self, query_result_internal)
139        if not query_result_internal.hasNextQueryResult():
140            return current_query_result
141        all_query_results = [current_query_result]
142        while query_result_internal.hasNextQueryResult():
143            query_result_internal = query_result_internal.getNextQueryResult()
144            if not query_result_internal.isSuccess():
145                raise RuntimeError(query_result_internal.getErrorMessage())
146            all_query_results.append(QueryResult(self, query_result_internal))
147        return all_query_results

Execute a query.

Parameters
  • query (str | PreparedStatement): A prepared statement or a query string. If a query string is given, a prepared statement will be created automatically.
  • parameters (dict[str, Any]): Parameters for the query.
Returns
  • QueryResult: Query result.
def prepare( self, query: str, parameters: dict[str, typing.Any] | None = None) -> PreparedStatement:
160    def prepare(
161        self,
162        query: str,
163        parameters: dict[str, Any] | None = None,
164    ) -> PreparedStatement:
165        """
166        Create a prepared statement for a query.
167
168        Parameters
169        ----------
170        query : str
171            Query to prepare.
172
173        parameters : dict[str, Any]
174            Parameters for the query.
175
176        Returns
177        -------
178        PreparedStatement
179            Prepared statement.
180
181        """
182        warnings.warn(
183            "The use of separate prepare + execute of queries is deprecated. "
184            "Please using a single call to the execute() API instead.",
185            DeprecationWarning,
186            stacklevel=2,
187        )
188        return self._prepare(query, parameters)

Create a prepared statement for a query.

Parameters
  • query (str): Query to prepare.
  • parameters (dict[str, Any]): Parameters for the query.
Returns
  • PreparedStatement: Prepared statement.
def set_query_timeout(self, timeout_in_ms: int) -> None:
245    def set_query_timeout(self, timeout_in_ms: int) -> None:
246        """
247        Set the query timeout value in ms for executing queries.
248
249        Parameters
250        ----------
251        timeout_in_ms : int
252            query timeout value in ms for executing queries.
253
254        """
255        self.init_connection()
256        self._connection.set_query_timeout(timeout_in_ms)

Set the query timeout value in ms for executing queries.

Parameters
  • timeout_in_ms (int): query timeout value in ms for executing queries.
def interrupt(self) -> None:
258    def interrupt(self) -> None:
259        """
260        Interrupts execution of the current query.
261
262        If there is no currently executing query, this function does nothing.
263        """
264        self._connection.interrupt()

Interrupts execution of the current query.

If there is no currently executing query, this function does nothing.

def create_function( self, name: str, udf: Callable[..., Any], params_type: list[Type | str] | None = None, return_type: Type | str = '', *, default_null_handling: bool = True, catch_exceptions: bool = False) -> None:
266    def create_function(
267        self,
268        name: str,
269        udf: Callable[[...], Any],
270        params_type: list[Type | str] | None = None,
271        return_type: Type | str = "",
272        *,
273        default_null_handling: bool = True,
274        catch_exceptions: bool = False,
275    ) -> None:
276        """
277        Set a User Defined Function (UDF) for use in cypher queries.
278
279        Parameters
280        ----------
281        name: str
282            name of function
283
284        udf: Callable[[...], Any]
285            function to be executed
286
287        params_type: Optional[list[Type]]
288            list of Type enums to describe the input parameters
289
290        return_type: Optional[Type]
291            a Type enum to describe the returned value
292
293        default_null_handling: Optional[bool]
294            if true, when any parameter is null, the resulting value will be null
295
296        catch_exceptions: Optional[bool]
297            if true, when an exception is thrown from python, the function output will be null
298            Otherwise, the exception will be rethrown
299        """
300        if params_type is None:
301            params_type = []
302        parsed_params_type = [x if type(x) is str else x.value for x in params_type]
303        if type(return_type) is not str:
304            return_type = return_type.value
305
306        self._connection.create_function(
307            name=name,
308            udf=udf,
309            params_type=parsed_params_type,
310            return_value=return_type,
311            default_null=default_null_handling,
312            catch_exceptions=catch_exceptions,
313        )

Set a User Defined Function (UDF) for use in cypher queries.

Parameters
  • name (str): name of function
  • udf (Callable[[...], Any]): function to be executed
  • params_type (Optional[list[Type]]): list of Type enums to describe the input parameters
  • return_type (Optional[Type]): a Type enum to describe the returned value
  • default_null_handling (Optional[bool]): if true, when any parameter is null, the resulting value will be null
  • catch_exceptions (Optional[bool]): if true, when an exception is thrown from python, the function output will be null Otherwise, the exception will be rethrown
def remove_function(self, name: str) -> None:
315    def remove_function(self, name: str) -> None:
316        """
317        Remove a User Defined Function (UDF).
318
319        Parameters
320        ----------
321        name: str
322            name of function to be removed.
323        """
324        self._connection.remove_function(name)

Remove a User Defined Function (UDF).

Parameters
  • name (str): name of function to be removed.
class Database:
 26class Database:
 27    """Kuzu database instance."""
 28
 29    def __init__(
 30        self,
 31        database_path: str | Path | None = None,
 32        *,
 33        buffer_pool_size: int = 0,
 34        max_num_threads: int = 0,
 35        compression: bool = True,
 36        lazy_init: bool = False,
 37        read_only: bool = False,
 38        max_db_size: int = (1 << 43),
 39        auto_checkpoint: bool = True,
 40        checkpoint_threshold: int = -1,
 41    ):
 42        """
 43        Parameters
 44        ----------
 45        database_path : str, Path
 46            The path to database files. If the path is not specified, or empty, or equal to `:memory:`, the database
 47            will be created in memory.
 48
 49        buffer_pool_size : int
 50            The maximum size of buffer pool in bytes. Defaults to ~80% of system memory.
 51
 52        max_num_threads : int
 53            The maximum number of threads to use for executing queries.
 54
 55        compression : bool
 56            Enable database compression.
 57
 58        lazy_init : bool
 59            If True, the database will not be initialized until the first query.
 60            This is useful when the database is not used in the main thread or
 61            when the main process is forked.
 62            Default to False.
 63
 64        read_only : bool
 65            If true, the database is opened read-only. No write transactions is
 66            allowed on the `Database` object. Multiple read-only `Database`
 67            objects can be created with the same database path. However, there
 68            cannot be multiple `Database` objects created with the same
 69            database path.
 70            Default to False.
 71
 72        max_db_size : int
 73            The maximum size of the database in bytes. Note that this is introduced
 74            temporarily for now to get around with the default 8TB mmap address
 75             space limit some environment. This will be removed once we implemente
 76             a better solution later. The value is default to 1 << 43 (8TB) under 64-bit
 77             environment and 1GB under 32-bit one.
 78
 79        auto_checkpoint: bool
 80            If true, the database will automatically checkpoint when the size of
 81            the WAL file exceeds the checkpoint threshold.
 82
 83        checkpoint_threshold: int
 84            The threshold of the WAL file size in bytes. When the size of the
 85            WAL file exceeds this threshold, the database will checkpoint if autoCheckpoint is true.
 86
 87        """
 88        if database_path is None:
 89            database_path = ":memory:"
 90        if isinstance(database_path, Path):
 91            database_path = str(database_path)
 92
 93        self.database_path = database_path
 94        self.buffer_pool_size = buffer_pool_size
 95        self.max_num_threads = max_num_threads
 96        self.compression = compression
 97        self.read_only = read_only
 98        self.max_db_size = max_db_size
 99        self.auto_checkpoint = auto_checkpoint
100        self.checkpoint_threshold = checkpoint_threshold
101        self.is_closed = False
102
103        self._database: Any = None  # (type: _kuzu.Database from pybind11)
104        if not lazy_init:
105            self.init_database()
106
107    def __enter__(self) -> Self:
108        return self
109
110    def __exit__(
111        self,
112        exc_type: type[BaseException] | None,
113        exc_value: BaseException | None,
114        exc_traceback: TracebackType | None,
115    ) -> None:
116        self.close()
117
118    @staticmethod
119    def get_version() -> str:
120        """
121        Get the version of the database.
122
123        Returns
124        -------
125        str
126            The version of the database.
127        """
128        return _kuzu.Database.get_version()  # type: ignore[union-attr]
129
130    @staticmethod
131    def get_storage_version() -> int:
132        """
133        Get the storage version of the database.
134
135        Returns
136        -------
137        int
138            The storage version of the database.
139        """
140        return _kuzu.Database.get_storage_version()  # type: ignore[union-attr]
141
142    def __getstate__(self) -> dict[str, Any]:
143        state = {
144            "database_path": self.database_path,
145            "buffer_pool_size": self.buffer_pool_size,
146            "compression": self.compression,
147            "read_only": self.read_only,
148            "_database": None,
149        }
150        return state
151
152    def init_database(self) -> None:
153        """Initialize the database."""
154        self.check_for_database_close()
155        if self._database is None:
156            self._database = _kuzu.Database(  # type: ignore[union-attr]
157                self.database_path,
158                self.buffer_pool_size,
159                self.max_num_threads,
160                self.compression,
161                self.read_only,
162                self.max_db_size,
163                self.auto_checkpoint,
164                self.checkpoint_threshold,
165            )
166
167    def get_torch_geometric_remote_backend(
168        self, num_threads: int | None = None
169    ) -> tuple[KuzuFeatureStore, KuzuGraphStore]:
170        """
171        Use the database as the remote backend for torch_geometric.
172
173        For the interface of the remote backend, please refer to
174        https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html.
175        The current implementation is read-only and does not support edge
176        features. The IDs of the nodes are based on the internal IDs (i.e., node
177        offsets). For the remote node IDs to be consistent with the positions in
178        the output tensors, please ensure that no deletion has been performed
179        on the node tables.
180
181        The remote backend can also be plugged into the data loader of
182        torch_geometric, which is useful for mini-batch training. For example:
183
184        ```python
185            loader_kuzu = NeighborLoader(
186                data=(feature_store, graph_store),
187                num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
188                batch_size=LOADER_BATCH_SIZE,
189                input_nodes=('paper', input_nodes),
190                num_workers=4,
191                filter_per_worker=False,
192            )
193        ```
194
195        Please note that the database instance is not fork-safe, so if more than
196        one worker is used, `filter_per_worker` must be set to False.
197
198        Parameters
199        ----------
200        num_threads : int
201            Number of threads to use for data loading. Default to None, which
202            means using the number of CPU cores.
203
204        Returns
205        -------
206        feature_store : KuzuFeatureStore
207            Feature store compatible with torch_geometric.
208        graph_store : KuzuGraphStore
209            Graph store compatible with torch_geometric.
210        """
211        self.check_for_database_close()
212        from .torch_geometric_feature_store import KuzuFeatureStore
213        from .torch_geometric_graph_store import KuzuGraphStore
214
215        return (
216            KuzuFeatureStore(self, num_threads),
217            KuzuGraphStore(self, num_threads),
218        )
219
220    def _scan_node_table(
221        self,
222        table_name: str,
223        prop_name: str,
224        prop_type: str,
225        dim: int,
226        indices: IndexType,
227        num_threads: int,
228    ) -> NDArray[Any]:
229        self.check_for_database_close()
230        import numpy as np
231
232        """
233        Scan a node table from storage directly, bypassing query engine.
234        Used internally by torch_geometric remote backend only.
235        """
236        self.init_database()
237        indices_cast = np.array(indices, dtype=np.uint64)
238        result = None
239
240        if prop_type == Type.INT64.value:
241            result = np.empty(len(indices) * dim, dtype=np.int64)
242            self._database.scan_node_table_as_int64(table_name, prop_name, indices_cast, result, num_threads)
243        elif prop_type == Type.INT32.value:
244            result = np.empty(len(indices) * dim, dtype=np.int32)
245            self._database.scan_node_table_as_int32(table_name, prop_name, indices_cast, result, num_threads)
246        elif prop_type == Type.INT16.value:
247            result = np.empty(len(indices) * dim, dtype=np.int16)
248            self._database.scan_node_table_as_int16(table_name, prop_name, indices_cast, result, num_threads)
249        elif prop_type == Type.DOUBLE.value:
250            result = np.empty(len(indices) * dim, dtype=np.float64)
251            self._database.scan_node_table_as_double(table_name, prop_name, indices_cast, result, num_threads)
252        elif prop_type == Type.FLOAT.value:
253            result = np.empty(len(indices) * dim, dtype=np.float32)
254            self._database.scan_node_table_as_float(table_name, prop_name, indices_cast, result, num_threads)
255
256        if result is not None:
257            return result
258
259        msg = f"Unsupported property type: {prop_type}"
260        raise ValueError(msg)
261
262    def close(self) -> None:
263        """
264        Close the database. Once the database is closed, the lock on the database
265        files is released and the database can be opened in another process.
266
267        Note: Call to this method is not required. The Python garbage collector
268        will automatically close the database when no references to the database
269        object exist. It is recommended not to call this method explicitly. If you
270        decide to manually close the database, make sure that all the QueryResult
271        and Connection objects are closed before calling this method.
272        """
273        if self.is_closed:
274            return
275        self.is_closed = True
276        if self._database is not None:
277            self._database.close()
278            self._database: Any = None  # (type: _kuzu.Database from pybind11)
279
280    def check_for_database_close(self) -> None:
281        """
282        Check if the database is closed and raise an exception if it is.
283
284        Raises
285        ------
286        Exception
287            If the database is closed.
288
289        """
290        if not self.is_closed:
291            return
292        msg = "Database is closed"
293        raise RuntimeError(msg)

Kuzu database instance.

Database( database_path: str | pathlib.Path | None = None, *, buffer_pool_size: int = 0, max_num_threads: int = 0, compression: bool = True, lazy_init: bool = False, read_only: bool = False, max_db_size: int = 8796093022208, auto_checkpoint: bool = True, checkpoint_threshold: int = -1)
 29    def __init__(
 30        self,
 31        database_path: str | Path | None = None,
 32        *,
 33        buffer_pool_size: int = 0,
 34        max_num_threads: int = 0,
 35        compression: bool = True,
 36        lazy_init: bool = False,
 37        read_only: bool = False,
 38        max_db_size: int = (1 << 43),
 39        auto_checkpoint: bool = True,
 40        checkpoint_threshold: int = -1,
 41    ):
 42        """
 43        Parameters
 44        ----------
 45        database_path : str, Path
 46            The path to database files. If the path is not specified, or empty, or equal to `:memory:`, the database
 47            will be created in memory.
 48
 49        buffer_pool_size : int
 50            The maximum size of buffer pool in bytes. Defaults to ~80% of system memory.
 51
 52        max_num_threads : int
 53            The maximum number of threads to use for executing queries.
 54
 55        compression : bool
 56            Enable database compression.
 57
 58        lazy_init : bool
 59            If True, the database will not be initialized until the first query.
 60            This is useful when the database is not used in the main thread or
 61            when the main process is forked.
 62            Default to False.
 63
 64        read_only : bool
 65            If true, the database is opened read-only. No write transactions is
 66            allowed on the `Database` object. Multiple read-only `Database`
 67            objects can be created with the same database path. However, there
 68            cannot be multiple `Database` objects created with the same
 69            database path.
 70            Default to False.
 71
 72        max_db_size : int
 73            The maximum size of the database in bytes. Note that this is introduced
 74            temporarily for now to get around with the default 8TB mmap address
 75             space limit some environment. This will be removed once we implemente
 76             a better solution later. The value is default to 1 << 43 (8TB) under 64-bit
 77             environment and 1GB under 32-bit one.
 78
 79        auto_checkpoint: bool
 80            If true, the database will automatically checkpoint when the size of
 81            the WAL file exceeds the checkpoint threshold.
 82
 83        checkpoint_threshold: int
 84            The threshold of the WAL file size in bytes. When the size of the
 85            WAL file exceeds this threshold, the database will checkpoint if autoCheckpoint is true.
 86
 87        """
 88        if database_path is None:
 89            database_path = ":memory:"
 90        if isinstance(database_path, Path):
 91            database_path = str(database_path)
 92
 93        self.database_path = database_path
 94        self.buffer_pool_size = buffer_pool_size
 95        self.max_num_threads = max_num_threads
 96        self.compression = compression
 97        self.read_only = read_only
 98        self.max_db_size = max_db_size
 99        self.auto_checkpoint = auto_checkpoint
100        self.checkpoint_threshold = checkpoint_threshold
101        self.is_closed = False
102
103        self._database: Any = None  # (type: _kuzu.Database from pybind11)
104        if not lazy_init:
105            self.init_database()
Parameters
  • database_path (str, Path): The path to database files. If the path is not specified, or empty, or equal to :memory:, the database will be created in memory.
  • buffer_pool_size (int): The maximum size of buffer pool in bytes. Defaults to ~80% of system memory.
  • max_num_threads (int): The maximum number of threads to use for executing queries.
  • compression (bool): Enable database compression.
  • lazy_init (bool): If True, the database will not be initialized until the first query. This is useful when the database is not used in the main thread or when the main process is forked. Default to False.
  • read_only (bool): If true, the database is opened read-only. No write transactions is allowed on the Database object. Multiple read-only Database objects can be created with the same database path. However, there cannot be multiple Database objects created with the same database path. Default to False.
  • max_db_size (int): The maximum size of the database in bytes. Note that this is introduced temporarily for now to get around with the default 8TB mmap address space limit some environment. This will be removed once we implemente a better solution later. The value is default to 1 << 43 (8TB) under 64-bit environment and 1GB under 32-bit one.
  • auto_checkpoint (bool): If true, the database will automatically checkpoint when the size of the WAL file exceeds the checkpoint threshold.
  • checkpoint_threshold (int): The threshold of the WAL file size in bytes. When the size of the WAL file exceeds this threshold, the database will checkpoint if autoCheckpoint is true.
database_path
buffer_pool_size
max_num_threads
compression
read_only
max_db_size
auto_checkpoint
checkpoint_threshold
is_closed
@staticmethod
def get_version() -> str:
118    @staticmethod
119    def get_version() -> str:
120        """
121        Get the version of the database.
122
123        Returns
124        -------
125        str
126            The version of the database.
127        """
128        return _kuzu.Database.get_version()  # type: ignore[union-attr]

Get the version of the database.

Returns
  • str: The version of the database.
@staticmethod
def get_storage_version() -> int:
130    @staticmethod
131    def get_storage_version() -> int:
132        """
133        Get the storage version of the database.
134
135        Returns
136        -------
137        int
138            The storage version of the database.
139        """
140        return _kuzu.Database.get_storage_version()  # type: ignore[union-attr]

Get the storage version of the database.

Returns
  • int: The storage version of the database.
def init_database(self) -> None:
152    def init_database(self) -> None:
153        """Initialize the database."""
154        self.check_for_database_close()
155        if self._database is None:
156            self._database = _kuzu.Database(  # type: ignore[union-attr]
157                self.database_path,
158                self.buffer_pool_size,
159                self.max_num_threads,
160                self.compression,
161                self.read_only,
162                self.max_db_size,
163                self.auto_checkpoint,
164                self.checkpoint_threshold,
165            )

Initialize the database.

def get_torch_geometric_remote_backend( self, num_threads: int | None = None) -> tuple[kuzu.torch_geometric_feature_store.KuzuFeatureStore, kuzu.torch_geometric_graph_store.KuzuGraphStore]:
167    def get_torch_geometric_remote_backend(
168        self, num_threads: int | None = None
169    ) -> tuple[KuzuFeatureStore, KuzuGraphStore]:
170        """
171        Use the database as the remote backend for torch_geometric.
172
173        For the interface of the remote backend, please refer to
174        https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html.
175        The current implementation is read-only and does not support edge
176        features. The IDs of the nodes are based on the internal IDs (i.e., node
177        offsets). For the remote node IDs to be consistent with the positions in
178        the output tensors, please ensure that no deletion has been performed
179        on the node tables.
180
181        The remote backend can also be plugged into the data loader of
182        torch_geometric, which is useful for mini-batch training. For example:
183
184        ```python
185            loader_kuzu = NeighborLoader(
186                data=(feature_store, graph_store),
187                num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
188                batch_size=LOADER_BATCH_SIZE,
189                input_nodes=('paper', input_nodes),
190                num_workers=4,
191                filter_per_worker=False,
192            )
193        ```
194
195        Please note that the database instance is not fork-safe, so if more than
196        one worker is used, `filter_per_worker` must be set to False.
197
198        Parameters
199        ----------
200        num_threads : int
201            Number of threads to use for data loading. Default to None, which
202            means using the number of CPU cores.
203
204        Returns
205        -------
206        feature_store : KuzuFeatureStore
207            Feature store compatible with torch_geometric.
208        graph_store : KuzuGraphStore
209            Graph store compatible with torch_geometric.
210        """
211        self.check_for_database_close()
212        from .torch_geometric_feature_store import KuzuFeatureStore
213        from .torch_geometric_graph_store import KuzuGraphStore
214
215        return (
216            KuzuFeatureStore(self, num_threads),
217            KuzuGraphStore(self, num_threads),
218        )

Use the database as the remote backend for torch_geometric.

For the interface of the remote backend, please refer to https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html. The current implementation is read-only and does not support edge features. The IDs of the nodes are based on the internal IDs (i.e., node offsets). For the remote node IDs to be consistent with the positions in the output tensors, please ensure that no deletion has been performed on the node tables.

The remote backend can also be plugged into the data loader of torch_geometric, which is useful for mini-batch training. For example:

    loader_kuzu = NeighborLoader(
        data=(feature_store, graph_store),
        num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
        batch_size=LOADER_BATCH_SIZE,
        input_nodes=('paper', input_nodes),
        num_workers=4,
        filter_per_worker=False,
    )

Please note that the database instance is not fork-safe, so if more than one worker is used, filter_per_worker must be set to False.

Parameters
  • num_threads (int): Number of threads to use for data loading. Default to None, which means using the number of CPU cores.
Returns
  • feature_store (KuzuFeatureStore): Feature store compatible with torch_geometric.
  • graph_store (KuzuGraphStore): Graph store compatible with torch_geometric.
def close(self) -> None:
262    def close(self) -> None:
263        """
264        Close the database. Once the database is closed, the lock on the database
265        files is released and the database can be opened in another process.
266
267        Note: Call to this method is not required. The Python garbage collector
268        will automatically close the database when no references to the database
269        object exist. It is recommended not to call this method explicitly. If you
270        decide to manually close the database, make sure that all the QueryResult
271        and Connection objects are closed before calling this method.
272        """
273        if self.is_closed:
274            return
275        self.is_closed = True
276        if self._database is not None:
277            self._database.close()
278            self._database: Any = None  # (type: _kuzu.Database from pybind11)

Close the database. Once the database is closed, the lock on the database files is released and the database can be opened in another process.

Note: Call to this method is not required. The Python garbage collector will automatically close the database when no references to the database object exist. It is recommended not to call this method explicitly. If you decide to manually close the database, make sure that all the QueryResult and Connection objects are closed before calling this method.

def check_for_database_close(self) -> None:
280    def check_for_database_close(self) -> None:
281        """
282        Check if the database is closed and raise an exception if it is.
283
284        Raises
285        ------
286        Exception
287            If the database is closed.
288
289        """
290        if not self.is_closed:
291            return
292        msg = "Database is closed"
293        raise RuntimeError(msg)

Check if the database is closed and raise an exception if it is.

Raises
  • Exception: If the database is closed.
class PreparedStatement:
10class PreparedStatement:
11    """
12    A prepared statement is a parameterized query which can avoid planning the
13    same query for repeated execution.
14    """
15
16    def __init__(self, connection: Connection, query: str, parameters: dict[str, Any] | None = None):
17        """
18        Parameters
19        ----------
20        connection : Connection
21            Connection to a database.
22        query : str
23            Query to prepare.
24        parameters : dict[str, Any]
25            Parameters for the query.
26        """
27        if parameters is None:
28            parameters = {}
29        self._prepared_statement = connection._connection.prepare(query, parameters)
30        self._connection = connection
31
32    def is_success(self) -> bool:
33        """
34        Check if the prepared statement is successfully prepared.
35
36        Returns
37        -------
38        bool
39            True if the prepared statement is successfully prepared.
40        """
41        return self._prepared_statement.is_success()
42
43    def get_error_message(self) -> str:
44        """
45        Get the error message if the query is not prepared successfully.
46
47        Returns
48        -------
49        str
50            Error message.
51        """
52        return self._prepared_statement.get_error_message()

A prepared statement is a parameterized query which can avoid planning the same query for repeated execution.

PreparedStatement( connection: Connection, query: str, parameters: dict[str, typing.Any] | None = None)
16    def __init__(self, connection: Connection, query: str, parameters: dict[str, Any] | None = None):
17        """
18        Parameters
19        ----------
20        connection : Connection
21            Connection to a database.
22        query : str
23            Query to prepare.
24        parameters : dict[str, Any]
25            Parameters for the query.
26        """
27        if parameters is None:
28            parameters = {}
29        self._prepared_statement = connection._connection.prepare(query, parameters)
30        self._connection = connection
Parameters
  • connection (Connection): Connection to a database.
  • query (str): Query to prepare.
  • parameters (dict[str, Any]): Parameters for the query.
def is_success(self) -> bool:
32    def is_success(self) -> bool:
33        """
34        Check if the prepared statement is successfully prepared.
35
36        Returns
37        -------
38        bool
39            True if the prepared statement is successfully prepared.
40        """
41        return self._prepared_statement.is_success()

Check if the prepared statement is successfully prepared.

Returns
  • bool: True if the prepared statement is successfully prepared.
def get_error_message(self) -> str:
43    def get_error_message(self) -> str:
44        """
45        Get the error message if the query is not prepared successfully.
46
47        Returns
48        -------
49        str
50            Error message.
51        """
52        return self._prepared_statement.get_error_message()

Get the error message if the query is not prepared successfully.

Returns
  • str: Error message.
class QueryResult:
 29class QueryResult:
 30    """QueryResult stores the result of a query execution."""
 31
 32    def __init__(self, connection: _kuzu.Connection, query_result: _kuzu.QueryResult):  # type: ignore[name-defined]
 33        """
 34        Parameters
 35        ----------
 36        connection : _kuzu.Connection
 37            The underlying C++ connection object from pybind11.
 38
 39        query_result : _kuzu.QueryResult
 40            The underlying C++ query result object from pybind11.
 41
 42        """
 43        self.connection = connection
 44        self._query_result = query_result
 45        self.is_closed = False
 46        self.as_dict = False
 47
 48    def __enter__(self) -> Self:
 49        return self
 50
 51    def __exit__(
 52        self,
 53        exc_type: type[BaseException] | None,
 54        exc_value: BaseException | None,
 55        exc_traceback: TracebackType | None,
 56    ) -> None:
 57        self.close()
 58
 59    def __del__(self) -> None:
 60        self.close()
 61
 62    def __iter__(self) -> Iterator[list[Any] | dict[str, Any]]:
 63        return self
 64
 65    def __next__(self) -> list[Any] | dict[str, Any]:
 66        if self.has_next():
 67            return self.get_next()
 68
 69        raise StopIteration
 70
 71    def has_next(self) -> bool:
 72        """
 73        Check if there are more rows in the query result.
 74
 75        Returns
 76        -------
 77        bool
 78            True if there are more rows in the query result, False otherwise.
 79        """
 80        self.check_for_query_result_close()
 81        return self._query_result.hasNext()
 82
 83    def get_next(self) -> list[Any] | dict[str, Any]:
 84        """
 85        Get the next row in the query result.
 86
 87        Returns
 88        -------
 89        list
 90            Next row in the query result.
 91
 92        Raises
 93        ------
 94        Exception
 95            If there are no more rows.
 96        """
 97        self.check_for_query_result_close()
 98        row = self._query_result.getNext()
 99        return _row_to_dict(self.columns, row) if self.as_dict else row
100
101    def get_all(self) -> list[list[Any] | dict[str, Any]]:
102        """
103        Get the next row in the query result.
104
105        Returns
106        -------
107        list
108            All remaining rows in the query result.
109        """
110        return list(self)
111
112    def get_n(self, count: int) -> list[list[Any] | dict[str, Any]]:
113        """
114        Get many rows in the query result.
115
116        Returns
117        -------
118        list
119            Up to `count` rows in the query result.
120        """
121        results = []
122        while self.has_next() and count > 0:
123            results.append(self.get_next())
124            count -= 1
125        return results
126
127    def close(self) -> None:
128        """Close the query result."""
129        if not self.is_closed:
130            # Allows the connection to be garbage collected if the query result
131            # is closed manually by the user.
132            self._query_result.close()
133            self.connection = None
134            self.is_closed = True
135
136    def check_for_query_result_close(self) -> None:
137        """
138        Check if the query result is closed and raise an exception if it is.
139
140        Raises
141        ------
142        Exception
143            If the query result is closed.
144
145        """
146        if self.is_closed:
147            msg = "Query result is closed"
148            raise RuntimeError(msg)
149
150    def get_as_df(self) -> pd.DataFrame:
151        """
152        Get the query result as a Pandas DataFrame.
153
154        See Also
155        --------
156        get_as_pl : Get the query result as a Polars DataFrame.
157        get_as_arrow : Get the query result as a PyArrow Table.
158
159        Returns
160        -------
161        pandas.DataFrame
162            Query result as a Pandas DataFrame.
163
164        """
165        self.check_for_query_result_close()
166
167        return self._query_result.getAsDF()
168
169    def get_as_pl(self) -> pl.DataFrame:
170        """
171        Get the query result as a Polars DataFrame.
172
173        See Also
174        --------
175        get_as_df : Get the query result as a Pandas DataFrame.
176        get_as_arrow : Get the query result as a PyArrow Table.
177
178        Returns
179        -------
180        polars.DataFrame
181            Query result as a Polars DataFrame.
182        """
183        import polars as pl
184
185        self.check_for_query_result_close()
186
187        # note: polars should always export just a single chunk,
188        # (eg: "-1") otherwise it will just need to rechunk anyway
189        return pl.from_arrow(  # type: ignore[return-value]
190            data=self.get_as_arrow(chunk_size=-1),
191        )
192
193    def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table:
194        """
195        Get the query result as a PyArrow Table.
196
197        Parameters
198        ----------
199        chunk_size : Number of rows to include in each chunk.
200            None
201                The chunk size is adaptive and depends on the number of columns in the query result.
202            -1 or 0
203                The entire result is returned as a single chunk.
204            > 0
205                The chunk size is the number of rows specified.
206
207        See Also
208        --------
209        get_as_pl : Get the query result as a Polars DataFrame.
210        get_as_df : Get the query result as a Pandas DataFrame.
211
212        Returns
213        -------
214        pyarrow.Table
215            Query result as a PyArrow Table.
216        """
217        self.check_for_query_result_close()
218
219        if chunk_size is None:
220            # Adaptive; target 10m total elements in each chunk.
221            # (eg: if we had 10 cols, this would result in a 1m row chunk_size).
222            target_n_elems = 10_000_000
223            chunk_size = max(target_n_elems // len(self.get_column_names()), 10)
224        elif chunk_size <= 0:
225            # No chunking: return the entire result as a single chunk
226            chunk_size = self.get_num_tuples()
227
228        return self._query_result.getAsArrow(chunk_size)
229
230    def get_column_data_types(self) -> list[str]:
231        """
232        Get the data types of the columns in the query result.
233
234        Returns
235        -------
236        list
237            Data types of the columns in the query result.
238
239        """
240        self.check_for_query_result_close()
241        return self._query_result.getColumnDataTypes()
242
243    def get_column_names(self) -> list[str]:
244        """
245        Get the names of the columns in the query result.
246
247        Returns
248        -------
249        list
250            Names of the columns in the query result.
251
252        """
253        self.check_for_query_result_close()
254        return self._query_result.getColumnNames()
255
256    def get_schema(self) -> dict[str, str]:
257        """
258        Get the column schema of the query result.
259
260        Returns
261        -------
262        dict
263            Schema of the query result.
264
265        """
266        self.check_for_query_result_close()
267        return dict(
268            zip(
269                self._query_result.getColumnNames(),
270                self._query_result.getColumnDataTypes(),
271            )
272        )
273
274    def reset_iterator(self) -> None:
275        """Reset the iterator of the query result."""
276        self.check_for_query_result_close()
277        self._query_result.resetIterator()
278
279    def get_as_networkx(
280        self,
281        directed: bool = True,  # noqa: FBT001
282    ) -> nx.MultiGraph | nx.MultiDiGraph:
283        """
284        Convert the nodes and rels in query result into a NetworkX directed or undirected graph
285        with the following rules:
286        Columns with data type other than node or rel will be ignored.
287        Duplicated nodes and rels will be converted only once.
288
289        Parameters
290        ----------
291        directed : bool
292            Whether the graph should be directed. Defaults to True.
293
294        Returns
295        -------
296        networkx.MultiDiGraph or networkx.MultiGraph
297            Query result as a NetworkX graph.
298
299        """
300        self.check_for_query_result_close()
301        import networkx as nx
302
303        nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph()
304        properties_to_extract = self._get_properties_to_extract()
305
306        self.reset_iterator()
307
308        nodes = {}
309        rels = {}
310        table_to_label_dict = {}
311        table_primary_key_dict = {}
312
313        def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str:
314            node_label = node["_label"]
315            return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}"
316
317        def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]:
318            return rel["_id"]["table"], rel["_id"]["offset"]
319
320        # De-duplicate nodes and rels
321        while self.has_next():
322            row = self.get_next()
323            for i in properties_to_extract:
324                # Skip empty nodes and rels, which may be returned by
325                # OPTIONAL MATCH
326                if row[i] is None or row[i] == {}:
327                    continue
328                column_type, _ = properties_to_extract[i]
329                if column_type == Type.NODE.value:
330                    nid = row[i]["_id"]
331                    nodes[nid["table"], nid["offset"]] = row[i]
332                    table_to_label_dict[nid["table"]] = row[i]["_label"]
333
334                elif column_type == Type.REL.value:
335                    rels[encode_rel_id(row[i])] = row[i]
336
337                elif column_type == Type.RECURSIVE_REL.value:
338                    for node in row[i]["_nodes"]:
339                        nid = node["_id"]
340                        nodes[nid["table"], nid["offset"]] = node
341                        table_to_label_dict[nid["table"]] = node["_label"]
342                    for rel in row[i]["_rels"]:
343                        for key in list(rel.keys()):
344                            if rel[key] is None:
345                                del rel[key]
346                        rels[encode_rel_id(rel)] = rel
347
348        # Add nodes
349        for node in nodes.values():
350            nid = node["_id"]
351            node_id = node["_label"] + "_" + str(nid["offset"])
352            if node["_label"] not in table_primary_key_dict:
353                props = self.connection._get_node_property_names(node["_label"])
354                for prop_name in props:
355                    if props[prop_name]["is_primary_key"]:
356                        table_primary_key_dict[node["_label"]] = prop_name
357                        break
358            node_id = encode_node_id(node, table_primary_key_dict)
359            node[node["_label"]] = True
360            nx_graph.add_node(node_id, **node)
361
362        # Add rels
363        for rel in rels.values():
364            src = rel["_src"]
365            dst = rel["_dst"]
366            src_node = nodes[src["table"], src["offset"]]
367            dst_node = nodes[dst["table"], dst["offset"]]
368            src_id = encode_node_id(src_node, table_primary_key_dict)
369            dst_id = encode_node_id(dst_node, table_primary_key_dict)
370            nx_graph.add_edge(src_id, dst_id, **rel)
371        return nx_graph
372
373    def _get_properties_to_extract(self) -> dict[int, tuple[str, str]]:
374        column_names = self.get_column_names()
375        column_types = self.get_column_data_types()
376        properties_to_extract = {}
377
378        # Iterate over columns and extract nodes and rels, ignoring other columns
379        for i in range(len(column_names)):
380            column_name = column_names[i]
381            column_type = column_types[i]
382            if column_type in [
383                Type.NODE.value,
384                Type.REL.value,
385                Type.RECURSIVE_REL.value,
386            ]:
387                properties_to_extract[i] = (column_type, column_name)
388        return properties_to_extract
389
390    def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]:  # type: ignore[type-arg]
391        """
392        Convert the nodes and rels in query result into a PyTorch Geometric graph representation
393        torch_geometric.data.Data or torch_geometric.data.HeteroData.
394
395        For node conversion, numerical and boolean properties are directly converted into tensor and
396        stored in Data/HeteroData. For properties cannot be converted into tensor automatically
397        (please refer to the notes below for more detail), they are returned as unconverted_properties.
398
399        For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned
400        as edge_properties.
401
402        Node properties that cannot be converted into tensor automatically:
403        - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted
404          automatically.
405        - If a node property contains a null value, it cannot be converted automatically.
406        - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be
407          converted automatically.
408        - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length
409          is 6 for one node but 5 for another node), it cannot be converted automatically.
410
411        Additional conversion rules:
412        - Columns with data type other than node or rel will be ignored.
413        - Duplicated nodes and rels will be converted only once.
414
415        Returns
416        -------
417        torch_geometric.data.Data or torch_geometric.data.HeteroData
418            Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties
419            and edge_index tensor.
420
421        dict
422            A dictionary that maps the positional offset of each node in Data/HeteroData to its primary
423            key in the database.
424
425        dict
426            A dictionary contains node properties that cannot be converted into tensor automatically. The
427            order of values for each property is aligned with nodes in Data/HeteroData.
428
429        dict
430            A dictionary contains edge properties. The order of values for each property is aligned with
431            edge_index in Data/HeteroData.
432        """
433        self.check_for_query_result_close()
434        # Despite we are not using torch_geometric in this file, we need to
435        # import it here to throw an error early if the user does not have
436        # torch_geometric or torch installed.
437
438        converter = TorchGeometricResultConverter(self)
439        return converter.get_as_torch_geometric()
440
441    def get_execution_time(self) -> int:
442        """
443        Get the time in ms which was required for executing the query.
444
445        Returns
446        -------
447        double
448            Query execution time as double in ms.
449
450        """
451        self.check_for_query_result_close()
452        return self._query_result.getExecutionTime()
453
454    def get_compiling_time(self) -> int:
455        """
456        Get the time in ms which was required for compiling the query.
457
458        Returns
459        -------
460        double
461            Query compile time as double in ms.
462
463        """
464        self.check_for_query_result_close()
465        return self._query_result.getCompilingTime()
466
467    def get_num_tuples(self) -> int:
468        """
469        Get the number of tuples which the query returned.
470
471        Returns
472        -------
473        int
474            Number of tuples.
475
476        """
477        self.check_for_query_result_close()
478        return self._query_result.getNumTuples()
479
480    def rows_as_dict(self, state=True) -> Self:
481        """
482        Change the format of the results, such that each row is a dict with the
483        column name as a key.
484
485        Parameters
486        ----------
487        state
488            Whether to turn dict formatting on or off. Turns it on by default.
489
490        Returns
491        -------
492        self
493            The object itself.
494
495        """
496        self.as_dict = state
497        if state:
498            self.columns = self.get_column_names()
499        return self

QueryResult stores the result of a query execution.

QueryResult(connection: Connection, query_result: '_kuzu.QueryResult')
32    def __init__(self, connection: _kuzu.Connection, query_result: _kuzu.QueryResult):  # type: ignore[name-defined]
33        """
34        Parameters
35        ----------
36        connection : _kuzu.Connection
37            The underlying C++ connection object from pybind11.
38
39        query_result : _kuzu.QueryResult
40            The underlying C++ query result object from pybind11.
41
42        """
43        self.connection = connection
44        self._query_result = query_result
45        self.is_closed = False
46        self.as_dict = False
Parameters
  • connection (_kuzu.Connection): The underlying C++ connection object from pybind11.
  • query_result (_kuzu.QueryResult): The underlying C++ query result object from pybind11.
connection
is_closed
as_dict
def has_next(self) -> bool:
71    def has_next(self) -> bool:
72        """
73        Check if there are more rows in the query result.
74
75        Returns
76        -------
77        bool
78            True if there are more rows in the query result, False otherwise.
79        """
80        self.check_for_query_result_close()
81        return self._query_result.hasNext()

Check if there are more rows in the query result.

Returns
  • bool: True if there are more rows in the query result, False otherwise.
def get_next(self) -> list[typing.Any] | dict[str, typing.Any]:
83    def get_next(self) -> list[Any] | dict[str, Any]:
84        """
85        Get the next row in the query result.
86
87        Returns
88        -------
89        list
90            Next row in the query result.
91
92        Raises
93        ------
94        Exception
95            If there are no more rows.
96        """
97        self.check_for_query_result_close()
98        row = self._query_result.getNext()
99        return _row_to_dict(self.columns, row) if self.as_dict else row

Get the next row in the query result.

Returns
  • list: Next row in the query result.
Raises
  • Exception: If there are no more rows.
def get_all(self) -> list[list[typing.Any] | dict[str, typing.Any]]:
101    def get_all(self) -> list[list[Any] | dict[str, Any]]:
102        """
103        Get the next row in the query result.
104
105        Returns
106        -------
107        list
108            All remaining rows in the query result.
109        """
110        return list(self)

Get the next row in the query result.

Returns
  • list: All remaining rows in the query result.
def get_n(self, count: int) -> list[list[typing.Any] | dict[str, typing.Any]]:
112    def get_n(self, count: int) -> list[list[Any] | dict[str, Any]]:
113        """
114        Get many rows in the query result.
115
116        Returns
117        -------
118        list
119            Up to `count` rows in the query result.
120        """
121        results = []
122        while self.has_next() and count > 0:
123            results.append(self.get_next())
124            count -= 1
125        return results

Get many rows in the query result.

Returns
  • list: Up to count rows in the query result.
def close(self) -> None:
127    def close(self) -> None:
128        """Close the query result."""
129        if not self.is_closed:
130            # Allows the connection to be garbage collected if the query result
131            # is closed manually by the user.
132            self._query_result.close()
133            self.connection = None
134            self.is_closed = True

Close the query result.

def check_for_query_result_close(self) -> None:
136    def check_for_query_result_close(self) -> None:
137        """
138        Check if the query result is closed and raise an exception if it is.
139
140        Raises
141        ------
142        Exception
143            If the query result is closed.
144
145        """
146        if self.is_closed:
147            msg = "Query result is closed"
148            raise RuntimeError(msg)

Check if the query result is closed and raise an exception if it is.

Raises
  • Exception: If the query result is closed.
def get_as_df(self) -> pandas.core.frame.DataFrame:
150    def get_as_df(self) -> pd.DataFrame:
151        """
152        Get the query result as a Pandas DataFrame.
153
154        See Also
155        --------
156        get_as_pl : Get the query result as a Polars DataFrame.
157        get_as_arrow : Get the query result as a PyArrow Table.
158
159        Returns
160        -------
161        pandas.DataFrame
162            Query result as a Pandas DataFrame.
163
164        """
165        self.check_for_query_result_close()
166
167        return self._query_result.getAsDF()

Get the query result as a Pandas DataFrame.

See Also

get_as_pl: Get the query result as a Polars DataFrame.
get_as_arrow: Get the query result as a PyArrow Table.

Returns
  • pandas.DataFrame: Query result as a Pandas DataFrame.
def get_as_pl(self) -> polars.dataframe.frame.DataFrame:
169    def get_as_pl(self) -> pl.DataFrame:
170        """
171        Get the query result as a Polars DataFrame.
172
173        See Also
174        --------
175        get_as_df : Get the query result as a Pandas DataFrame.
176        get_as_arrow : Get the query result as a PyArrow Table.
177
178        Returns
179        -------
180        polars.DataFrame
181            Query result as a Polars DataFrame.
182        """
183        import polars as pl
184
185        self.check_for_query_result_close()
186
187        # note: polars should always export just a single chunk,
188        # (eg: "-1") otherwise it will just need to rechunk anyway
189        return pl.from_arrow(  # type: ignore[return-value]
190            data=self.get_as_arrow(chunk_size=-1),
191        )

Get the query result as a Polars DataFrame.

See Also

get_as_df: Get the query result as a Pandas DataFrame.
get_as_arrow: Get the query result as a PyArrow Table.

Returns
  • polars.DataFrame: Query result as a Polars DataFrame.
def get_as_arrow(self, chunk_size: int | None = None) -> pyarrow.lib.Table:
193    def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table:
194        """
195        Get the query result as a PyArrow Table.
196
197        Parameters
198        ----------
199        chunk_size : Number of rows to include in each chunk.
200            None
201                The chunk size is adaptive and depends on the number of columns in the query result.
202            -1 or 0
203                The entire result is returned as a single chunk.
204            > 0
205                The chunk size is the number of rows specified.
206
207        See Also
208        --------
209        get_as_pl : Get the query result as a Polars DataFrame.
210        get_as_df : Get the query result as a Pandas DataFrame.
211
212        Returns
213        -------
214        pyarrow.Table
215            Query result as a PyArrow Table.
216        """
217        self.check_for_query_result_close()
218
219        if chunk_size is None:
220            # Adaptive; target 10m total elements in each chunk.
221            # (eg: if we had 10 cols, this would result in a 1m row chunk_size).
222            target_n_elems = 10_000_000
223            chunk_size = max(target_n_elems // len(self.get_column_names()), 10)
224        elif chunk_size <= 0:
225            # No chunking: return the entire result as a single chunk
226            chunk_size = self.get_num_tuples()
227
228        return self._query_result.getAsArrow(chunk_size)

Get the query result as a PyArrow Table.

Parameters
  • chunk_size (Number of rows to include in each chunk.): None The chunk size is adaptive and depends on the number of columns in the query result. -1 or 0 The entire result is returned as a single chunk. > 0 The chunk size is the number of rows specified.
See Also

get_as_pl: Get the query result as a Polars DataFrame.
get_as_df: Get the query result as a Pandas DataFrame.

Returns
  • pyarrow.Table: Query result as a PyArrow Table.
def get_column_data_types(self) -> list[str]:
230    def get_column_data_types(self) -> list[str]:
231        """
232        Get the data types of the columns in the query result.
233
234        Returns
235        -------
236        list
237            Data types of the columns in the query result.
238
239        """
240        self.check_for_query_result_close()
241        return self._query_result.getColumnDataTypes()

Get the data types of the columns in the query result.

Returns
  • list: Data types of the columns in the query result.
def get_column_names(self) -> list[str]:
243    def get_column_names(self) -> list[str]:
244        """
245        Get the names of the columns in the query result.
246
247        Returns
248        -------
249        list
250            Names of the columns in the query result.
251
252        """
253        self.check_for_query_result_close()
254        return self._query_result.getColumnNames()

Get the names of the columns in the query result.

Returns
  • list: Names of the columns in the query result.
def get_schema(self) -> dict[str, str]:
256    def get_schema(self) -> dict[str, str]:
257        """
258        Get the column schema of the query result.
259
260        Returns
261        -------
262        dict
263            Schema of the query result.
264
265        """
266        self.check_for_query_result_close()
267        return dict(
268            zip(
269                self._query_result.getColumnNames(),
270                self._query_result.getColumnDataTypes(),
271            )
272        )

Get the column schema of the query result.

Returns
  • dict: Schema of the query result.
def reset_iterator(self) -> None:
274    def reset_iterator(self) -> None:
275        """Reset the iterator of the query result."""
276        self.check_for_query_result_close()
277        self._query_result.resetIterator()

Reset the iterator of the query result.

def get_as_networkx( self, directed: bool = True) -> networkx.classes.multigraph.MultiGraph | networkx.classes.multidigraph.MultiDiGraph:
279    def get_as_networkx(
280        self,
281        directed: bool = True,  # noqa: FBT001
282    ) -> nx.MultiGraph | nx.MultiDiGraph:
283        """
284        Convert the nodes and rels in query result into a NetworkX directed or undirected graph
285        with the following rules:
286        Columns with data type other than node or rel will be ignored.
287        Duplicated nodes and rels will be converted only once.
288
289        Parameters
290        ----------
291        directed : bool
292            Whether the graph should be directed. Defaults to True.
293
294        Returns
295        -------
296        networkx.MultiDiGraph or networkx.MultiGraph
297            Query result as a NetworkX graph.
298
299        """
300        self.check_for_query_result_close()
301        import networkx as nx
302
303        nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph()
304        properties_to_extract = self._get_properties_to_extract()
305
306        self.reset_iterator()
307
308        nodes = {}
309        rels = {}
310        table_to_label_dict = {}
311        table_primary_key_dict = {}
312
313        def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str:
314            node_label = node["_label"]
315            return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}"
316
317        def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]:
318            return rel["_id"]["table"], rel["_id"]["offset"]
319
320        # De-duplicate nodes and rels
321        while self.has_next():
322            row = self.get_next()
323            for i in properties_to_extract:
324                # Skip empty nodes and rels, which may be returned by
325                # OPTIONAL MATCH
326                if row[i] is None or row[i] == {}:
327                    continue
328                column_type, _ = properties_to_extract[i]
329                if column_type == Type.NODE.value:
330                    nid = row[i]["_id"]
331                    nodes[nid["table"], nid["offset"]] = row[i]
332                    table_to_label_dict[nid["table"]] = row[i]["_label"]
333
334                elif column_type == Type.REL.value:
335                    rels[encode_rel_id(row[i])] = row[i]
336
337                elif column_type == Type.RECURSIVE_REL.value:
338                    for node in row[i]["_nodes"]:
339                        nid = node["_id"]
340                        nodes[nid["table"], nid["offset"]] = node
341                        table_to_label_dict[nid["table"]] = node["_label"]
342                    for rel in row[i]["_rels"]:
343                        for key in list(rel.keys()):
344                            if rel[key] is None:
345                                del rel[key]
346                        rels[encode_rel_id(rel)] = rel
347
348        # Add nodes
349        for node in nodes.values():
350            nid = node["_id"]
351            node_id = node["_label"] + "_" + str(nid["offset"])
352            if node["_label"] not in table_primary_key_dict:
353                props = self.connection._get_node_property_names(node["_label"])
354                for prop_name in props:
355                    if props[prop_name]["is_primary_key"]:
356                        table_primary_key_dict[node["_label"]] = prop_name
357                        break
358            node_id = encode_node_id(node, table_primary_key_dict)
359            node[node["_label"]] = True
360            nx_graph.add_node(node_id, **node)
361
362        # Add rels
363        for rel in rels.values():
364            src = rel["_src"]
365            dst = rel["_dst"]
366            src_node = nodes[src["table"], src["offset"]]
367            dst_node = nodes[dst["table"], dst["offset"]]
368            src_id = encode_node_id(src_node, table_primary_key_dict)
369            dst_id = encode_node_id(dst_node, table_primary_key_dict)
370            nx_graph.add_edge(src_id, dst_id, **rel)
371        return nx_graph

Convert the nodes and rels in query result into a NetworkX directed or undirected graph with the following rules: Columns with data type other than node or rel will be ignored. Duplicated nodes and rels will be converted only once.

Parameters
  • directed (bool): Whether the graph should be directed. Defaults to True.
Returns
  • networkx.MultiDiGraph or networkx.MultiGraph: Query result as a NetworkX graph.
def get_as_torch_geometric( self) -> tuple[torch_geometric.data.data.Data | torch_geometric.data.hetero_data.HeteroData, dict, dict, dict]:
390    def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]:  # type: ignore[type-arg]
391        """
392        Convert the nodes and rels in query result into a PyTorch Geometric graph representation
393        torch_geometric.data.Data or torch_geometric.data.HeteroData.
394
395        For node conversion, numerical and boolean properties are directly converted into tensor and
396        stored in Data/HeteroData. For properties cannot be converted into tensor automatically
397        (please refer to the notes below for more detail), they are returned as unconverted_properties.
398
399        For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned
400        as edge_properties.
401
402        Node properties that cannot be converted into tensor automatically:
403        - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted
404          automatically.
405        - If a node property contains a null value, it cannot be converted automatically.
406        - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be
407          converted automatically.
408        - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length
409          is 6 for one node but 5 for another node), it cannot be converted automatically.
410
411        Additional conversion rules:
412        - Columns with data type other than node or rel will be ignored.
413        - Duplicated nodes and rels will be converted only once.
414
415        Returns
416        -------
417        torch_geometric.data.Data or torch_geometric.data.HeteroData
418            Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties
419            and edge_index tensor.
420
421        dict
422            A dictionary that maps the positional offset of each node in Data/HeteroData to its primary
423            key in the database.
424
425        dict
426            A dictionary contains node properties that cannot be converted into tensor automatically. The
427            order of values for each property is aligned with nodes in Data/HeteroData.
428
429        dict
430            A dictionary contains edge properties. The order of values for each property is aligned with
431            edge_index in Data/HeteroData.
432        """
433        self.check_for_query_result_close()
434        # Despite we are not using torch_geometric in this file, we need to
435        # import it here to throw an error early if the user does not have
436        # torch_geometric or torch installed.
437
438        converter = TorchGeometricResultConverter(self)
439        return converter.get_as_torch_geometric()

Convert the nodes and rels in query result into a PyTorch Geometric graph representation torch_geometric.data.Data or torch_geometric.data.HeteroData.

For node conversion, numerical and boolean properties are directly converted into tensor and stored in Data/HeteroData. For properties cannot be converted into tensor automatically (please refer to the notes below for more detail), they are returned as unconverted_properties.

For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned as edge_properties.

Node properties that cannot be converted into tensor automatically:

  • If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted automatically.
  • If a node property contains a null value, it cannot be converted automatically.
  • If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be converted automatically.
  • If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length is 6 for one node but 5 for another node), it cannot be converted automatically.

Additional conversion rules:

  • Columns with data type other than node or rel will be ignored.
  • Duplicated nodes and rels will be converted only once.
Returns
  • torch_geometric.data.Data or torch_geometric.data.HeteroData: Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties and edge_index tensor.
  • dict: A dictionary that maps the positional offset of each node in Data/HeteroData to its primary key in the database.
  • dict: A dictionary contains node properties that cannot be converted into tensor automatically. The order of values for each property is aligned with nodes in Data/HeteroData.
  • dict: A dictionary contains edge properties. The order of values for each property is aligned with edge_index in Data/HeteroData.
def get_execution_time(self) -> int:
441    def get_execution_time(self) -> int:
442        """
443        Get the time in ms which was required for executing the query.
444
445        Returns
446        -------
447        double
448            Query execution time as double in ms.
449
450        """
451        self.check_for_query_result_close()
452        return self._query_result.getExecutionTime()

Get the time in ms which was required for executing the query.

Returns
  • double: Query execution time as double in ms.
def get_compiling_time(self) -> int:
454    def get_compiling_time(self) -> int:
455        """
456        Get the time in ms which was required for compiling the query.
457
458        Returns
459        -------
460        double
461            Query compile time as double in ms.
462
463        """
464        self.check_for_query_result_close()
465        return self._query_result.getCompilingTime()

Get the time in ms which was required for compiling the query.

Returns
  • double: Query compile time as double in ms.
def get_num_tuples(self) -> int:
467    def get_num_tuples(self) -> int:
468        """
469        Get the number of tuples which the query returned.
470
471        Returns
472        -------
473        int
474            Number of tuples.
475
476        """
477        self.check_for_query_result_close()
478        return self._query_result.getNumTuples()

Get the number of tuples which the query returned.

Returns
  • int: Number of tuples.
def rows_as_dict(self, state=True) -> typing_extensions.Self:
480    def rows_as_dict(self, state=True) -> Self:
481        """
482        Change the format of the results, such that each row is a dict with the
483        column name as a key.
484
485        Parameters
486        ----------
487        state
488            Whether to turn dict formatting on or off. Turns it on by default.
489
490        Returns
491        -------
492        self
493            The object itself.
494
495        """
496        self.as_dict = state
497        if state:
498            self.columns = self.get_column_names()
499        return self

Change the format of the results, such that each row is a dict with the column name as a key.

Parameters
  • state: Whether to turn dict formatting on or off. Turns it on by default.
Returns
  • self: The object itself.
class Type(enum.Enum):
 5class Type(Enum):
 6    """The type of a value in the database."""
 7
 8    ANY = "ANY"
 9    NODE = "NODE"
10    REL = "REL"
11    RECURSIVE_REL = "RECURSIVE_REL"
12    SERIAL = "SERIAL"
13    BOOL = "BOOL"
14    INT64 = "INT64"
15    INT32 = "INT32"
16    INT16 = "INT16"
17    INT8 = "INT8"
18    UINT64 = "UINT64"
19    UINT32 = "UINT32"
20    UINT16 = "UINT16"
21    UINT8 = "UINT8"
22    INT128 = "INT128"
23    DOUBLE = "DOUBLE"
24    FLOAT = "FLOAT"
25    DATE = "DATE"
26    TIMESTAMP = "TIMESTAMP"
27    TIMSTAMP_TZ = "TIMESTAMP_TZ"
28    TIMESTAMP_NS = "TIMESTAMP_NS"
29    TIMESTAMP_MS = "TIMESTAMP_MS"
30    TIMESTAMP_SEC = "TIMESTAMP_SEC"
31    INTERVAL = "INTERVAL"
32    INTERNAL_ID = "INTERNAL_ID"
33    STRING = "STRING"
34    BLOB = "BLOB"
35    UUID = "UUID"
36    LIST = "LIST"
37    ARRAY = "ARRAY"
38    STRUCT = "STRUCT"
39    MAP = "MAP"
40    UNION = "UNION"

The type of a value in the database.

ANY = <Type.ANY: 'ANY'>
NODE = <Type.NODE: 'NODE'>
REL = <Type.REL: 'REL'>
RECURSIVE_REL = <Type.RECURSIVE_REL: 'RECURSIVE_REL'>
SERIAL = <Type.SERIAL: 'SERIAL'>
BOOL = <Type.BOOL: 'BOOL'>
INT64 = <Type.INT64: 'INT64'>
INT32 = <Type.INT32: 'INT32'>
INT16 = <Type.INT16: 'INT16'>
INT8 = <Type.INT8: 'INT8'>
UINT64 = <Type.UINT64: 'UINT64'>
UINT32 = <Type.UINT32: 'UINT32'>
UINT16 = <Type.UINT16: 'UINT16'>
UINT8 = <Type.UINT8: 'UINT8'>
INT128 = <Type.INT128: 'INT128'>
DOUBLE = <Type.DOUBLE: 'DOUBLE'>
FLOAT = <Type.FLOAT: 'FLOAT'>
DATE = <Type.DATE: 'DATE'>
TIMESTAMP = <Type.TIMESTAMP: 'TIMESTAMP'>
TIMSTAMP_TZ = <Type.TIMSTAMP_TZ: 'TIMESTAMP_TZ'>
TIMESTAMP_NS = <Type.TIMESTAMP_NS: 'TIMESTAMP_NS'>
TIMESTAMP_MS = <Type.TIMESTAMP_MS: 'TIMESTAMP_MS'>
TIMESTAMP_SEC = <Type.TIMESTAMP_SEC: 'TIMESTAMP_SEC'>
INTERVAL = <Type.INTERVAL: 'INTERVAL'>
INTERNAL_ID = <Type.INTERNAL_ID: 'INTERNAL_ID'>
STRING = <Type.STRING: 'STRING'>
BLOB = <Type.BLOB: 'BLOB'>
UUID = <Type.UUID: 'UUID'>
LIST = <Type.LIST: 'LIST'>
ARRAY = <Type.ARRAY: 'ARRAY'>
STRUCT = <Type.STRUCT: 'STRUCT'>
MAP = <Type.MAP: 'MAP'>
UNION = <Type.UNION: 'UNION'>
__version__
storage_version
version