kuzu
Kuzu Python API bindings.
This package provides a Python API for Kuzu graph database management system.
To install the package, run:
python3 -m pip install kuzu
Example usage:
import kuzu
db = kuzu.Database("./test")
conn = kuzu.Connection(db)
# Define the schema
conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))")
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)")
# Load some data
conn.execute('COPY User FROM "user.csv"')
conn.execute('COPY City FROM "city.csv"')
conn.execute('COPY Follows FROM "follows.csv"')
conn.execute('COPY LivesIn FROM "lives-in.csv"')
# Query the data
results = conn.execute("MATCH (u:User) RETURN u.name, u.age;")
while results.has_next():
print(results.get_next())
The dataset used in this example can be found here.
1""" 2# Kuzu Python API bindings. 3 4This package provides a Python API for Kuzu graph database management system. 5 6To install the package, run: 7``` 8python3 -m pip install kuzu 9``` 10 11Example usage: 12```python 13import kuzu 14 15db = kuzu.Database("./test") 16conn = kuzu.Connection(db) 17 18# Define the schema 19conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))") 20conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))") 21conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)") 22conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)") 23 24# Load some data 25conn.execute('COPY User FROM "user.csv"') 26conn.execute('COPY City FROM "city.csv"') 27conn.execute('COPY Follows FROM "follows.csv"') 28conn.execute('COPY LivesIn FROM "lives-in.csv"') 29 30# Query the data 31results = conn.execute("MATCH (u:User) RETURN u.name, u.age;") 32while results.has_next(): 33 print(results.get_next()) 34``` 35 36The dataset used in this example can be found [here](https://github.com/kuzudb/kuzu/tree/master/dataset/demo-db/csv). 37 38""" 39 40from __future__ import annotations 41 42import os 43import sys 44 45# Set RTLD_GLOBAL and RTLD_LAZY flags on Linux to fix the issue with loading 46# extensions 47if sys.platform == "linux": 48 original_dlopen_flags = sys.getdlopenflags() 49 sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY) 50 51from .async_connection import AsyncConnection 52from .connection import Connection 53from .database import Database 54from .prepared_statement import PreparedStatement 55from .query_result import QueryResult 56from .types import Type 57 58 59def __getattr__(name: str) -> str | int: 60 if name in ("version", "__version__"): 61 return Database.get_version() 62 elif name == "storage_version": 63 return Database.get_storage_version() 64 else: 65 msg = f"module {__name__!r} has no attribute {name!r}" 66 raise AttributeError(msg) 67 68 69# Restore the original dlopen flags 70if sys.platform == "linux": 71 sys.setdlopenflags(original_dlopen_flags) 72 73__all__ = [ 74 "AsyncConnection", 75 "Connection", 76 "Database", 77 "PreparedStatement", 78 "QueryResult", 79 "Type", 80 "__version__", # noqa: F822 81 "storage_version", # noqa: F822 82 "version", # noqa: F822 83]
26class AsyncConnection: 27 """AsyncConnection enables asynchronous execution of queries with a pool of connections and threads.""" 28 29 def __init__( 30 self, 31 database: Database, 32 max_concurrent_queries: int = 4, 33 max_threads_per_query: int = 0, 34 ) -> None: 35 """ 36 Initialise the async connection. 37 38 Parameters 39 ---------- 40 database : Database 41 Database to connect to. 42 43 max_concurrent_queries : int 44 Maximum number of concurrent queries to execute. This corresponds to the 45 number of connections and thread pool size. Default is 4. 46 47 max_threads_per_query : int 48 Controls the maximum number of threads per connection that can be used 49 to execute one query. Default is 0, which means no limit. 50 """ 51 self.database = database 52 self.connections = [Connection(database) for _ in range(max_concurrent_queries)] 53 self.connections_counter = [0 for _ in range(max_concurrent_queries)] 54 self.lock = threading.Lock() 55 56 for conn in self.connections: 57 conn.init_connection() 58 conn.set_max_threads_for_exec(max_threads_per_query) 59 60 self.executor = ThreadPoolExecutor(max_workers=max_concurrent_queries) 61 62 def __enter__(self) -> Self: 63 return self 64 65 def __exit__( 66 self, 67 exc_type: type[BaseException] | None, 68 exc_value: BaseException | None, 69 exc_traceback: TracebackType | None, 70 ) -> None: 71 self.close() 72 73 def __del__(self) -> None: 74 self.close() 75 76 def __get_connection_with_least_queries(self) -> tuple[Connection, int]: 77 with self.lock: 78 conn_index = self.connections_counter.index(min(self.connections_counter)) 79 self.connections_counter[conn_index] += 1 80 return self.connections[conn_index], conn_index 81 82 def __decrement_connection_counter(self, conn_index: int) -> None: 83 """Decrement the query counter for a connection.""" 84 with self.lock: 85 self.connections_counter[conn_index] -= 1 86 if self.connections_counter[conn_index] < 0: 87 self.connections_counter[conn_index] = 0 88 89 def acquire_connection(self) -> Connection: 90 """ 91 Acquire a connection from the connection pool for temporary synchronous 92 calls. If the connection pool is oversubscribed, the method will return 93 the connection with the least number of queued queries. It is required 94 to release the connection by calling `release_connection` after the 95 connection is no longer needed. 96 97 Returns 98 ------- 99 Connection 100 A connection object. 101 """ 102 conn, _ = self.__get_connection_with_least_queries() 103 return conn 104 105 def release_connection(self, conn: Connection) -> None: 106 """ 107 Release a connection acquired by `acquire_connection` back to the 108 connection pool. Calling this method is required when the connection is 109 no longer needed. 110 111 Parameters 112 ---------- 113 conn : Connection 114 Connection object to release. 115 116 117 """ 118 for i, existing_conn in enumerate(self.connections): 119 if existing_conn == conn: 120 self.__decrement_connection_counter(i) 121 break 122 123 def set_query_timeout(self, timeout_in_ms: int) -> None: 124 """ 125 Set the query timeout value in ms for executing queries. 126 127 Parameters 128 ---------- 129 timeout_in_ms : int 130 query timeout value in ms for executing queries. 131 132 """ 133 for conn in self.connections: 134 conn.set_query_timeout(timeout_in_ms) 135 136 async def execute( 137 self, query: str | PreparedStatement, parameters: dict[str, Any] | None = None 138 ) -> QueryResult | list[QueryResult]: 139 """ 140 Execute a query asynchronously. 141 142 Parameters 143 ---------- 144 query : str | PreparedStatement 145 A prepared statement or a query string. 146 If a query string is given, a prepared statement will be created 147 automatically. 148 149 parameters : dict[str, Any] 150 Parameters for the query. 151 152 Returns 153 ------- 154 QueryResult 155 Query result. 156 157 """ 158 loop = asyncio.get_running_loop() 159 # If the query is a prepared statement, use the connection associated with it 160 if isinstance(query, PreparedStatement): 161 conn = query._connection 162 for i, existing_conn in enumerate(self.connections): 163 if existing_conn == conn: 164 conn_index = i 165 with self.lock: 166 self.connections_counter[conn_index] += 1 167 break 168 else: 169 conn, conn_index = self.__get_connection_with_least_queries() 170 171 try: 172 return await loop.run_in_executor(self.executor, conn.execute, query, parameters) 173 except asyncio.CancelledError: 174 conn.interrupt() 175 finally: 176 self.__decrement_connection_counter(conn_index) 177 178 async def _prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement: 179 """ 180 The only parameters supported during prepare are dataframes. 181 Any remaining parameters will be ignored and should be passed to execute(). 182 """ # noqa: D401 183 loop = asyncio.get_running_loop() 184 conn, conn_index = self.__get_connection_with_least_queries() 185 186 try: 187 preparedStatement = await loop.run_in_executor(self.executor, conn.prepare, query, parameters) 188 return preparedStatement 189 finally: 190 self.__decrement_connection_counter(conn_index) 191 192 async def prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement: 193 """ 194 Create a prepared statement for a query asynchronously. 195 196 Parameters 197 ---------- 198 query : str 199 Query to prepare. 200 parameters : dict[str, Any] 201 Parameters for the query. 202 203 Returns 204 ------- 205 PreparedStatement 206 Prepared statement. 207 208 """ 209 warnings.warn( 210 "The use of separate prepare + execute of queries is deprecated. " 211 "Please using a single call to the execute() API instead.", 212 DeprecationWarning, 213 stacklevel=2, 214 ) 215 return await self._prepare(query, parameters) 216 217 def close(self) -> None: 218 """ 219 Close all connections and shutdown the thread pool. 220 221 Note: Call to this method is optional. The connections and thread pool 222 will be closed automatically when the instance is garbage collected. 223 """ 224 for conn in self.connections: 225 conn.close() 226 227 self.executor.shutdown(wait=True)
AsyncConnection enables asynchronous execution of queries with a pool of connections and threads.
29 def __init__( 30 self, 31 database: Database, 32 max_concurrent_queries: int = 4, 33 max_threads_per_query: int = 0, 34 ) -> None: 35 """ 36 Initialise the async connection. 37 38 Parameters 39 ---------- 40 database : Database 41 Database to connect to. 42 43 max_concurrent_queries : int 44 Maximum number of concurrent queries to execute. This corresponds to the 45 number of connections and thread pool size. Default is 4. 46 47 max_threads_per_query : int 48 Controls the maximum number of threads per connection that can be used 49 to execute one query. Default is 0, which means no limit. 50 """ 51 self.database = database 52 self.connections = [Connection(database) for _ in range(max_concurrent_queries)] 53 self.connections_counter = [0 for _ in range(max_concurrent_queries)] 54 self.lock = threading.Lock() 55 56 for conn in self.connections: 57 conn.init_connection() 58 conn.set_max_threads_for_exec(max_threads_per_query) 59 60 self.executor = ThreadPoolExecutor(max_workers=max_concurrent_queries)
Initialise the async connection.
Parameters
- database (Database): Database to connect to.
- max_concurrent_queries (int): Maximum number of concurrent queries to execute. This corresponds to the number of connections and thread pool size. Default is 4.
- max_threads_per_query (int): Controls the maximum number of threads per connection that can be used to execute one query. Default is 0, which means no limit.
89 def acquire_connection(self) -> Connection: 90 """ 91 Acquire a connection from the connection pool for temporary synchronous 92 calls. If the connection pool is oversubscribed, the method will return 93 the connection with the least number of queued queries. It is required 94 to release the connection by calling `release_connection` after the 95 connection is no longer needed. 96 97 Returns 98 ------- 99 Connection 100 A connection object. 101 """ 102 conn, _ = self.__get_connection_with_least_queries() 103 return conn
Acquire a connection from the connection pool for temporary synchronous
calls. If the connection pool is oversubscribed, the method will return
the connection with the least number of queued queries. It is required
to release the connection by calling release_connection
after the
connection is no longer needed.
Returns
- Connection: A connection object.
105 def release_connection(self, conn: Connection) -> None: 106 """ 107 Release a connection acquired by `acquire_connection` back to the 108 connection pool. Calling this method is required when the connection is 109 no longer needed. 110 111 Parameters 112 ---------- 113 conn : Connection 114 Connection object to release. 115 116 117 """ 118 for i, existing_conn in enumerate(self.connections): 119 if existing_conn == conn: 120 self.__decrement_connection_counter(i) 121 break
Release a connection acquired by acquire_connection
back to the
connection pool. Calling this method is required when the connection is
no longer needed.
Parameters
- conn (Connection): Connection object to release.
123 def set_query_timeout(self, timeout_in_ms: int) -> None: 124 """ 125 Set the query timeout value in ms for executing queries. 126 127 Parameters 128 ---------- 129 timeout_in_ms : int 130 query timeout value in ms for executing queries. 131 132 """ 133 for conn in self.connections: 134 conn.set_query_timeout(timeout_in_ms)
Set the query timeout value in ms for executing queries.
Parameters
- timeout_in_ms (int): query timeout value in ms for executing queries.
136 async def execute( 137 self, query: str | PreparedStatement, parameters: dict[str, Any] | None = None 138 ) -> QueryResult | list[QueryResult]: 139 """ 140 Execute a query asynchronously. 141 142 Parameters 143 ---------- 144 query : str | PreparedStatement 145 A prepared statement or a query string. 146 If a query string is given, a prepared statement will be created 147 automatically. 148 149 parameters : dict[str, Any] 150 Parameters for the query. 151 152 Returns 153 ------- 154 QueryResult 155 Query result. 156 157 """ 158 loop = asyncio.get_running_loop() 159 # If the query is a prepared statement, use the connection associated with it 160 if isinstance(query, PreparedStatement): 161 conn = query._connection 162 for i, existing_conn in enumerate(self.connections): 163 if existing_conn == conn: 164 conn_index = i 165 with self.lock: 166 self.connections_counter[conn_index] += 1 167 break 168 else: 169 conn, conn_index = self.__get_connection_with_least_queries() 170 171 try: 172 return await loop.run_in_executor(self.executor, conn.execute, query, parameters) 173 except asyncio.CancelledError: 174 conn.interrupt() 175 finally: 176 self.__decrement_connection_counter(conn_index)
Execute a query asynchronously.
Parameters
- query (str | PreparedStatement): A prepared statement or a query string. If a query string is given, a prepared statement will be created automatically.
- parameters (dict[str, Any]): Parameters for the query.
Returns
- QueryResult: Query result.
192 async def prepare(self, query: str, parameters: dict[str, Any] | None = None) -> PreparedStatement: 193 """ 194 Create a prepared statement for a query asynchronously. 195 196 Parameters 197 ---------- 198 query : str 199 Query to prepare. 200 parameters : dict[str, Any] 201 Parameters for the query. 202 203 Returns 204 ------- 205 PreparedStatement 206 Prepared statement. 207 208 """ 209 warnings.warn( 210 "The use of separate prepare + execute of queries is deprecated. " 211 "Please using a single call to the execute() API instead.", 212 DeprecationWarning, 213 stacklevel=2, 214 ) 215 return await self._prepare(query, parameters)
Create a prepared statement for a query asynchronously.
Parameters
- query (str): Query to prepare.
- parameters (dict[str, Any]): Parameters for the query.
Returns
- PreparedStatement: Prepared statement.
217 def close(self) -> None: 218 """ 219 Close all connections and shutdown the thread pool. 220 221 Note: Call to this method is optional. The connections and thread pool 222 will be closed automatically when the instance is garbage collected. 223 """ 224 for conn in self.connections: 225 conn.close() 226 227 self.executor.shutdown(wait=True)
Close all connections and shutdown the thread pool.
Note: Call to this method is optional. The connections and thread pool will be closed automatically when the instance is garbage collected.
24class Connection: 25 """Connection to a database.""" 26 27 def __init__(self, database: Database, num_threads: int = 0): 28 """ 29 Initialise kuzu database connection. 30 31 Parameters 32 ---------- 33 database : Database 34 Database to connect to. 35 36 num_threads : int 37 Maximum number of threads to use for executing queries. 38 39 """ 40 self._connection: Any = None # (type: _kuzu.Connection from pybind11) 41 self.database = database 42 self.num_threads = num_threads 43 self.is_closed = False 44 self.init_connection() 45 46 def __getstate__(self) -> dict[str, Any]: 47 state = { 48 "database": self.database, 49 "num_threads": self.num_threads, 50 "_connection": None, 51 } 52 return state 53 54 def init_connection(self) -> None: 55 """Establish a connection to the database, if not already initalised.""" 56 if self.is_closed: 57 error_msg = "Connection is closed." 58 raise RuntimeError(error_msg) 59 self.database.init_database() 60 if self._connection is None: 61 self._connection = _kuzu.Connection(self.database._database, self.num_threads) # type: ignore[union-attr] 62 63 def set_max_threads_for_exec(self, num_threads: int) -> None: 64 """ 65 Set the maximum number of threads for executing queries. 66 67 Parameters 68 ---------- 69 num_threads : int 70 Maximum number of threads to use for executing queries. 71 72 """ 73 self.init_connection() 74 self._connection.set_max_threads_for_exec(num_threads) 75 76 def close(self) -> None: 77 """ 78 Close the connection. 79 80 Note: Call to this method is optional. The connection will be closed 81 automatically when the object goes out of scope. 82 """ 83 if self._connection is not None: 84 self._connection.close() 85 self._connection = None 86 self.is_closed = True 87 88 def __enter__(self) -> Self: 89 return self 90 91 def __exit__( 92 self, 93 exc_type: type[BaseException] | None, 94 exc_value: BaseException | None, 95 exc_traceback: TracebackType | None, 96 ) -> None: 97 self.close() 98 99 def execute( 100 self, 101 query: str | PreparedStatement, 102 parameters: dict[str, Any] | None = None, 103 ) -> QueryResult | list[QueryResult]: 104 """ 105 Execute a query. 106 107 Parameters 108 ---------- 109 query : str | PreparedStatement 110 A prepared statement or a query string. 111 If a query string is given, a prepared statement will be created 112 automatically. 113 114 parameters : dict[str, Any] 115 Parameters for the query. 116 117 Returns 118 ------- 119 QueryResult 120 Query result. 121 122 """ 123 if parameters is None: 124 parameters = {} 125 126 self.init_connection() 127 if not isinstance(parameters, dict): 128 msg = f"Parameters must be a dict; found {type(parameters)}." 129 raise RuntimeError(msg) # noqa: TRY004 130 131 if len(parameters) == 0 and isinstance(query, str): 132 query_result_internal = self._connection.query(query) 133 else: 134 prepared_statement = self._prepare(query, parameters) if isinstance(query, str) else query 135 query_result_internal = self._connection.execute(prepared_statement._prepared_statement, parameters) 136 if not query_result_internal.isSuccess(): 137 raise RuntimeError(query_result_internal.getErrorMessage()) 138 current_query_result = QueryResult(self, query_result_internal) 139 if not query_result_internal.hasNextQueryResult(): 140 return current_query_result 141 all_query_results = [current_query_result] 142 while query_result_internal.hasNextQueryResult(): 143 query_result_internal = query_result_internal.getNextQueryResult() 144 if not query_result_internal.isSuccess(): 145 raise RuntimeError(query_result_internal.getErrorMessage()) 146 all_query_results.append(QueryResult(self, query_result_internal)) 147 return all_query_results 148 149 def _prepare( 150 self, 151 query: str, 152 parameters: dict[str, Any] | None = None, 153 ) -> PreparedStatement: 154 """ 155 The only parameters supported during prepare are dataframes. 156 Any remaining parameters will be ignored and should be passed to execute(). 157 """ # noqa: D401 158 return PreparedStatement(self, query, parameters) 159 160 def prepare( 161 self, 162 query: str, 163 parameters: dict[str, Any] | None = None, 164 ) -> PreparedStatement: 165 """ 166 Create a prepared statement for a query. 167 168 Parameters 169 ---------- 170 query : str 171 Query to prepare. 172 173 parameters : dict[str, Any] 174 Parameters for the query. 175 176 Returns 177 ------- 178 PreparedStatement 179 Prepared statement. 180 181 """ 182 warnings.warn( 183 "The use of separate prepare + execute of queries is deprecated. " 184 "Please using a single call to the execute() API instead.", 185 DeprecationWarning, 186 stacklevel=2, 187 ) 188 return self._prepare(query, parameters) 189 190 def _get_node_property_names(self, table_name: str) -> dict[str, Any]: 191 LIST_START_SYMBOL = "[" 192 LIST_END_SYMBOL = "]" 193 self.init_connection() 194 query_result = self.execute(f"CALL table_info('{table_name}') RETURN *;") 195 results = {} 196 while query_result.has_next(): 197 row = query_result.get_next() 198 prop_name = row[1] 199 prop_type = row[2] 200 is_primary_key = row[4] is True 201 dimension = prop_type.count(LIST_START_SYMBOL) 202 splitted = prop_type.split(LIST_START_SYMBOL) 203 shape = [] 204 for s in splitted: 205 if LIST_END_SYMBOL not in s: 206 continue 207 s = s.split(LIST_END_SYMBOL)[0] 208 if s != "": 209 shape.append(int(s)) 210 prop_type = splitted[0] 211 results[prop_name] = { 212 "type": prop_type, 213 "dimension": dimension, 214 "is_primary_key": is_primary_key, 215 } 216 if len(shape) > 0: 217 results[prop_name]["shape"] = tuple(shape) 218 return results 219 220 def _get_node_table_names(self) -> list[Any]: 221 results = [] 222 self.init_connection() 223 query_result = self.execute("CALL show_tables() RETURN *;") 224 while query_result.has_next(): 225 row = query_result.get_next() 226 if row[2] == "NODE": 227 results.append(row[1]) 228 return results 229 230 def _get_rel_table_names(self) -> list[dict[str, Any]]: 231 results = [] 232 self.init_connection() 233 tables_result = self.execute("CALL show_tables() RETURN *;") 234 while tables_result.has_next(): 235 row = tables_result.get_next() 236 if row[2] == "REL": 237 name = row[1] 238 connections_result = self.execute(f"CALL show_connection({name!r}) RETURN *;") 239 src_dst_row = connections_result.get_next() 240 src_node = src_dst_row[0] 241 dst_node = src_dst_row[1] 242 results.append({"name": name, "src": src_node, "dst": dst_node}) 243 return results 244 245 def set_query_timeout(self, timeout_in_ms: int) -> None: 246 """ 247 Set the query timeout value in ms for executing queries. 248 249 Parameters 250 ---------- 251 timeout_in_ms : int 252 query timeout value in ms for executing queries. 253 254 """ 255 self.init_connection() 256 self._connection.set_query_timeout(timeout_in_ms) 257 258 def interrupt(self) -> None: 259 """ 260 Interrupts execution of the current query. 261 262 If there is no currently executing query, this function does nothing. 263 """ 264 self._connection.interrupt() 265 266 def create_function( 267 self, 268 name: str, 269 udf: Callable[[...], Any], 270 params_type: list[Type | str] | None = None, 271 return_type: Type | str = "", 272 *, 273 default_null_handling: bool = True, 274 catch_exceptions: bool = False, 275 ) -> None: 276 """ 277 Set a User Defined Function (UDF) for use in cypher queries. 278 279 Parameters 280 ---------- 281 name: str 282 name of function 283 284 udf: Callable[[...], Any] 285 function to be executed 286 287 params_type: Optional[list[Type]] 288 list of Type enums to describe the input parameters 289 290 return_type: Optional[Type] 291 a Type enum to describe the returned value 292 293 default_null_handling: Optional[bool] 294 if true, when any parameter is null, the resulting value will be null 295 296 catch_exceptions: Optional[bool] 297 if true, when an exception is thrown from python, the function output will be null 298 Otherwise, the exception will be rethrown 299 """ 300 if params_type is None: 301 params_type = [] 302 parsed_params_type = [x if type(x) is str else x.value for x in params_type] 303 if type(return_type) is not str: 304 return_type = return_type.value 305 306 self._connection.create_function( 307 name=name, 308 udf=udf, 309 params_type=parsed_params_type, 310 return_value=return_type, 311 default_null=default_null_handling, 312 catch_exceptions=catch_exceptions, 313 ) 314 315 def remove_function(self, name: str) -> None: 316 """ 317 Remove a User Defined Function (UDF). 318 319 Parameters 320 ---------- 321 name: str 322 name of function to be removed. 323 """ 324 self._connection.remove_function(name)
Connection to a database.
27 def __init__(self, database: Database, num_threads: int = 0): 28 """ 29 Initialise kuzu database connection. 30 31 Parameters 32 ---------- 33 database : Database 34 Database to connect to. 35 36 num_threads : int 37 Maximum number of threads to use for executing queries. 38 39 """ 40 self._connection: Any = None # (type: _kuzu.Connection from pybind11) 41 self.database = database 42 self.num_threads = num_threads 43 self.is_closed = False 44 self.init_connection()
Initialise kuzu database connection.
Parameters
- database (Database): Database to connect to.
- num_threads (int): Maximum number of threads to use for executing queries.
54 def init_connection(self) -> None: 55 """Establish a connection to the database, if not already initalised.""" 56 if self.is_closed: 57 error_msg = "Connection is closed." 58 raise RuntimeError(error_msg) 59 self.database.init_database() 60 if self._connection is None: 61 self._connection = _kuzu.Connection(self.database._database, self.num_threads) # type: ignore[union-attr]
Establish a connection to the database, if not already initalised.
63 def set_max_threads_for_exec(self, num_threads: int) -> None: 64 """ 65 Set the maximum number of threads for executing queries. 66 67 Parameters 68 ---------- 69 num_threads : int 70 Maximum number of threads to use for executing queries. 71 72 """ 73 self.init_connection() 74 self._connection.set_max_threads_for_exec(num_threads)
Set the maximum number of threads for executing queries.
Parameters
- num_threads (int): Maximum number of threads to use for executing queries.
76 def close(self) -> None: 77 """ 78 Close the connection. 79 80 Note: Call to this method is optional. The connection will be closed 81 automatically when the object goes out of scope. 82 """ 83 if self._connection is not None: 84 self._connection.close() 85 self._connection = None 86 self.is_closed = True
Close the connection.
Note: Call to this method is optional. The connection will be closed automatically when the object goes out of scope.
99 def execute( 100 self, 101 query: str | PreparedStatement, 102 parameters: dict[str, Any] | None = None, 103 ) -> QueryResult | list[QueryResult]: 104 """ 105 Execute a query. 106 107 Parameters 108 ---------- 109 query : str | PreparedStatement 110 A prepared statement or a query string. 111 If a query string is given, a prepared statement will be created 112 automatically. 113 114 parameters : dict[str, Any] 115 Parameters for the query. 116 117 Returns 118 ------- 119 QueryResult 120 Query result. 121 122 """ 123 if parameters is None: 124 parameters = {} 125 126 self.init_connection() 127 if not isinstance(parameters, dict): 128 msg = f"Parameters must be a dict; found {type(parameters)}." 129 raise RuntimeError(msg) # noqa: TRY004 130 131 if len(parameters) == 0 and isinstance(query, str): 132 query_result_internal = self._connection.query(query) 133 else: 134 prepared_statement = self._prepare(query, parameters) if isinstance(query, str) else query 135 query_result_internal = self._connection.execute(prepared_statement._prepared_statement, parameters) 136 if not query_result_internal.isSuccess(): 137 raise RuntimeError(query_result_internal.getErrorMessage()) 138 current_query_result = QueryResult(self, query_result_internal) 139 if not query_result_internal.hasNextQueryResult(): 140 return current_query_result 141 all_query_results = [current_query_result] 142 while query_result_internal.hasNextQueryResult(): 143 query_result_internal = query_result_internal.getNextQueryResult() 144 if not query_result_internal.isSuccess(): 145 raise RuntimeError(query_result_internal.getErrorMessage()) 146 all_query_results.append(QueryResult(self, query_result_internal)) 147 return all_query_results
Execute a query.
Parameters
- query (str | PreparedStatement): A prepared statement or a query string. If a query string is given, a prepared statement will be created automatically.
- parameters (dict[str, Any]): Parameters for the query.
Returns
- QueryResult: Query result.
160 def prepare( 161 self, 162 query: str, 163 parameters: dict[str, Any] | None = None, 164 ) -> PreparedStatement: 165 """ 166 Create a prepared statement for a query. 167 168 Parameters 169 ---------- 170 query : str 171 Query to prepare. 172 173 parameters : dict[str, Any] 174 Parameters for the query. 175 176 Returns 177 ------- 178 PreparedStatement 179 Prepared statement. 180 181 """ 182 warnings.warn( 183 "The use of separate prepare + execute of queries is deprecated. " 184 "Please using a single call to the execute() API instead.", 185 DeprecationWarning, 186 stacklevel=2, 187 ) 188 return self._prepare(query, parameters)
Create a prepared statement for a query.
Parameters
- query (str): Query to prepare.
- parameters (dict[str, Any]): Parameters for the query.
Returns
- PreparedStatement: Prepared statement.
245 def set_query_timeout(self, timeout_in_ms: int) -> None: 246 """ 247 Set the query timeout value in ms for executing queries. 248 249 Parameters 250 ---------- 251 timeout_in_ms : int 252 query timeout value in ms for executing queries. 253 254 """ 255 self.init_connection() 256 self._connection.set_query_timeout(timeout_in_ms)
Set the query timeout value in ms for executing queries.
Parameters
- timeout_in_ms (int): query timeout value in ms for executing queries.
258 def interrupt(self) -> None: 259 """ 260 Interrupts execution of the current query. 261 262 If there is no currently executing query, this function does nothing. 263 """ 264 self._connection.interrupt()
Interrupts execution of the current query.
If there is no currently executing query, this function does nothing.
266 def create_function( 267 self, 268 name: str, 269 udf: Callable[[...], Any], 270 params_type: list[Type | str] | None = None, 271 return_type: Type | str = "", 272 *, 273 default_null_handling: bool = True, 274 catch_exceptions: bool = False, 275 ) -> None: 276 """ 277 Set a User Defined Function (UDF) for use in cypher queries. 278 279 Parameters 280 ---------- 281 name: str 282 name of function 283 284 udf: Callable[[...], Any] 285 function to be executed 286 287 params_type: Optional[list[Type]] 288 list of Type enums to describe the input parameters 289 290 return_type: Optional[Type] 291 a Type enum to describe the returned value 292 293 default_null_handling: Optional[bool] 294 if true, when any parameter is null, the resulting value will be null 295 296 catch_exceptions: Optional[bool] 297 if true, when an exception is thrown from python, the function output will be null 298 Otherwise, the exception will be rethrown 299 """ 300 if params_type is None: 301 params_type = [] 302 parsed_params_type = [x if type(x) is str else x.value for x in params_type] 303 if type(return_type) is not str: 304 return_type = return_type.value 305 306 self._connection.create_function( 307 name=name, 308 udf=udf, 309 params_type=parsed_params_type, 310 return_value=return_type, 311 default_null=default_null_handling, 312 catch_exceptions=catch_exceptions, 313 )
Set a User Defined Function (UDF) for use in cypher queries.
Parameters
- name (str): name of function
- udf (Callable[[...], Any]): function to be executed
- params_type (Optional[list[Type]]): list of Type enums to describe the input parameters
- return_type (Optional[Type]): a Type enum to describe the returned value
- default_null_handling (Optional[bool]): if true, when any parameter is null, the resulting value will be null
- catch_exceptions (Optional[bool]): if true, when an exception is thrown from python, the function output will be null Otherwise, the exception will be rethrown
315 def remove_function(self, name: str) -> None: 316 """ 317 Remove a User Defined Function (UDF). 318 319 Parameters 320 ---------- 321 name: str 322 name of function to be removed. 323 """ 324 self._connection.remove_function(name)
Remove a User Defined Function (UDF).
Parameters
- name (str): name of function to be removed.
26class Database: 27 """Kuzu database instance.""" 28 29 def __init__( 30 self, 31 database_path: str | Path | None = None, 32 *, 33 buffer_pool_size: int = 0, 34 max_num_threads: int = 0, 35 compression: bool = True, 36 lazy_init: bool = False, 37 read_only: bool = False, 38 max_db_size: int = (1 << 43), 39 auto_checkpoint: bool = True, 40 checkpoint_threshold: int = -1, 41 ): 42 """ 43 Parameters 44 ---------- 45 database_path : str, Path 46 The path to database files. If the path is not specified, or empty, or equal to `:memory:`, the database 47 will be created in memory. 48 49 buffer_pool_size : int 50 The maximum size of buffer pool in bytes. Defaults to ~80% of system memory. 51 52 max_num_threads : int 53 The maximum number of threads to use for executing queries. 54 55 compression : bool 56 Enable database compression. 57 58 lazy_init : bool 59 If True, the database will not be initialized until the first query. 60 This is useful when the database is not used in the main thread or 61 when the main process is forked. 62 Default to False. 63 64 read_only : bool 65 If true, the database is opened read-only. No write transactions is 66 allowed on the `Database` object. Multiple read-only `Database` 67 objects can be created with the same database path. However, there 68 cannot be multiple `Database` objects created with the same 69 database path. 70 Default to False. 71 72 max_db_size : int 73 The maximum size of the database in bytes. Note that this is introduced 74 temporarily for now to get around with the default 8TB mmap address 75 space limit some environment. This will be removed once we implemente 76 a better solution later. The value is default to 1 << 43 (8TB) under 64-bit 77 environment and 1GB under 32-bit one. 78 79 auto_checkpoint: bool 80 If true, the database will automatically checkpoint when the size of 81 the WAL file exceeds the checkpoint threshold. 82 83 checkpoint_threshold: int 84 The threshold of the WAL file size in bytes. When the size of the 85 WAL file exceeds this threshold, the database will checkpoint if autoCheckpoint is true. 86 87 """ 88 if database_path is None: 89 database_path = ":memory:" 90 if isinstance(database_path, Path): 91 database_path = str(database_path) 92 93 self.database_path = database_path 94 self.buffer_pool_size = buffer_pool_size 95 self.max_num_threads = max_num_threads 96 self.compression = compression 97 self.read_only = read_only 98 self.max_db_size = max_db_size 99 self.auto_checkpoint = auto_checkpoint 100 self.checkpoint_threshold = checkpoint_threshold 101 self.is_closed = False 102 103 self._database: Any = None # (type: _kuzu.Database from pybind11) 104 if not lazy_init: 105 self.init_database() 106 107 def __enter__(self) -> Self: 108 return self 109 110 def __exit__( 111 self, 112 exc_type: type[BaseException] | None, 113 exc_value: BaseException | None, 114 exc_traceback: TracebackType | None, 115 ) -> None: 116 self.close() 117 118 @staticmethod 119 def get_version() -> str: 120 """ 121 Get the version of the database. 122 123 Returns 124 ------- 125 str 126 The version of the database. 127 """ 128 return _kuzu.Database.get_version() # type: ignore[union-attr] 129 130 @staticmethod 131 def get_storage_version() -> int: 132 """ 133 Get the storage version of the database. 134 135 Returns 136 ------- 137 int 138 The storage version of the database. 139 """ 140 return _kuzu.Database.get_storage_version() # type: ignore[union-attr] 141 142 def __getstate__(self) -> dict[str, Any]: 143 state = { 144 "database_path": self.database_path, 145 "buffer_pool_size": self.buffer_pool_size, 146 "compression": self.compression, 147 "read_only": self.read_only, 148 "_database": None, 149 } 150 return state 151 152 def init_database(self) -> None: 153 """Initialize the database.""" 154 self.check_for_database_close() 155 if self._database is None: 156 self._database = _kuzu.Database( # type: ignore[union-attr] 157 self.database_path, 158 self.buffer_pool_size, 159 self.max_num_threads, 160 self.compression, 161 self.read_only, 162 self.max_db_size, 163 self.auto_checkpoint, 164 self.checkpoint_threshold, 165 ) 166 167 def get_torch_geometric_remote_backend( 168 self, num_threads: int | None = None 169 ) -> tuple[KuzuFeatureStore, KuzuGraphStore]: 170 """ 171 Use the database as the remote backend for torch_geometric. 172 173 For the interface of the remote backend, please refer to 174 https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html. 175 The current implementation is read-only and does not support edge 176 features. The IDs of the nodes are based on the internal IDs (i.e., node 177 offsets). For the remote node IDs to be consistent with the positions in 178 the output tensors, please ensure that no deletion has been performed 179 on the node tables. 180 181 The remote backend can also be plugged into the data loader of 182 torch_geometric, which is useful for mini-batch training. For example: 183 184 ```python 185 loader_kuzu = NeighborLoader( 186 data=(feature_store, graph_store), 187 num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]}, 188 batch_size=LOADER_BATCH_SIZE, 189 input_nodes=('paper', input_nodes), 190 num_workers=4, 191 filter_per_worker=False, 192 ) 193 ``` 194 195 Please note that the database instance is not fork-safe, so if more than 196 one worker is used, `filter_per_worker` must be set to False. 197 198 Parameters 199 ---------- 200 num_threads : int 201 Number of threads to use for data loading. Default to None, which 202 means using the number of CPU cores. 203 204 Returns 205 ------- 206 feature_store : KuzuFeatureStore 207 Feature store compatible with torch_geometric. 208 graph_store : KuzuGraphStore 209 Graph store compatible with torch_geometric. 210 """ 211 self.check_for_database_close() 212 from .torch_geometric_feature_store import KuzuFeatureStore 213 from .torch_geometric_graph_store import KuzuGraphStore 214 215 return ( 216 KuzuFeatureStore(self, num_threads), 217 KuzuGraphStore(self, num_threads), 218 ) 219 220 def _scan_node_table( 221 self, 222 table_name: str, 223 prop_name: str, 224 prop_type: str, 225 dim: int, 226 indices: IndexType, 227 num_threads: int, 228 ) -> NDArray[Any]: 229 self.check_for_database_close() 230 import numpy as np 231 232 """ 233 Scan a node table from storage directly, bypassing query engine. 234 Used internally by torch_geometric remote backend only. 235 """ 236 self.init_database() 237 indices_cast = np.array(indices, dtype=np.uint64) 238 result = None 239 240 if prop_type == Type.INT64.value: 241 result = np.empty(len(indices) * dim, dtype=np.int64) 242 self._database.scan_node_table_as_int64(table_name, prop_name, indices_cast, result, num_threads) 243 elif prop_type == Type.INT32.value: 244 result = np.empty(len(indices) * dim, dtype=np.int32) 245 self._database.scan_node_table_as_int32(table_name, prop_name, indices_cast, result, num_threads) 246 elif prop_type == Type.INT16.value: 247 result = np.empty(len(indices) * dim, dtype=np.int16) 248 self._database.scan_node_table_as_int16(table_name, prop_name, indices_cast, result, num_threads) 249 elif prop_type == Type.DOUBLE.value: 250 result = np.empty(len(indices) * dim, dtype=np.float64) 251 self._database.scan_node_table_as_double(table_name, prop_name, indices_cast, result, num_threads) 252 elif prop_type == Type.FLOAT.value: 253 result = np.empty(len(indices) * dim, dtype=np.float32) 254 self._database.scan_node_table_as_float(table_name, prop_name, indices_cast, result, num_threads) 255 256 if result is not None: 257 return result 258 259 msg = f"Unsupported property type: {prop_type}" 260 raise ValueError(msg) 261 262 def close(self) -> None: 263 """ 264 Close the database. Once the database is closed, the lock on the database 265 files is released and the database can be opened in another process. 266 267 Note: Call to this method is not required. The Python garbage collector 268 will automatically close the database when no references to the database 269 object exist. It is recommended not to call this method explicitly. If you 270 decide to manually close the database, make sure that all the QueryResult 271 and Connection objects are closed before calling this method. 272 """ 273 if self.is_closed: 274 return 275 self.is_closed = True 276 if self._database is not None: 277 self._database.close() 278 self._database: Any = None # (type: _kuzu.Database from pybind11) 279 280 def check_for_database_close(self) -> None: 281 """ 282 Check if the database is closed and raise an exception if it is. 283 284 Raises 285 ------ 286 Exception 287 If the database is closed. 288 289 """ 290 if not self.is_closed: 291 return 292 msg = "Database is closed" 293 raise RuntimeError(msg)
Kuzu database instance.
29 def __init__( 30 self, 31 database_path: str | Path | None = None, 32 *, 33 buffer_pool_size: int = 0, 34 max_num_threads: int = 0, 35 compression: bool = True, 36 lazy_init: bool = False, 37 read_only: bool = False, 38 max_db_size: int = (1 << 43), 39 auto_checkpoint: bool = True, 40 checkpoint_threshold: int = -1, 41 ): 42 """ 43 Parameters 44 ---------- 45 database_path : str, Path 46 The path to database files. If the path is not specified, or empty, or equal to `:memory:`, the database 47 will be created in memory. 48 49 buffer_pool_size : int 50 The maximum size of buffer pool in bytes. Defaults to ~80% of system memory. 51 52 max_num_threads : int 53 The maximum number of threads to use for executing queries. 54 55 compression : bool 56 Enable database compression. 57 58 lazy_init : bool 59 If True, the database will not be initialized until the first query. 60 This is useful when the database is not used in the main thread or 61 when the main process is forked. 62 Default to False. 63 64 read_only : bool 65 If true, the database is opened read-only. No write transactions is 66 allowed on the `Database` object. Multiple read-only `Database` 67 objects can be created with the same database path. However, there 68 cannot be multiple `Database` objects created with the same 69 database path. 70 Default to False. 71 72 max_db_size : int 73 The maximum size of the database in bytes. Note that this is introduced 74 temporarily for now to get around with the default 8TB mmap address 75 space limit some environment. This will be removed once we implemente 76 a better solution later. The value is default to 1 << 43 (8TB) under 64-bit 77 environment and 1GB under 32-bit one. 78 79 auto_checkpoint: bool 80 If true, the database will automatically checkpoint when the size of 81 the WAL file exceeds the checkpoint threshold. 82 83 checkpoint_threshold: int 84 The threshold of the WAL file size in bytes. When the size of the 85 WAL file exceeds this threshold, the database will checkpoint if autoCheckpoint is true. 86 87 """ 88 if database_path is None: 89 database_path = ":memory:" 90 if isinstance(database_path, Path): 91 database_path = str(database_path) 92 93 self.database_path = database_path 94 self.buffer_pool_size = buffer_pool_size 95 self.max_num_threads = max_num_threads 96 self.compression = compression 97 self.read_only = read_only 98 self.max_db_size = max_db_size 99 self.auto_checkpoint = auto_checkpoint 100 self.checkpoint_threshold = checkpoint_threshold 101 self.is_closed = False 102 103 self._database: Any = None # (type: _kuzu.Database from pybind11) 104 if not lazy_init: 105 self.init_database()
Parameters
- database_path (str, Path):
The path to database files. If the path is not specified, or empty, or equal to
:memory:
, the database will be created in memory. - buffer_pool_size (int): The maximum size of buffer pool in bytes. Defaults to ~80% of system memory.
- max_num_threads (int): The maximum number of threads to use for executing queries.
- compression (bool): Enable database compression.
- lazy_init (bool): If True, the database will not be initialized until the first query. This is useful when the database is not used in the main thread or when the main process is forked. Default to False.
- read_only (bool):
If true, the database is opened read-only. No write transactions is
allowed on the
Database
object. Multiple read-onlyDatabase
objects can be created with the same database path. However, there cannot be multipleDatabase
objects created with the same database path. Default to False. - max_db_size (int): The maximum size of the database in bytes. Note that this is introduced temporarily for now to get around with the default 8TB mmap address space limit some environment. This will be removed once we implemente a better solution later. The value is default to 1 << 43 (8TB) under 64-bit environment and 1GB under 32-bit one.
- auto_checkpoint (bool): If true, the database will automatically checkpoint when the size of the WAL file exceeds the checkpoint threshold.
- checkpoint_threshold (int): The threshold of the WAL file size in bytes. When the size of the WAL file exceeds this threshold, the database will checkpoint if autoCheckpoint is true.
118 @staticmethod 119 def get_version() -> str: 120 """ 121 Get the version of the database. 122 123 Returns 124 ------- 125 str 126 The version of the database. 127 """ 128 return _kuzu.Database.get_version() # type: ignore[union-attr]
Get the version of the database.
Returns
- str: The version of the database.
130 @staticmethod 131 def get_storage_version() -> int: 132 """ 133 Get the storage version of the database. 134 135 Returns 136 ------- 137 int 138 The storage version of the database. 139 """ 140 return _kuzu.Database.get_storage_version() # type: ignore[union-attr]
Get the storage version of the database.
Returns
- int: The storage version of the database.
152 def init_database(self) -> None: 153 """Initialize the database.""" 154 self.check_for_database_close() 155 if self._database is None: 156 self._database = _kuzu.Database( # type: ignore[union-attr] 157 self.database_path, 158 self.buffer_pool_size, 159 self.max_num_threads, 160 self.compression, 161 self.read_only, 162 self.max_db_size, 163 self.auto_checkpoint, 164 self.checkpoint_threshold, 165 )
Initialize the database.
167 def get_torch_geometric_remote_backend( 168 self, num_threads: int | None = None 169 ) -> tuple[KuzuFeatureStore, KuzuGraphStore]: 170 """ 171 Use the database as the remote backend for torch_geometric. 172 173 For the interface of the remote backend, please refer to 174 https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html. 175 The current implementation is read-only and does not support edge 176 features. The IDs of the nodes are based on the internal IDs (i.e., node 177 offsets). For the remote node IDs to be consistent with the positions in 178 the output tensors, please ensure that no deletion has been performed 179 on the node tables. 180 181 The remote backend can also be plugged into the data loader of 182 torch_geometric, which is useful for mini-batch training. For example: 183 184 ```python 185 loader_kuzu = NeighborLoader( 186 data=(feature_store, graph_store), 187 num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]}, 188 batch_size=LOADER_BATCH_SIZE, 189 input_nodes=('paper', input_nodes), 190 num_workers=4, 191 filter_per_worker=False, 192 ) 193 ``` 194 195 Please note that the database instance is not fork-safe, so if more than 196 one worker is used, `filter_per_worker` must be set to False. 197 198 Parameters 199 ---------- 200 num_threads : int 201 Number of threads to use for data loading. Default to None, which 202 means using the number of CPU cores. 203 204 Returns 205 ------- 206 feature_store : KuzuFeatureStore 207 Feature store compatible with torch_geometric. 208 graph_store : KuzuGraphStore 209 Graph store compatible with torch_geometric. 210 """ 211 self.check_for_database_close() 212 from .torch_geometric_feature_store import KuzuFeatureStore 213 from .torch_geometric_graph_store import KuzuGraphStore 214 215 return ( 216 KuzuFeatureStore(self, num_threads), 217 KuzuGraphStore(self, num_threads), 218 )
Use the database as the remote backend for torch_geometric.
For the interface of the remote backend, please refer to https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html. The current implementation is read-only and does not support edge features. The IDs of the nodes are based on the internal IDs (i.e., node offsets). For the remote node IDs to be consistent with the positions in the output tensors, please ensure that no deletion has been performed on the node tables.
The remote backend can also be plugged into the data loader of torch_geometric, which is useful for mini-batch training. For example:
loader_kuzu = NeighborLoader(
data=(feature_store, graph_store),
num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
batch_size=LOADER_BATCH_SIZE,
input_nodes=('paper', input_nodes),
num_workers=4,
filter_per_worker=False,
)
Please note that the database instance is not fork-safe, so if more than
one worker is used, filter_per_worker
must be set to False.
Parameters
- num_threads (int): Number of threads to use for data loading. Default to None, which means using the number of CPU cores.
Returns
- feature_store (KuzuFeatureStore): Feature store compatible with torch_geometric.
- graph_store (KuzuGraphStore): Graph store compatible with torch_geometric.
262 def close(self) -> None: 263 """ 264 Close the database. Once the database is closed, the lock on the database 265 files is released and the database can be opened in another process. 266 267 Note: Call to this method is not required. The Python garbage collector 268 will automatically close the database when no references to the database 269 object exist. It is recommended not to call this method explicitly. If you 270 decide to manually close the database, make sure that all the QueryResult 271 and Connection objects are closed before calling this method. 272 """ 273 if self.is_closed: 274 return 275 self.is_closed = True 276 if self._database is not None: 277 self._database.close() 278 self._database: Any = None # (type: _kuzu.Database from pybind11)
Close the database. Once the database is closed, the lock on the database files is released and the database can be opened in another process.
Note: Call to this method is not required. The Python garbage collector will automatically close the database when no references to the database object exist. It is recommended not to call this method explicitly. If you decide to manually close the database, make sure that all the QueryResult and Connection objects are closed before calling this method.
280 def check_for_database_close(self) -> None: 281 """ 282 Check if the database is closed and raise an exception if it is. 283 284 Raises 285 ------ 286 Exception 287 If the database is closed. 288 289 """ 290 if not self.is_closed: 291 return 292 msg = "Database is closed" 293 raise RuntimeError(msg)
Check if the database is closed and raise an exception if it is.
Raises
- Exception: If the database is closed.
10class PreparedStatement: 11 """ 12 A prepared statement is a parameterized query which can avoid planning the 13 same query for repeated execution. 14 """ 15 16 def __init__(self, connection: Connection, query: str, parameters: dict[str, Any] | None = None): 17 """ 18 Parameters 19 ---------- 20 connection : Connection 21 Connection to a database. 22 query : str 23 Query to prepare. 24 parameters : dict[str, Any] 25 Parameters for the query. 26 """ 27 if parameters is None: 28 parameters = {} 29 self._prepared_statement = connection._connection.prepare(query, parameters) 30 self._connection = connection 31 32 def is_success(self) -> bool: 33 """ 34 Check if the prepared statement is successfully prepared. 35 36 Returns 37 ------- 38 bool 39 True if the prepared statement is successfully prepared. 40 """ 41 return self._prepared_statement.is_success() 42 43 def get_error_message(self) -> str: 44 """ 45 Get the error message if the query is not prepared successfully. 46 47 Returns 48 ------- 49 str 50 Error message. 51 """ 52 return self._prepared_statement.get_error_message()
A prepared statement is a parameterized query which can avoid planning the same query for repeated execution.
16 def __init__(self, connection: Connection, query: str, parameters: dict[str, Any] | None = None): 17 """ 18 Parameters 19 ---------- 20 connection : Connection 21 Connection to a database. 22 query : str 23 Query to prepare. 24 parameters : dict[str, Any] 25 Parameters for the query. 26 """ 27 if parameters is None: 28 parameters = {} 29 self._prepared_statement = connection._connection.prepare(query, parameters) 30 self._connection = connection
Parameters
- connection (Connection): Connection to a database.
- query (str): Query to prepare.
- parameters (dict[str, Any]): Parameters for the query.
32 def is_success(self) -> bool: 33 """ 34 Check if the prepared statement is successfully prepared. 35 36 Returns 37 ------- 38 bool 39 True if the prepared statement is successfully prepared. 40 """ 41 return self._prepared_statement.is_success()
Check if the prepared statement is successfully prepared.
Returns
- bool: True if the prepared statement is successfully prepared.
43 def get_error_message(self) -> str: 44 """ 45 Get the error message if the query is not prepared successfully. 46 47 Returns 48 ------- 49 str 50 Error message. 51 """ 52 return self._prepared_statement.get_error_message()
Get the error message if the query is not prepared successfully.
Returns
- str: Error message.
29class QueryResult: 30 """QueryResult stores the result of a query execution.""" 31 32 def __init__(self, connection: _kuzu.Connection, query_result: _kuzu.QueryResult): # type: ignore[name-defined] 33 """ 34 Parameters 35 ---------- 36 connection : _kuzu.Connection 37 The underlying C++ connection object from pybind11. 38 39 query_result : _kuzu.QueryResult 40 The underlying C++ query result object from pybind11. 41 42 """ 43 self.connection = connection 44 self._query_result = query_result 45 self.is_closed = False 46 self.as_dict = False 47 48 def __enter__(self) -> Self: 49 return self 50 51 def __exit__( 52 self, 53 exc_type: type[BaseException] | None, 54 exc_value: BaseException | None, 55 exc_traceback: TracebackType | None, 56 ) -> None: 57 self.close() 58 59 def __del__(self) -> None: 60 self.close() 61 62 def __iter__(self) -> Iterator[list[Any] | dict[str, Any]]: 63 return self 64 65 def __next__(self) -> list[Any] | dict[str, Any]: 66 if self.has_next(): 67 return self.get_next() 68 69 raise StopIteration 70 71 def has_next(self) -> bool: 72 """ 73 Check if there are more rows in the query result. 74 75 Returns 76 ------- 77 bool 78 True if there are more rows in the query result, False otherwise. 79 """ 80 self.check_for_query_result_close() 81 return self._query_result.hasNext() 82 83 def get_next(self) -> list[Any] | dict[str, Any]: 84 """ 85 Get the next row in the query result. 86 87 Returns 88 ------- 89 list 90 Next row in the query result. 91 92 Raises 93 ------ 94 Exception 95 If there are no more rows. 96 """ 97 self.check_for_query_result_close() 98 row = self._query_result.getNext() 99 return _row_to_dict(self.columns, row) if self.as_dict else row 100 101 def get_all(self) -> list[list[Any] | dict[str, Any]]: 102 """ 103 Get the next row in the query result. 104 105 Returns 106 ------- 107 list 108 All remaining rows in the query result. 109 """ 110 return list(self) 111 112 def get_n(self, count: int) -> list[list[Any] | dict[str, Any]]: 113 """ 114 Get many rows in the query result. 115 116 Returns 117 ------- 118 list 119 Up to `count` rows in the query result. 120 """ 121 results = [] 122 while self.has_next() and count > 0: 123 results.append(self.get_next()) 124 count -= 1 125 return results 126 127 def close(self) -> None: 128 """Close the query result.""" 129 if not self.is_closed: 130 # Allows the connection to be garbage collected if the query result 131 # is closed manually by the user. 132 self._query_result.close() 133 self.connection = None 134 self.is_closed = True 135 136 def check_for_query_result_close(self) -> None: 137 """ 138 Check if the query result is closed and raise an exception if it is. 139 140 Raises 141 ------ 142 Exception 143 If the query result is closed. 144 145 """ 146 if self.is_closed: 147 msg = "Query result is closed" 148 raise RuntimeError(msg) 149 150 def get_as_df(self) -> pd.DataFrame: 151 """ 152 Get the query result as a Pandas DataFrame. 153 154 See Also 155 -------- 156 get_as_pl : Get the query result as a Polars DataFrame. 157 get_as_arrow : Get the query result as a PyArrow Table. 158 159 Returns 160 ------- 161 pandas.DataFrame 162 Query result as a Pandas DataFrame. 163 164 """ 165 self.check_for_query_result_close() 166 167 return self._query_result.getAsDF() 168 169 def get_as_pl(self) -> pl.DataFrame: 170 """ 171 Get the query result as a Polars DataFrame. 172 173 See Also 174 -------- 175 get_as_df : Get the query result as a Pandas DataFrame. 176 get_as_arrow : Get the query result as a PyArrow Table. 177 178 Returns 179 ------- 180 polars.DataFrame 181 Query result as a Polars DataFrame. 182 """ 183 import polars as pl 184 185 self.check_for_query_result_close() 186 187 # note: polars should always export just a single chunk, 188 # (eg: "-1") otherwise it will just need to rechunk anyway 189 return pl.from_arrow( # type: ignore[return-value] 190 data=self.get_as_arrow(chunk_size=-1), 191 ) 192 193 def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table: 194 """ 195 Get the query result as a PyArrow Table. 196 197 Parameters 198 ---------- 199 chunk_size : Number of rows to include in each chunk. 200 None 201 The chunk size is adaptive and depends on the number of columns in the query result. 202 -1 or 0 203 The entire result is returned as a single chunk. 204 > 0 205 The chunk size is the number of rows specified. 206 207 See Also 208 -------- 209 get_as_pl : Get the query result as a Polars DataFrame. 210 get_as_df : Get the query result as a Pandas DataFrame. 211 212 Returns 213 ------- 214 pyarrow.Table 215 Query result as a PyArrow Table. 216 """ 217 self.check_for_query_result_close() 218 219 if chunk_size is None: 220 # Adaptive; target 10m total elements in each chunk. 221 # (eg: if we had 10 cols, this would result in a 1m row chunk_size). 222 target_n_elems = 10_000_000 223 chunk_size = max(target_n_elems // len(self.get_column_names()), 10) 224 elif chunk_size <= 0: 225 # No chunking: return the entire result as a single chunk 226 chunk_size = self.get_num_tuples() 227 228 return self._query_result.getAsArrow(chunk_size) 229 230 def get_column_data_types(self) -> list[str]: 231 """ 232 Get the data types of the columns in the query result. 233 234 Returns 235 ------- 236 list 237 Data types of the columns in the query result. 238 239 """ 240 self.check_for_query_result_close() 241 return self._query_result.getColumnDataTypes() 242 243 def get_column_names(self) -> list[str]: 244 """ 245 Get the names of the columns in the query result. 246 247 Returns 248 ------- 249 list 250 Names of the columns in the query result. 251 252 """ 253 self.check_for_query_result_close() 254 return self._query_result.getColumnNames() 255 256 def get_schema(self) -> dict[str, str]: 257 """ 258 Get the column schema of the query result. 259 260 Returns 261 ------- 262 dict 263 Schema of the query result. 264 265 """ 266 self.check_for_query_result_close() 267 return dict( 268 zip( 269 self._query_result.getColumnNames(), 270 self._query_result.getColumnDataTypes(), 271 ) 272 ) 273 274 def reset_iterator(self) -> None: 275 """Reset the iterator of the query result.""" 276 self.check_for_query_result_close() 277 self._query_result.resetIterator() 278 279 def get_as_networkx( 280 self, 281 directed: bool = True, # noqa: FBT001 282 ) -> nx.MultiGraph | nx.MultiDiGraph: 283 """ 284 Convert the nodes and rels in query result into a NetworkX directed or undirected graph 285 with the following rules: 286 Columns with data type other than node or rel will be ignored. 287 Duplicated nodes and rels will be converted only once. 288 289 Parameters 290 ---------- 291 directed : bool 292 Whether the graph should be directed. Defaults to True. 293 294 Returns 295 ------- 296 networkx.MultiDiGraph or networkx.MultiGraph 297 Query result as a NetworkX graph. 298 299 """ 300 self.check_for_query_result_close() 301 import networkx as nx 302 303 nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph() 304 properties_to_extract = self._get_properties_to_extract() 305 306 self.reset_iterator() 307 308 nodes = {} 309 rels = {} 310 table_to_label_dict = {} 311 table_primary_key_dict = {} 312 313 def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str: 314 node_label = node["_label"] 315 return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}" 316 317 def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]: 318 return rel["_id"]["table"], rel["_id"]["offset"] 319 320 # De-duplicate nodes and rels 321 while self.has_next(): 322 row = self.get_next() 323 for i in properties_to_extract: 324 # Skip empty nodes and rels, which may be returned by 325 # OPTIONAL MATCH 326 if row[i] is None or row[i] == {}: 327 continue 328 column_type, _ = properties_to_extract[i] 329 if column_type == Type.NODE.value: 330 nid = row[i]["_id"] 331 nodes[nid["table"], nid["offset"]] = row[i] 332 table_to_label_dict[nid["table"]] = row[i]["_label"] 333 334 elif column_type == Type.REL.value: 335 rels[encode_rel_id(row[i])] = row[i] 336 337 elif column_type == Type.RECURSIVE_REL.value: 338 for node in row[i]["_nodes"]: 339 nid = node["_id"] 340 nodes[nid["table"], nid["offset"]] = node 341 table_to_label_dict[nid["table"]] = node["_label"] 342 for rel in row[i]["_rels"]: 343 for key in list(rel.keys()): 344 if rel[key] is None: 345 del rel[key] 346 rels[encode_rel_id(rel)] = rel 347 348 # Add nodes 349 for node in nodes.values(): 350 nid = node["_id"] 351 node_id = node["_label"] + "_" + str(nid["offset"]) 352 if node["_label"] not in table_primary_key_dict: 353 props = self.connection._get_node_property_names(node["_label"]) 354 for prop_name in props: 355 if props[prop_name]["is_primary_key"]: 356 table_primary_key_dict[node["_label"]] = prop_name 357 break 358 node_id = encode_node_id(node, table_primary_key_dict) 359 node[node["_label"]] = True 360 nx_graph.add_node(node_id, **node) 361 362 # Add rels 363 for rel in rels.values(): 364 src = rel["_src"] 365 dst = rel["_dst"] 366 src_node = nodes[src["table"], src["offset"]] 367 dst_node = nodes[dst["table"], dst["offset"]] 368 src_id = encode_node_id(src_node, table_primary_key_dict) 369 dst_id = encode_node_id(dst_node, table_primary_key_dict) 370 nx_graph.add_edge(src_id, dst_id, **rel) 371 return nx_graph 372 373 def _get_properties_to_extract(self) -> dict[int, tuple[str, str]]: 374 column_names = self.get_column_names() 375 column_types = self.get_column_data_types() 376 properties_to_extract = {} 377 378 # Iterate over columns and extract nodes and rels, ignoring other columns 379 for i in range(len(column_names)): 380 column_name = column_names[i] 381 column_type = column_types[i] 382 if column_type in [ 383 Type.NODE.value, 384 Type.REL.value, 385 Type.RECURSIVE_REL.value, 386 ]: 387 properties_to_extract[i] = (column_type, column_name) 388 return properties_to_extract 389 390 def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]: # type: ignore[type-arg] 391 """ 392 Convert the nodes and rels in query result into a PyTorch Geometric graph representation 393 torch_geometric.data.Data or torch_geometric.data.HeteroData. 394 395 For node conversion, numerical and boolean properties are directly converted into tensor and 396 stored in Data/HeteroData. For properties cannot be converted into tensor automatically 397 (please refer to the notes below for more detail), they are returned as unconverted_properties. 398 399 For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned 400 as edge_properties. 401 402 Node properties that cannot be converted into tensor automatically: 403 - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted 404 automatically. 405 - If a node property contains a null value, it cannot be converted automatically. 406 - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be 407 converted automatically. 408 - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length 409 is 6 for one node but 5 for another node), it cannot be converted automatically. 410 411 Additional conversion rules: 412 - Columns with data type other than node or rel will be ignored. 413 - Duplicated nodes and rels will be converted only once. 414 415 Returns 416 ------- 417 torch_geometric.data.Data or torch_geometric.data.HeteroData 418 Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties 419 and edge_index tensor. 420 421 dict 422 A dictionary that maps the positional offset of each node in Data/HeteroData to its primary 423 key in the database. 424 425 dict 426 A dictionary contains node properties that cannot be converted into tensor automatically. The 427 order of values for each property is aligned with nodes in Data/HeteroData. 428 429 dict 430 A dictionary contains edge properties. The order of values for each property is aligned with 431 edge_index in Data/HeteroData. 432 """ 433 self.check_for_query_result_close() 434 # Despite we are not using torch_geometric in this file, we need to 435 # import it here to throw an error early if the user does not have 436 # torch_geometric or torch installed. 437 438 converter = TorchGeometricResultConverter(self) 439 return converter.get_as_torch_geometric() 440 441 def get_execution_time(self) -> int: 442 """ 443 Get the time in ms which was required for executing the query. 444 445 Returns 446 ------- 447 double 448 Query execution time as double in ms. 449 450 """ 451 self.check_for_query_result_close() 452 return self._query_result.getExecutionTime() 453 454 def get_compiling_time(self) -> int: 455 """ 456 Get the time in ms which was required for compiling the query. 457 458 Returns 459 ------- 460 double 461 Query compile time as double in ms. 462 463 """ 464 self.check_for_query_result_close() 465 return self._query_result.getCompilingTime() 466 467 def get_num_tuples(self) -> int: 468 """ 469 Get the number of tuples which the query returned. 470 471 Returns 472 ------- 473 int 474 Number of tuples. 475 476 """ 477 self.check_for_query_result_close() 478 return self._query_result.getNumTuples() 479 480 def rows_as_dict(self, state=True) -> Self: 481 """ 482 Change the format of the results, such that each row is a dict with the 483 column name as a key. 484 485 Parameters 486 ---------- 487 state 488 Whether to turn dict formatting on or off. Turns it on by default. 489 490 Returns 491 ------- 492 self 493 The object itself. 494 495 """ 496 self.as_dict = state 497 if state: 498 self.columns = self.get_column_names() 499 return self
QueryResult stores the result of a query execution.
32 def __init__(self, connection: _kuzu.Connection, query_result: _kuzu.QueryResult): # type: ignore[name-defined] 33 """ 34 Parameters 35 ---------- 36 connection : _kuzu.Connection 37 The underlying C++ connection object from pybind11. 38 39 query_result : _kuzu.QueryResult 40 The underlying C++ query result object from pybind11. 41 42 """ 43 self.connection = connection 44 self._query_result = query_result 45 self.is_closed = False 46 self.as_dict = False
Parameters
- connection (_kuzu.Connection): The underlying C++ connection object from pybind11.
- query_result (_kuzu.QueryResult): The underlying C++ query result object from pybind11.
71 def has_next(self) -> bool: 72 """ 73 Check if there are more rows in the query result. 74 75 Returns 76 ------- 77 bool 78 True if there are more rows in the query result, False otherwise. 79 """ 80 self.check_for_query_result_close() 81 return self._query_result.hasNext()
Check if there are more rows in the query result.
Returns
- bool: True if there are more rows in the query result, False otherwise.
83 def get_next(self) -> list[Any] | dict[str, Any]: 84 """ 85 Get the next row in the query result. 86 87 Returns 88 ------- 89 list 90 Next row in the query result. 91 92 Raises 93 ------ 94 Exception 95 If there are no more rows. 96 """ 97 self.check_for_query_result_close() 98 row = self._query_result.getNext() 99 return _row_to_dict(self.columns, row) if self.as_dict else row
Get the next row in the query result.
Returns
- list: Next row in the query result.
Raises
- Exception: If there are no more rows.
101 def get_all(self) -> list[list[Any] | dict[str, Any]]: 102 """ 103 Get the next row in the query result. 104 105 Returns 106 ------- 107 list 108 All remaining rows in the query result. 109 """ 110 return list(self)
Get the next row in the query result.
Returns
- list: All remaining rows in the query result.
112 def get_n(self, count: int) -> list[list[Any] | dict[str, Any]]: 113 """ 114 Get many rows in the query result. 115 116 Returns 117 ------- 118 list 119 Up to `count` rows in the query result. 120 """ 121 results = [] 122 while self.has_next() and count > 0: 123 results.append(self.get_next()) 124 count -= 1 125 return results
Get many rows in the query result.
Returns
- list: Up to
count
rows in the query result.
127 def close(self) -> None: 128 """Close the query result.""" 129 if not self.is_closed: 130 # Allows the connection to be garbage collected if the query result 131 # is closed manually by the user. 132 self._query_result.close() 133 self.connection = None 134 self.is_closed = True
Close the query result.
136 def check_for_query_result_close(self) -> None: 137 """ 138 Check if the query result is closed and raise an exception if it is. 139 140 Raises 141 ------ 142 Exception 143 If the query result is closed. 144 145 """ 146 if self.is_closed: 147 msg = "Query result is closed" 148 raise RuntimeError(msg)
Check if the query result is closed and raise an exception if it is.
Raises
- Exception: If the query result is closed.
150 def get_as_df(self) -> pd.DataFrame: 151 """ 152 Get the query result as a Pandas DataFrame. 153 154 See Also 155 -------- 156 get_as_pl : Get the query result as a Polars DataFrame. 157 get_as_arrow : Get the query result as a PyArrow Table. 158 159 Returns 160 ------- 161 pandas.DataFrame 162 Query result as a Pandas DataFrame. 163 164 """ 165 self.check_for_query_result_close() 166 167 return self._query_result.getAsDF()
Get the query result as a Pandas DataFrame.
See Also
get_as_pl
: Get the query result as a Polars DataFrame.
get_as_arrow
: Get the query result as a PyArrow Table.
Returns
- pandas.DataFrame: Query result as a Pandas DataFrame.
169 def get_as_pl(self) -> pl.DataFrame: 170 """ 171 Get the query result as a Polars DataFrame. 172 173 See Also 174 -------- 175 get_as_df : Get the query result as a Pandas DataFrame. 176 get_as_arrow : Get the query result as a PyArrow Table. 177 178 Returns 179 ------- 180 polars.DataFrame 181 Query result as a Polars DataFrame. 182 """ 183 import polars as pl 184 185 self.check_for_query_result_close() 186 187 # note: polars should always export just a single chunk, 188 # (eg: "-1") otherwise it will just need to rechunk anyway 189 return pl.from_arrow( # type: ignore[return-value] 190 data=self.get_as_arrow(chunk_size=-1), 191 )
Get the query result as a Polars DataFrame.
See Also
get_as_df
: Get the query result as a Pandas DataFrame.
get_as_arrow
: Get the query result as a PyArrow Table.
Returns
- polars.DataFrame: Query result as a Polars DataFrame.
193 def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table: 194 """ 195 Get the query result as a PyArrow Table. 196 197 Parameters 198 ---------- 199 chunk_size : Number of rows to include in each chunk. 200 None 201 The chunk size is adaptive and depends on the number of columns in the query result. 202 -1 or 0 203 The entire result is returned as a single chunk. 204 > 0 205 The chunk size is the number of rows specified. 206 207 See Also 208 -------- 209 get_as_pl : Get the query result as a Polars DataFrame. 210 get_as_df : Get the query result as a Pandas DataFrame. 211 212 Returns 213 ------- 214 pyarrow.Table 215 Query result as a PyArrow Table. 216 """ 217 self.check_for_query_result_close() 218 219 if chunk_size is None: 220 # Adaptive; target 10m total elements in each chunk. 221 # (eg: if we had 10 cols, this would result in a 1m row chunk_size). 222 target_n_elems = 10_000_000 223 chunk_size = max(target_n_elems // len(self.get_column_names()), 10) 224 elif chunk_size <= 0: 225 # No chunking: return the entire result as a single chunk 226 chunk_size = self.get_num_tuples() 227 228 return self._query_result.getAsArrow(chunk_size)
Get the query result as a PyArrow Table.
Parameters
- chunk_size (Number of rows to include in each chunk.): None The chunk size is adaptive and depends on the number of columns in the query result. -1 or 0 The entire result is returned as a single chunk. > 0 The chunk size is the number of rows specified.
See Also
get_as_pl
: Get the query result as a Polars DataFrame.
get_as_df
: Get the query result as a Pandas DataFrame.
Returns
- pyarrow.Table: Query result as a PyArrow Table.
230 def get_column_data_types(self) -> list[str]: 231 """ 232 Get the data types of the columns in the query result. 233 234 Returns 235 ------- 236 list 237 Data types of the columns in the query result. 238 239 """ 240 self.check_for_query_result_close() 241 return self._query_result.getColumnDataTypes()
Get the data types of the columns in the query result.
Returns
- list: Data types of the columns in the query result.
243 def get_column_names(self) -> list[str]: 244 """ 245 Get the names of the columns in the query result. 246 247 Returns 248 ------- 249 list 250 Names of the columns in the query result. 251 252 """ 253 self.check_for_query_result_close() 254 return self._query_result.getColumnNames()
Get the names of the columns in the query result.
Returns
- list: Names of the columns in the query result.
256 def get_schema(self) -> dict[str, str]: 257 """ 258 Get the column schema of the query result. 259 260 Returns 261 ------- 262 dict 263 Schema of the query result. 264 265 """ 266 self.check_for_query_result_close() 267 return dict( 268 zip( 269 self._query_result.getColumnNames(), 270 self._query_result.getColumnDataTypes(), 271 ) 272 )
Get the column schema of the query result.
Returns
- dict: Schema of the query result.
274 def reset_iterator(self) -> None: 275 """Reset the iterator of the query result.""" 276 self.check_for_query_result_close() 277 self._query_result.resetIterator()
Reset the iterator of the query result.
279 def get_as_networkx( 280 self, 281 directed: bool = True, # noqa: FBT001 282 ) -> nx.MultiGraph | nx.MultiDiGraph: 283 """ 284 Convert the nodes and rels in query result into a NetworkX directed or undirected graph 285 with the following rules: 286 Columns with data type other than node or rel will be ignored. 287 Duplicated nodes and rels will be converted only once. 288 289 Parameters 290 ---------- 291 directed : bool 292 Whether the graph should be directed. Defaults to True. 293 294 Returns 295 ------- 296 networkx.MultiDiGraph or networkx.MultiGraph 297 Query result as a NetworkX graph. 298 299 """ 300 self.check_for_query_result_close() 301 import networkx as nx 302 303 nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph() 304 properties_to_extract = self._get_properties_to_extract() 305 306 self.reset_iterator() 307 308 nodes = {} 309 rels = {} 310 table_to_label_dict = {} 311 table_primary_key_dict = {} 312 313 def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str: 314 node_label = node["_label"] 315 return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}" 316 317 def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]: 318 return rel["_id"]["table"], rel["_id"]["offset"] 319 320 # De-duplicate nodes and rels 321 while self.has_next(): 322 row = self.get_next() 323 for i in properties_to_extract: 324 # Skip empty nodes and rels, which may be returned by 325 # OPTIONAL MATCH 326 if row[i] is None or row[i] == {}: 327 continue 328 column_type, _ = properties_to_extract[i] 329 if column_type == Type.NODE.value: 330 nid = row[i]["_id"] 331 nodes[nid["table"], nid["offset"]] = row[i] 332 table_to_label_dict[nid["table"]] = row[i]["_label"] 333 334 elif column_type == Type.REL.value: 335 rels[encode_rel_id(row[i])] = row[i] 336 337 elif column_type == Type.RECURSIVE_REL.value: 338 for node in row[i]["_nodes"]: 339 nid = node["_id"] 340 nodes[nid["table"], nid["offset"]] = node 341 table_to_label_dict[nid["table"]] = node["_label"] 342 for rel in row[i]["_rels"]: 343 for key in list(rel.keys()): 344 if rel[key] is None: 345 del rel[key] 346 rels[encode_rel_id(rel)] = rel 347 348 # Add nodes 349 for node in nodes.values(): 350 nid = node["_id"] 351 node_id = node["_label"] + "_" + str(nid["offset"]) 352 if node["_label"] not in table_primary_key_dict: 353 props = self.connection._get_node_property_names(node["_label"]) 354 for prop_name in props: 355 if props[prop_name]["is_primary_key"]: 356 table_primary_key_dict[node["_label"]] = prop_name 357 break 358 node_id = encode_node_id(node, table_primary_key_dict) 359 node[node["_label"]] = True 360 nx_graph.add_node(node_id, **node) 361 362 # Add rels 363 for rel in rels.values(): 364 src = rel["_src"] 365 dst = rel["_dst"] 366 src_node = nodes[src["table"], src["offset"]] 367 dst_node = nodes[dst["table"], dst["offset"]] 368 src_id = encode_node_id(src_node, table_primary_key_dict) 369 dst_id = encode_node_id(dst_node, table_primary_key_dict) 370 nx_graph.add_edge(src_id, dst_id, **rel) 371 return nx_graph
Convert the nodes and rels in query result into a NetworkX directed or undirected graph with the following rules: Columns with data type other than node or rel will be ignored. Duplicated nodes and rels will be converted only once.
Parameters
- directed (bool): Whether the graph should be directed. Defaults to True.
Returns
- networkx.MultiDiGraph or networkx.MultiGraph: Query result as a NetworkX graph.
390 def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]: # type: ignore[type-arg] 391 """ 392 Convert the nodes and rels in query result into a PyTorch Geometric graph representation 393 torch_geometric.data.Data or torch_geometric.data.HeteroData. 394 395 For node conversion, numerical and boolean properties are directly converted into tensor and 396 stored in Data/HeteroData. For properties cannot be converted into tensor automatically 397 (please refer to the notes below for more detail), they are returned as unconverted_properties. 398 399 For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned 400 as edge_properties. 401 402 Node properties that cannot be converted into tensor automatically: 403 - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted 404 automatically. 405 - If a node property contains a null value, it cannot be converted automatically. 406 - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be 407 converted automatically. 408 - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length 409 is 6 for one node but 5 for another node), it cannot be converted automatically. 410 411 Additional conversion rules: 412 - Columns with data type other than node or rel will be ignored. 413 - Duplicated nodes and rels will be converted only once. 414 415 Returns 416 ------- 417 torch_geometric.data.Data or torch_geometric.data.HeteroData 418 Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties 419 and edge_index tensor. 420 421 dict 422 A dictionary that maps the positional offset of each node in Data/HeteroData to its primary 423 key in the database. 424 425 dict 426 A dictionary contains node properties that cannot be converted into tensor automatically. The 427 order of values for each property is aligned with nodes in Data/HeteroData. 428 429 dict 430 A dictionary contains edge properties. The order of values for each property is aligned with 431 edge_index in Data/HeteroData. 432 """ 433 self.check_for_query_result_close() 434 # Despite we are not using torch_geometric in this file, we need to 435 # import it here to throw an error early if the user does not have 436 # torch_geometric or torch installed. 437 438 converter = TorchGeometricResultConverter(self) 439 return converter.get_as_torch_geometric()
Convert the nodes and rels in query result into a PyTorch Geometric graph representation torch_geometric.data.Data or torch_geometric.data.HeteroData.
For node conversion, numerical and boolean properties are directly converted into tensor and stored in Data/HeteroData. For properties cannot be converted into tensor automatically (please refer to the notes below for more detail), they are returned as unconverted_properties.
For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned as edge_properties.
Node properties that cannot be converted into tensor automatically:
- If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted automatically.
- If a node property contains a null value, it cannot be converted automatically.
- If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be converted automatically.
- If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length is 6 for one node but 5 for another node), it cannot be converted automatically.
Additional conversion rules:
- Columns with data type other than node or rel will be ignored.
- Duplicated nodes and rels will be converted only once.
Returns
- torch_geometric.data.Data or torch_geometric.data.HeteroData: Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties and edge_index tensor.
- dict: A dictionary that maps the positional offset of each node in Data/HeteroData to its primary key in the database.
- dict: A dictionary contains node properties that cannot be converted into tensor automatically. The order of values for each property is aligned with nodes in Data/HeteroData.
- dict: A dictionary contains edge properties. The order of values for each property is aligned with edge_index in Data/HeteroData.
441 def get_execution_time(self) -> int: 442 """ 443 Get the time in ms which was required for executing the query. 444 445 Returns 446 ------- 447 double 448 Query execution time as double in ms. 449 450 """ 451 self.check_for_query_result_close() 452 return self._query_result.getExecutionTime()
Get the time in ms which was required for executing the query.
Returns
- double: Query execution time as double in ms.
454 def get_compiling_time(self) -> int: 455 """ 456 Get the time in ms which was required for compiling the query. 457 458 Returns 459 ------- 460 double 461 Query compile time as double in ms. 462 463 """ 464 self.check_for_query_result_close() 465 return self._query_result.getCompilingTime()
Get the time in ms which was required for compiling the query.
Returns
- double: Query compile time as double in ms.
467 def get_num_tuples(self) -> int: 468 """ 469 Get the number of tuples which the query returned. 470 471 Returns 472 ------- 473 int 474 Number of tuples. 475 476 """ 477 self.check_for_query_result_close() 478 return self._query_result.getNumTuples()
Get the number of tuples which the query returned.
Returns
- int: Number of tuples.
480 def rows_as_dict(self, state=True) -> Self: 481 """ 482 Change the format of the results, such that each row is a dict with the 483 column name as a key. 484 485 Parameters 486 ---------- 487 state 488 Whether to turn dict formatting on or off. Turns it on by default. 489 490 Returns 491 ------- 492 self 493 The object itself. 494 495 """ 496 self.as_dict = state 497 if state: 498 self.columns = self.get_column_names() 499 return self
Change the format of the results, such that each row is a dict with the column name as a key.
Parameters
- state: Whether to turn dict formatting on or off. Turns it on by default.
Returns
- self: The object itself.
5class Type(Enum): 6 """The type of a value in the database.""" 7 8 ANY = "ANY" 9 NODE = "NODE" 10 REL = "REL" 11 RECURSIVE_REL = "RECURSIVE_REL" 12 SERIAL = "SERIAL" 13 BOOL = "BOOL" 14 INT64 = "INT64" 15 INT32 = "INT32" 16 INT16 = "INT16" 17 INT8 = "INT8" 18 UINT64 = "UINT64" 19 UINT32 = "UINT32" 20 UINT16 = "UINT16" 21 UINT8 = "UINT8" 22 INT128 = "INT128" 23 DOUBLE = "DOUBLE" 24 FLOAT = "FLOAT" 25 DATE = "DATE" 26 TIMESTAMP = "TIMESTAMP" 27 TIMSTAMP_TZ = "TIMESTAMP_TZ" 28 TIMESTAMP_NS = "TIMESTAMP_NS" 29 TIMESTAMP_MS = "TIMESTAMP_MS" 30 TIMESTAMP_SEC = "TIMESTAMP_SEC" 31 INTERVAL = "INTERVAL" 32 INTERNAL_ID = "INTERNAL_ID" 33 STRING = "STRING" 34 BLOB = "BLOB" 35 UUID = "UUID" 36 LIST = "LIST" 37 ARRAY = "ARRAY" 38 STRUCT = "STRUCT" 39 MAP = "MAP" 40 UNION = "UNION"
The type of a value in the database.