diff --git a/.gitignore b/.gitignore index c6f9a44..dbd9ea8 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ .vscode/settings.json +dist/*.whl +__pycache__/ diff --git a/Makefile b/Makefile index b800712..6bcb512 100644 --- a/Makefile +++ b/Makefile @@ -2,21 +2,22 @@ VERSION := $(shell git describe --tags 2> /dev/null || echo unknown) PYTHON=python3 +PIP=pip3 default: build install build: - pip install -r requirements.txt + $(PIP) install -r requirements.txt $(PYTHON) -m build --wheel install: - pip install ./dist/*.whl --force-reinstall + $(PIP) install ./dist/*.whl --force-reinstall test: $(PYTHON) -m unittest discover release: - pip install -r requirements.txt + $(PIP) install -r requirements.txt $(PYTHON) -m build --wheel -o ./build/ docker: diff --git a/README.md b/README.md index 2a624d0..18e6617 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,15 @@ Python client for Featurebase SQL endpoint. +For more complete documentation, see: + +https://docs.featurebase.com/docs/tools/python-client-library/python-client-library-home/ + # Client Library Usage: -First install the python-featuebase package. Running `make` from project folder +First install the python-featurebase package. Running `make` from project folder will build and install the package. After installing the package you can try -executing queries as shown in the following examples. +executing queries as shown in the following examples: import featurebase @@ -16,30 +20,31 @@ executing queries as shown in the following examples. client = featurebase.client() # query the endpoint with SQL - result = client.query("SELECT * from demo;") - if result.ok: + try: + result = client.query("SELECT * from demo;") print(result.data) + except Exception as e: + # SQL errors and connection errors both come back as exceptions + print(e) # query the endpoint with a batch of SQLs, running the SQLs synchronously # Synchronous run best suited for executing DDL and DMLs that need to follow specific run order # passing the optional parameter "stoponerror=True" will stop execution at the failed SQL and the remaining SQLs in the list will not be executed. - sqllist=[] - sqllist.append("CREATE TABLE demo1(_id id, i1 int);") - sqllist.append("INSERT INTO demo1(_id, i1) VALUES(1, 100);") - sqllist.append("INSERT INTO demo1(_id, i1) VALUES(2, 200);") - sqllist.append("select * from demo1;") - results = client.querybatch(sqllist, stoponerror=True) - for result in results: - if result.ok: - print(result.data) - + sqllist=[] + sqllist.append("CREATE TABLE demo1(_id id, i1 int);") + sqllist.append("INSERT INTO demo1(_id, i1) VALUES(1, 100);") + sqllist.append("INSERT INTO demo1(_id, i1) VALUES(2, 200);") + sqllist.append("select * from demo1;") + results = client.querybatch(sqllist, stoponerror=True) + for result in results: + print(result.data) + # query the endpoint with a batch of SQLs, running the SQLs Asynchronously # Asynchronous run best suited for running SELECT queries that can be run concurrently. - sqllist=[] - sqllist.append("SELECT * from demo1;") - sqllist.append("SELECT count(*) from demo1;") - sqllist.append("SELECT max(i1) from demo1;") - results = client.querybatch(sqllist, asynchronous=True) - for result in results: - if result.ok: - print(result.data) + sqllist=[] + sqllist.append("SELECT * from demo1;") + sqllist.append("SELECT count(*) from demo1;") + sqllist.append("SELECT max(i1) from demo1;") + results = client.querybatch(sqllist, asynchronous=True) + for result in results: + print(result.data) diff --git a/dist/featurebase-0.0.1-py3-none-any.whl b/dist/featurebase-0.0.1-py3-none-any.whl deleted file mode 100644 index d5a73e1..0000000 Binary files a/dist/featurebase-0.0.1-py3-none-any.whl and /dev/null differ diff --git a/example/bulkInsert.py b/example/bulkInsert.py index 0a337e2..32fe08a 100644 --- a/example/bulkInsert.py +++ b/example/bulkInsert.py @@ -4,55 +4,67 @@ import time # intialize featurebase client for community or cloud featurebase server -client = featurebase.client(hostport="localhost:10101") #community -#client = client(hostport="query.featurebase.com/v2", database="", apikey="") #cloud +# local server running community +client = featurebase.client(hostport="localhost:10101") +# cloud server, using database and API key. +# client = featurebase.client(hostport="query.featurebase.com/v2", database="", apikey="") #cloud # generate random data def get_random_string(length: int): letters = string.ascii_lowercase - result_str = ''.join(random.choice(letters) for i in range(length)) + result_str = "".join(random.choice(letters) for i in range(length)) return result_str -# build a BULK INSERT sql and execute it using featurebase client -def upload_data_bulk(key_from: int, key_to: int): + +# build a BULK INSERT sql and execute it using featurebase client +def upload_data_bulk(key_from: int, count: int): # build bulk insert sql - insertClause="BULK INSERT INTO demo_upload(_id, keycol, val1, val2) MAP (0 ID, 1 INT, 2 STRING, 3 STRING) FROM x" - withClause=" WITH INPUT 'INLINE' FORMAT 'CSV' BATCHSIZE " + str((key_to-key_from)+1) - records="" - for i in range(key_from, key_to): + insert_clause = "BULK INSERT INTO demo_upload(_id, keycol, val1, val2) MAP (0 ID, 1 INT, 2 STRING, 3 STRING) FROM x" + with_clause = " WITH INPUT 'INLINE' FORMAT 'CSV' BATCHSIZE " + str((count) + 1) + records = "" + for i in range(key_from, key_from + count): val1 = get_random_string(3) val2 = get_random_string(12) - if records!="": - records+='\n' - records+='%i, %i, "%s", "%s"'%(i, i, val1, val2) - bulkInsertSql=insertClause + "'" + records + "'" + withClause - stime=time.time() - result=client.query(sql=bulkInsertSql) - etime=time.time() - if result.ok: - print("inserted " + str(result.rows_affected) + " rows in " + str(etime-stime) + " seconds") - else: - print(result.error.description) - return result.ok + if records != "": + records += "\n" + records += '%i, %i, "%s", "%s"' % (i, i, val1, val2) + bulk_insert_sql = insert_clause + "'" + records + "'" + with_clause + stime = time.time() + result = client.query(sql=bulk_insert_sql) + etime = time.time() + try: + print( + "inserted " + + str(result.rows_affected) + + " rows in " + + str(etime - stime) + + " seconds" + ) + except Exception as e: + print(e) + return False + return True + # create a demo table and load million rows -def run(batchSize: int): - # create demo table - result=client.query(sql="CREATE TABLE demo_upload(_id ID, keycol INT, val1 STRING, val2 STRING)") - if not result.ok: - print(result.error) - # insert batchSize rows per insert for 1000 times - n=int(1000000/batchSize) - l=1 - h=batchSize - for i in range(1, n): - if not upload_data_bulk(l, h): +def run(batch_size: int): + # create demo table + try: + client.query(sql="DROP TABLE IF EXISTS demo_upload") + client.query( + sql="CREATE TABLE demo_upload(_id ID, keycol INT, val1 STRING, val2 STRING)" + ) + except Exception as e: + print(e) + # insert batch_size rows per insert + # (will not upload the full million if batch_size does not evenly divide 1M) + n = int(1000000 / batch_size) + l = 1 + for i in range(n): + if not upload_data_bulk(l, batch_size): break - l=h+1 - h+=batchSize - if h>1000000: - h=1000000 + l += batch_size -run(10000) \ No newline at end of file +run(10000) diff --git a/pyproject.toml b/pyproject.toml index f579a39..f11493d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "featurebase" -version = "0.0.1" +version = "0.1.0" authors = [ { name="Featurebase Developers", email="dev@featurebase.com", url="https://www.featurebase.com/" }, ] @@ -20,4 +20,4 @@ classifiers = [ [project.urls] "Homepage" = "https://github.com/featurebasedb/python-featurebase" -"Bug Tracker" = "https://github.com/featurebasedb/python-featurebase/issues" \ No newline at end of file +"Bug Tracker" = "https://github.com/featurebasedb/python-featurebase/issues" diff --git a/src/featurebase/client.py b/src/featurebase/client.py index 9acb088..90d6f84 100644 --- a/src/featurebase/client.py +++ b/src/featurebase/client.py @@ -1,95 +1,121 @@ import json import concurrent.futures +import ssl import urllib.request import urllib.error + # client represents a http connection to the FeatureBase sql endpoint. +# the hostport parameter must be present when using an api key. the +# database parameter is optional, but if set must be a valid string. +# assumes http by default, but switches to https if certificate config +# is provided or an API key is present. class client: """Client represents a http connection to the FeatureBase sql endpoint. - Keyword arguments: + Keyword arguments: hostport -- hostname and port number of your featurebase instance, it should be passed in 'host:port' format (default localhost:10101) - database -- database id of your featurebase cloud instance(default None) - apikey -- api key (default None) + database -- database id of your featurebase cloud instance (default None) + apikey -- api key (default None) -- applicable only when specifying a host/port cafile -- Fully qualified certificate file path (default None) capath -- Fully qualified certificate folder (default None) - origin -- request origin, should be one of the allowed origins defined for your featurebase instance (default None) - timeout -- seconds to wait before timing out on server connection attempts""" + origin -- request origin, should be one of the allowed origins defined for your featurebase instance (default None) + timeout -- seconds to wait before timing out on server connection attempts + + When specifying API key, you should specify a host and port, and the + client will expect HTTPS.""" + # client constructor initializes the client with key attributes needed to # make connection to the sql endpoint - def __init__(self, hostport='localhost:10101', database=None, apikey=None, cafile=None, capath=None, origin=None, timeout=None): - self.hostport=hostport - self.database=database - self.apikey=apikey - self.cafile=cafile - self.capath=capath - self.timeout=timeout - self.origin=origin + def __init__( + self, + hostport=None, + database=None, + apikey=None, + cafile=None, + capath=None, + origin=None, + timeout=None, + ): + self.hostport = hostport + self.database = database + self.apikey = apikey + self.timeout = timeout + self.origin = origin + if hostport is None: + if apikey is not None: + raise ValueError("when specifying API key, hostport is not optional") + self.hostport = "localhost:10101" + else: + self.hostport = hostport + if apikey is not None and apikey == "": + raise ValueError("API key, if set, must not be empty string") + if database is not None and database == "": + raise ValueError("database ID, if set, must not be empty string") + scheme = "http" + if cafile or capath or apikey: + scheme = "https" + # force https + self.sslContext = ssl.create_default_context(cafile=cafile, capath=capath) + else: + self.sslContext = None + path = "/sql" + if self.database: + path = "/databases/{}/query/sql".format(self.database) + self.url = "{}://{}{}".format(scheme, self.hostport, path) # private helper to create a new request/session object intialized with tls # attributes if any provided adds header entries as expected by the sql # endpoint def _newrequest(self): - request=urllib.request.Request(self._geturl(),method='POST') - if self.origin!=None: - request.origin_req_host=self.origin + request = urllib.request.Request(self.url, method="POST") + if self.origin != None: + request.origin_req_host = self.origin return self._addheaders(request) # private helper adds header entries to a request def _addheaders(self, request): - request.add_header("Content-Type","text/plain") - request.add_header("Accept","application/json") - if self.apikey!=None: - request.add_header("X-API-Key",self.apikey) + request.add_header("Content-Type", "text/plain") + request.add_header("Accept", "application/json") + if self.apikey != None: + request.add_header("X-API-Key", self.apikey) return request - # private helper to build url for the request it determines http or https - # default url points to sql endpoint, database is added to the path if - # provided optionally it can point to other paths. - def _geturl(self, path=None): - scheme='http' - if self.cafile!=None or self.capath!=None or self.apikey!=None: - scheme='https' - if path==None: - if self.database != None: - path="/databases/" + self.database+ "/query/sql" - else: - path="/sql" - return scheme + "://" + self.hostport + path - - # helper method executes the http post request and returns a callable future + # helper method executes the http post request and returns a callable future def _post(self, sql): - data = bytes(sql, 'utf-8') - # use context manager to ensure connection is promptly closed and released - with urllib.request.urlopen(self._newrequest(), data=data, timeout=self.timeout, cafile=self.cafile, capath=self.capath) as conn: - response=conn.read() + data = bytes(sql, "utf-8") + # use context manager to ensure connection is promptly closed and released + with urllib.request.urlopen( + self._newrequest(), + data=data, + timeout=self.timeout, + context=self.sslContext, + ) as conn: + response = conn.read() return result(sql=sql, response=response, code=conn.code) - # helper method executes the http post request and returns a callable future and handles exception - def _postforasync(self, sql): - try: - response=self._post(sql) - except Exception as exec: - exec.add_note(sql) - return result(sql=sql, response=response, code=500, exec=exec) - return response - # helper method accepts a list of sql queries and executes them # asynchronously and returns the results as a list def _batchasync(self, sqllist): - results=[] + results = [] + exceptions = [] # use context manger to ensure threads are cleaned up promptly with concurrent.futures.ThreadPoolExecutor() as executor: # Start the query execution and mark each future with its sql - future_to_sql = {executor.submit(self._postforasync, sql): sql for sql in sqllist} + future_to_sql = {executor.submit(self._post, sql): sql for sql in sqllist} for future in concurrent.futures.as_completed(future_to_sql, self.timeout): - results.append(future.result()) + try: + results.append(future.result()) + except Exception as e: + exceptions.append(e) + if exceptions: + raise ExceptionGroup("batch exception(s):", exceptions) return results # public method accepts a sql query creates a new request object pointing to # sql endpoint attaches the sql query as payload and posts the request # returns a simple result object providing access to data, status and - # warnings. + # warnings. if the server returns an error, it will be raised as an exception. def query(self, sql): """Executes a SQL query and returns a result object. @@ -99,78 +125,52 @@ def query(self, sql): # public method accepts a list of sql queries and executes them # synchronously or asynchronously and returns the results as a list - def querybatch(self, sqllist, asynchronous=False, stoponerror=False): + # asynchronously, it runs all queries. if one or more queries hits + # an exception, it raises an ExceptionGroup of the exceptions, otherwise + # it returns a list of results. + def querybatch(self, sqllist, asynchronous=False): """Executes a list of SQLs and returns a list of result objects. - Keyword arguments: + Keyword arguments: sqllist -- the list of SQL queries to be executed - asynchronous -- a flag to indicate the SQLs should be run concurrently (default False) - stoponerror -- a flag to indicate what to do when a SQL error happens. Passing True will stop executing remaining SQLs in the input list after the errored SQL item. This parameter is ignored when asynchronous=True (default False)""" - results =[] + asynchronous -- a flag to indicate the SQLs should be run concurrently (default False)""" + results = [] if asynchronous: - results=self._batchasync(sqllist) - excs = [] - for result in results: - if not result.ok: - excs.append(result.exec) - if len(excs)>0: - raise ExceptionGroup('Batch exception(s):', excs) + results = self._batchasync(sqllist) else: for sql in sqllist: - result=self._post(sql) - results.append(result) - # during synchronous execution if a query fails and stoponerror is - # true then stop executing remaining queries - if not result.ok and stoponerror: - break + results.append(self._post(sql)) return results # simple data object representing query result returned by the sql endpoint for # successful requests, data returned by the service will be populated in the -# data, schema attributes along with any warnings, for failed requests error and -# exception info will be populated in the respective attributes +# data, schema attributes along with any warnings. only successful requests +# generate results, server and communication errors are raised as exceptions. class result: """Result is a simple data object representing results of a SQL query. - Keyword arguments: - ok -- boolean indicating query execution status + Keyword arguments: + sql -- the SQL which was executed schema -- field definitions for the result data data -- data rows returned by the server - error -- SQL error information warnings -- warning information returned by the server execution_time -- amount of time (microseconds) it took for the server to execute the SQL rows_affected -- number of rows affected by the SQL statement - exec -- exception captured during asynchronous execution raw_response -- original request response """ - def __init__(self, sql, response, code, exec=None): - self.ok=False - self.schema=None - self.data=None - self.error=None - self.warnings=None - self.execution_time=0 - self.sql=sql - self.ok=code==200 - self.rows_affected=0 - self.exec=None - self.raw_response=response - if self.ok: - try: - result=json.loads(response) - if 'error' in result.keys(): - self.ok=False - self.error=result['error'] - else: - self.schema=result.get('schema') - self.data=result.get('data') - self.warnings=result.get('warnings') - self.execution_time=result.get('execution-time') - self.rows_affected=result.get('rows-affected') - except json.JSONDecodeError as exec: - self.ok=False - self.error=str(exec) - self.exec=exec - else: - self.exec=exec + + def __init__(self, sql, response, code): + self.sql = sql + if code != 200: + # HTTP error of some kind. + raise RuntimeError("HTTP response code %d" % code) + self.raw_response = response + result = json.loads(response) + if "error" in result: + raise RuntimeError(result["error"]) + self.schema = result.get("schema") + self.data = result.get("data") + self.warnings = result.get("warnings", None) + self.execution_time = result.get("execution-time", 0) + self.rows_affected = result.get("rows-affected", 0) diff --git a/tests/test_client.py b/tests/test_client.py index 3075bb7..f005729 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,254 +1,229 @@ +import json import os import unittest import calendar import time from featurebase import client, result -class FeaturebaseClientTestCase(unittest.TestCase): +client_hostport = os.getenv("FEATUREBASE_HOSTPORT", "localhost:10101") + +class FeaturebaseClientTestCase(unittest.TestCase): # test client for default attributes - def testDefaultClient(self): - default_client = client() - self.assertEqual(default_client.hostport, 'localhost:10101') + def testDefaultClient(self): + default_client = client() + self.assertEqual(default_client.hostport, "localhost:10101") self.assertEqual(default_client.database, None) self.assertEqual(default_client.apikey, None) self.assertEqual(default_client.origin, None) - self.assertEqual(default_client.capath, None) - self.assertEqual(default_client.cafile, None) + self.assertEqual(default_client.url, "http://localhost:10101/sql") # test URL generation schemes def testURL(self): # default URL - test_client = client(hostport='featurebase.com:2020') - self.assertEqual(test_client._geturl(), 'http://featurebase.com:2020/sql' ) + test_client = client(hostport="featurebase.com:2020") + self.assertEqual(test_client.url, "http://featurebase.com:2020/sql") # URL for context database - test_client = client(hostport='featurebase.com:2020', database='db-1') - self.assertEqual(test_client._geturl(), 'http://featurebase.com:2020/databases/db-1/query/sql' ) + test_client = client(hostport="featurebase.com:2020", database="db-1") + self.assertEqual( + test_client.url, + "http://featurebase.com:2020/databases/db-1/query/sql", + ) # https when CA attributes are defined - test_client = client(hostport='featurebase.com:2020', database='db-1', capath='./pem/') - self.assertEqual(test_client._geturl(), 'https://featurebase.com:2020/databases/db-1/query/sql' ) - test_client = client(hostport='featurebase.com:2020', database='db-1', cafile='./pem') - self.assertEqual(test_client._geturl(), 'https://featurebase.com:2020/databases/db-1/query/sql' ) - # url for custom path - self.assertEqual(test_client._geturl('/test'), 'https://featurebase.com:2020/test' ) + test_client = client( + hostport="featurebase.com:2020", database="db-1", capath="." + ) + self.assertEqual( + test_client.url, + "https://featurebase.com:2020/databases/db-1/query/sql", + ) # test request for method, origin and headers def testRequest(self): - test_client = client(hostport='featurebase.com:2020', origin='gitlab.com', apikey='testapikey') - request=test_client._newrequest() - self.assertEqual(request.full_url, 'https://featurebase.com:2020/sql') + test_client = client( + hostport="featurebase.com:2020", origin="gitlab.com", apikey="testapikey" + ) + request = test_client._newrequest() + self.assertEqual(request.full_url, "https://featurebase.com:2020/sql") # method must be POST - self.assertEqual(request.method, 'POST') - # request origin must match origin supplied to the client - self.assertEqual(request.origin_req_host, 'gitlab.com') + self.assertEqual(request.method, "POST") + # request origin must match origin supplied to the client + self.assertEqual(request.origin_req_host, "gitlab.com") # headers should have specific entries including the api key supplied to the client - expectedheader={'Content-type':'text/plain', 'Accept':'application/json', 'X-api-key':'testapikey'} - self.assertDictEqual(expectedheader,request.headers) + expectedheader = { + "Content-type": "text/plain", + "Accept": "application/json", + "X-api-key": "testapikey", + } + self.assertDictEqual(expectedheader, request.headers) # test client for post error scenarios def testPostExceptions(self): # domain exists but no /sql path defined - result=None - exec=None - test_client = client(hostport='featurebase.com:2020', timeout=5) - try: - result=test_client._post('This is test data, has no meaning when posted.') + result = None + exec = None + test_client = client(hostport="featurebase.com:2020", timeout=5) + try: + result = test_client._post("This is test data, has no meaning when posted.") except Exception as ex: - exec=ex - self.assertIsNotNone(exec) + exec = ex + self.assertIsNotNone(exec) self.assertIsNone(result) # unknown domain - result=None - exec=None - test_client = client(hostport='notarealhost.com', timeout=5) + result = None + exec = None + test_client = client(hostport="notarealhost.com", timeout=5) try: - result=test_client._post('This is test data, has no meaning when posted.') + result = test_client._post("This is test data, has no meaning when posted.") except Exception as ex: - exec=ex + exec = ex self.assertIsNotNone(exec) self.assertIsNone(result) # bad CA attributes - result=None - exec=None - test_client = client(timeout=5, cafile='/nonexistingfile.pem') + result = None + exec = None try: - result=test_client._post('This is test data, has no meaning when posted.') + test_client = client(timeout=5, cafile="/nonexistingfile.pem") + result = test_client._post("This is test data, has no meaning when posted.") except Exception as ex: - exec=ex + exec = ex self.assertIsNotNone(exec) - self.assertIsNone(result) + self.assertIsNone(result) + # test result data construction based on http response data class FeaturebaseResultTestCase(unittest.TestCase): - # test general HTTP failure - def testGeneralFailure(self): - res=result(sql='test sql', response='test raw response', code=500, exec=Exception('test exeception')) - self.assertEqual(res.sql, 'test sql') - self.assertEqual(res.ok, False) - self.assertEqual(res.error, None) - self.assertEqual(res.schema, None) - self.assertEqual(res.data, None) - self.assertEqual(res.warnings, None) - self.assertEqual(res.execution_time, 0) - self.assertEqual(res.rows_affected, 0) - self.assertEqual(res.raw_response, 'test raw response') - self.assertEqual(str(res.exec), str(Exception('test exeception'))) - + def testGeneralFailure(self): + with self.assertRaises(RuntimeError): + res = result( + sql="test sql", + response="test raw response", + code=500, + ) + # test response with a bad JSON that fails to deserialize - def testJSONParseFailure(self): - res=result(sql='test sql', response="{'broken':{}", code=200, exec=None) - self.assertEqual(res.sql, 'test sql') - self.assertEqual(res.ok, False) - self.assertEqual(res.error, str(res.exec)) - self.assertEqual(res.schema, None) - self.assertEqual(res.data, None) - self.assertEqual(res.warnings, None) - self.assertEqual(res.execution_time, 0) - self.assertEqual(res.rows_affected, 0) - self.assertEqual(res.raw_response, "{'broken':{}") - self.assertIsNotNone(res.exec) + def testJSONParseFailure(self): + with self.assertRaises(json.JSONDecodeError): + res = result(sql="test sql", response="{'broken':{}", code=200) # test response with SQL error - def testSQLError(self): - resp=b'{"schema":{},"data":{}, "warnings":{}, "execution-time":10,"error":"test sql error"}' - res=result(sql='test sql', response=resp, code=200, exec=None) - self.assertEqual(res.sql, 'test sql') - self.assertEqual(res.ok, False) - self.assertEqual(res.error, 'test sql error' ) - self.assertEqual(res.schema, None) - self.assertEqual(res.data, None) - self.assertEqual(res.warnings, None) - self.assertEqual(res.execution_time, 0) - self.assertEqual(res.rows_affected, 0) - self.assertEqual(res.raw_response, resp) - self.assertIsNone(res.exec) + def testSQLError(self): + with self.assertRaises(RuntimeError): + resp = b'{"schema":{},"data":{}, "warnings":{}, "execution-time":10,"error":"test sql error"}' + res = result(sql="test sql", response=resp, code=200) # test successful response def testSuccess(self): - kv={'k1':'v1'} - res=result(sql='test sql', response=b'{"schema":{"k1":"v1"},"data":{"k1":"v1"}, "warnings":{"k1":"v1"}, "execution-time":10}', code=200, exec=None) - self.assertEqual(res.sql, 'test sql') - self.assertEqual(res.ok, True) - self.assertEqual(res.error, None) - self.assertDictEqual(res.schema, kv) - self.assertDictEqual(res.data, kv) - self.assertDictEqual(res.warnings, kv) - self.assertEqual(res.execution_time, 10) + kv = {"k1": "v1"} + res = result( + sql="test sql", + response=b'{"schema":{"k1":"v1"},"data":{"k1":"v1"}, "warnings":{"k1":"v1"}, "execution-time":10}', + code=200, + ) + self.assertEqual(res.sql, "test sql") + self.assertDictEqual(res.schema, kv) + self.assertDictEqual(res.data, kv) + self.assertDictEqual(res.warnings, kv) + self.assertEqual(res.execution_time, 10) + # test query interface class FeaturebaseQueryTestCase(unittest.TestCase): # test SQL for error def testQueryError(self): - test_client=client(hostport=os.getenv('FEATUREBASE_HOSTPORT', 'localhost:10101')) - result=test_client.query("select non_existing_column from non_existing_table;") - self.assertEqual(result.ok,False) - self.assertIsNotNone(result.error) - + test_client = client(client_hostport) + with self.assertRaises(RuntimeError): + result = test_client.query( + "select non_existing_column from non_existing_table;" + ) + # test SQL for success def testQuerySuccess(self): - test_client=client(hostport=os.getenv('FEATUREBASE_HOSTPORT', 'localhost:10101')) - result=test_client.query("select toTimeStamp(0);") - self.assertEqual(result.ok,True) - self.assertEqual(result.data[0][0],'1970-01-01T00:00:00Z') + test_client = client(client_hostport) + result = test_client.query("select toTimeStamp(0);") + self.assertEqual(result.data[0][0], "1970-01-01T00:00:00Z") + # test query batch interface class FeaturebaseQueryBatchTestCase(unittest.TestCase): # test SQL batch synchronous def testQueryBatchSync(self): - test_client=client(hostport=os.getenv('FEATUREBASE_HOSTPORT', 'localhost:10101')) + test_client = client(client_hostport) # create a table and insert rows and query the rows before dropping the table. # all these SQLs to succeed they need to be run in a specific order # so they are run synchronously - tablename='pclt_' + str(calendar.timegm(time.gmtime())) - sql0='select * from '+tablename+';' - sql1='create table '+tablename+' (_id id, i1 int, s1 string) ;' - sql2='insert into '+tablename+"(_id,i1,s1) values(1,1,'text1');" - sql3='insert into '+tablename+"(_id,i1,s1) values(2,2,'text2');" - sql4='select count(*) from '+tablename+';' - sql5='drop table ' + tablename + ';' - sqllist = [sql0,sql1,sql2,sql3, sql4, sql5] - results = test_client.querybatch(sqllist,asynchronous=False) - self.assertEqual(len(results),6) - for result in results: - # first query should fail with a SQL error, because the table doesn't exist yet. - if result.sql==sql0: - self.assertEqual(result.ok,False) - else: - self.assertEqual(result.ok,True) + tablename = "pclt_" + str(calendar.timegm(time.gmtime())) + sqllist = [ + "select * from {};", + "create table {} (_id id, i1 int, s1 string) ;", + "insert into {}(_id,i1,s1) values(1,1,'text1');", + "insert into {}(_id,i1,s1) values(2,2,'text2');", + "select count(*) from {};", + "drop table {};", + ] + sqllist = [sql.format(tablename) for sql in sqllist] + # if you try to run the full list, you should get an exception + with self.assertRaises(RuntimeError): + results = test_client.querybatch(sqllist) + # if you skip the first one, you should get five back + results = test_client.querybatch(sqllist[1:]) + self.assertEqual(len(results), 5) - # test SQL batch Asynchronous - def testQueryBatchAsync(self): - + def testQueryBatchAsync(self): # create 2 test tables and insert some rows - # this need to be run synchronously because tables + # this need to be run synchronously because tables # should be created before inserts can be run - sql0='create table if not exists pclt_test_t1(_id id, i1 int, s1 string);' - sql1='create table if not exists pclt_test_t2(_id id, i1 int, s1 string);' - sql2="insert into pclt_test_t1(_id, i1, s1) values(1,1,'text1');" - sql3="insert into pclt_test_t1(_id, i1, s1) values(2,2,'text2');" - sql4="insert into pclt_test_t1(_id, i1, s1) values(3,3,'text3');" - sql5="insert into pclt_test_t1(_id, i1, s1) values(4,4,'text4');" - sql6="insert into pclt_test_t2(_id, i1, s1) values(1,1,'text1');" - sql7="insert into pclt_test_t2(_id, i1, s1) values(2,2,'text2');" - sqllist=[sql0,sql1, sql2, sql3, sql4, sql5, sql6, sql7] - - test_client=client(hostport=os.getenv('FEATUREBASE_HOSTPORT', 'localhost:10101')) - results = test_client.querybatch(sqllist,asynchronous=False) - - self.assertEqual(len(results),8) - for result in results: - desc="" - if not result.ok: - desc=result.error - self.assertEqual(result.ok,True, result.sql + ' ->' + desc) + sqllist = [ + "create table if not exists pclt_test_t1(_id id, i1 int, s1 string);", + "create table if not exists pclt_test_t2(_id id, i1 int, s1 string);", + "insert into pclt_test_t1(_id, i1, s1) values(1,1,'text1');", + "insert into pclt_test_t1(_id, i1, s1) values(2,2,'text2');", + "insert into pclt_test_t1(_id, i1, s1) values(3,3,'text3');", + "insert into pclt_test_t1(_id, i1, s1) values(4,4,'text4');", + "insert into pclt_test_t2(_id, i1, s1) values(1,1,'text1');", + "insert into pclt_test_t2(_id, i1, s1) values(2,2,'text2');", + ] + + test_client = client(client_hostport) + results = test_client.querybatch(sqllist, asynchronous=False) + + self.assertEqual(len(results), 8) # run some select queries on the test tables # these queries will be run asynchronously - sql0='select * from pclt_test_t1;' - sql1='select * from pclt_test_t2;' - sql2='select count(*) from pclt_test_t1;' - sql3='select count(*) from pclt_test_t2;' - - sqllist = [sql0,sql1,sql2,sql3] - results = test_client.querybatch(sqllist,asynchronous=True) - self.assertEqual(len(results),4) + sqlexpecting = { + "select * from pclt_test_t1;": lambda x: len(x.data) == 4, + "select * from pclt_test_t2;": lambda x: len(x.data) == 2, + "select count(*) from pclt_test_t1;": lambda x: x.data[0][0] == 4, + "select count(*) from pclt_test_t2;": lambda x: x.data[0][0] == 2, + } + sqllist = sqlexpecting.keys() + + results = test_client.querybatch(sqllist, asynchronous=True) + self.assertEqual(len(results), 4) for result in results: - desc="" - if not result.ok: - desc=result.error - self.assertEqual(result.ok,True, result.sql + ' ->' + desc) - if result.sql==sql0: - self.assertGreaterEqual(len(result.data), 4) - elif result.sql==sql1: - self.assertGreaterEqual(len(result.data), 2) - elif result.sql==sql2: - self.assertGreaterEqual(result.data[0][0], 4) - elif result.sql==sql3: - self.assertGreaterEqual(result.data[0][0], 2) - - bad_client=client(hostport='bad-address') - results=None - exec=None + self.assertEqual(sqlexpecting[result.sql](result), True) + + bad_client = client(hostport="bad-address") + results = None + exec = None try: - results = bad_client.querybatch(sqllist,asynchronous=True) + results = bad_client.querybatch(sqllist, asynchronous=True) except Exception as ex: - exec=ex - self.assertIsNotNone(exec) + exec = ex + self.assertIsNotNone(exec) self.assertIsNone(results) # cleanup by droping the test tables - sql0='drop table pclt_test_t1;' - sql1='drop table pclt_test_t2;' - sqllist=[sql0,sql1] - - results = test_client.querybatch(sqllist,asynchronous=True) - self.assertEqual(len(results),2) - for result in results: - desc="" - if not result.ok: - desc=result.error - self.assertEqual(result.ok,True, result.sql + ' ->' + desc) + sqllist = [ + "drop table pclt_test_t1;", + "drop table pclt_test_t2;", + ] + + results = test_client.querybatch(sqllist, asynchronous=True) + self.assertEqual(len(results), 2) + -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main()